package db import ( "context" "fmt" "strings" "github.com/pkg/errors" "github.com/uptrace/bun" ) type inserter[T any] struct { tx bun.Tx q *bun.InsertQuery model *T models []*T isBulk bool audit *AuditMeta auditInfo *AuditInfo } // Insert creates an inserter for a single model // The model will have all fields populated after Exec() via Returning("*") func Insert[T any](tx bun.Tx, model *T) *inserter[T] { if model == nil { panic("model cannot be nil") } return &inserter[T]{ tx: tx, q: tx.NewInsert().Model(model).Returning("*"), model: model, isBulk: false, } } // InsertMultiple creates an inserter for bulk insert // All models will have fields populated after Exec() via Returning("*") func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] { if len(models) == 0 { panic("models cannot be nil or empty") } return &inserter[T]{ tx: tx, q: tx.NewInsert().Model(&models).Returning("*"), models: models, isBulk: true, } } func (i *inserter[T]) ConflictNothing(conflicts ...string) *inserter[T] { fieldstr := strings.Join(conflicts, ", ") i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO NOTHING", fieldstr)) return i } func (i *inserter[T]) ConflictUpdate(conflicts []string, columns ...string) *inserter[T] { fieldstr := strings.Join(conflicts, ", ") i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO UPDATE", fieldstr)) for _, column := range columns { i.q = i.q.Set(fmt.Sprintf("%s = EXCLUDED.%s", column, column)) } return i } // Returning overrides the default Returning("*") clause // Example: .Returning("id", "created_at") func (i *inserter[T]) Returning(columns ...string) *inserter[T] { if len(columns) == 0 { return i } // Build column list as single string columnList := strings.Join(columns, ", ") i.q = i.q.Returning(columnList) return i } // WithAudit enables audit logging for this insert operation // If the provided *AuditInfo is nil, will use reflection to automatically work out the details func (i *inserter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *inserter[T] { i.audit = meta i.auditInfo = info return i } // Exec executes the insert and optionally logs to audit // Returns an error if insert fails or if audit callback fails (triggering rollback) func (i *inserter[T]) Exec(ctx context.Context) error { // Execute insert _, err := i.q.Exec(ctx) if err != nil { return errors.Wrap(err, "bun.InsertQuery.Exec") } // Handle audit logging if enabled if i.audit != nil { if i.auditInfo == nil { tableName := extractTableName[T]() resourceType := extractResourceType(tableName) action := buildAction(resourceType, "create") i.auditInfo = &AuditInfo{ Action: action, ResourceType: resourceType, ResourceID: nil, Details: nil, } if i.isBulk { i.auditInfo.Details = map[string]any{ "count": len(i.models), } } else { i.auditInfo.ResourceID = extractPrimaryKey(i.model) i.auditInfo.Details = i.model } } err = LogSuccess(ctx, i.tx, i.audit, i.auditInfo) if err != nil { return errors.Wrap(err, "LogSuccess") } } return nil }