Files
oslstats/internal/db/insert.go

124 lines
3.1 KiB
Go

package db
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type inserter[T any] struct {
tx bun.Tx
q *bun.InsertQuery
model *T
models []*T
isBulk bool
audit *AuditMeta
auditInfo *AuditInfo
}
// Insert creates an inserter for a single model
// The model will have all fields populated after Exec() via Returning("*")
func Insert[T any](tx bun.Tx, model *T) *inserter[T] {
if model == nil {
panic("model cannot be nil")
}
return &inserter[T]{
tx: tx,
q: tx.NewInsert().Model(model).Returning("*"),
model: model,
isBulk: false,
}
}
// InsertMultiple creates an inserter for bulk insert
// All models will have fields populated after Exec() via Returning("*")
func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] {
if len(models) == 0 {
panic("models cannot be nil or empty")
}
return &inserter[T]{
tx: tx,
q: tx.NewInsert().Model(&models).Returning("*"),
models: models,
isBulk: true,
}
}
func (i *inserter[T]) ConflictNothing(conflicts ...string) *inserter[T] {
fieldstr := strings.Join(conflicts, ", ")
i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO NOTHING", fieldstr))
return i
}
func (i *inserter[T]) ConflictUpdate(conflicts []string, columns ...string) *inserter[T] {
fieldstr := strings.Join(conflicts, ", ")
i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO UPDATE", fieldstr))
for _, column := range columns {
i.q = i.q.Set(fmt.Sprintf("%s = EXCLUDED.%s", column, column))
}
return i
}
// Returning overrides the default Returning("*") clause
// Example: .Returning("id", "created_at")
func (i *inserter[T]) Returning(columns ...string) *inserter[T] {
if len(columns) == 0 {
return i
}
// Build column list as single string
columnList := strings.Join(columns, ", ")
i.q = i.q.Returning(columnList)
return i
}
// WithAudit enables audit logging for this insert operation
// If the provided *AuditInfo is nil, will use reflection to automatically work out the details
func (i *inserter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *inserter[T] {
i.audit = meta
i.auditInfo = info
return i
}
// Exec executes the insert and optionally logs to audit
// Returns an error if insert fails or if audit callback fails (triggering rollback)
func (i *inserter[T]) Exec(ctx context.Context) error {
// Execute insert
_, err := i.q.Exec(ctx)
if err != nil {
return errors.Wrap(err, "bun.InsertQuery.Exec")
}
// Handle audit logging if enabled
if i.audit != nil {
if i.auditInfo == nil {
tableName := extractTableName[T]()
resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "create")
i.auditInfo = &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: nil,
Details: nil,
}
if i.isBulk {
i.auditInfo.Details = map[string]any{
"count": len(i.models),
}
} else {
i.auditInfo.ResourceID = extractPrimaryKey(i.model)
i.auditInfo.Details = i.model
}
}
err = LogSuccess(ctx, i.tx, i.audit, i.auditInfo)
if err != nil {
return errors.Wrap(err, "LogSuccess")
}
}
return nil
}