124 lines
3.1 KiB
Go
124 lines
3.1 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
type inserter[T any] struct {
|
|
tx bun.Tx
|
|
q *bun.InsertQuery
|
|
model *T
|
|
models []*T
|
|
isBulk bool
|
|
audit *AuditMeta
|
|
auditInfo *AuditInfo
|
|
}
|
|
|
|
// Insert creates an inserter for a single model
|
|
// The model will have all fields populated after Exec() via Returning("*")
|
|
func Insert[T any](tx bun.Tx, model *T) *inserter[T] {
|
|
if model == nil {
|
|
panic("model cannot be nil")
|
|
}
|
|
return &inserter[T]{
|
|
tx: tx,
|
|
q: tx.NewInsert().Model(model).Returning("*"),
|
|
model: model,
|
|
isBulk: false,
|
|
}
|
|
}
|
|
|
|
// InsertMultiple creates an inserter for bulk insert
|
|
// All models will have fields populated after Exec() via Returning("*")
|
|
func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] {
|
|
if len(models) == 0 {
|
|
panic("models cannot be nil or empty")
|
|
}
|
|
return &inserter[T]{
|
|
tx: tx,
|
|
q: tx.NewInsert().Model(&models).Returning("*"),
|
|
models: models,
|
|
isBulk: true,
|
|
}
|
|
}
|
|
|
|
func (i *inserter[T]) ConflictNothing(conflicts ...string) *inserter[T] {
|
|
fieldstr := strings.Join(conflicts, ", ")
|
|
i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO NOTHING", fieldstr))
|
|
return i
|
|
}
|
|
|
|
func (i *inserter[T]) ConflictUpdate(conflicts []string, columns ...string) *inserter[T] {
|
|
fieldstr := strings.Join(conflicts, ", ")
|
|
i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO UPDATE", fieldstr))
|
|
for _, column := range columns {
|
|
i.q = i.q.Set(fmt.Sprintf("%s = EXCLUDED.%s", column, column))
|
|
}
|
|
return i
|
|
}
|
|
|
|
// Returning overrides the default Returning("*") clause
|
|
// Example: .Returning("id", "created_at")
|
|
func (i *inserter[T]) Returning(columns ...string) *inserter[T] {
|
|
if len(columns) == 0 {
|
|
return i
|
|
}
|
|
// Build column list as single string
|
|
columnList := strings.Join(columns, ", ")
|
|
i.q = i.q.Returning(columnList)
|
|
return i
|
|
}
|
|
|
|
// WithAudit enables audit logging for this insert operation
|
|
// If the provided *AuditInfo is nil, will use reflection to automatically work out the details
|
|
func (i *inserter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *inserter[T] {
|
|
i.audit = meta
|
|
i.auditInfo = info
|
|
return i
|
|
}
|
|
|
|
// Exec executes the insert and optionally logs to audit
|
|
// Returns an error if insert fails or if audit callback fails (triggering rollback)
|
|
func (i *inserter[T]) Exec(ctx context.Context) error {
|
|
// Execute insert
|
|
_, err := i.q.Exec(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "bun.InsertQuery.Exec")
|
|
}
|
|
|
|
// Handle audit logging if enabled
|
|
if i.audit != nil {
|
|
if i.auditInfo == nil {
|
|
tableName := extractTableName[T]()
|
|
resourceType := extractResourceType(tableName)
|
|
action := buildAction(resourceType, "create")
|
|
i.auditInfo = &AuditInfo{
|
|
Action: action,
|
|
ResourceType: resourceType,
|
|
ResourceID: nil,
|
|
Details: nil,
|
|
}
|
|
if i.isBulk {
|
|
i.auditInfo.Details = map[string]any{
|
|
"count": len(i.models),
|
|
}
|
|
} else {
|
|
i.auditInfo.ResourceID = extractPrimaryKey(i.model)
|
|
i.auditInfo.Details = i.model
|
|
}
|
|
}
|
|
|
|
err = LogSuccess(ctx, i.tx, i.audit, i.auditInfo)
|
|
if err != nil {
|
|
return errors.Wrap(err, "LogSuccess")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|