Files
oslstats/internal/db/insert.go

129 lines
3.3 KiB
Go

package db
import (
"context"
"net/http"
"strings"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type inserter[T any] struct {
tx bun.Tx
q *bun.InsertQuery
model *T
models []*T
isBulk bool
auditCallback AuditCallback
auditRequest *http.Request
}
// Insert creates an inserter for a single model
// The model will have all fields populated after Exec() via Returning("*")
func Insert[T any](tx bun.Tx, model *T) *inserter[T] {
if model == nil {
panic("model cannot be nil")
}
return &inserter[T]{
tx: tx,
q: tx.NewInsert().Model(model).Returning("*"),
model: model,
isBulk: false,
}
}
// InsertMultiple creates an inserter for bulk insert
// All models will have fields populated after Exec() via Returning("*")
func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] {
if len(models) == 0 {
panic("models cannot be nil or empty")
}
return &inserter[T]{
tx: tx,
q: tx.NewInsert().Model(&models).Returning("*"),
models: models,
isBulk: true,
}
}
// OnConflict adds conflict handling for upserts
// Example: .OnConflict("(discord_id) DO UPDATE")
func (i *inserter[T]) OnConflict(query string) *inserter[T] {
i.q = i.q.On(query)
return i
}
// Set adds a SET clause for upserts (use with OnConflict)
// Example: .Set("access_token = EXCLUDED.access_token")
func (i *inserter[T]) Set(query string, args ...any) *inserter[T] {
i.q = i.q.Set(query, args...)
return i
}
// Returning overrides the default Returning("*") clause
// Example: .Returning("id", "created_at")
func (i *inserter[T]) Returning(columns ...string) *inserter[T] {
if len(columns) == 0 {
return i
}
// Build column list as single string
columnList := strings.Join(columns, ", ")
i.q = i.q.Returning(columnList)
return i
}
// WithAudit enables audit logging for this insert operation
// The callback will be invoked after successful insert with auto-generated audit info
// If the callback returns an error, the transaction will be rolled back
func (i *inserter[T]) WithAudit(r *http.Request, callback AuditCallback) *inserter[T] {
i.auditRequest = r
i.auditCallback = callback
return i
}
// Exec executes the insert and optionally logs to audit
// Returns an error if insert fails or if audit callback fails (triggering rollback)
func (i *inserter[T]) Exec(ctx context.Context) error {
// Execute insert
_, err := i.q.Exec(ctx)
if err != nil {
return errors.Wrap(err, "bun.InsertQuery.Exec")
}
// Handle audit logging if enabled
if i.auditCallback != nil && i.auditRequest != nil {
tableName := extractTableName[T]()
resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "create")
var info *AuditInfo
if i.isBulk {
// For bulk inserts, log once with count in details
info = &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: nil,
Details: map[string]any{
"count": len(i.models),
},
}
} else {
// For single insert, log with resource ID
info = &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: extractPrimaryKey(i.model),
Details: nil,
}
}
// Call audit callback - if it fails, return error to trigger rollback
if err := i.auditCallback(ctx, i.tx, info, i.auditRequest); err != nil {
return errors.Wrap(err, "audit.callback")
}
}
return nil
}