fixed some migration issues and added generics for update and insert
This commit is contained in:
146
internal/db/audit.go
Normal file
146
internal/db/audit.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// AuditInfo contains metadata for audit logging
|
||||
type AuditInfo struct {
|
||||
Action string // e.g., "seasons.create", "users.update"
|
||||
ResourceType string // e.g., "season", "user"
|
||||
ResourceID any // Primary key value (int, string, etc.)
|
||||
Details map[string]any // Changed fields or additional metadata
|
||||
}
|
||||
|
||||
// AuditCallback is called after successful database operations to log changes
|
||||
type AuditCallback func(ctx context.Context, tx bun.Tx, info *AuditInfo, r *http.Request) error
|
||||
|
||||
// extractTableName gets the bun table name from a model type using reflection
|
||||
// Example: Season with `bun:"table:seasons,alias:s"` returns "seasons"
|
||||
func extractTableName[T any]() string {
|
||||
var model T
|
||||
t := reflect.TypeOf(model)
|
||||
|
||||
// Handle pointer types
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
// Look for bun.BaseModel field with table tag
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
if field.Type.Name() == "BaseModel" {
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" {
|
||||
// Parse tag: "table:seasons,alias:s" -> "seasons"
|
||||
parts := strings.Split(bunTag, ",")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, "table:") {
|
||||
return strings.TrimPrefix(part, "table:")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: use struct name in lowercase + "s"
|
||||
return strings.ToLower(t.Name()) + "s"
|
||||
}
|
||||
|
||||
// extractResourceType converts a table name to singular resource type
|
||||
// Example: "seasons" -> "season", "users" -> "user"
|
||||
func extractResourceType(tableName string) string {
|
||||
// Simple singularization: remove trailing 's'
|
||||
if strings.HasSuffix(tableName, "s") && len(tableName) > 1 {
|
||||
return tableName[:len(tableName)-1]
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
// buildAction creates a permission-style action string
|
||||
// Example: ("season", "create") -> "seasons.create"
|
||||
func buildAction(resourceType, operation string) string {
|
||||
// Pluralize resource type (simple: add 's')
|
||||
plural := resourceType
|
||||
if !strings.HasSuffix(plural, "s") {
|
||||
plural = plural + "s"
|
||||
}
|
||||
return plural + "." + operation
|
||||
}
|
||||
|
||||
// extractPrimaryKey uses reflection to find and return the primary key value from a model
|
||||
// Returns nil if no primary key is found
|
||||
func extractPrimaryKey[T any](model *T) any {
|
||||
if model == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(model)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && strings.Contains(bunTag, "pk") {
|
||||
// Found primary key field
|
||||
fieldValue := v.Field(i)
|
||||
if fieldValue.IsValid() && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractChangedFields builds a map of field names to their new values
|
||||
// Only includes fields specified in the columns list
|
||||
func extractChangedFields[T any](model *T, columns []string) map[string]any {
|
||||
if model == nil || len(columns) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
v := reflect.ValueOf(model)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
|
||||
// Build map of bun column names to field names
|
||||
columnToField := make(map[string]int)
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" {
|
||||
// Parse bun tag to get column name (first part before comma)
|
||||
parts := strings.Split(bunTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
columnToField[parts[0]] = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract values for requested columns
|
||||
for _, col := range columns {
|
||||
if fieldIdx, ok := columnToField[col]; ok {
|
||||
fieldValue := v.Field(fieldIdx)
|
||||
if fieldValue.IsValid() && fieldValue.CanInterface() {
|
||||
result[col] = fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Note: We don't need getTxFromQuery since we store the tx directly in our helper structs
|
||||
@@ -3,13 +3,18 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type deleter[T any] struct {
|
||||
q *bun.DeleteQuery
|
||||
tx bun.Tx
|
||||
q *bun.DeleteQuery
|
||||
resourceID any // Store ID before deletion for audit
|
||||
auditCallback AuditCallback
|
||||
auditRequest *http.Request
|
||||
}
|
||||
|
||||
type systemType interface {
|
||||
@@ -18,13 +23,27 @@ type systemType interface {
|
||||
|
||||
func DeleteItem[T any](tx bun.Tx) *deleter[T] {
|
||||
return &deleter[T]{
|
||||
tx.NewDelete().
|
||||
tx: tx,
|
||||
q: tx.NewDelete().
|
||||
Model((*T)(nil)),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *deleter[T]) Where(query string, args ...any) *deleter[T] {
|
||||
d.q = d.q.Where(query, args...)
|
||||
// Try to capture resource ID from WHERE clause if it's a simple "id = ?" pattern
|
||||
if query == "id = ?" && len(args) > 0 {
|
||||
d.resourceID = args[0]
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// WithAudit enables audit logging for this delete operation
|
||||
// The callback will be invoked after successful deletion with auto-generated audit info
|
||||
// If the callback returns an error, the transaction will be rolled back
|
||||
func (d *deleter[T]) WithAudit(r *http.Request, callback AuditCallback) *deleter[T] {
|
||||
d.auditRequest = r
|
||||
d.auditCallback = callback
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -34,8 +53,29 @@ func (d *deleter[T]) Delete(ctx context.Context) error {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err, "bun.DeleteQuery.Exec")
|
||||
}
|
||||
return errors.Wrap(err, "bun.DeleteQuery.Exec")
|
||||
|
||||
// Handle audit logging if enabled
|
||||
if d.auditCallback != nil && d.auditRequest != nil {
|
||||
tableName := extractTableName[T]()
|
||||
resourceType := extractResourceType(tableName)
|
||||
action := buildAction(resourceType, "delete")
|
||||
|
||||
info := &AuditInfo{
|
||||
Action: action,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: d.resourceID,
|
||||
Details: nil, // Delete doesn't need details
|
||||
}
|
||||
|
||||
// Call audit callback - if it fails, return error to trigger rollback
|
||||
if err := d.auditCallback(ctx, d.tx, info, d.auditRequest); err != nil {
|
||||
return errors.Wrap(err, "audit.callback")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] {
|
||||
|
||||
128
internal/db/insert.go
Normal file
128
internal/db/insert.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type inserter[T any] struct {
|
||||
tx bun.Tx
|
||||
q *bun.InsertQuery
|
||||
model *T
|
||||
models []*T
|
||||
isBulk bool
|
||||
auditCallback AuditCallback
|
||||
auditRequest *http.Request
|
||||
}
|
||||
|
||||
// Insert creates an inserter for a single model
|
||||
// The model will have all fields populated after Exec() via Returning("*")
|
||||
func Insert[T any](tx bun.Tx, model *T) *inserter[T] {
|
||||
if model == nil {
|
||||
panic("model cannot be nil")
|
||||
}
|
||||
return &inserter[T]{
|
||||
tx: tx,
|
||||
q: tx.NewInsert().Model(model).Returning("*"),
|
||||
model: model,
|
||||
isBulk: false,
|
||||
}
|
||||
}
|
||||
|
||||
// InsertMultiple creates an inserter for bulk insert
|
||||
// All models will have fields populated after Exec() via Returning("*")
|
||||
func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] {
|
||||
if len(models) == 0 {
|
||||
panic("models cannot be nil or empty")
|
||||
}
|
||||
return &inserter[T]{
|
||||
tx: tx,
|
||||
q: tx.NewInsert().Model(&models).Returning("*"),
|
||||
models: models,
|
||||
isBulk: true,
|
||||
}
|
||||
}
|
||||
|
||||
// OnConflict adds conflict handling for upserts
|
||||
// Example: .OnConflict("(discord_id) DO UPDATE")
|
||||
func (i *inserter[T]) OnConflict(query string) *inserter[T] {
|
||||
i.q = i.q.On(query)
|
||||
return i
|
||||
}
|
||||
|
||||
// Set adds a SET clause for upserts (use with OnConflict)
|
||||
// Example: .Set("access_token = EXCLUDED.access_token")
|
||||
func (i *inserter[T]) Set(query string, args ...any) *inserter[T] {
|
||||
i.q = i.q.Set(query, args...)
|
||||
return i
|
||||
}
|
||||
|
||||
// Returning overrides the default Returning("*") clause
|
||||
// Example: .Returning("id", "created_at")
|
||||
func (i *inserter[T]) Returning(columns ...string) *inserter[T] {
|
||||
if len(columns) == 0 {
|
||||
return i
|
||||
}
|
||||
// Build column list as single string
|
||||
columnList := strings.Join(columns, ", ")
|
||||
i.q = i.q.Returning(columnList)
|
||||
return i
|
||||
}
|
||||
|
||||
// WithAudit enables audit logging for this insert operation
|
||||
// The callback will be invoked after successful insert with auto-generated audit info
|
||||
// If the callback returns an error, the transaction will be rolled back
|
||||
func (i *inserter[T]) WithAudit(r *http.Request, callback AuditCallback) *inserter[T] {
|
||||
i.auditRequest = r
|
||||
i.auditCallback = callback
|
||||
return i
|
||||
}
|
||||
|
||||
// Exec executes the insert and optionally logs to audit
|
||||
// Returns an error if insert fails or if audit callback fails (triggering rollback)
|
||||
func (i *inserter[T]) Exec(ctx context.Context) error {
|
||||
// Execute insert
|
||||
_, err := i.q.Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bun.InsertQuery.Exec")
|
||||
}
|
||||
|
||||
// Handle audit logging if enabled
|
||||
if i.auditCallback != nil && i.auditRequest != nil {
|
||||
tableName := extractTableName[T]()
|
||||
resourceType := extractResourceType(tableName)
|
||||
action := buildAction(resourceType, "create")
|
||||
|
||||
var info *AuditInfo
|
||||
if i.isBulk {
|
||||
// For bulk inserts, log once with count in details
|
||||
info = &AuditInfo{
|
||||
Action: action,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: nil,
|
||||
Details: map[string]any{
|
||||
"count": len(i.models),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// For single insert, log with resource ID
|
||||
info = &AuditInfo{
|
||||
Action: action,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: extractPrimaryKey(i.model),
|
||||
Details: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// Call audit callback - if it fails, return error to trigger rollback
|
||||
if err := i.auditCallback(ctx, i.tx, info, i.auditRequest); err != nil {
|
||||
return errors.Wrap(err, "audit.callback")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -13,34 +13,22 @@ type Season struct {
|
||||
bun.BaseModel `bun:"table:seasons,alias:s"`
|
||||
|
||||
ID int `bun:"id,pk,autoincrement"`
|
||||
Name string `bun:"name,unique"`
|
||||
ShortName string `bun:"short_name,unique"`
|
||||
Name string `bun:"name,unique,notnull"`
|
||||
ShortName string `bun:"short_name,unique,notnull"`
|
||||
StartDate time.Time `bun:"start_date,notnull"`
|
||||
EndDate bun.NullTime `bun:"end_date"`
|
||||
FinalsStartDate bun.NullTime `bun:"finals_start_date"`
|
||||
FinalsEndDate bun.NullTime `bun:"finals_end_date"`
|
||||
}
|
||||
|
||||
func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start time.Time) (*Season, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("name cannot be empty")
|
||||
}
|
||||
if shortname == "" {
|
||||
return nil, errors.New("shortname cannot be empty")
|
||||
}
|
||||
// NewSeason returns a new season. It does not add it to the database
|
||||
func NewSeason(name, shortname string, start time.Time) *Season {
|
||||
season := &Season{
|
||||
Name: name,
|
||||
ShortName: strings.ToUpper(shortname),
|
||||
StartDate: start.Truncate(time.Hour * 24),
|
||||
}
|
||||
_, err := tx.NewInsert().
|
||||
Model(season).
|
||||
Returning("id").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.NewInsert")
|
||||
}
|
||||
return season, nil
|
||||
return season
|
||||
}
|
||||
|
||||
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) {
|
||||
@@ -60,31 +48,16 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error
|
||||
return GetByField[Season](tx, "short_name", shortname).GetFirst(ctx)
|
||||
}
|
||||
|
||||
func UpdateSeason(ctx context.Context, tx bun.Tx, season *Season) error {
|
||||
if season == nil {
|
||||
return errors.New("season cannot be nil")
|
||||
// Update updates the season struct. It does not insert to the database
|
||||
func (s *Season) Update(start, end, finalsStart, finalsEnd time.Time) {
|
||||
s.StartDate = start.Truncate(time.Hour * 24)
|
||||
if !end.IsZero() {
|
||||
s.EndDate.Time = end.Truncate(time.Hour * 24)
|
||||
}
|
||||
if season.ID == 0 {
|
||||
return errors.New("season ID cannot be 0")
|
||||
if !finalsStart.IsZero() {
|
||||
s.FinalsStartDate.Time = finalsStart.Truncate(time.Hour * 24)
|
||||
}
|
||||
// Truncate dates to day precision
|
||||
season.StartDate = season.StartDate.Truncate(time.Hour * 24)
|
||||
if !season.EndDate.IsZero() {
|
||||
season.EndDate.Time = season.EndDate.Time.Truncate(time.Hour * 24)
|
||||
if !finalsEnd.IsZero() {
|
||||
s.FinalsEndDate.Time = finalsEnd.Truncate(time.Hour * 24)
|
||||
}
|
||||
if !season.FinalsStartDate.IsZero() {
|
||||
season.FinalsStartDate.Time = season.FinalsStartDate.Time.Truncate(time.Hour * 24)
|
||||
}
|
||||
if !season.FinalsEndDate.IsZero() {
|
||||
season.FinalsEndDate.Time = season.FinalsEndDate.Time.Truncate(time.Hour * 24)
|
||||
}
|
||||
_, err := tx.NewUpdate().
|
||||
Model(season).
|
||||
Column("start_date", "end_date", "finals_start_date", "finals_end_date").
|
||||
Where("id = ?", season.ID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.NewUpdate")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
115
internal/db/update.go
Normal file
115
internal/db/update.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type updater[T any] struct {
|
||||
tx bun.Tx
|
||||
q *bun.UpdateQuery
|
||||
model *T
|
||||
columns []string
|
||||
auditCallback AuditCallback
|
||||
auditRequest *http.Request
|
||||
}
|
||||
|
||||
// Update creates an updater for a model
|
||||
// You must specify which columns to update via .Column() or use .WherePK()
|
||||
func Update[T any](tx bun.Tx, model *T) *updater[T] {
|
||||
if model == nil {
|
||||
panic("model cannot be nil")
|
||||
}
|
||||
return &updater[T]{
|
||||
tx: tx,
|
||||
q: tx.NewUpdate().Model(model),
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateByID creates an updater with an ID where clause
|
||||
// You must still specify which columns to update via .Column()
|
||||
func UpdateByID[T any](tx bun.Tx, id int, model *T) *updater[T] {
|
||||
if id <= 0 {
|
||||
panic("id must be positive")
|
||||
}
|
||||
return Update(tx, model).Where("id = ?", id)
|
||||
}
|
||||
|
||||
// Column specifies which columns to update
|
||||
// Example: .Column("start_date", "end_date")
|
||||
func (u *updater[T]) Column(columns ...string) *updater[T] {
|
||||
u.columns = append(u.columns, columns...)
|
||||
u.q = u.q.Column(columns...)
|
||||
return u
|
||||
}
|
||||
|
||||
// Where adds a WHERE clause
|
||||
// Example: .Where("id = ?", 123)
|
||||
func (u *updater[T]) Where(query string, args ...any) *updater[T] {
|
||||
u.q = u.q.Where(query, args...)
|
||||
return u
|
||||
}
|
||||
|
||||
// WherePK adds a WHERE clause on the primary key
|
||||
// The model must have its primary key field populated
|
||||
func (u *updater[T]) WherePK() *updater[T] {
|
||||
u.q = u.q.WherePK()
|
||||
return u
|
||||
}
|
||||
|
||||
// Set adds a raw SET clause for complex updates
|
||||
// Example: .Set("updated_at = NOW()")
|
||||
func (u *updater[T]) Set(query string, args ...any) *updater[T] {
|
||||
u.q = u.q.Set(query, args...)
|
||||
return u
|
||||
}
|
||||
|
||||
// WithAudit enables audit logging for this update operation
|
||||
// The callback will be invoked after successful update with auto-generated audit info
|
||||
// If the callback returns an error, the transaction will be rolled back
|
||||
func (u *updater[T]) WithAudit(r *http.Request, callback AuditCallback) *updater[T] {
|
||||
u.auditRequest = r
|
||||
u.auditCallback = callback
|
||||
return u
|
||||
}
|
||||
|
||||
// Exec executes the update and optionally logs to audit
|
||||
// Returns an error if update fails or if audit callback fails (triggering rollback)
|
||||
func (u *updater[T]) Exec(ctx context.Context) error {
|
||||
// Build audit details BEFORE update (captures changed fields)
|
||||
var details map[string]any
|
||||
if u.auditCallback != nil && len(u.columns) > 0 {
|
||||
details = extractChangedFields(u.model, u.columns)
|
||||
}
|
||||
|
||||
// Execute update
|
||||
_, err := u.q.Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bun.UpdateQuery.Exec")
|
||||
}
|
||||
|
||||
// Handle audit logging if enabled
|
||||
if u.auditCallback != nil && u.auditRequest != nil {
|
||||
tableName := extractTableName[T]()
|
||||
resourceType := extractResourceType(tableName)
|
||||
action := buildAction(resourceType, "update")
|
||||
|
||||
info := &AuditInfo{
|
||||
Action: action,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: extractPrimaryKey(u.model),
|
||||
Details: details, // Changed fields only
|
||||
}
|
||||
|
||||
// Call audit callback - if it fails, return error to trigger rollback
|
||||
if err := u.auditCallback(ctx, u.tx, info, u.auditRequest); err != nil {
|
||||
return errors.Wrap(err, "audit.callback")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user