fixed some migration issues and added generics for update and insert

This commit is contained in:
2026-02-09 21:58:50 +11:00
parent bf3e526f1e
commit b89ee75ca7
19 changed files with 591 additions and 261 deletions

146
internal/db/audit.go Normal file
View File

@@ -0,0 +1,146 @@
package db
import (
"context"
"net/http"
"reflect"
"strings"
"github.com/uptrace/bun"
)
// AuditInfo contains metadata for audit logging
type AuditInfo struct {
Action string // e.g., "seasons.create", "users.update"
ResourceType string // e.g., "season", "user"
ResourceID any // Primary key value (int, string, etc.)
Details map[string]any // Changed fields or additional metadata
}
// AuditCallback is called after successful database operations to log changes
type AuditCallback func(ctx context.Context, tx bun.Tx, info *AuditInfo, r *http.Request) error
// extractTableName gets the bun table name from a model type using reflection
// Example: Season with `bun:"table:seasons,alias:s"` returns "seasons"
func extractTableName[T any]() string {
var model T
t := reflect.TypeOf(model)
// Handle pointer types
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// Look for bun.BaseModel field with table tag
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Type.Name() == "BaseModel" {
bunTag := field.Tag.Get("bun")
if bunTag != "" {
// Parse tag: "table:seasons,alias:s" -> "seasons"
parts := strings.Split(bunTag, ",")
for _, part := range parts {
if strings.HasPrefix(part, "table:") {
return strings.TrimPrefix(part, "table:")
}
}
}
}
}
// Fallback: use struct name in lowercase + "s"
return strings.ToLower(t.Name()) + "s"
}
// extractResourceType converts a table name to singular resource type
// Example: "seasons" -> "season", "users" -> "user"
func extractResourceType(tableName string) string {
// Simple singularization: remove trailing 's'
if strings.HasSuffix(tableName, "s") && len(tableName) > 1 {
return tableName[:len(tableName)-1]
}
return tableName
}
// buildAction creates a permission-style action string
// Example: ("season", "create") -> "seasons.create"
func buildAction(resourceType, operation string) string {
// Pluralize resource type (simple: add 's')
plural := resourceType
if !strings.HasSuffix(plural, "s") {
plural = plural + "s"
}
return plural + "." + operation
}
// extractPrimaryKey uses reflection to find and return the primary key value from a model
// Returns nil if no primary key is found
func extractPrimaryKey[T any](model *T) any {
if model == nil {
return nil
}
v := reflect.ValueOf(model)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
bunTag := field.Tag.Get("bun")
if bunTag != "" && strings.Contains(bunTag, "pk") {
// Found primary key field
fieldValue := v.Field(i)
if fieldValue.IsValid() && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
}
return nil
}
// extractChangedFields builds a map of field names to their new values
// Only includes fields specified in the columns list
func extractChangedFields[T any](model *T, columns []string) map[string]any {
if model == nil || len(columns) == 0 {
return nil
}
result := make(map[string]any)
v := reflect.ValueOf(model)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
t := v.Type()
// Build map of bun column names to field names
columnToField := make(map[string]int)
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
bunTag := field.Tag.Get("bun")
if bunTag != "" {
// Parse bun tag to get column name (first part before comma)
parts := strings.Split(bunTag, ",")
if len(parts) > 0 && parts[0] != "" {
columnToField[parts[0]] = i
}
}
}
// Extract values for requested columns
for _, col := range columns {
if fieldIdx, ok := columnToField[col]; ok {
fieldValue := v.Field(fieldIdx)
if fieldValue.IsValid() && fieldValue.CanInterface() {
result[col] = fieldValue.Interface()
}
}
}
return result
}
// Note: We don't need getTxFromQuery since we store the tx directly in our helper structs

View File

@@ -3,13 +3,18 @@ package db
import (
"context"
"database/sql"
"net/http"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type deleter[T any] struct {
q *bun.DeleteQuery
tx bun.Tx
q *bun.DeleteQuery
resourceID any // Store ID before deletion for audit
auditCallback AuditCallback
auditRequest *http.Request
}
type systemType interface {
@@ -18,13 +23,27 @@ type systemType interface {
func DeleteItem[T any](tx bun.Tx) *deleter[T] {
return &deleter[T]{
tx.NewDelete().
tx: tx,
q: tx.NewDelete().
Model((*T)(nil)),
}
}
func (d *deleter[T]) Where(query string, args ...any) *deleter[T] {
d.q = d.q.Where(query, args...)
// Try to capture resource ID from WHERE clause if it's a simple "id = ?" pattern
if query == "id = ?" && len(args) > 0 {
d.resourceID = args[0]
}
return d
}
// WithAudit enables audit logging for this delete operation
// The callback will be invoked after successful deletion with auto-generated audit info
// If the callback returns an error, the transaction will be rolled back
func (d *deleter[T]) WithAudit(r *http.Request, callback AuditCallback) *deleter[T] {
d.auditRequest = r
d.auditCallback = callback
return d
}
@@ -34,8 +53,29 @@ func (d *deleter[T]) Delete(ctx context.Context) error {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return errors.Wrap(err, "bun.DeleteQuery.Exec")
}
return errors.Wrap(err, "bun.DeleteQuery.Exec")
// Handle audit logging if enabled
if d.auditCallback != nil && d.auditRequest != nil {
tableName := extractTableName[T]()
resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "delete")
info := &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: d.resourceID,
Details: nil, // Delete doesn't need details
}
// Call audit callback - if it fails, return error to trigger rollback
if err := d.auditCallback(ctx, d.tx, info, d.auditRequest); err != nil {
return errors.Wrap(err, "audit.callback")
}
}
return nil
}
func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] {

128
internal/db/insert.go Normal file
View File

@@ -0,0 +1,128 @@
package db
import (
"context"
"net/http"
"strings"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type inserter[T any] struct {
tx bun.Tx
q *bun.InsertQuery
model *T
models []*T
isBulk bool
auditCallback AuditCallback
auditRequest *http.Request
}
// Insert creates an inserter for a single model
// The model will have all fields populated after Exec() via Returning("*")
func Insert[T any](tx bun.Tx, model *T) *inserter[T] {
if model == nil {
panic("model cannot be nil")
}
return &inserter[T]{
tx: tx,
q: tx.NewInsert().Model(model).Returning("*"),
model: model,
isBulk: false,
}
}
// InsertMultiple creates an inserter for bulk insert
// All models will have fields populated after Exec() via Returning("*")
func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] {
if len(models) == 0 {
panic("models cannot be nil or empty")
}
return &inserter[T]{
tx: tx,
q: tx.NewInsert().Model(&models).Returning("*"),
models: models,
isBulk: true,
}
}
// OnConflict adds conflict handling for upserts
// Example: .OnConflict("(discord_id) DO UPDATE")
func (i *inserter[T]) OnConflict(query string) *inserter[T] {
i.q = i.q.On(query)
return i
}
// Set adds a SET clause for upserts (use with OnConflict)
// Example: .Set("access_token = EXCLUDED.access_token")
func (i *inserter[T]) Set(query string, args ...any) *inserter[T] {
i.q = i.q.Set(query, args...)
return i
}
// Returning overrides the default Returning("*") clause
// Example: .Returning("id", "created_at")
func (i *inserter[T]) Returning(columns ...string) *inserter[T] {
if len(columns) == 0 {
return i
}
// Build column list as single string
columnList := strings.Join(columns, ", ")
i.q = i.q.Returning(columnList)
return i
}
// WithAudit enables audit logging for this insert operation
// The callback will be invoked after successful insert with auto-generated audit info
// If the callback returns an error, the transaction will be rolled back
func (i *inserter[T]) WithAudit(r *http.Request, callback AuditCallback) *inserter[T] {
i.auditRequest = r
i.auditCallback = callback
return i
}
// Exec executes the insert and optionally logs to audit
// Returns an error if insert fails or if audit callback fails (triggering rollback)
func (i *inserter[T]) Exec(ctx context.Context) error {
// Execute insert
_, err := i.q.Exec(ctx)
if err != nil {
return errors.Wrap(err, "bun.InsertQuery.Exec")
}
// Handle audit logging if enabled
if i.auditCallback != nil && i.auditRequest != nil {
tableName := extractTableName[T]()
resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "create")
var info *AuditInfo
if i.isBulk {
// For bulk inserts, log once with count in details
info = &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: nil,
Details: map[string]any{
"count": len(i.models),
},
}
} else {
// For single insert, log with resource ID
info = &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: extractPrimaryKey(i.model),
Details: nil,
}
}
// Call audit callback - if it fails, return error to trigger rollback
if err := i.auditCallback(ctx, i.tx, info, i.auditRequest); err != nil {
return errors.Wrap(err, "audit.callback")
}
}
return nil
}

View File

@@ -13,34 +13,22 @@ type Season struct {
bun.BaseModel `bun:"table:seasons,alias:s"`
ID int `bun:"id,pk,autoincrement"`
Name string `bun:"name,unique"`
ShortName string `bun:"short_name,unique"`
Name string `bun:"name,unique,notnull"`
ShortName string `bun:"short_name,unique,notnull"`
StartDate time.Time `bun:"start_date,notnull"`
EndDate bun.NullTime `bun:"end_date"`
FinalsStartDate bun.NullTime `bun:"finals_start_date"`
FinalsEndDate bun.NullTime `bun:"finals_end_date"`
}
func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start time.Time) (*Season, error) {
if name == "" {
return nil, errors.New("name cannot be empty")
}
if shortname == "" {
return nil, errors.New("shortname cannot be empty")
}
// NewSeason returns a new season. It does not add it to the database
func NewSeason(name, shortname string, start time.Time) *Season {
season := &Season{
Name: name,
ShortName: strings.ToUpper(shortname),
StartDate: start.Truncate(time.Hour * 24),
}
_, err := tx.NewInsert().
Model(season).
Returning("id").
Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewInsert")
}
return season, nil
return season
}
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) {
@@ -60,31 +48,16 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error
return GetByField[Season](tx, "short_name", shortname).GetFirst(ctx)
}
func UpdateSeason(ctx context.Context, tx bun.Tx, season *Season) error {
if season == nil {
return errors.New("season cannot be nil")
// Update updates the season struct. It does not insert to the database
func (s *Season) Update(start, end, finalsStart, finalsEnd time.Time) {
s.StartDate = start.Truncate(time.Hour * 24)
if !end.IsZero() {
s.EndDate.Time = end.Truncate(time.Hour * 24)
}
if season.ID == 0 {
return errors.New("season ID cannot be 0")
if !finalsStart.IsZero() {
s.FinalsStartDate.Time = finalsStart.Truncate(time.Hour * 24)
}
// Truncate dates to day precision
season.StartDate = season.StartDate.Truncate(time.Hour * 24)
if !season.EndDate.IsZero() {
season.EndDate.Time = season.EndDate.Time.Truncate(time.Hour * 24)
if !finalsEnd.IsZero() {
s.FinalsEndDate.Time = finalsEnd.Truncate(time.Hour * 24)
}
if !season.FinalsStartDate.IsZero() {
season.FinalsStartDate.Time = season.FinalsStartDate.Time.Truncate(time.Hour * 24)
}
if !season.FinalsEndDate.IsZero() {
season.FinalsEndDate.Time = season.FinalsEndDate.Time.Truncate(time.Hour * 24)
}
_, err := tx.NewUpdate().
Model(season).
Column("start_date", "end_date", "finals_start_date", "finals_end_date").
Where("id = ?", season.ID).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewUpdate")
}
return nil
}

115
internal/db/update.go Normal file
View File

@@ -0,0 +1,115 @@
package db
import (
"context"
"net/http"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type updater[T any] struct {
tx bun.Tx
q *bun.UpdateQuery
model *T
columns []string
auditCallback AuditCallback
auditRequest *http.Request
}
// Update creates an updater for a model
// You must specify which columns to update via .Column() or use .WherePK()
func Update[T any](tx bun.Tx, model *T) *updater[T] {
if model == nil {
panic("model cannot be nil")
}
return &updater[T]{
tx: tx,
q: tx.NewUpdate().Model(model),
model: model,
}
}
// UpdateByID creates an updater with an ID where clause
// You must still specify which columns to update via .Column()
func UpdateByID[T any](tx bun.Tx, id int, model *T) *updater[T] {
if id <= 0 {
panic("id must be positive")
}
return Update(tx, model).Where("id = ?", id)
}
// Column specifies which columns to update
// Example: .Column("start_date", "end_date")
func (u *updater[T]) Column(columns ...string) *updater[T] {
u.columns = append(u.columns, columns...)
u.q = u.q.Column(columns...)
return u
}
// Where adds a WHERE clause
// Example: .Where("id = ?", 123)
func (u *updater[T]) Where(query string, args ...any) *updater[T] {
u.q = u.q.Where(query, args...)
return u
}
// WherePK adds a WHERE clause on the primary key
// The model must have its primary key field populated
func (u *updater[T]) WherePK() *updater[T] {
u.q = u.q.WherePK()
return u
}
// Set adds a raw SET clause for complex updates
// Example: .Set("updated_at = NOW()")
func (u *updater[T]) Set(query string, args ...any) *updater[T] {
u.q = u.q.Set(query, args...)
return u
}
// WithAudit enables audit logging for this update operation
// The callback will be invoked after successful update with auto-generated audit info
// If the callback returns an error, the transaction will be rolled back
func (u *updater[T]) WithAudit(r *http.Request, callback AuditCallback) *updater[T] {
u.auditRequest = r
u.auditCallback = callback
return u
}
// Exec executes the update and optionally logs to audit
// Returns an error if update fails or if audit callback fails (triggering rollback)
func (u *updater[T]) Exec(ctx context.Context) error {
// Build audit details BEFORE update (captures changed fields)
var details map[string]any
if u.auditCallback != nil && len(u.columns) > 0 {
details = extractChangedFields(u.model, u.columns)
}
// Execute update
_, err := u.q.Exec(ctx)
if err != nil {
return errors.Wrap(err, "bun.UpdateQuery.Exec")
}
// Handle audit logging if enabled
if u.auditCallback != nil && u.auditRequest != nil {
tableName := extractTableName[T]()
resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "update")
info := &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: extractPrimaryKey(u.model),
Details: details, // Changed fields only
}
// Call audit callback - if it fails, return error to trigger rollback
if err := u.auditCallback(ctx, u.tx, info, u.auditRequest); err != nil {
return errors.Wrap(err, "audit.callback")
}
}
return nil
}