diff --git a/cmd/oslstats/db.go b/cmd/oslstats/db.go deleted file mode 100644 index 0981b7d..0000000 --- a/cmd/oslstats/db.go +++ /dev/null @@ -1,48 +0,0 @@ -package main - -import ( - "database/sql" - "fmt" - "time" - - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/internal/db" - "github.com/uptrace/bun" - "github.com/uptrace/bun/dialect/pgdialect" - "github.com/uptrace/bun/driver/pgdriver" -) - -func setupBun(cfg *config.Config) (conn *bun.DB, close func() error) { - dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s", - cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL) - sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) - - sqldb.SetMaxOpenConns(25) - sqldb.SetMaxIdleConns(10) - sqldb.SetConnMaxLifetime(5 * time.Minute) - sqldb.SetConnMaxIdleTime(5 * time.Minute) - - conn = bun.NewDB(sqldb, pgdialect.New()) - registerDBModels(conn) - close = sqldb.Close - return conn, close -} - -func registerDBModels(conn *bun.DB) []any { - models := []any{ - (*db.RolePermission)(nil), - (*db.UserRole)(nil), - (*db.SeasonLeague)(nil), - (*db.TeamParticipation)(nil), - (*db.User)(nil), - (*db.DiscordToken)(nil), - (*db.Season)(nil), - (*db.League)(nil), - (*db.Team)(nil), - (*db.Role)(nil), - (*db.Permission)(nil), - (*db.AuditLog)(nil), - } - conn.RegisterModel(models...) - return models -} diff --git a/cmd/oslstats/main.go b/cmd/oslstats/main.go index 342ad08..4bdeec8 100644 --- a/cmd/oslstats/main.go +++ b/cmd/oslstats/main.go @@ -7,6 +7,7 @@ import ( "git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db/migrate" "github.com/pkg/errors" ) @@ -48,7 +49,7 @@ func main() { // Handle migration file creation (doesn't need DB connection) if flags.MigrateCreate != "" { - if err := createMigration(flags.MigrateCreate); err != nil { + if err := migrate.CreateMigration(flags.MigrateCreate); err != nil { logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "createMigration"))).Msg("Error creating migration") } return @@ -59,17 +60,21 @@ func main() { flags.MigrateStatus || flags.MigrateDryRun || flags.ResetDB { + var command, countStr string // Route to appropriate command if flags.MigrateUp != "" { - err = runMigrations(ctx, cfg, "up", flags.MigrateUp) + command = "up" + countStr = flags.MigrateUp } else if flags.MigrateRollback != "" { - err = runMigrations(ctx, cfg, "rollback", flags.MigrateRollback) + command = "rollback" + countStr = flags.MigrateRollback } else if flags.MigrateStatus { - err = runMigrations(ctx, cfg, "status", "") - } else if flags.MigrateDryRun { - err = runMigrations(ctx, cfg, "dry-run", "") - } else if flags.ResetDB { - err = resetDatabase(ctx, cfg) + command = "status" + } + if flags.ResetDB { + err = migrate.ResetDatabase(ctx, cfg) + } else { + err = migrate.RunMigrations(ctx, cfg, command, countStr) } if err != nil { diff --git a/cmd/oslstats/run.go b/cmd/oslstats/run.go index 030652e..a400a80 100644 --- a/cmd/oslstats/run.go +++ b/cmd/oslstats/run.go @@ -12,8 +12,10 @@ import ( "github.com/pkg/errors" "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/embedfs" + "git.haelnorr.com/h/oslstats/internal/server" "git.haelnorr.com/h/oslstats/internal/store" ) @@ -25,8 +27,7 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error { // Setup the database connection logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Connecting to database") - bun, closedb := setupBun(cfg) - // registerDBModels(bun) + conn := db.NewDB(cfg.DB) // Setup embedded files logger.Debug().Msg("Getting embedded files") @@ -47,7 +48,7 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error { } logger.Debug().Msg("Setting up HTTP server") - httpServer, err := setupHTTPServer(&staticFS, cfg, logger, bun, store, discordAPI) + httpServer, err := server.Setup(staticFS, cfg, logger, conn, store, discordAPI) if err != nil { return errors.Wrap(err, "setupHttpServer") } @@ -71,7 +72,7 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error { if err != nil { logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "httpServer.Shutdown"))).Msg("Error during HTTP server shutdown") } - err = closedb() + err = conn.Close() if err != nil { logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "closedb"))).Msg("Error during database close") } diff --git a/internal/auditlog/logger.go b/internal/auditlog/logger.go deleted file mode 100644 index 5b6abca..0000000 --- a/internal/auditlog/logger.go +++ /dev/null @@ -1,187 +0,0 @@ -// Package auditlog provides a system for logging events that require permissions to the audit log -package auditlog - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "time" - - "git.haelnorr.com/h/oslstats/internal/db" - "github.com/pkg/errors" - "github.com/uptrace/bun" -) - -type Logger struct { - conn *bun.DB -} - -func NewLogger(conn *bun.DB) *Logger { - return &Logger{conn: conn} -} - -// LogSuccess logs a successful permission-protected action -func (l *Logger) LogSuccess( - ctx context.Context, - tx bun.Tx, - user *db.User, - action string, - resourceType string, - resourceID any, // Can be int, string, or nil - details map[string]any, - r *http.Request, -) error { - return l.log(ctx, tx, user, action, resourceType, resourceID, details, "success", nil, r) -} - -// LogError logs a failed action due to an error -func (l *Logger) LogError( - ctx context.Context, - tx bun.Tx, - user *db.User, - action string, - resourceType string, - resourceID any, - err error, - r *http.Request, -) error { - errMsg := err.Error() - return l.log(ctx, tx, user, action, resourceType, resourceID, nil, "error", &errMsg, r) -} - -func (l *Logger) log( - ctx context.Context, - tx bun.Tx, - user *db.User, - action string, - resourceType string, - resourceID any, - details map[string]any, - result string, - errorMessage *string, - r *http.Request, -) error { - if user == nil { - return errors.New("user cannot be nil for audit logging") - } - - // Convert resourceID to string - var resourceIDStr *string - if resourceID != nil { - idStr := fmt.Sprintf("%v", resourceID) - resourceIDStr = &idStr - } - - // Marshal details to JSON - var detailsJSON json.RawMessage - if details != nil { - jsonBytes, err := json.Marshal(details) - if err != nil { - return errors.Wrap(err, "json.Marshal details") - } - detailsJSON = jsonBytes - } - - // Extract IP and User-Agent from request - ipAddress := r.RemoteAddr - userAgent := r.UserAgent() - - log := &db.AuditLog{ - UserID: user.ID, - Action: action, - ResourceType: resourceType, - ResourceID: resourceIDStr, - Details: detailsJSON, - IPAddress: ipAddress, - UserAgent: userAgent, - Result: result, - ErrorMessage: errorMessage, - CreatedAt: time.Now().Unix(), - } - - return db.CreateAuditLog(ctx, tx, log) -} - -// GetRecentLogs retrieves recent audit logs with pagination -func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) { - var logs *db.List[db.AuditLog] - if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { - var err error - logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil) - if err != nil { - return errors.Wrap(err, "db.GetAuditLogs") - } - return nil - }); err != nil { - return nil, errors.Wrap(err, "db.WithTxFailSilently") - } - return logs, nil -} - -// GetLogsByUser retrieves audit logs for a specific user -func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) { - var logs *db.List[db.AuditLog] - if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { - var err error - logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts) - if err != nil { - return errors.Wrap(err, "db.GetAuditLogsByUser") - } - return nil - }); err != nil { - return nil, errors.Wrap(err, "db.WithTxFailSilently") - } - return logs, nil -} - -// CleanupOldLogs deletes audit logs older than the specified number of days -func (l *Logger) CleanupOldLogs(ctx context.Context, daysToKeep int) (int, error) { - if daysToKeep <= 0 { - return 0, errors.New("daysToKeep must be positive") - } - - cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix() - - var count int - if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { - var err error - count, err = db.CleanupOldAuditLogs(ctx, tx, cutoffTime) - if err != nil { - return errors.Wrap(err, "db.CleanupOldAuditLogs") - } - return nil - }); err != nil { - return 0, errors.Wrap(err, "db.WithTxFailSilently") - } - return count, nil -} - -// Callback returns a db.AuditCallback that logs to this Logger -// This is used with the generic database helpers (Insert, Update, Delete) -// -// Usage: -// -// audit := auditlog.NewLogger(conn) -// err := db.Insert(tx, season). -// WithAudit(r, audit.Callback()). -// Exec(ctx) -func (l *Logger) Callback() db.AuditCallback { - return func(ctx context.Context, tx bun.Tx, info *db.AuditInfo, r *http.Request) error { - user := db.CurrentUser(ctx) - if user == nil { - return errors.New("no user in context for audit logging") - } - - return l.LogSuccess( - ctx, - tx, - user, - info.Action, - info.ResourceType, - info.ResourceID, - info.Details, - r, - ) - } -} diff --git a/internal/db/audit.go b/internal/db/audit.go index bb89e8b..d7013e7 100644 --- a/internal/db/audit.go +++ b/internal/db/audit.go @@ -1,14 +1,23 @@ package db import ( - "context" "net/http" "reflect" "strings" - - "github.com/uptrace/bun" ) +type AuditMeta struct { + r *http.Request + u *User +} + +func NewAudit(r *http.Request, u *User) *AuditMeta { + if u == nil { + u = CurrentUser(r.Context()) + } + return &AuditMeta{r, u} +} + // AuditInfo contains metadata for audit logging type AuditInfo struct { Action string // e.g., "seasons.create", "users.update" @@ -17,9 +26,6 @@ type AuditInfo struct { Details map[string]any // Changed fields or additional metadata } -// AuditCallback is called after successful database operations to log changes -type AuditCallback func(ctx context.Context, tx bun.Tx, info *AuditInfo, r *http.Request) error - // extractTableName gets the bun table name from a model type using reflection // Example: Season with `bun:"table:seasons,alias:s"` returns "seasons" func extractTableName[T any]() string { @@ -27,7 +33,7 @@ func extractTableName[T any]() string { t := reflect.TypeOf(model) // Handle pointer types - if t.Kind() == reflect.Ptr { + if t.Kind() == reflect.Pointer { t = t.Elem() } @@ -38,11 +44,9 @@ func extractTableName[T any]() string { bunTag := field.Tag.Get("bun") if bunTag != "" { // Parse tag: "table:seasons,alias:s" -> "seasons" - parts := strings.Split(bunTag, ",") - for _, part := range parts { - if strings.HasPrefix(part, "table:") { - return strings.TrimPrefix(part, "table:") - } + for part := range strings.SplitSeq(bunTag, ",") { + part, _ := strings.CutPrefix(part, "table:") + return part } } } @@ -81,7 +85,7 @@ func extractPrimaryKey[T any](model *T) any { } v := reflect.ValueOf(model) - if v.Kind() == reflect.Ptr { + if v.Kind() == reflect.Pointer { v = v.Elem() } @@ -110,7 +114,7 @@ func extractChangedFields[T any](model *T, columns []string) map[string]any { result := make(map[string]any) v := reflect.ValueOf(model) - if v.Kind() == reflect.Ptr { + if v.Kind() == reflect.Pointer { v = v.Elem() } @@ -142,5 +146,3 @@ func extractChangedFields[T any](model *T, columns []string) map[string]any { return result } - -// Note: We don't need getTxFromQuery since we store the tx directly in our helper structs diff --git a/internal/db/auditlogger.go b/internal/db/auditlogger.go new file mode 100644 index 0000000..988d4d6 --- /dev/null +++ b/internal/db/auditlogger.go @@ -0,0 +1,91 @@ +package db + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +// LogSuccess logs a successful permission-protected action +func LogSuccess( + ctx context.Context, + tx bun.Tx, + meta *AuditMeta, + info *AuditInfo, +) error { + return log(ctx, tx, meta, info, "success", nil) +} + +// LogError logs a failed action due to an error +func LogError( + ctx context.Context, + tx bun.Tx, + meta *AuditMeta, + info *AuditInfo, + err error, +) error { + errMsg := err.Error() + return log(ctx, tx, meta, info, "error", &errMsg) +} + +func log( + ctx context.Context, + tx bun.Tx, + meta *AuditMeta, + info *AuditInfo, + result string, + errorMessage *string, +) error { + if meta == nil { + return errors.New("audit meta cannot be nil for audit logging") + } + if info == nil { + return errors.New("audit info cannot be nil for audit logging") + } + if meta.u == nil { + return errors.New("user cannot be nil for audit logging") + } + if meta.r == nil { + return errors.New("request cannot be nil for audit logging") + } + + // Convert resourceID to string + var resourceIDStr *string + if info.ResourceID != nil { + idStr := fmt.Sprintf("%v", info.ResourceID) + resourceIDStr = &idStr + } + + // Marshal details to JSON + var detailsJSON json.RawMessage + if info.Details != nil { + jsonBytes, err := json.Marshal(info.Details) + if err != nil { + return errors.Wrap(err, "json.Marshal details") + } + detailsJSON = jsonBytes + } + + // Extract IP and User-Agent from request + ipAddress := meta.r.RemoteAddr + userAgent := meta.r.UserAgent() + + log := &AuditLog{ + UserID: meta.u.ID, + Action: info.Action, + ResourceType: info.ResourceType, + ResourceID: resourceIDStr, + Details: detailsJSON, + IPAddress: ipAddress, + UserAgent: userAgent, + Result: result, + ErrorMessage: errorMessage, + CreatedAt: time.Now().Unix(), + } + + return CreateAuditLog(ctx, tx, log) +} diff --git a/internal/backup/backup.go b/internal/db/backup.go similarity index 80% rename from internal/backup/backup.go rename to internal/db/backup.go index 00b8278..2bfe847 100644 --- a/internal/backup/backup.go +++ b/internal/db/backup.go @@ -1,4 +1,4 @@ -package backup +package db import ( "context" @@ -9,14 +9,13 @@ import ( "sort" "time" - "git.haelnorr.com/h/oslstats/internal/config" "github.com/pkg/errors" ) // CreateBackup creates a compressed PostgreSQL dump before migrations // Returns backup filename and error // If pg_dump is not available, returns nil error with warning -func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (string, error) { +func CreateBackup(ctx context.Context, cfg *Config, operation string) (string, error) { // Check if pg_dump is available if _, err := exec.LookPath("pg_dump"); err != nil { fmt.Println("[WARN] pg_dump not found - skipping backup") @@ -28,13 +27,13 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st } // Ensure backup directory exists - if err := os.MkdirAll(cfg.DB.BackupDir, 0755); err != nil { + if err := os.MkdirAll(cfg.BackupDir, 0o755); err != nil { return "", errors.Wrap(err, "failed to create backup directory") } // Generate filename: YYYYMMDD_HHmmss_pre_{operation}.sql.gz timestamp := time.Now().Format("20060102_150405") - filename := filepath.Join(cfg.DB.BackupDir, + filename := filepath.Join(cfg.BackupDir, fmt.Sprintf("%s_pre_%s.sql.gz", timestamp, operation)) // Check if gzip is available @@ -42,7 +41,7 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st if _, err := exec.LookPath("gzip"); err != nil { fmt.Println("[WARN] gzip not found - using uncompressed backup") useGzip = false - filename = filepath.Join(cfg.DB.BackupDir, + filename = filepath.Join(cfg.BackupDir, fmt.Sprintf("%s_pre_%s.sql", timestamp, operation)) } @@ -52,19 +51,19 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st // Use shell to pipe pg_dump through gzip pgDumpCmd := fmt.Sprintf( "pg_dump -h %s -p %d -U %s -d %s --no-owner --no-acl --clean --if-exists | gzip > %s", - cfg.DB.Host, - cfg.DB.Port, - cfg.DB.User, - cfg.DB.DB, + cfg.Host, + cfg.Port, + cfg.User, + cfg.DB, filename, ) cmd = exec.CommandContext(ctx, "sh", "-c", pgDumpCmd) } else { cmd = exec.CommandContext(ctx, "pg_dump", - "-h", cfg.DB.Host, - "-p", fmt.Sprint(cfg.DB.Port), - "-U", cfg.DB.User, - "-d", cfg.DB.DB, + "-h", cfg.Host, + "-p", fmt.Sprint(cfg.Port), + "-U", cfg.User, + "-d", cfg.DB, "-f", filename, "--no-owner", "--no-acl", @@ -75,7 +74,7 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st // Set password via environment variable cmd.Env = append(os.Environ(), - fmt.Sprintf("PGPASSWORD=%s", cfg.DB.Password)) + fmt.Sprintf("PGPASSWORD=%s", cfg.Password)) // Run backup if err := cmd.Run(); err != nil { @@ -95,14 +94,14 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st } // CleanOldBackups keeps only the N most recent backups -func CleanOldBackups(cfg *config.Config, keepCount int) error { +func CleanOldBackups(cfg *Config, keepCount int) error { // Get all backup files (both .sql and .sql.gz) - sqlFiles, err := filepath.Glob(filepath.Join(cfg.DB.BackupDir, "*.sql")) + sqlFiles, err := filepath.Glob(filepath.Join(cfg.BackupDir, "*.sql")) if err != nil { return errors.Wrap(err, "failed to list .sql backups") } - gzFiles, err := filepath.Glob(filepath.Join(cfg.DB.BackupDir, "*.sql.gz")) + gzFiles, err := filepath.Glob(filepath.Join(cfg.BackupDir, "*.sql.gz")) if err != nil { return errors.Wrap(err, "failed to list .sql.gz backups") } diff --git a/internal/db/delete.go b/internal/db/delete.go index ac2607f..07753b5 100644 --- a/internal/db/delete.go +++ b/internal/db/delete.go @@ -3,18 +3,17 @@ package db import ( "context" "database/sql" - "net/http" "github.com/pkg/errors" "github.com/uptrace/bun" ) type deleter[T any] struct { - tx bun.Tx - q *bun.DeleteQuery - resourceID any // Store ID before deletion for audit - auditCallback AuditCallback - auditRequest *http.Request + tx bun.Tx + q *bun.DeleteQuery + resourceID any // Store ID before deletion for audit + audit *AuditMeta + auditInfo *AuditInfo } type systemType interface { @@ -39,11 +38,10 @@ func (d *deleter[T]) Where(query string, args ...any) *deleter[T] { } // WithAudit enables audit logging for this delete operation -// The callback will be invoked after successful deletion with auto-generated audit info -// If the callback returns an error, the transaction will be rolled back -func (d *deleter[T]) WithAudit(r *http.Request, callback AuditCallback) *deleter[T] { - d.auditRequest = r - d.auditCallback = callback +// If the provided *AuditInfo is nil, will use reflection to automatically work out the details +func (d *deleter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *deleter[T] { + d.audit = meta + d.auditInfo = info return d } @@ -57,21 +55,23 @@ func (d *deleter[T]) Delete(ctx context.Context) error { } // Handle audit logging if enabled - if d.auditCallback != nil && d.auditRequest != nil { - tableName := extractTableName[T]() - resourceType := extractResourceType(tableName) - action := buildAction(resourceType, "delete") + if d.audit != nil { + if d.auditInfo == nil { + tableName := extractTableName[T]() + resourceType := extractResourceType(tableName) + action := buildAction(resourceType, "delete") - info := &AuditInfo{ - Action: action, - ResourceType: resourceType, - ResourceID: d.resourceID, - Details: nil, // Delete doesn't need details + d.auditInfo = &AuditInfo{ + Action: action, + ResourceType: resourceType, + ResourceID: d.resourceID, + Details: nil, // Delete doesn't need details + } } - // Call audit callback - if it fails, return error to trigger rollback - if err := d.auditCallback(ctx, d.tx, info, d.auditRequest); err != nil { - return errors.Wrap(err, "audit.callback") + err = LogSuccess(ctx, d.tx, d.audit, d.auditInfo) + if err != nil { + return errors.Wrap(err, "LogSuccess") } } @@ -82,7 +82,7 @@ func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] { return DeleteItem[T](tx).Where("id = ?", id) } -func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) error { +func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error { deleter := DeleteByID[T](tx, id) item, err := GetByID[T](tx, id).Get(ctx) if err != nil { @@ -94,5 +94,8 @@ func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) if (*item).isSystem() { return errors.New("record is system protected") } + if audit != nil { + deleter = deleter.WithAudit(audit, nil) + } return deleter.Delete(ctx) } diff --git a/internal/db/insert.go b/internal/db/insert.go index a8082dd..30b4c45 100644 --- a/internal/db/insert.go +++ b/internal/db/insert.go @@ -3,7 +3,6 @@ package db import ( "context" "fmt" - "net/http" "strings" "github.com/pkg/errors" @@ -11,13 +10,13 @@ import ( ) type inserter[T any] struct { - tx bun.Tx - q *bun.InsertQuery - model *T - models []*T - isBulk bool - auditCallback AuditCallback - auditRequest *http.Request + tx bun.Tx + q *bun.InsertQuery + model *T + models []*T + isBulk bool + audit *AuditMeta + auditInfo *AuditInfo } // Insert creates an inserter for a single model @@ -76,11 +75,10 @@ func (i *inserter[T]) Returning(columns ...string) *inserter[T] { } // WithAudit enables audit logging for this insert operation -// The callback will be invoked after successful insert with auto-generated audit info -// If the callback returns an error, the transaction will be rolled back -func (i *inserter[T]) WithAudit(r *http.Request, callback AuditCallback) *inserter[T] { - i.auditRequest = r - i.auditCallback = callback +// If the provided *AuditInfo is nil, will use reflection to automatically work out the details +func (i *inserter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *inserter[T] { + i.audit = meta + i.auditInfo = info return i } @@ -94,35 +92,29 @@ func (i *inserter[T]) Exec(ctx context.Context) error { } // Handle audit logging if enabled - if i.auditCallback != nil && i.auditRequest != nil { - tableName := extractTableName[T]() - resourceType := extractResourceType(tableName) - action := buildAction(resourceType, "create") - - var info *AuditInfo - if i.isBulk { - // For bulk inserts, log once with count in details - info = &AuditInfo{ + if i.audit != nil { + if i.auditInfo == nil { + tableName := extractTableName[T]() + resourceType := extractResourceType(tableName) + action := buildAction(resourceType, "create") + i.auditInfo = &AuditInfo{ Action: action, ResourceType: resourceType, ResourceID: nil, - Details: map[string]any{ - "count": len(i.models), - }, - } - } else { - // For single insert, log with resource ID - info = &AuditInfo{ - Action: action, - ResourceType: resourceType, - ResourceID: extractPrimaryKey(i.model), Details: nil, } + if i.isBulk { + i.auditInfo.Details = map[string]any{ + "count": len(i.models), + } + } else { + i.auditInfo.ResourceID = extractPrimaryKey(i.model) + } } - // Call audit callback - if it fails, return error to trigger rollback - if err := i.auditCallback(ctx, i.tx, info, i.auditRequest); err != nil { - return errors.Wrap(err, "audit.callback") + err = LogSuccess(ctx, i.tx, i.audit, i.auditInfo) + if err != nil { + return errors.Wrap(err, "LogSuccess") } } diff --git a/internal/db/league.go b/internal/db/league.go index 9a96a36..1814086 100644 --- a/internal/db/league.go +++ b/internal/db/league.go @@ -19,13 +19,6 @@ type League struct { Teams []Team `bun:"m2m:team_participations,join:League=Team"` } -type SeasonLeague struct { - SeasonID int `bun:",pk"` - Season *Season `bun:"rel:belongs-to,join:season_id=id"` - LeagueID int `bun:",pk"` - League *League `bun:"rel:belongs-to,join:league_id=id"` -} - func GetLeagues(ctx context.Context, tx bun.Tx) ([]*League, error) { return GetList[League](tx).Relation("Seasons").GetAll(ctx) } @@ -37,41 +30,16 @@ func GetLeague(ctx context.Context, tx bun.Tx, shortname string) (*League, error return GetByField[League](tx, "short_name", shortname).Relation("Seasons").Get(ctx) } -// GetSeasonLeague retrieves a specific season-league combination with teams -func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string) (*Season, *League, []*Team, error) { - if seasonShortName == "" { - return nil, nil, nil, errors.New("season short_name cannot be empty") +func NewLeague(ctx context.Context, tx bun.Tx, name, shortname, description string, audit *AuditMeta) (*League, error) { + league := &League{ + Name: name, + ShortName: shortname, + Description: description, } - if leagueShortName == "" { - return nil, nil, nil, errors.New("league short_name cannot be empty") - } - - // Get the season - season, err := GetSeason(ctx, tx, seasonShortName) + err := Insert(tx, league). + WithAudit(audit, nil).Exec(ctx) if err != nil { - return nil, nil, nil, errors.Wrap(err, "GetSeason") + return nil, errors.Wrap(err, "db.Insert") } - - // Get the league - league, err := GetLeague(ctx, tx, leagueShortName) - if err != nil { - return nil, nil, nil, errors.Wrap(err, "GetLeague") - } - if season == nil || league == nil || !season.HasLeague(league.ID) { - return nil, nil, nil, nil - } - - // Get all teams participating in this season+league - var teams []*Team - err = tx.NewSelect(). - Model(&teams). - Join("INNER JOIN team_participations AS tp ON tp.team_id = t.id"). - Where("tp.season_id = ? AND tp.league_id = ?", season.ID, league.ID). - Order("t.name ASC"). - Scan(ctx) - if err != nil { - return nil, nil, nil, errors.Wrap(err, "tx.Select teams") - } - - return season, league, teams, nil + return league, nil } diff --git a/cmd/oslstats/migrate.go b/internal/db/migrate/migrate.go similarity index 87% rename from cmd/oslstats/migrate.go rename to internal/db/migrate/migrate.go index 85e1fbd..1fa523b 100644 --- a/cmd/oslstats/migrate.go +++ b/internal/db/migrate/migrate.go @@ -1,4 +1,5 @@ -package main +// Package migrate provides functions for managing database migrations +package migrate import ( "bufio" @@ -11,20 +12,19 @@ import ( "text/tabwriter" "time" - "git.haelnorr.com/h/oslstats/cmd/oslstats/migrations" - "git.haelnorr.com/h/oslstats/internal/backup" "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/db/migrations" "github.com/pkg/errors" - "github.com/uptrace/bun" "github.com/uptrace/bun/migrate" ) -// runMigrations executes database migrations -func runMigrations(ctx context.Context, cfg *config.Config, command string, countStr string) error { - conn, close := setupBun(cfg) - defer func() { _ = close() }() +// RunMigrations executes database migrations +func RunMigrations(ctx context.Context, cfg *config.Config, command string, countStr string) error { + conn := db.NewDB(cfg.DB) + defer func() { _ = conn.Close() }() - migrator := migrate.NewMigrator(conn, migrations.Migrations) + migrator := migrate.NewMigrator(conn.DB, migrations.Migrations) // Initialize migration tables if err := migrator.Init(ctx); err != nil { @@ -47,15 +47,13 @@ func runMigrations(ctx context.Context, cfg *config.Config, command string, coun return migrateRollback(ctx, migrator, conn, cfg, countStr) case "status": return migrateStatus(ctx, migrator) - case "dry-run": - return migrateDryRun(ctx, migrator) default: return fmt.Errorf("unknown migration command: %s", command) } } // migrateUp runs pending migrations -func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cfg *config.Config, countStr string) error { +func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error { // Parse count parameter count, all, err := parseMigrationCount(countStr) if err != nil { @@ -101,13 +99,13 @@ func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cf // Create backup unless --no-backup flag is set if !cfg.Flags.MigrateNoBackup { fmt.Println("[INFO] Step 3/5: Creating backup...") - _, err := backup.CreateBackup(ctx, cfg, "migration") + _, err := db.CreateBackup(ctx, cfg.DB, "migration") if err != nil { return errors.Wrap(err, "create backup") } // Clean old backups - if err := backup.CleanOldBackups(cfg, cfg.DB.BackupRetention); err != nil { + if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil { fmt.Printf("[WARN] Failed to clean old backups: %v\n", err) } } else { @@ -143,7 +141,7 @@ func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cf } // migrateRollback rolls back migrations -func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cfg *config.Config, countStr string) error { +func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error { // Parse count parameter count, all, err := parseMigrationCount(countStr) if err != nil { @@ -182,13 +180,13 @@ func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *bun. // Create backup unless --no-backup flag is set if !cfg.Flags.MigrateNoBackup { fmt.Println("[INFO] Creating backup before rollback...") - _, err := backup.CreateBackup(ctx, cfg, "rollback") + _, err := db.CreateBackup(ctx, cfg.DB, "rollback") if err != nil { return errors.Wrap(err, "create backup") } // Clean old backups - if err := backup.CleanOldBackups(cfg, cfg.DB.BackupRetention); err != nil { + if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil { fmt.Printf("[WARN] Failed to clean old backups: %v\n", err) } } else { @@ -255,27 +253,6 @@ func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error { return nil } -// migrateDryRun shows what migrations would run without applying them -func migrateDryRun(ctx context.Context, migrator *migrate.Migrator) error { - group, err := migrator.Migrate(ctx, migrate.WithNopMigration()) - if err != nil { - return errors.Wrap(err, "dry-run") - } - - if group.IsZero() { - fmt.Println("[INFO] No pending migrations") - return nil - } - - fmt.Println("[INFO] Pending migrations (dry-run):") - for _, migration := range group.Migrations { - fmt.Printf(" 📋 %s\n", migration.Name) - } - fmt.Printf("[INFO] Would migrate to group %d\n", group.ID) - - return nil -} - // validateMigrations ensures migrations compile before running func validateMigrations(ctx context.Context) error { cmd := exec.CommandContext(ctx, "go", "build", @@ -292,7 +269,7 @@ func validateMigrations(ctx context.Context) error { } // acquireMigrationLock prevents concurrent migrations using PostgreSQL advisory lock -func acquireMigrationLock(ctx context.Context, conn *bun.DB) error { +func acquireMigrationLock(ctx context.Context, conn *db.DB) error { const lockID = 1234567890 // Arbitrary unique ID for migration lock const timeoutSeconds = 300 // 5 minutes @@ -318,7 +295,7 @@ func acquireMigrationLock(ctx context.Context, conn *bun.DB) error { } // releaseMigrationLock releases the migration lock -func releaseMigrationLock(ctx context.Context, conn *bun.DB) { +func releaseMigrationLock(ctx context.Context, conn *db.DB) { const lockID = 1234567890 _, err := conn.NewRaw("SELECT pg_advisory_unlock(?)", lockID).Exec(ctx) @@ -329,8 +306,8 @@ func releaseMigrationLock(ctx context.Context, conn *bun.DB) { } } -// createMigration generates a new migration file -func createMigration(name string) error { +// CreateMigration generates a new migration file +func CreateMigration(name string) error { if name == "" { return errors.New("migration name cannot be empty") } @@ -340,7 +317,7 @@ func createMigration(name string) error { // Generate timestamp timestamp := time.Now().Format("20060102150405") - filename := fmt.Sprintf("cmd/oslstats/migrations/%s_%s.go", timestamp, name) + filename := fmt.Sprintf("internal/db/migrations/%s_%s.go", timestamp, name) // Template template := `package migrations @@ -502,8 +479,8 @@ func executeDownMigrations(ctx context.Context, migrator *migrate.Migrator, migr return rolledBack, nil } -// resetDatabase drops and recreates all tables (destructive) -func resetDatabase(ctx context.Context, cfg *config.Config) error { +// ResetDatabase drops and recreates all tables (destructive) +func ResetDatabase(ctx context.Context, cfg *config.Config) error { fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!") fmt.Print("Type 'yes' to continue: ") @@ -518,10 +495,10 @@ func resetDatabase(ctx context.Context, cfg *config.Config) error { fmt.Println("❌ Reset cancelled") return nil } - conn, close := setupBun(cfg) - defer func() { _ = close() }() + conn := db.NewDB(cfg.DB) + defer func() { _ = conn.Close() }() - models := registerDBModels(conn) + models := conn.RegisterModels() for _, model := range models { if err := conn.ResetModel(ctx, model); err != nil { diff --git a/cmd/oslstats/migrations/20250124000001_initial_schema.go b/internal/db/migrations/20250124000001_initial_schema.go similarity index 74% rename from cmd/oslstats/migrations/20250124000001_initial_schema.go rename to internal/db/migrations/20250124000001_initial_schema.go index b42286f..f0ba125 100644 --- a/cmd/oslstats/migrations/20250124000001_initial_schema.go +++ b/internal/db/migrations/20250124000001_initial_schema.go @@ -10,9 +10,9 @@ import ( func init() { Migrations.MustRegister( // UP: Create initial tables (users, discord_tokens) - func(ctx context.Context, dbConn *bun.DB) error { + func(ctx context.Context, conn *bun.DB) error { // Create users table - _, err := dbConn.NewCreateTable(). + _, err := conn.NewCreateTable(). Model((*db.User)(nil)). Exec(ctx) if err != nil { @@ -20,15 +20,15 @@ func init() { } // Create discord_tokens table - _, err = dbConn.NewCreateTable(). + _, err = conn.NewCreateTable(). Model((*db.DiscordToken)(nil)). Exec(ctx) return err }, // DOWN: Drop tables in reverse order - func(ctx context.Context, dbConn *bun.DB) error { + func(ctx context.Context, conn *bun.DB) error { // Drop discord_tokens first (has foreign key to users) - _, err := dbConn.NewDropTable(). + _, err := conn.NewDropTable(). Model((*db.DiscordToken)(nil)). IfExists(). Exec(ctx) @@ -37,7 +37,7 @@ func init() { } // Drop users table - _, err = dbConn.NewDropTable(). + _, err = conn.NewDropTable(). Model((*db.User)(nil)). IfExists(). Exec(ctx) diff --git a/cmd/oslstats/migrations/20260127194815_seasons.go b/internal/db/migrations/20260127194815_seasons.go similarity index 70% rename from cmd/oslstats/migrations/20260127194815_seasons.go rename to internal/db/migrations/20260127194815_seasons.go index 81f3d18..49ac8ca 100644 --- a/cmd/oslstats/migrations/20260127194815_seasons.go +++ b/internal/db/migrations/20260127194815_seasons.go @@ -10,8 +10,8 @@ import ( func init() { Migrations.MustRegister( // UP migration - func(ctx context.Context, dbConn *bun.DB) error { - _, err := dbConn.NewCreateTable(). + func(ctx context.Context, conn *bun.DB) error { + _, err := conn.NewCreateTable(). Model((*db.Season)(nil)). Exec(ctx) if err != nil { @@ -20,8 +20,8 @@ func init() { return nil }, // DOWN migration - func(ctx context.Context, dbConn *bun.DB) error { - _, err := dbConn.NewDropTable(). + func(ctx context.Context, conn *bun.DB) error { + _, err := conn.NewDropTable(). Model((*db.Season)(nil)). IfExists(). Exec(ctx) diff --git a/cmd/oslstats/migrations/20260202231414_add_rbac_system.go b/internal/db/migrations/20260202231414_add_rbac_system.go similarity index 85% rename from cmd/oslstats/migrations/20260202231414_add_rbac_system.go rename to internal/db/migrations/20260202231414_add_rbac_system.go index 4f986b8..92fb60d 100644 --- a/cmd/oslstats/migrations/20260202231414_add_rbac_system.go +++ b/internal/db/migrations/20260202231414_add_rbac_system.go @@ -12,10 +12,10 @@ import ( func init() { Migrations.MustRegister( // UP migration - func(ctx context.Context, dbConn *bun.DB) error { - dbConn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil)) + func(ctx context.Context, conn *bun.DB) error { + conn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil)) // Create permissions table - _, err := dbConn.NewCreateTable(). + _, err := conn.NewCreateTable(). Model((*db.Role)(nil)). Exec(ctx) if err != nil { @@ -23,7 +23,7 @@ func init() { } // Create permissions table - _, err = dbConn.NewCreateTable(). + _, err = conn.NewCreateTable(). Model((*db.Permission)(nil)). Exec(ctx) if err != nil { @@ -31,7 +31,7 @@ func init() { } // Create indexes for permissions - _, err = dbConn.NewCreateIndex(). + _, err = conn.NewCreateIndex(). Model((*db.Permission)(nil)). Index("idx_permissions_resource"). Column("resource"). @@ -40,7 +40,7 @@ func init() { return err } - _, err = dbConn.NewCreateIndex(). + _, err = conn.NewCreateIndex(). Model((*db.Permission)(nil)). Index("idx_permissions_action"). Column("action"). @@ -49,21 +49,21 @@ func init() { return err } - _, err = dbConn.NewCreateTable(). + _, err = conn.NewCreateTable(). Model((*db.RolePermission)(nil)). Exec(ctx) if err != nil { return err } - _, err = dbConn.ExecContext(ctx, ` + _, err = conn.ExecContext(ctx, ` CREATE INDEX idx_role_permissions_role ON role_permissions(role_id) `) if err != nil { return err } - _, err = dbConn.ExecContext(ctx, ` + _, err = conn.ExecContext(ctx, ` CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id) `) if err != nil { @@ -71,7 +71,7 @@ func init() { } // Create user_roles table - _, err = dbConn.NewCreateTable(). + _, err = conn.NewCreateTable(). Model((*db.UserRole)(nil)). Exec(ctx) if err != nil { @@ -79,7 +79,7 @@ func init() { } // Create indexes for user_roles - _, err = dbConn.NewCreateIndex(). + _, err = conn.NewCreateIndex(). Model((*db.UserRole)(nil)). Index("idx_user_roles_user"). Column("user_id"). @@ -88,7 +88,7 @@ func init() { return err } - _, err = dbConn.NewCreateIndex(). + _, err = conn.NewCreateIndex(). Model((*db.UserRole)(nil)). Index("idx_user_roles_role"). Column("role_id"). @@ -98,7 +98,7 @@ func init() { } // Create audit_log table - _, err = dbConn.NewCreateTable(). + _, err = conn.NewCreateTable(). Model((*db.AuditLog)(nil)). Exec(ctx) if err != nil { @@ -106,7 +106,7 @@ func init() { } // Create indexes for audit_log - _, err = dbConn.NewCreateIndex(). + _, err = conn.NewCreateIndex(). Model((*db.AuditLog)(nil)). Index("idx_audit_log_user"). Column("user_id"). @@ -115,7 +115,7 @@ func init() { return err } - _, err = dbConn.NewCreateIndex(). + _, err = conn.NewCreateIndex(). Model((*db.AuditLog)(nil)). Index("idx_audit_log_action"). Column("action"). @@ -124,7 +124,7 @@ func init() { return err } - _, err = dbConn.NewCreateIndex(). + _, err = conn.NewCreateIndex(). Model((*db.AuditLog)(nil)). Index("idx_audit_log_resource"). Column("resource_type", "resource_id"). @@ -133,7 +133,7 @@ func init() { return err } - _, err = dbConn.NewCreateIndex(). + _, err = conn.NewCreateIndex(). Model((*db.AuditLog)(nil)). Index("idx_audit_log_created"). Column("created_at"). @@ -142,7 +142,7 @@ func init() { return err } - err = seedSystemRBAC(ctx, dbConn) + err = seedSystemRBAC(ctx, conn) if err != nil { return err } @@ -173,7 +173,7 @@ func init() { ) } -func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error { +func seedSystemRBAC(ctx context.Context, conn *bun.DB) error { // Seed system roles now := time.Now().Unix() @@ -185,7 +185,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error { CreatedAt: now, } - _, err := dbConn.NewInsert(). + _, err := conn.NewInsert(). Model(adminRole). Returning("id"). Exec(ctx) @@ -201,7 +201,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error { CreatedAt: now, } - _, err = dbConn.NewInsert(). + _, err = conn.NewInsert(). Model(userRole). Exec(ctx) if err != nil { @@ -219,7 +219,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error { {Name: "users.manage_roles", DisplayName: "Manage User Roles", Description: "Assign and revoke user roles", Resource: "users", Action: "manage_roles", IsSystem: true, CreatedAt: now}, } - _, err = dbConn.NewInsert(). + _, err = conn.NewInsert(). Model(&permissionsData). Exec(ctx) if err != nil { @@ -229,7 +229,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error { // Grant wildcard permission to admin role using Bun // First, get the IDs var wildcardPerm db.Permission - err = dbConn.NewSelect(). + err = conn.NewSelect(). Model(&wildcardPerm). Where("name = ?", "*"). Scan(ctx) @@ -242,7 +242,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error { RoleID: adminRole.ID, PermissionID: wildcardPerm.ID, } - _, err = dbConn.NewInsert(). + _, err = conn.NewInsert(). Model(adminRolePerms). On("CONFLICT (role_id, permission_id) DO NOTHING"). Exec(ctx) diff --git a/cmd/oslstats/migrations/20260210182212_add_leagues.go b/internal/db/migrations/20260210182212_add_leagues.go similarity index 76% rename from cmd/oslstats/migrations/20260210182212_add_leagues.go rename to internal/db/migrations/20260210182212_add_leagues.go index ec21071..53599d8 100644 --- a/cmd/oslstats/migrations/20260210182212_add_leagues.go +++ b/internal/db/migrations/20260210182212_add_leagues.go @@ -11,9 +11,9 @@ import ( func init() { Migrations.MustRegister( // UP migration - func(ctx context.Context, dbConn *bun.DB) error { + func(ctx context.Context, conn *bun.DB) error { // Add slap_version column to seasons table - _, err := dbConn.NewAddColumn(). + _, err := conn.NewAddColumn(). Model((*db.Season)(nil)). ColumnExpr("slap_version VARCHAR NOT NULL DEFAULT 'rebound'"). IfNotExists(). @@ -23,7 +23,7 @@ func init() { } // Create leagues table - _, err = dbConn.NewCreateTable(). + _, err = conn.NewCreateTable(). Model((*db.League)(nil)). Exec(ctx) if err != nil { @@ -31,15 +31,15 @@ func init() { } // Create season_leagues join table - _, err = dbConn.NewCreateTable(). + _, err = conn.NewCreateTable(). Model((*db.SeasonLeague)(nil)). Exec(ctx) return err }, // DOWN migration - func(ctx context.Context, dbConn *bun.DB) error { + func(ctx context.Context, conn *bun.DB) error { // Drop season_leagues join table first - _, err := dbConn.NewDropTable(). + _, err := conn.NewDropTable(). Model((*db.SeasonLeague)(nil)). IfExists(). Exec(ctx) @@ -48,7 +48,7 @@ func init() { } // Drop leagues table - _, err = dbConn.NewDropTable(). + _, err = conn.NewDropTable(). Model((*db.League)(nil)). IfExists(). Exec(ctx) @@ -57,7 +57,7 @@ func init() { } // Remove slap_version column from seasons table - _, err = dbConn.NewDropColumn(). + _, err = conn.NewDropColumn(). Model((*db.Season)(nil)). ColumnExpr("slap_version"). Exec(ctx) diff --git a/cmd/oslstats/migrations/20260211225253_teams.go b/internal/db/migrations/20260211225253_teams.go similarity index 73% rename from cmd/oslstats/migrations/20260211225253_teams.go rename to internal/db/migrations/20260211225253_teams.go index 2b9d908..1a9b99b 100644 --- a/cmd/oslstats/migrations/20260211225253_teams.go +++ b/internal/db/migrations/20260211225253_teams.go @@ -10,15 +10,15 @@ import ( func init() { Migrations.MustRegister( // UP migration - func(ctx context.Context, dbConn *bun.DB) error { + func(ctx context.Context, conn *bun.DB) error { // Add your migration code here - _, err := dbConn.NewCreateTable(). + _, err := conn.NewCreateTable(). Model((*db.Team)(nil)). Exec(ctx) if err != nil { return err } - _, err = dbConn.NewCreateTable(). + _, err = conn.NewCreateTable(). Model((*db.TeamParticipation)(nil)). Exec(ctx) if err != nil { @@ -27,16 +27,16 @@ func init() { return nil }, // DOWN migration - func(ctx context.Context, dbConn *bun.DB) error { + func(ctx context.Context, conn *bun.DB) error { // Add your rollback code here - _, err := dbConn.NewDropTable(). + _, err := conn.NewDropTable(). Model((*db.TeamParticipation)(nil)). IfExists(). Exec(ctx) if err != nil { return err } - _, err = dbConn.NewDropTable(). + _, err = conn.NewDropTable(). Model((*db.Team)(nil)). IfExists(). Exec(ctx) diff --git a/cmd/oslstats/migrations/20260213162216_missing_permissions.go b/internal/db/migrations/20260213162216_missing_permissions.go similarity index 100% rename from cmd/oslstats/migrations/20260213162216_missing_permissions.go rename to internal/db/migrations/20260213162216_missing_permissions.go diff --git a/cmd/oslstats/migrations/migrations.go b/internal/db/migrations/migrations.go similarity index 100% rename from cmd/oslstats/migrations/migrations.go rename to internal/db/migrations/migrations.go diff --git a/internal/db/paginate.go b/internal/db/paginate.go index bae2952..b9a7dae 100644 --- a/internal/db/paginate.go +++ b/internal/db/paginate.go @@ -1,8 +1,11 @@ package db import ( + "net/http" "strings" + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/validation" "github.com/uptrace/bun" ) @@ -19,6 +22,41 @@ type OrderOpts struct { Label string } +func GetPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request) (*PageOpts, bool) { + var getter validation.Getter + switch r.Method { + case "GET": + getter = validation.NewQueryGetter(r) + case "POST": + var ok bool + getter, ok = validation.ParseFormOrError(s, w, r) + if !ok { + return nil, false + } + default: + return nil, false + } + return getPageOpts(s, w, r, getter), true +} + +func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *PageOpts { + page := g.Int("page").Optional().Min(1).Value + perPage := g.Int("per_page").Optional().Min(1).Max(100).Value + order := g.String("order").TrimSpace().ToUpper().Optional().AllowedValues([]string{"ASC", "DESC"}).Value + orderBy := g.String("order_by").TrimSpace().Optional().ToLower().Value + valid := g.ValidateAndError(s, w, r) + if !valid { + return nil + } + pageOpts := &PageOpts{ + Page: page, + PerPage: perPage, + Order: bun.Order(order), + OrderBy: orderBy, + } + return pageOpts +} + func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.SelectQuery, *PageOpts) { if p == nil { p = new(PageOpts) diff --git a/internal/db/permission.go b/internal/db/permission.go index 10194ce..6321e0a 100644 --- a/internal/db/permission.go +++ b/internal/db/permission.go @@ -92,5 +92,5 @@ func DeletePermission(ctx context.Context, tx bun.Tx, id int) error { if id <= 0 { return errors.New("id must be positive") } - return DeleteWithProtection[Permission](ctx, tx, id) + return DeleteWithProtection[Permission](ctx, tx, id, nil) } diff --git a/internal/db/role.go b/internal/db/role.go index 9ed2957..0962c8d 100644 --- a/internal/db/role.go +++ b/internal/db/role.go @@ -25,13 +25,6 @@ type Role struct { Permissions []Permission `bun:"m2m:role_permissions,join:Role=Permission"` } -type RolePermission struct { - RoleID int `bun:",pk"` - Role *Role `bun:"rel:belongs-to,join:role_id=id"` - PermissionID int `bun:",pk"` - Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"` -} - func (r Role) isSystem() bool { return r.IsSystem } @@ -42,17 +35,12 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro if name == "" { return nil, errors.New("name cannot be empty") } - return GetByField[Role](tx, "name", name).Get(ctx) + return GetByField[Role](tx, "name", name).Relation("Permissions").Get(ctx) } // GetRoleByID queries the database for a role matching the given ID // Returns nil, nil if no role is found func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) { - return GetByID[Role](tx, id).Get(ctx) -} - -// GetRoleWithPermissions loads a role and all its permissions -func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) { return GetByID[Role](tx, id).Relation("Permissions").Get(ctx) } @@ -73,7 +61,7 @@ func GetRoles(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Role], } // CreateRole creates a new role -func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error { +func CreateRole(ctx context.Context, tx bun.Tx, role *Role, audit *AuditMeta) error { if role == nil { return errors.New("role cannot be nil") } @@ -81,6 +69,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error { err := Insert(tx, role). Returning("id"). + WithAudit(audit, nil). Exec(ctx) if err != nil { return errors.Wrap(err, "db.Insert") @@ -90,7 +79,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error { } // UpdateRole updates an existing role -func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error { +func UpdateRole(ctx context.Context, tx bun.Tx, role *Role, audit *AuditMeta) error { if role == nil { return errors.New("role cannot be nil") } @@ -100,6 +89,7 @@ func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error { err := Update(tx, role). WherePK(). + WithAudit(audit, nil). Exec(ctx) if err != nil { return errors.Wrap(err, "db.Update") @@ -110,7 +100,7 @@ func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error { // DeleteRole deletes a role (checks IsSystem protection) // Also cleans up join table entries in role_permissions and user_roles -func DeleteRole(ctx context.Context, tx bun.Tx, id int) error { +func DeleteRole(ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error { if id <= 0 { return errors.New("id must be positive") } @@ -146,47 +136,5 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int) error { } // Finally delete the role - return DeleteWithProtection[Role](ctx, tx, id) -} - -// AddPermissionToRole grants a permission to a role -func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error { - if roleID <= 0 { - return errors.New("roleID must be positive") - } - if permissionID <= 0 { - return errors.New("permissionID must be positive") - } - rolePerm := &RolePermission{ - RoleID: roleID, - PermissionID: permissionID, - } - err := Insert(tx, rolePerm). - ConflictNothing("role_id", "permission_id"). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "db.Insert") - } - - return nil -} - -// RemovePermissionFromRole revokes a permission from a role -func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error { - if roleID <= 0 { - return errors.New("roleID must be positive") - } - if permissionID <= 0 { - return errors.New("permissionID must be positive") - } - - err := DeleteItem[RolePermission](tx). - Where("role_id = ?", roleID). - Where("permission_id = ?", permissionID). - Delete(ctx) - if err != nil { - return errors.Wrap(err, "DeleteItem") - } - - return nil + return DeleteWithProtection[Role](ctx, tx, id, audit) } diff --git a/internal/db/rolepermission.go b/internal/db/rolepermission.go new file mode 100644 index 0000000..55b28d1 --- /dev/null +++ b/internal/db/rolepermission.go @@ -0,0 +1,99 @@ +package db + +import ( + "context" + "slices" + + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type RolePermission struct { + RoleID int `bun:",pk"` + Role *Role `bun:"rel:belongs-to,join:role_id=id"` + PermissionID int `bun:",pk"` + Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"` +} + +func (r *Role) UpdatePermissions(ctx context.Context, tx bun.Tx, newPermissionsIDs []int, audit *AuditMeta) error { + addPerms, removePerms, err := detectChangedPermissions(ctx, tx, r, newPermissionsIDs) + if err != nil { + return errors.Wrap(err, "detectChangedPermissions") + } + addedPerms := []string{} + removedPerms := []string{} + for _, perm := range addPerms { + rolePerm := &RolePermission{ + RoleID: r.ID, + PermissionID: perm.ID, + } + err := Insert(tx, rolePerm). + ConflictNothing("role_id", "permission_id"). + Exec(ctx) + if err != nil { + return errors.Wrap(err, "db.Insert") + } + addedPerms = append(addedPerms, perm.Name.String()) + } + for _, perm := range removePerms { + err := DeleteItem[RolePermission](tx). + Where("role_id = ?", r.ID). + Where("permission_id = ?", perm.ID). + Delete(ctx) + if err != nil { + return errors.Wrap(err, "DeleteItem") + } + removedPerms = append(removedPerms, perm.Name.String()) + } + // Log the permission changes + if len(addedPerms) > 0 || len(removedPerms) > 0 { + details := map[string]any{ + "role_name": string(r.Name), + } + if len(addedPerms) > 0 { + details["added_permissions"] = addedPerms + } + if len(removedPerms) > 0 { + details["removed_permissions"] = removedPerms + } + info := &AuditInfo{ + "roles.update_permissions", + "role", + r.ID, + details, + } + err = LogSuccess(ctx, tx, audit, info) + if err != nil { + return errors.Wrap(err, "LogSuccess") + } + } + return nil +} + +func detectChangedPermissions(ctx context.Context, tx bun.Tx, role *Role, permissionIDs []int) ([]*Permission, []*Permission, error) { + allPermissions, err := ListAllPermissions(ctx, tx) + if err != nil { + return nil, nil, errors.Wrap(err, "ListAllPermissions") + } + // Build map of current permissions + currentPermIDs := make(map[int]bool) + for _, perm := range role.Permissions { + currentPermIDs[perm.ID] = true + } + + var addedPerms []*Permission + var removedPerms []*Permission + + // Determine what to add and remove + for _, perm := range allPermissions { + hasNow := currentPermIDs[perm.ID] + shouldHave := slices.Contains(permissionIDs, perm.ID) + + if shouldHave && !hasNow { + addedPerms = append(addedPerms, perm) + } else if !shouldHave && hasNow { + removedPerms = append(removedPerms, perm) + } + } + return addedPerms, removedPerms, nil +} diff --git a/internal/db/season.go b/internal/db/season.go index 3df536f..3567bfe 100644 --- a/internal/db/season.go +++ b/internal/db/season.go @@ -25,15 +25,22 @@ type Season struct { Teams []Team `bun:"m2m:team_participations,join:Season=Team"` } -// NewSeason returns a new season. It does not add it to the database -func NewSeason(name, version, shortname string, start time.Time) *Season { +// NewSeason creats a new season +func NewSeason(ctx context.Context, tx bun.Tx, name, version, shortname string, + start time.Time, audit *AuditMeta, +) (*Season, error) { season := &Season{ Name: name, ShortName: strings.ToUpper(shortname), StartDate: start.Truncate(time.Hour * 24), SlapVersion: version, } - return season + err := Insert(tx, season). + WithAudit(audit, nil).Exec(ctx) + if err != nil { + return nil, errors.WithMessage(err, "db.Insert") + } + return season, nil } func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) { @@ -54,7 +61,9 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error } // Update updates the season struct. It does not insert to the database -func (s *Season) Update(version string, start, end, finalsStart, finalsEnd time.Time) { +func (s *Season) Update(ctx context.Context, tx bun.Tx, version string, + start, end, finalsStart, finalsEnd time.Time, audit *AuditMeta, +) error { s.SlapVersion = version s.StartDate = start.Truncate(time.Hour * 24) if !end.IsZero() { @@ -66,6 +75,9 @@ func (s *Season) Update(version string, start, end, finalsStart, finalsEnd time. if !finalsEnd.IsZero() { s.FinalsEndDate.Time = finalsEnd.Truncate(time.Hour * 24) } + return Update(tx, s).WherePK(). + Column("slap_version", "start_date", "end_date", "finals_start_date", "finals_end_date"). + WithAudit(audit, nil).Exec(ctx) } func (s *Season) MapTeamsToLeagues(ctx context.Context, tx bun.Tx) ([]LeagueWithTeams, error) { diff --git a/internal/db/seasonleague.go b/internal/db/seasonleague.go new file mode 100644 index 0000000..df9cdea --- /dev/null +++ b/internal/db/seasonleague.go @@ -0,0 +1,118 @@ +package db + +import ( + "context" + + "git.haelnorr.com/h/oslstats/internal/permissions" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type SeasonLeague struct { + SeasonID int `bun:",pk"` + Season *Season `bun:"rel:belongs-to,join:season_id=id"` + LeagueID int `bun:",pk"` + League *League `bun:"rel:belongs-to,join:league_id=id"` +} + +// GetSeasonLeague retrieves a specific season-league combination with teams +func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string) (*Season, *League, []*Team, error) { + if seasonShortName == "" { + return nil, nil, nil, errors.New("season short_name cannot be empty") + } + if leagueShortName == "" { + return nil, nil, nil, errors.New("league short_name cannot be empty") + } + + // Get the season + season, err := GetSeason(ctx, tx, seasonShortName) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "GetSeason") + } + + // Get the league + league, err := GetLeague(ctx, tx, leagueShortName) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "GetLeague") + } + if season == nil || league == nil || !season.HasLeague(league.ID) { + return nil, nil, nil, nil + } + + // Get all teams participating in this season+league + var teams []*Team + err = tx.NewSelect(). + Model(&teams). + Join("INNER JOIN team_participations AS tp ON tp.team_id = t.id"). + Where("tp.season_id = ? AND tp.league_id = ?", season.ID, league.ID). + Order("t.name ASC"). + Scan(ctx) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "tx.Select teams") + } + + return season, league, teams, nil +} + +func NewSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string, audit *AuditMeta) error { + season, err := GetSeason(ctx, tx, seasonShortName) + if err != nil { + return errors.Wrap(err, "GetSeason") + } + if season == nil { + return errors.New("season not found") + } + league, err := GetLeague(ctx, tx, leagueShortName) + if err != nil { + return errors.Wrap(err, "GetLeague") + } + if league == nil { + return errors.New("league not found") + } + if season.HasLeague(league.ID) { + return errors.New("league already added to season") + } + seasonLeague := &SeasonLeague{ + SeasonID: season.ID, + LeagueID: league.ID, + } + info := &AuditInfo{ + string(permissions.SeasonsAddLeague), + "season", + season.ID, + map[string]any{"league_id": league.ID}, + } + err = Insert(tx, seasonLeague).WithAudit(audit, info).Exec(ctx) + if err != nil { + return errors.Wrap(err, "db.Insert") + } + return nil +} + +func (s *Season) RemoveLeague(ctx context.Context, tx bun.Tx, leagueShortName string, audit *AuditMeta) error { + league, err := GetLeague(ctx, tx, leagueShortName) + if err != nil { + return errors.Wrap(err, "GetLeague") + } + if league == nil { + return errors.New("league not found") + } + if !s.HasLeague(league.ID) { + return errors.New("league not in season") + } + info := &AuditInfo{ + string(permissions.SeasonsRemoveLeague), + "season", + s.ID, + map[string]any{"league_id": league.ID}, + } + err = DeleteItem[SeasonLeague](tx). + Where("season_id = ?", s.ID). + Where("league_id = ?", league.ID). + WithAudit(audit, info). + Delete(ctx) + if err != nil { + return errors.Wrap(err, "db.DeleteItem") + } + return nil +} diff --git a/internal/db/setup.go b/internal/db/setup.go new file mode 100644 index 0000000..dd2b4b4 --- /dev/null +++ b/internal/db/setup.go @@ -0,0 +1,55 @@ +package db + +import ( + "database/sql" + "fmt" + "time" + + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" + "github.com/uptrace/bun/driver/pgdriver" +) + +type DB struct { + *bun.DB +} + +func (db *DB) Close() error { + return db.DB.Close() +} + +func (db *DB) RegisterModels() []any { + models := []any{ + (*RolePermission)(nil), + (*UserRole)(nil), + (*SeasonLeague)(nil), + (*TeamParticipation)(nil), + (*User)(nil), + (*DiscordToken)(nil), + (*Season)(nil), + (*League)(nil), + (*Team)(nil), + (*Role)(nil), + (*Permission)(nil), + (*AuditLog)(nil), + } + db.RegisterModel(models...) + return models +} + +func NewDB(cfg *Config) *DB { + dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s", + cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DB, cfg.SSL) + sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) + + sqldb.SetMaxOpenConns(25) + sqldb.SetMaxIdleConns(10) + sqldb.SetConnMaxLifetime(5 * time.Minute) + sqldb.SetConnMaxIdleTime(5 * time.Minute) + + db := &DB{ + bun.NewDB(sqldb, pgdialect.New()), + } + db.RegisterModels() + return db +} diff --git a/internal/db/team.go b/internal/db/team.go index b0cd882..2bd4aa9 100644 --- a/internal/db/team.go +++ b/internal/db/team.go @@ -19,13 +19,19 @@ type Team struct { Leagues []League `bun:"m2m:team_participations,join:Team=League"` } -type TeamParticipation struct { - SeasonID int `bun:",pk,unique:season_team"` - Season *Season `bun:"rel:belongs-to,join:season_id=id"` - LeagueID int `bun:",pk"` - League *League `bun:"rel:belongs-to,join:league_id=id"` - TeamID int `bun:",pk,unique:season_team"` - Team *Team `bun:"rel:belongs-to,join:team_id=id"` +func NewTeam(ctx context.Context, tx bun.Tx, name, shortName, altShortName, color string, audit *AuditMeta) (*Team, error) { + team := &Team{ + Name: name, + ShortName: shortName, + AltShortName: altShortName, + Color: color, + } + err := Insert(tx, team). + WithAudit(audit, nil).Exec(ctx) + if err != nil { + return nil, errors.Wrap(err, "db.Insert") + } + return team, nil } func ListTeams(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Team], error) { @@ -45,14 +51,23 @@ func GetTeam(ctx context.Context, tx bun.Tx, id int) (*Team, error) { return GetByID[Team](tx, id).Relation("Seasons").Relation("Leagues").Get(ctx) } -func TeamShortNamesUnique(ctx context.Context, tx bun.Tx, shortname, altshortname string) (bool, error) { +func TeamShortNamesUnique(ctx context.Context, tx bun.Tx, shortName, altShortName string) (bool, error) { // Check if this combination of short_name and alt_short_name exists count, err := tx.NewSelect(). Model((*Team)(nil)). - Where("short_name = ? AND alt_short_name = ?", shortname, altshortname). + Where("short_name = ? AND alt_short_name = ?", shortName, altShortName). Count(ctx) if err != nil { return false, errors.Wrap(err, "tx.Select") } return count == 0, nil } + +func (t *Team) InSeason(seasonID int) bool { + for _, season := range t.Seasons { + if season.ID == seasonID { + return true + } + } + return false +} diff --git a/internal/db/teamparticipation.go b/internal/db/teamparticipation.go new file mode 100644 index 0000000..51eb6dc --- /dev/null +++ b/internal/db/teamparticipation.go @@ -0,0 +1,67 @@ +package db + +import ( + "context" + + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type TeamParticipation struct { + SeasonID int `bun:",pk,unique:season_team"` + Season *Season `bun:"rel:belongs-to,join:season_id=id"` + LeagueID int `bun:",pk"` + League *League `bun:"rel:belongs-to,join:league_id=id"` + TeamID int `bun:",pk,unique:season_team"` + Team *Team `bun:"rel:belongs-to,join:team_id=id"` +} + +func NewTeamParticipation(ctx context.Context, tx bun.Tx, + seasonShortName, leagueShortName string, teamID int, audit *AuditMeta, +) (*Team, *Season, *League, error) { + season, err := GetSeason(ctx, tx, seasonShortName) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "GetSeason") + } + if season == nil { + return nil, nil, nil, errors.New("season not found") + } + league, err := GetLeague(ctx, tx, leagueShortName) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "GetLeague") + } + if league == nil { + return nil, nil, nil, errors.New("league not found") + } + if !season.HasLeague(league.ID) { + return nil, nil, nil, errors.New("league is not assigned to the season") + } + team, err := GetTeam(ctx, tx, teamID) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "GetTeam") + } + if team == nil { + return nil, nil, nil, errors.New("team not found") + } + if team.InSeason(season.ID) { + return nil, nil, nil, errors.New("team already in season") + } + participation := &TeamParticipation{ + SeasonID: season.ID, + LeagueID: league.ID, + TeamID: team.ID, + } + + info := &AuditInfo{ + "teams.join_season", + "team", + teamID, + map[string]any{"season_id": season.ID, "league_id": league.ID}, + } + err = Insert(tx, participation). + WithAudit(audit, info).Exec(ctx) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "db.Insert") + } + return team, season, league, nil +} diff --git a/internal/db/txhelpers.go b/internal/db/txhelpers.go index ab3d3ab..ad81787 100644 --- a/internal/db/txhelpers.go +++ b/internal/db/txhelpers.go @@ -22,16 +22,15 @@ var timeout = 15 * time.Second // WithReadTx executes a read-only transaction with automatic rollback // Returns true if successful, false if error was thrown to client -func WithReadTx( +func (db *DB) WithReadTx( s *hws.Server, w http.ResponseWriter, r *http.Request, - conn *bun.DB, fn TxFunc, ) bool { ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() - ok, err := withTx(ctx, conn, fn, false) + ok, err := db.withTx(ctx, fn, false) if err != nil { throw.InternalServiceError(s, w, r, "Database error", err) } @@ -41,31 +40,29 @@ func WithReadTx( // WithTxFailSilently executes a transaction with automatic rollback // Returns true if successful, false if error occured. // Does not throw any errors to the client. -func WithTxFailSilently( +func (db *DB) WithTxFailSilently( ctx context.Context, - conn *bun.DB, fn TxFuncSilent, ) error { fnc := func(ctx context.Context, tx bun.Tx) (bool, error) { err := fn(ctx, tx) return err == nil, err } - _, err := withTx(ctx, conn, fnc, true) + _, err := db.withTx(ctx, fnc, true) return err } // WithWriteTx executes a write transaction with automatic rollback on error // Commits only if fn returns nil. Returns true if successful. -func WithWriteTx( +func (db *DB) WithWriteTx( s *hws.Server, w http.ResponseWriter, r *http.Request, - conn *bun.DB, fn TxFunc, ) bool { ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() - ok, err := withTx(ctx, conn, fn, true) + ok, err := db.withTx(ctx, fn, true) if err != nil { throw.InternalServiceError(s, w, r, "Database error", err) } @@ -74,16 +71,15 @@ func WithWriteTx( // WithNotifyTx executes a transaction with notification-based error handling // Uses notifyInternalServiceError instead of throwInternalServiceError -func WithNotifyTx( +func (db *DB) WithNotifyTx( s *hws.Server, w http.ResponseWriter, r *http.Request, - conn *bun.DB, fn TxFunc, ) bool { ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() - ok, err := withTx(ctx, conn, fn, true) + ok, err := db.withTx(ctx, fn, true) if err != nil { notify.InternalServiceError(s, w, r, "Database error", err) } @@ -91,13 +87,12 @@ func WithNotifyTx( } // withTx executes a transaction with automatic rollback on error -func withTx( +func (db *DB) withTx( ctx context.Context, - conn *bun.DB, fn TxFunc, write bool, ) (bool, error) { - tx, err := conn.BeginTx(ctx, nil) + tx, err := db.BeginTx(ctx, nil) if err != nil { return false, errors.Wrap(err, "conn.BeginTx") } diff --git a/internal/db/update.go b/internal/db/update.go index 2f26cb5..900ef06 100644 --- a/internal/db/update.go +++ b/internal/db/update.go @@ -2,19 +2,18 @@ package db import ( "context" - "net/http" "github.com/pkg/errors" "github.com/uptrace/bun" ) type updater[T any] struct { - tx bun.Tx - q *bun.UpdateQuery - model *T - columns []string - auditCallback AuditCallback - auditRequest *http.Request + tx bun.Tx + q *bun.UpdateQuery + model *T + columns []string + audit *AuditMeta + auditInfo *AuditInfo } // Update creates an updater for a model @@ -69,11 +68,10 @@ func (u *updater[T]) Set(query string, args ...any) *updater[T] { } // WithAudit enables audit logging for this update operation -// The callback will be invoked after successful update with auto-generated audit info -// If the callback returns an error, the transaction will be rolled back -func (u *updater[T]) WithAudit(r *http.Request, callback AuditCallback) *updater[T] { - u.auditRequest = r - u.auditCallback = callback +// If the provided *AuditInfo is nil, will use reflection to automatically work out the details +func (u *updater[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *updater[T] { + u.audit = meta + u.auditInfo = info return u } @@ -82,7 +80,7 @@ func (u *updater[T]) WithAudit(r *http.Request, callback AuditCallback) *updater func (u *updater[T]) Exec(ctx context.Context) error { // Build audit details BEFORE update (captures changed fields) var details map[string]any - if u.auditCallback != nil && len(u.columns) > 0 { + if u.audit != nil && len(u.columns) > 0 { details = extractChangedFields(u.model, u.columns) } @@ -93,21 +91,22 @@ func (u *updater[T]) Exec(ctx context.Context) error { } // Handle audit logging if enabled - if u.auditCallback != nil && u.auditRequest != nil { - tableName := extractTableName[T]() - resourceType := extractResourceType(tableName) - action := buildAction(resourceType, "update") + if u.audit != nil { + if u.auditInfo == nil { + tableName := extractTableName[T]() + resourceType := extractResourceType(tableName) + action := buildAction(resourceType, "update") - info := &AuditInfo{ - Action: action, - ResourceType: resourceType, - ResourceID: extractPrimaryKey(u.model), - Details: details, // Changed fields only + u.auditInfo = &AuditInfo{ + Action: action, + ResourceType: resourceType, + ResourceID: extractPrimaryKey(u.model), + Details: details, // Changed fields only + } } - - // Call audit callback - if it fails, return error to trigger rollback - if err := u.auditCallback(ctx, u.tx, info, u.auditRequest); err != nil { - return errors.Wrap(err, "audit.callback") + err = LogSuccess(ctx, u.tx, u.audit, u.auditInfo) + if err != nil { + return errors.Wrap(err, "LogSuccess") } } diff --git a/internal/db/user.go b/internal/db/user.go index a46c78c..4826f31 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -12,8 +12,6 @@ import ( "github.com/uptrace/bun" ) -var CurrentUser hwsauth.ContextLoader[*User] - type User struct { bun.BaseModel `bun:"table:users,alias:u"` @@ -29,8 +27,10 @@ func (u *User) GetID() int { return u.ID } +var CurrentUser hwsauth.ContextLoader[*User] + // CreateUser creates a new user with the given username and password -func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) { +func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User, audit *AuditMeta) (*User, error) { if discorduser == nil { return nil, errors.New("user cannot be nil") } @@ -39,8 +39,10 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di CreatedAt: time.Now().Unix(), DiscordID: discorduser.ID, } + audit.u = user err := Insert(tx, user). + WithAudit(audit, nil). Returning("id"). Exec(ctx) if err != nil { diff --git a/internal/db/userrole.go b/internal/db/userrole.go index ccd70bb..20c1a47 100644 --- a/internal/db/userrole.go +++ b/internal/db/userrole.go @@ -3,6 +3,7 @@ package db import ( "context" + "git.haelnorr.com/h/oslstats/internal/permissions" "git.haelnorr.com/h/oslstats/internal/roles" "github.com/pkg/errors" "github.com/uptrace/bun" @@ -16,7 +17,7 @@ type UserRole struct { } // AssignRole grants a role to a user -func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { +func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, audit *AuditMeta) error { if userID <= 0 { return errors.New("userID must be positive") } @@ -28,8 +29,20 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { UserID: userID, RoleID: roleID, } + details := map[string]any{ + "action": "grant", + "role_id": roleID, + } + info := &AuditInfo{ + string(permissions.UsersManageRoles), + "user", + userID, + details, + } err := Insert(tx, userRole). - ConflictNothing("user_id", "role_id").Exec(ctx) + ConflictNothing("user_id", "role_id"). + WithAudit(audit, info). + Exec(ctx) if err != nil { return errors.Wrap(err, "db.Insert") } @@ -38,7 +51,7 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { } // RevokeRole removes a role from a user -func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { +func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int, audit *AuditMeta) error { if userID <= 0 { return errors.New("userID must be positive") } @@ -46,9 +59,20 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { return errors.New("roleID must be positive") } + details := map[string]any{ + "action": "revoke", + "role_id": roleID, + } + info := &AuditInfo{ + string(permissions.UsersManageRoles), + "user", + userID, + details, + } err := DeleteItem[UserRole](tx). Where("user_id = ?", userID). Where("role_id = ?", roleID). + WithAudit(audit, info). Delete(ctx) if err != nil { return errors.Wrap(err, "DeleteItem") diff --git a/internal/embedfs/embedfs.go b/internal/embedfs/embedfs.go index 4e9720d..3172721 100644 --- a/internal/embedfs/embedfs.go +++ b/internal/embedfs/embedfs.go @@ -12,10 +12,10 @@ import ( var embeddedFiles embed.FS // GetEmbeddedFS gets the embedded files -func GetEmbeddedFS() (fs.FS, error) { +func GetEmbeddedFS() (*fs.FS, error) { subFS, err := fs.Sub(embeddedFiles, "web") if err != nil { return nil, errors.Wrap(err, "fs.Sub") } - return subFS, nil + return &subFS, nil } diff --git a/internal/handlers/admin_audit.go b/internal/handlers/admin_audit.go index c00a2b9..c5a2c88 100644 --- a/internal/handlers/admin_audit.go +++ b/internal/handlers/admin_audit.go @@ -17,10 +17,10 @@ import ( ) // AdminAuditLogsPage renders the full admin dashboard page with audit logs section (GET request) -func AdminAuditLogsPage(s *hws.Server, conn *bun.DB) http.Handler { +func AdminAuditLogsPage(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pageOpts := pageOptsFromQuery(s, w, r) - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } @@ -29,7 +29,7 @@ func AdminAuditLogsPage(s *hws.Server, conn *bun.DB) http.Handler { var actions []string var resourceTypes []string - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error // Get filters from query @@ -73,10 +73,10 @@ func AdminAuditLogsPage(s *hws.Server, conn *bun.DB) http.Handler { } // AdminAuditLogsList shows the full audit logs list with filters (POST request for HTMX) -func AdminAuditLogsList(s *hws.Server, conn *bun.DB) http.Handler { +func AdminAuditLogsList(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pageOpts := pageOptsFromForm(s, w, r) - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } @@ -85,7 +85,7 @@ func AdminAuditLogsList(s *hws.Server, conn *bun.DB) http.Handler { var actions []string var resourceTypes []string - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error // Get filters from form @@ -129,16 +129,16 @@ func AdminAuditLogsList(s *hws.Server, conn *bun.DB) http.Handler { } // AdminAuditLogsFilter returns only the results container (table + pagination) for HTMX updates -func AdminAuditLogsFilter(s *hws.Server, conn *bun.DB) http.Handler { +func AdminAuditLogsFilter(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pageOpts := pageOptsFromForm(s, w, r) - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } var logs *db.List[db.AuditLog] - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error // Get filters from form @@ -164,7 +164,7 @@ func AdminAuditLogsFilter(s *hws.Server, conn *bun.DB) http.Handler { } // AdminAuditLogDetail shows details for a single audit log entry -func AdminAuditLogDetail(s *hws.Server, conn *bun.DB) http.Handler { +func AdminAuditLogDetail(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get ID from path idStr := r.PathValue("id") @@ -181,7 +181,7 @@ func AdminAuditLogDetail(s *hws.Server, conn *bun.DB) http.Handler { var log *db.AuditLog - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error log, err = db.GetAuditLogByID(ctx, tx, id) if err != nil { diff --git a/internal/handlers/admin_dashboard.go b/internal/handlers/admin_dashboard.go index 5639199..66dc7fe 100644 --- a/internal/handlers/admin_dashboard.go +++ b/internal/handlers/admin_dashboard.go @@ -12,10 +12,10 @@ import ( ) // AdminDashboard renders the full admin dashboard page (defaults to users section) -func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler { +func AdminDashboard(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var users *db.List[db.User] - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error users, err = db.GetUsersWithRoles(ctx, tx, nil) if err != nil { diff --git a/internal/handlers/admin_permissions.go b/internal/handlers/admin_permissions.go deleted file mode 100644 index 40714c8..0000000 --- a/internal/handlers/admin_permissions.go +++ /dev/null @@ -1,25 +0,0 @@ -package handlers - -import ( - "net/http" - - "git.haelnorr.com/h/golib/hws" - adminview "git.haelnorr.com/h/oslstats/internal/view/adminview" - "github.com/uptrace/bun" -) - -// AdminPermissionsPage renders the full admin dashboard page with permissions section -func AdminPermissionsPage(s *hws.Server, conn *bun.DB) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // TODO: Load permissions from database - renderSafely(adminview.PermissionsPage(), s, r, w) - }) -} - -// AdminPermissionsList shows all permissions (HTMX content replacement) -func AdminPermissionsList(s *hws.Server, conn *bun.DB) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // TODO: Load permissions from database - renderSafely(adminview.PermissionsList(), s, r, w) - }) -} diff --git a/internal/handlers/admin_preview_role.go b/internal/handlers/admin_preview_role.go index 4f07199..3723972 100644 --- a/internal/handlers/admin_preview_role.go +++ b/internal/handlers/admin_preview_role.go @@ -6,7 +6,6 @@ import ( "strconv" "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/roles" @@ -16,7 +15,7 @@ import ( ) // AdminPreviewRoleStart starts preview mode for a specific role -func AdminPreviewRoleStart(s *hws.Server, conn *bun.DB, cfg *config.Config) http.Handler { +func AdminPreviewRoleStart(s *hws.Server, conn *db.DB, ssl bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get role ID from URL roleIDStr := r.PathValue("id") @@ -28,7 +27,7 @@ func AdminPreviewRoleStart(s *hws.Server, conn *bun.DB, cfg *config.Config) http // Verify role exists and is not admin var role *db.Role - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error role, err = db.GetRoleByID(ctx, tx, roleID) if err != nil { @@ -49,7 +48,7 @@ func AdminPreviewRoleStart(s *hws.Server, conn *bun.DB, cfg *config.Config) http } // Set preview role cookie - rbac.SetPreviewRoleCookie(w, roleID, cfg.HWSAuth.SSL) + rbac.SetPreviewRoleCookie(w, roleID, ssl) // Redirect to home page http.Redirect(w, r, "/", http.StatusSeeOther) diff --git a/internal/handlers/admin_roles.go b/internal/handlers/admin_roles.go index 7d68014..53e6cbd 100644 --- a/internal/handlers/admin_roles.go +++ b/internal/handlers/admin_roles.go @@ -8,10 +8,8 @@ import ( "time" "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/roles" - "git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/validation" adminview "git.haelnorr.com/h/oslstats/internal/view/adminview" "github.com/pkg/errors" @@ -19,20 +17,15 @@ import ( ) // AdminRoles renders the full admin dashboard page with roles section -func AdminRoles(s *hws.Server, conn *bun.DB) http.Handler { +func AdminRoles(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var pageOpts *db.PageOpts - if r.Method == "GET" { - pageOpts = pageOptsFromQuery(s, w, r) - } else { - pageOpts = pageOptsFromForm(s, w, r) - } - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } var rolesList *db.List[db.Role] - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error rolesList, err = db.GetRoles(ctx, tx, pageOpts) if err != nil { @@ -59,7 +52,7 @@ func AdminRoleCreateForm(s *hws.Server) http.Handler { } // AdminRoleCreate creates a new role -func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.Handler { +func AdminRoleCreate(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { getter, ok := validation.ParseFormOrNotify(s, w, r) if !ok { @@ -74,14 +67,14 @@ func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H return } - pageOpts := pageOptsFromForm(s, w, r) - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } var rolesList *db.List[db.Role] var newRole *db.Role - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { newRole = &db.Role{ Name: roles.Role(name), DisplayName: displayName, @@ -90,9 +83,9 @@ func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H CreatedAt: time.Now().Unix(), } - err := db.Insert(tx, newRole).WithAudit(r, audit.Callback()).Exec(ctx) + err := db.CreateRole(ctx, tx, newRole, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.Insert") + return false, errors.Wrap(err, "db.CreateRole") } rolesList, err = db.GetRoles(ctx, tx, pageOpts) @@ -110,7 +103,7 @@ func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H } // AdminRoleManage shows the role management modal with details and actions -func AdminRoleManage(s *hws.Server, conn *bun.DB) http.Handler { +func AdminRoleManage(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { roleIDStr := r.PathValue("id") roleID, err := strconv.Atoi(roleIDStr) @@ -120,7 +113,7 @@ func AdminRoleManage(s *hws.Server, conn *bun.DB) http.Handler { } var role *db.Role - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error role, err = db.GetRoleByID(ctx, tx, roleID) if err != nil { @@ -139,7 +132,7 @@ func AdminRoleManage(s *hws.Server, conn *bun.DB) http.Handler { } // AdminRoleDeleteConfirm shows the delete confirmation dialog -func AdminRoleDeleteConfirm(s *hws.Server, conn *bun.DB) http.Handler { +func AdminRoleDeleteConfirm(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { roleIDStr := r.PathValue("id") roleID, err := strconv.Atoi(roleIDStr) @@ -149,7 +142,7 @@ func AdminRoleDeleteConfirm(s *hws.Server, conn *bun.DB) http.Handler { } var role *db.Role - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error role, err = db.GetRoleByID(ctx, tx, roleID) if err != nil { @@ -168,7 +161,7 @@ func AdminRoleDeleteConfirm(s *hws.Server, conn *bun.DB) http.Handler { } // AdminRoleDelete deletes a role -func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.Handler { +func AdminRoleDelete(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { roleIDStr := r.PathValue("id") roleID, err := strconv.Atoi(roleIDStr) @@ -177,13 +170,13 @@ func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H return } - pageOpts := pageOptsFromForm(s, w, r) - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } var rolesList *db.List[db.Role] - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { // First check if role exists and get its details role, err := db.GetRoleByID(ctx, tx, roleID) if err != nil { @@ -199,9 +192,9 @@ func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H } // Delete the role with audit logging - err = db.DeleteByID[db.Role](tx, roleID).WithAudit(r, audit.Callback()).Delete(ctx) + err = db.DeleteRole(ctx, tx, roleID, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.DeleteByID") + return false, errors.Wrap(err, "db.DeleteRole") } // Reload roles @@ -220,7 +213,7 @@ func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H } // AdminRolePermissionsModal shows the permissions management modal for a role -func AdminRolePermissionsModal(s *hws.Server, conn *bun.DB) http.Handler { +func AdminRolePermissionsModal(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { roleIDStr := r.PathValue("id") roleID, err := strconv.Atoi(roleIDStr) @@ -234,12 +227,12 @@ func AdminRolePermissionsModal(s *hws.Server, conn *bun.DB) http.Handler { var groupedPerms []adminview.PermissionsByResource var rolePermIDs map[int]bool - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { // Load role with permissions var err error - role, err = db.GetRoleWithPermissions(ctx, tx, roleID) + role, err = db.GetRoleByID(ctx, tx, roleID) if err != nil { - return false, errors.Wrap(err, "db.GetRoleWithPermissions") + return false, errors.Wrap(err, "db.GetRoleByID") } if role == nil { return false, errors.New("role not found") @@ -283,7 +276,7 @@ func AdminRolePermissionsModal(s *hws.Server, conn *bun.DB) http.Handler { } // AdminRolePermissionsUpdate updates the permissions for a role -func AdminRolePermissionsUpdate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.Handler { +func AdminRolePermissionsUpdate(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { roleIDStr := r.PathValue("id") roleID, err := strconv.Atoi(roleIDStr) @@ -291,7 +284,6 @@ func AdminRolePermissionsUpdate(s *hws.Server, conn *bun.DB, audit *auditlog.Log w.WriteHeader(http.StatusBadRequest) return } - user := db.CurrentUser(r.Context()) getter, ok := validation.ParseFormOrNotify(s, w, r) if !ok { @@ -304,80 +296,24 @@ func AdminRolePermissionsUpdate(s *hws.Server, conn *bun.DB, audit *auditlog.Log return } - selectedPermIDs := make(map[int]bool) - for _, id := range permissionIDs { - selectedPermIDs[id] = true - } - - pageOpts := pageOptsFromForm(s, w, r) - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } var rolesList *db.List[db.Role] - if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { - // Get role with current permissions - role, err := db.GetRoleWithPermissions(ctx, tx, roleID) + if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { + role, err := db.GetRoleByID(ctx, tx, roleID) if err != nil { - return false, errors.Wrap(err, "db.GetRoleWithPermissions") + return false, errors.Wrap(err, "db.GetRoleByID") } if role == nil { - throw.NotFound(s, w, r, "Role not found") + w.WriteHeader(http.StatusBadRequest) return false, nil } - - // Get all permissions to know what exists - allPermissions, err := db.ListAllPermissions(ctx, tx) + err = role.UpdatePermissions(ctx, tx, permissionIDs, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.ListAllPermissions") - } - - // Build map of current permissions - currentPermIDs := make(map[int]bool) - for _, perm := range role.Permissions { - currentPermIDs[perm.ID] = true - } - - var addedPerms []string - var removedPerms []string - - // Determine what to add and remove - for _, perm := range allPermissions { - hasNow := currentPermIDs[perm.ID] - shouldHave := selectedPermIDs[perm.ID] - - if shouldHave && !hasNow { - // Add permission - err := db.AddPermissionToRole(ctx, tx, roleID, perm.ID) - if err != nil { - return false, errors.Wrap(err, "db.AddPermissionToRole") - } - addedPerms = append(addedPerms, string(perm.Name)) - } else if !shouldHave && hasNow { - // Remove permission - err := db.RemovePermissionFromRole(ctx, tx, roleID, perm.ID) - if err != nil { - return false, errors.Wrap(err, "db.RemovePermissionFromRole") - } - removedPerms = append(removedPerms, string(perm.Name)) - } - } - - // Log the permission changes - if len(addedPerms) > 0 || len(removedPerms) > 0 { - details := map[string]any{ - "role_name": string(role.Name), - } - if len(addedPerms) > 0 { - details["added_permissions"] = addedPerms - } - if len(removedPerms) > 0 { - details["removed_permissions"] = removedPerms - } - err = audit.LogSuccess(ctx, tx, user, "update", "role_permissions", roleID, details, r) - if err != nil { - return false, errors.Wrap(err, "audit.LogSuccess") - } + return false, errors.Wrap(err, "role.UpdatePermissions") } // Reload roles diff --git a/internal/handlers/admin_users.go b/internal/handlers/admin_users.go index e1827de..c8af087 100644 --- a/internal/handlers/admin_users.go +++ b/internal/handlers/admin_users.go @@ -12,20 +12,15 @@ import ( ) // AdminUsersPage renders the full admin dashboard page with users section -func AdminUsersPage(s *hws.Server, conn *bun.DB) http.Handler { +func AdminUsersPage(s *hws.Server, conn *db.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var pageOpts *db.PageOpts - if r.Method == "GET" { - pageOpts = pageOptsFromQuery(s, w, r) - } else { - pageOpts = pageOptsFromForm(s, w, r) - } - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } var users *db.List[db.User] - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error users, err = db.GetUsersWithRoles(ctx, tx, pageOpts) if err != nil { diff --git a/internal/handlers/auth_helpers.go b/internal/handlers/auth_helpers.go index 7511bc1..5aac97b 100644 --- a/internal/handlers/auth_helpers.go +++ b/internal/handlers/auth_helpers.go @@ -47,7 +47,7 @@ func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error } // Grant admin role - err = db.AssignRole(ctx, tx, user.ID, adminRole.ID) + err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil) if err != nil { return errors.Wrap(err, "db.AssignRole") } diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go index ff310ee..99e016b 100644 --- a/internal/handlers/callback.go +++ b/internal/handlers/callback.go @@ -22,7 +22,7 @@ import ( func Callback( s *hws.Server, auth *hwsauth.Authenticator[*db.User, bun.Tx], - conn *bun.DB, + conn *db.DB, cfg *config.Config, store *store.Store, discordAPI *discord.APIClient, @@ -70,7 +70,7 @@ func Callback( switch data { case "login": var redirect func() - if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI) if err != nil { throw.InternalServiceError(s, w, r, "OAuth login failed", err) diff --git a/internal/handlers/isunique.go b/internal/handlers/isunique.go index c75daa6..4f4d4d7 100644 --- a/internal/handlers/isunique.go +++ b/internal/handlers/isunique.go @@ -15,7 +15,7 @@ import ( // Returns 200 OK if unique, 409 Conflict if not unique func IsUnique( s *hws.Server, - conn *bun.DB, + conn *db.DB, model any, field string, ) http.Handler { @@ -31,7 +31,7 @@ func IsUnique( return } unique := false - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { unique, err = db.IsUnique(ctx, tx, model, field, value) if err != nil { return false, errors.Wrap(err, "db.IsUnique") diff --git a/internal/handlers/leagues_list.go b/internal/handlers/leagues_list.go index 489a238..ddccaec 100644 --- a/internal/handlers/leagues_list.go +++ b/internal/handlers/leagues_list.go @@ -14,11 +14,11 @@ import ( func LeaguesList( s *hws.Server, - conn *bun.DB, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var leagues []*db.League - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error leagues, err = db.GetLeagues(ctx, tx) if err != nil { diff --git a/internal/handlers/leagues_new.go b/internal/handlers/leagues_new.go index 967b515..07fe0fa 100644 --- a/internal/handlers/leagues_new.go +++ b/internal/handlers/leagues_new.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "github.com/uptrace/bun" - "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/validation" @@ -18,20 +17,15 @@ import ( func NewLeague( s *hws.Server, - conn *bun.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "GET" { - renderSafely(leaguesview.NewPage(), s, r, w) - return - } + renderSafely(leaguesview.NewPage(), s, r, w) }) } func NewLeagueSubmit( s *hws.Server, - conn *bun.DB, - audit *auditlog.Logger, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { getter, ok := validation.ParseFormOrNotify(s, w, r) @@ -53,7 +47,7 @@ func NewLeagueSubmit( nameUnique := false shortNameUnique := false var league *db.League - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error nameUnique, err = db.IsUnique(ctx, tx, (*db.League)(nil), "name", name) if err != nil { @@ -66,14 +60,9 @@ func NewLeagueSubmit( if !nameUnique || !shortNameUnique { return true, nil } - league = &db.League{ - Name: name, - ShortName: shortname, - Description: description, - } - err = db.Insert(tx, league).WithAudit(r, audit.Callback()).Exec(ctx) + league, err = db.NewLeague(ctx, tx, name, shortname, description, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.Insert") + return false, errors.Wrap(err, "db.NewLeague") } return true, nil }); !ok { diff --git a/internal/handlers/login.go b/internal/handlers/login.go index 7944df0..5097651 100644 --- a/internal/handlers/login.go +++ b/internal/handlers/login.go @@ -7,9 +7,9 @@ import ( "git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/hws" "github.com/pkg/errors" - "github.com/uptrace/bun" "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/store" @@ -19,7 +19,7 @@ import ( func Login( s *hws.Server, - conn *bun.DB, + conn *db.DB, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient, diff --git a/internal/handlers/logout.go b/internal/handlers/logout.go index 93ae282..40f2cc9 100644 --- a/internal/handlers/logout.go +++ b/internal/handlers/logout.go @@ -16,7 +16,7 @@ import ( func Logout( s *hws.Server, auth *hwsauth.Authenticator[*db.User, bun.Tx], - conn *bun.DB, + conn *db.DB, discordAPI *discord.APIClient, ) http.Handler { return http.HandlerFunc( @@ -27,7 +27,7 @@ func Logout( w.Header().Set("HX-Redirect", "/") return } - if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { token, err := user.DeleteDiscordTokens(ctx, tx) if err != nil { return false, errors.Wrap(err, "user.DeleteDiscordTokens") diff --git a/internal/handlers/pageopt_helpers.go b/internal/handlers/pageopt_helpers.go deleted file mode 100644 index f73c1be..0000000 --- a/internal/handlers/pageopt_helpers.go +++ /dev/null @@ -1,45 +0,0 @@ -package handlers - -import ( - "net/http" - - "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/db" - "git.haelnorr.com/h/oslstats/internal/validation" - "github.com/uptrace/bun" -) - -// pageOptsFromForm calls r.ParseForm and gets the pageOpts from the formdata. -// It renders a Bad Request error page on fail -// PageOpts will be nil on fail -func pageOptsFromForm(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts { - getter, ok := validation.ParseFormOrError(s, w, r) - if !ok { - return nil - } - return getPageOpts(s, w, r, getter) -} - -// pageOptsFromQuery gets the pageOpts from the request query and renders a Bad Request error page on fail -// PageOpts will be nil on fail -func pageOptsFromQuery(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts { - return getPageOpts(s, w, r, validation.NewQueryGetter(r)) -} - -func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *db.PageOpts { - page := g.Int("page").Optional().Min(1).Value - perPage := g.Int("per_page").Optional().Min(1).Max(100).Value - order := g.String("order").TrimSpace().ToUpper().Optional().AllowedValues([]string{"ASC", "DESC"}).Value - orderBy := g.String("order_by").TrimSpace().Optional().ToLower().Value - valid := g.ValidateAndError(s, w, r) - if !valid { - return nil - } - pageOpts := &db.PageOpts{ - Page: page, - PerPage: perPage, - Order: bun.Order(order), - OrderBy: orderBy, - } - return pageOpts -} diff --git a/internal/handlers/register.go b/internal/handlers/register.go index 0ef308c..de76f67 100644 --- a/internal/handlers/register.go +++ b/internal/handlers/register.go @@ -20,7 +20,7 @@ import ( func Register( s *hws.Server, auth *hwsauth.Authenticator[*db.User, bun.Tx], - conn *bun.DB, + conn *db.DB, cfg *config.Config, store *store.Store, ) http.Handler { @@ -55,7 +55,7 @@ func Register( username := r.FormValue("username") unique := false var user *db.User - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username) if err != nil { return false, errors.Wrap(err, "db.IsUsernameUnique") @@ -63,7 +63,7 @@ func Register( if !unique { return true, nil } - user, err = db.CreateUser(ctx, tx, username, details.DiscordUser) + user, err = db.CreateUser(ctx, tx, username, details.DiscordUser, db.NewAudit(r, nil)) if err != nil { return false, errors.Wrap(err, "db.CreateUser") } diff --git a/internal/handlers/season_detail.go b/internal/handlers/season_detail.go index 99bc049..6be206b 100644 --- a/internal/handlers/season_detail.go +++ b/internal/handlers/season_detail.go @@ -14,14 +14,14 @@ import ( func SeasonPage( s *hws.Server, - conn *bun.DB, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { seasonStr := r.PathValue("season_short_name") var season *db.Season var leaguesWithTeams []db.LeagueWithTeams - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error season, err = db.GetSeason(ctx, tx, seasonStr) if err != nil { diff --git a/internal/handlers/season_edit.go b/internal/handlers/season_edit.go index 38c66ee..3a0a898 100644 --- a/internal/handlers/season_edit.go +++ b/internal/handlers/season_edit.go @@ -6,7 +6,6 @@ import ( "net/http" "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/throw" @@ -19,13 +18,13 @@ import ( func SeasonEditPage( s *hws.Server, - conn *bun.DB, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { seasonStr := r.PathValue("season_short_name") var season *db.Season var allLeagues []*db.League - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error season, err = db.GetSeason(ctx, tx, seasonStr) if err != nil { @@ -49,8 +48,7 @@ func SeasonEditPage( func SeasonEditSubmit( s *hws.Server, - conn *bun.DB, - audit *auditlog.Logger, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { seasonStr := r.PathValue("season_short_name") @@ -77,7 +75,7 @@ func SeasonEditSubmit( } var season *db.Season - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error season, err = db.GetSeason(ctx, tx, seasonStr) if err != nil { @@ -86,12 +84,9 @@ func SeasonEditSubmit( if season == nil { return false, errors.New("season does not exist") } - season.Update(version, start, end, finalsStart, finalsEnd) - err = db.Update(tx, season).WherePK(). - Column("slap_version", "start_date", "end_date", "finals_start_date", "finals_end_date"). - WithAudit(r, audit.Callback()).Exec(ctx) + err = season.Update(ctx, tx, version, start, end, finalsStart, finalsEnd, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.Update") + return false, errors.Wrap(err, "season.Update") } return true, nil }); !ok { diff --git a/internal/handlers/season_league_add_team.go b/internal/handlers/season_league_add_team.go index 2f57082..c6146c5 100644 --- a/internal/handlers/season_league_add_team.go +++ b/internal/handlers/season_league_add_team.go @@ -6,7 +6,6 @@ import ( "net/http" "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/validation" @@ -16,8 +15,7 @@ import ( func SeasonLeagueAddTeam( s *hws.Server, - conn *bun.DB, - audit *auditlog.Logger, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { seasonStr := r.PathValue("season_short_name") @@ -36,73 +34,12 @@ func SeasonLeagueAddTeam( var league *db.League var team *db.Team - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error - - // Get season - season, err = db.GetSeason(ctx, tx, seasonStr) + team, season, league, err = db.NewTeamParticipation(ctx, tx, seasonStr, leagueStr, teamID, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.GetSeason") + return false, errors.Wrap(err, "db.NewTeamParticipation") } - if season == nil { - notify.Warn(s, w, r, "Not Found", "Season not found.", nil) - return false, nil - } - - // Get league - league, err = db.GetLeague(ctx, tx, leagueStr) - if err != nil { - return false, errors.Wrap(err, "db.GetLeague") - } - if league == nil { - notify.Warn(s, w, r, "Not Found", "League not found.", nil) - return false, nil - } - - if !season.HasLeague(league.ID) { - notify.Warn(s, w, r, "Invalid League", "This league is not associated with this season.", nil) - return false, nil - } - - // Get team - team, err = db.GetTeam(ctx, tx, teamID) - if err != nil { - return false, errors.Wrap(err, "db.GetTeam") - } - if team == nil { - notify.Warn(s, w, r, "Not Found", "Team not found.", nil) - return false, nil - } - - // Check if team is already in this season (in any league) - var tpCount int - tpCount, err = tx.NewSelect(). - Model((*db.TeamParticipation)(nil)). - Where("season_id = ? AND team_id = ?", season.ID, team.ID). - Count(ctx) - if err != nil { - return false, errors.Wrap(err, "tx.NewSelect") - } - if tpCount > 0 { - notify.Warn(s, w, r, "Already In Season", fmt.Sprintf( - "Team '%s' is already participating in this season.", - team.Name, - ), nil) - return false, nil - } - - // Add team to league - participation := &db.TeamParticipation{ - SeasonID: season.ID, - LeagueID: league.ID, - TeamID: team.ID, - } - - err = db.Insert(tx, participation).WithAudit(r, audit.Callback()).Exec(ctx) - if err != nil { - return false, errors.Wrap(err, "db.Insert") - } - return true, nil }); !ok { return diff --git a/internal/handlers/season_league_detail.go b/internal/handlers/season_league_detail.go index 0aa89cd..2725a34 100644 --- a/internal/handlers/season_league_detail.go +++ b/internal/handlers/season_league_detail.go @@ -14,7 +14,7 @@ import ( func SeasonLeaguePage( s *hws.Server, - conn *bun.DB, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { seasonStr := r.PathValue("season_short_name") @@ -25,7 +25,7 @@ func SeasonLeaguePage( var teams []*db.Team var allTeams []*db.Team - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error season, league, teams, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr) if err != nil { diff --git a/internal/handlers/season_leagues.go b/internal/handlers/season_leagues.go index 0079e45..5411e3f 100644 --- a/internal/handlers/season_leagues.go +++ b/internal/handlers/season_leagues.go @@ -8,7 +8,6 @@ import ( "github.com/pkg/errors" "github.com/uptrace/bun" - "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/view/seasonsview" @@ -16,8 +15,7 @@ import ( func SeasonAddLeague( s *hws.Server, - conn *bun.DB, - audit *auditlog.Logger, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { seasonStr := r.PathValue("season_short_name") @@ -25,32 +23,10 @@ func SeasonAddLeague( var season *db.Season var allLeagues []*db.League - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { - var err error - season, err = db.GetSeason(ctx, tx, seasonStr) + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { + err := db.NewSeasonLeague(ctx, tx, seasonStr, leagueStr, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.GetSeason") - } - if season == nil { - return false, errors.New("season not found") - } - - league, err := db.GetLeague(ctx, tx, leagueStr) - if err != nil { - return false, errors.Wrap(err, "db.GetLeague") - } - if league == nil { - return false, errors.New("league not found") - } - - // Create the many-to-many relationship - seasonLeague := &db.SeasonLeague{ - SeasonID: season.ID, - LeagueID: league.ID, - } - err = db.Insert(tx, seasonLeague).WithAudit(r, audit.Callback()).Exec(ctx) - if err != nil { - return false, errors.Wrap(err, "db.Insert") + return false, errors.Wrap(err, "db.NewSeasonLeague") } // Reload season with updated leagues @@ -76,8 +52,7 @@ func SeasonAddLeague( func SeasonRemoveLeague( s *hws.Server, - conn *bun.DB, - audit *auditlog.Logger, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { seasonStr := r.PathValue("season_short_name") @@ -85,7 +60,7 @@ func SeasonRemoveLeague( var season *db.Season var allLeagues []*db.League - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error season, err = db.GetSeason(ctx, tx, seasonStr) if err != nil { @@ -94,22 +69,9 @@ func SeasonRemoveLeague( if season == nil { return false, errors.New("season not found") } - - league, err := db.GetLeague(ctx, tx, leagueStr) + err = season.RemoveLeague(ctx, tx, leagueStr, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.GetLeague") - } - if league == nil { - return false, errors.New("league not found") - } - - // Delete the many-to-many relationship - err = db.DeleteItem[db.SeasonLeague](tx). - Where("season_id = ? AND league_id = ?", season.ID, league.ID). - WithAudit(r, audit.Callback()). - Delete(ctx) - if err != nil { - return false, errors.Wrap(err, "db.DeleteItem") + return false, errors.Wrap(err, "season.RemoveLeague") } // Reload season with updated leagues diff --git a/internal/handlers/seasons_list.go b/internal/handlers/seasons_list.go index fbf826d..f9b2757 100644 --- a/internal/handlers/seasons_list.go +++ b/internal/handlers/seasons_list.go @@ -11,18 +11,18 @@ import ( "github.com/uptrace/bun" ) -// SeasonsPage renders the full page with the seasons list, for use with GET requests +// SeasonsPage renders the season list. On GET it returns the full page, otherwise it just returns the list func SeasonsPage( s *hws.Server, - conn *bun.DB, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pageOpts := pageOptsFromQuery(s, w, r) - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } var seasons *db.List[db.Season] - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error seasons, err = db.ListSeasons(ctx, tx, pageOpts) if err != nil { @@ -32,31 +32,10 @@ func SeasonsPage( }); !ok { return } - renderSafely(seasonsview.ListPage(seasons), s, r, w) - }) -} - -// SeasonsList renders just the seasons list, for use with POST requests and HTMX -func SeasonsList( - s *hws.Server, - conn *bun.DB, -) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pageOpts := pageOptsFromForm(s, w, r) - if pageOpts == nil { - return - } - var seasons *db.List[db.Season] - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { - var err error - seasons, err = db.ListSeasons(ctx, tx, pageOpts) - if err != nil { - return false, errors.Wrap(err, "db.ListSeasons") - } - return true, nil - }); !ok { - return - } - renderSafely(seasonsview.SeasonsList(seasons), s, r, w) + if r.Method == "GET" { + renderSafely(seasonsview.ListPage(seasons), s, r, w) + } else { + renderSafely(seasonsview.SeasonsList(seasons), s, r, w) + } }) } diff --git a/internal/handlers/seasons_new.go b/internal/handlers/seasons_new.go index e534ef1..b2fd21b 100644 --- a/internal/handlers/seasons_new.go +++ b/internal/handlers/seasons_new.go @@ -6,7 +6,6 @@ import ( "net/http" "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/validation" @@ -18,20 +17,15 @@ import ( func NewSeason( s *hws.Server, - conn *bun.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "GET" { - renderSafely(seasonsview.NewPage(), s, r, w) - return - } + renderSafely(seasonsview.NewPage(), s, r, w) }) } func NewSeasonSubmit( s *hws.Server, - conn *bun.DB, - audit *auditlog.Logger, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { getter, ok := validation.ParseFormOrNotify(s, w, r) @@ -58,7 +52,7 @@ func NewSeasonSubmit( nameUnique := false shortNameUnique := false var season *db.Season - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error nameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "name", name) if err != nil { @@ -71,10 +65,9 @@ func NewSeasonSubmit( if !nameUnique || !shortNameUnique { return true, nil } - season = db.NewSeason(name, version, shortname, start) - err = db.Insert(tx, season).WithAudit(r, audit.Callback()).Exec(ctx) + season, err = db.NewSeason(ctx, tx, name, version, shortname, start, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.Insert") + return false, errors.Wrap(err, "db.NewSeason") } return true, nil }); !ok { diff --git a/internal/handlers/team_shortnames_unique.go b/internal/handlers/team_shortnames_unique.go index 56f71b2..7ee2b57 100644 --- a/internal/handlers/team_shortnames_unique.go +++ b/internal/handlers/team_shortnames_unique.go @@ -15,7 +15,7 @@ import ( // and also validates that they are different from each other func IsTeamShortNamesUnique( s *hws.Server, - conn *bun.DB, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { getter, err := validation.ParseForm(r) @@ -38,7 +38,7 @@ func IsTeamShortNamesUnique( } var isUnique bool - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { isUnique, err = db.TeamShortNamesUnique(ctx, tx, shortName, altShortName) if err != nil { return false, errors.Wrap(err, "db.TeamShortNamesUnique") diff --git a/internal/handlers/teams_list.go b/internal/handlers/teams_list.go index 9fee626..5eaec71 100644 --- a/internal/handlers/teams_list.go +++ b/internal/handlers/teams_list.go @@ -14,15 +14,15 @@ import ( // TeamsPage renders the full page with the teams list, for use with GET requests func TeamsPage( s *hws.Server, - conn *bun.DB, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pageOpts := pageOptsFromQuery(s, w, r) - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } var teams *db.List[db.Team] - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error teams, err = db.ListTeams(ctx, tx, pageOpts) if err != nil { @@ -39,15 +39,15 @@ func TeamsPage( // TeamsList renders just the teams list, for use with POST requests and HTMX func TeamsList( s *hws.Server, - conn *bun.DB, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - pageOpts := pageOptsFromForm(s, w, r) - if pageOpts == nil { + pageOpts, ok := db.GetPageOpts(s, w, r) + if !ok { return } var teams *db.List[db.Team] - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error teams, err = db.ListTeams(ctx, tx, pageOpts) if err != nil { diff --git a/internal/handlers/teams_new.go b/internal/handlers/teams_new.go index 2580ca0..66c79a5 100644 --- a/internal/handlers/teams_new.go +++ b/internal/handlers/teams_new.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "github.com/uptrace/bun" - "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/validation" @@ -18,7 +17,6 @@ import ( func NewTeamPage( s *hws.Server, - conn *bun.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { renderSafely(teamsview.NewPage(), s, r, w) @@ -27,8 +25,7 @@ func NewTeamPage( func NewTeamSubmit( s *hws.Server, - conn *bun.DB, - audit *auditlog.Logger, + conn *db.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { getter, ok := validation.ParseFormOrNotify(s, w, r) @@ -38,10 +35,10 @@ func NewTeamSubmit( name := getter.String("name"). TrimSpace().Required(). MaxLength(25).MinLength(3).Value - shortname := getter.String("short_name"). + shortName := getter.String("short_name"). TrimSpace().Required(). MaxLength(3).MinLength(3).Value - altShortname := getter.String("alt_short_name"). + altShortName := getter.String("alt_short_name"). TrimSpace().Required(). MaxLength(3).MinLength(3).Value color := getter.String("color"). @@ -51,22 +48,21 @@ func NewTeamSubmit( } // Check that short names are different - if shortname == altShortname { + if shortName == altShortName { notify.Warn(s, w, r, "Invalid Short Names", "Short name and alternative short name must be different.", nil) return } nameUnique := false shortNameComboUnique := false - var team *db.Team - if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error nameUnique, err = db.IsUnique(ctx, tx, (*db.Team)(nil), "name", name) if err != nil { return false, errors.Wrap(err, "db.IsTeamNameUnique") } - shortNameComboUnique, err = db.TeamShortNamesUnique(ctx, tx, shortname, altShortname) + shortNameComboUnique, err = db.TeamShortNamesUnique(ctx, tx, shortName, altShortName) if err != nil { return false, errors.Wrap(err, "db.TeamShortNamesUnique") } @@ -74,15 +70,9 @@ func NewTeamSubmit( if !nameUnique || !shortNameComboUnique { return true, nil } - team = &db.Team{ - Name: name, - ShortName: shortname, - AltShortName: altShortname, - Color: color, - } - err = db.Insert(tx, team).WithAudit(r, audit.Callback()).Exec(ctx) + _, err = db.NewTeam(ctx, tx, name, shortName, altShortName, color, db.NewAudit(r, nil)) if err != nil { - return false, errors.Wrap(err, "db.Insert") + return false, errors.Wrap(err, "db.NewTeam") } return true, nil }); !ok { diff --git a/internal/rbac/cache_middleware.go b/internal/rbac/cache_middleware.go index 39bb661..27a0d9a 100644 --- a/internal/rbac/cache_middleware.go +++ b/internal/rbac/cache_middleware.go @@ -39,12 +39,12 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware { var roles_ []*db.Role var perms []*db.Permission - if err := db.WithTxFailSilently(r.Context(), c.conn, func(ctx context.Context, tx bun.Tx) error { + if err := c.conn.WithTxFailSilently(r.Context(), func(ctx context.Context, tx bun.Tx) error { var err error if previewRole != nil { // In preview mode: use the preview role instead of user's roles - role, err := db.GetRoleWithPermissions(ctx, tx, previewRole.ID) + role, err := db.GetRoleByID(ctx, tx, previewRole.ID) if err != nil { return errors.Wrap(err, "db.GetRoleWithPermissions") } diff --git a/internal/rbac/checker.go b/internal/rbac/checker.go index 3b81876..4729a8a 100644 --- a/internal/rbac/checker.go +++ b/internal/rbac/checker.go @@ -13,11 +13,11 @@ import ( ) type Checker struct { - conn *bun.DB + conn *db.DB s *hws.Server } -func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) { +func NewChecker(conn *db.DB, s *hws.Server) (*Checker, error) { if conn == nil { return nil, errors.New("conn cannot be nil") } @@ -56,7 +56,7 @@ func (c *Checker) UserHasPermission(ctx context.Context, user *db.User, permissi // Not in preview mode: fallback to database for actual user permissions var has bool - if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error { + if err := c.conn.WithTxFailSilently(ctx, func(ctx context.Context, tx bun.Tx) error { var err error has, err = user.HasPermission(ctx, tx, permission) if err != nil { @@ -94,7 +94,7 @@ func (c *Checker) UserHasRole(ctx context.Context, user *db.User, role roles.Rol // Not in preview mode: fallback to database for actual user roles var has bool - if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error { + if err := c.conn.WithTxFailSilently(ctx, func(ctx context.Context, tx bun.Tx) error { var err error has, err = user.HasRole(ctx, tx, role) if err != nil { diff --git a/internal/rbac/preview_middleware.go b/internal/rbac/preview_middleware.go index 87b58a1..9944763 100644 --- a/internal/rbac/preview_middleware.go +++ b/internal/rbac/preview_middleware.go @@ -15,7 +15,7 @@ import ( // LoadPreviewRoleMiddleware loads the preview role from the session cookie if present // and adds it to the request context. This must run after authentication but before // the RBAC cache middleware. -func LoadPreviewRoleMiddleware(s *hws.Server, conn *bun.DB) func(http.Handler) http.Handler { +func LoadPreviewRoleMiddleware(s *hws.Server, conn *db.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check if there's a preview role in the cookie @@ -26,10 +26,25 @@ func LoadPreviewRoleMiddleware(s *hws.Server, conn *bun.DB) func(http.Handler) h return } + user := db.CurrentUser(r.Context()) + if user == nil { + // User not logged in, + ClearPreviewRoleCookie(w) + next.ServeHTTP(w, r) + return + } + // Load the preview role from the database var previewRole *db.Role - if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { - var err error + if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { + isAdmin, err := user.IsAdmin(ctx, tx) + if err != nil { + return false, errors.Wrap(err, "user.IsAdmin") + } + if !isAdmin { + ClearPreviewRoleCookie(w) + return true, nil + } previewRole, err = db.GetRoleByID(ctx, tx, roleID) if err != nil { return false, errors.Wrap(err, "db.GetRoleByID") diff --git a/internal/rbac/protection_middleware.go b/internal/rbac/protection_middleware.go index ab695c4..05a5511 100644 --- a/internal/rbac/protection_middleware.go +++ b/internal/rbac/protection_middleware.go @@ -90,7 +90,7 @@ func (c *Checker) RequireActualAdmin(s *hws.Server) func(http.Handler) http.Hand // Check user's ACTUAL role in database, bypassing preview mode var hasAdmin bool - if ok := db.WithReadTx(s, w, r, c.conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + if ok := c.conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error hasAdmin, err = user.HasRole(ctx, tx, roles.Admin) if err != nil { diff --git a/cmd/oslstats/auth.go b/internal/server/auth.go similarity index 95% rename from cmd/oslstats/auth.go rename to internal/server/auth.go index 781d40a..45587f5 100644 --- a/cmd/oslstats/auth.go +++ b/internal/server/auth.go @@ -1,4 +1,4 @@ -package main +package server import ( "context" @@ -15,7 +15,7 @@ import ( func setupAuth( cfg *hwsauth.Config, logger *hlog.Logger, - conn *bun.DB, + conn *db.DB, server *hws.Server, ignoredPaths []string, ) (*hwsauth.Authenticator[*db.User, bun.Tx], error) { @@ -30,7 +30,7 @@ func setupAuth( beginTx, logger, handlers.ErrorPage, - conn.DB, + conn.DB.DB, ) if err != nil { return nil, errors.Wrap(err, "hwsauth.NewAuthenticator") diff --git a/cmd/oslstats/middleware.go b/internal/server/middleware.go similarity index 99% rename from cmd/oslstats/middleware.go rename to internal/server/middleware.go index c4b701b..21f3f28 100644 --- a/cmd/oslstats/middleware.go +++ b/internal/server/middleware.go @@ -1,4 +1,4 @@ -package main +package server import ( "context" @@ -27,7 +27,7 @@ func addMiddleware( perms *rbac.Checker, discordAPI *discord.APIClient, store *store.Store, - conn *bun.DB, + conn *db.DB, ) error { err := server.AddMiddleware( auth.Authenticate(tokenRefresh(auth, discordAPI, store)), diff --git a/cmd/oslstats/routes.go b/internal/server/routes.go similarity index 89% rename from cmd/oslstats/routes.go rename to internal/server/routes.go index 877f847..4dbf13c 100644 --- a/cmd/oslstats/routes.go +++ b/internal/server/routes.go @@ -1,4 +1,4 @@ -package main +package server import ( "net/http" @@ -8,7 +8,6 @@ import ( "github.com/pkg/errors" "github.com/uptrace/bun" - "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" @@ -22,12 +21,11 @@ func addRoutes( s *hws.Server, staticFS *http.FileSystem, cfg *config.Config, - conn *bun.DB, + conn *db.DB, auth *hwsauth.Authenticator[*db.User, bun.Tx], store *store.Store, discordAPI *discord.APIClient, perms *rbac.Checker, - audit *auditlog.Logger, ) error { // Create the routes baseRoutes := []hws.Route{ @@ -69,23 +67,18 @@ func addRoutes( seasonRoutes := []hws.Route{ { Path: "/seasons", - Method: hws.MethodGET, + Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, Handler: handlers.SeasonsPage(s, conn), }, - { - Path: "/seasons", - Method: hws.MethodPOST, - Handler: handlers.SeasonsList(s, conn), - }, { Path: "/seasons/new", Method: hws.MethodGET, - Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeason(s, conn)), + Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeason(s)), }, { Path: "/seasons/new", Method: hws.MethodPOST, - Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeasonSubmit(s, conn, audit)), + Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeasonSubmit(s, conn)), }, { Path: "/seasons/{season_short_name}", @@ -100,7 +93,7 @@ func addRoutes( { Path: "/seasons/{season_short_name}/edit", Method: hws.MethodPOST, - Handler: perms.RequirePermission(s, permissions.SeasonsUpdate)(handlers.SeasonEditSubmit(s, conn, audit)), + Handler: perms.RequirePermission(s, permissions.SeasonsUpdate)(handlers.SeasonEditSubmit(s, conn)), }, { Path: "/seasons/{season_short_name}/leagues/{league_short_name}", @@ -110,17 +103,17 @@ func addRoutes( { Path: "/seasons/{season_short_name}/leagues/add/{league_short_name}", Method: hws.MethodPOST, - Handler: perms.RequirePermission(s, permissions.SeasonsAddLeague)(handlers.SeasonAddLeague(s, conn, audit)), + Handler: perms.RequirePermission(s, permissions.SeasonsAddLeague)(handlers.SeasonAddLeague(s, conn)), }, { Path: "/seasons/{season_short_name}/leagues/{league_short_name}", Method: hws.MethodDELETE, - Handler: perms.RequirePermission(s, permissions.SeasonsRemoveLeague)(handlers.SeasonRemoveLeague(s, conn, audit)), + Handler: perms.RequirePermission(s, permissions.SeasonsRemoveLeague)(handlers.SeasonRemoveLeague(s, conn)), }, { Path: "/seasons/{season_short_name}/leagues/{league_short_name}/teams/add", Method: hws.MethodPOST, - Handler: perms.RequirePermission(s, permissions.TeamsAddToLeague)(handlers.SeasonLeagueAddTeam(s, conn, audit)), + Handler: perms.RequirePermission(s, permissions.TeamsAddToLeague)(handlers.SeasonLeagueAddTeam(s, conn)), }, } @@ -133,12 +126,12 @@ func addRoutes( { Path: "/leagues/new", Method: hws.MethodGET, - Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeague(s, conn)), + Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeague(s)), }, { Path: "/leagues/new", Method: hws.MethodPOST, - Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeagueSubmit(s, conn, audit)), + Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeagueSubmit(s, conn)), }, } @@ -156,12 +149,12 @@ func addRoutes( { Path: "/teams/new", Method: hws.MethodGET, - Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamPage(s, conn)), + Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamPage(s)), }, { Path: "/teams/new", Method: hws.MethodPOST, - Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamSubmit(s, conn, audit)), + Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamSubmit(s, conn)), }, } @@ -234,21 +227,11 @@ func addRoutes( Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, Handler: perms.RequireAdmin(s)(handlers.AdminRoles(s, conn)), }, - { - Path: "/admin/permissions", - Method: hws.MethodGET, - Handler: perms.RequireAdmin(s)(handlers.AdminPermissionsPage(s, conn)), - }, { Path: "/admin/audit", Method: hws.MethodGET, Handler: perms.RequireAdmin(s)(handlers.AdminAuditLogsPage(s, conn)), }, - { - Path: "/admin/permissions", - Method: hws.MethodPOST, - Handler: perms.RequireAdmin(s)(handlers.AdminPermissionsList(s, conn)), - }, { Path: "/admin/audit", Method: hws.MethodPOST, @@ -263,7 +246,7 @@ func addRoutes( { Path: "/admin/roles/create", Method: hws.MethodPOST, - Handler: perms.RequireAdmin(s)(handlers.AdminRoleCreate(s, conn, audit)), + Handler: perms.RequireAdmin(s)(handlers.AdminRoleCreate(s, conn)), }, { Path: "/admin/roles/{id}/manage", @@ -273,7 +256,7 @@ func addRoutes( { Path: "/admin/roles/{id}", Method: hws.MethodDELETE, - Handler: perms.RequireAdmin(s)(handlers.AdminRoleDelete(s, conn, audit)), + Handler: perms.RequireAdmin(s)(handlers.AdminRoleDelete(s, conn)), }, { Path: "/admin/roles/{id}/delete-confirm", @@ -288,12 +271,12 @@ func addRoutes( { Path: "/admin/roles/{id}/permissions", Method: hws.MethodPOST, - Handler: perms.RequireAdmin(s)(handlers.AdminRolePermissionsUpdate(s, conn, audit)), + Handler: perms.RequireAdmin(s)(handlers.AdminRolePermissionsUpdate(s, conn)), }, { Path: "/admin/roles/{id}/preview-start", Method: hws.MethodPOST, - Handler: perms.RequireAdmin(s)(handlers.AdminPreviewRoleStart(s, conn, cfg)), + Handler: perms.RequireAdmin(s)(handlers.AdminPreviewRoleStart(s, conn, cfg.HWSAuth.SSL)), }, { Path: "/admin/roles/preview-stop", diff --git a/cmd/oslstats/httpserver.go b/internal/server/setup.go similarity index 80% rename from cmd/oslstats/httpserver.go rename to internal/server/setup.go index a6b7b44..1fe44dc 100644 --- a/cmd/oslstats/httpserver.go +++ b/internal/server/setup.go @@ -1,4 +1,5 @@ -package main +// Package server provides setup utilities for the HTTP server +package server import ( "io/fs" @@ -7,21 +8,20 @@ import ( "git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/hws" "github.com/pkg/errors" - "github.com/uptrace/bun" - "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/handlers" "git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/store" ) -func setupHTTPServer( +func Setup( staticFS *fs.FS, cfg *config.Config, logger *hlog.Logger, - bun *bun.DB, + conn *db.DB, store *store.Store, discordAPI *discord.APIClient, ) (server *hws.Server, err error) { @@ -41,7 +41,7 @@ func setupHTTPServer( } auth, err := setupAuth( - cfg.HWSAuth, logger, bun, httpServer, ignoredPaths) + cfg.HWSAuth, logger, conn, httpServer, ignoredPaths) if err != nil { return nil, errors.Wrap(err, "setupAuth") } @@ -62,20 +62,17 @@ func setupHTTPServer( } // Initialize permissions checker - perms, err := rbac.NewChecker(bun, httpServer) + perms, err := rbac.NewChecker(conn, httpServer) if err != nil { return nil, errors.Wrap(err, "rbac.NewChecker") } - // Initialize audit logger - audit := auditlog.NewLogger(bun) - - err = addRoutes(httpServer, &fs, cfg, bun, auth, store, discordAPI, perms, audit) + err = addRoutes(httpServer, &fs, cfg, conn, auth, store, discordAPI, perms) if err != nil { return nil, errors.Wrap(err, "addRoutes") } - err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store, bun) + err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store, conn) if err != nil { return nil, errors.Wrap(err, "addMiddleware") } diff --git a/internal/view/baseview/navbar.templ b/internal/view/baseview/navbar.templ index a0cca8a..1ac4c42 100644 --- a/internal/view/baseview/navbar.templ +++ b/internal/view/baseview/navbar.templ @@ -136,38 +136,38 @@ templ profileDropdown(user *db.User, items []ProfileItem) { x-on:click.away="isActive = false" x-on:keydown.escape.window="isActive = false" > - - if previewRole != nil { -
-

- Viewing as: { previewRole.DisplayName } -

-
-
- -
-
- +
+
+ -
+ role="menuitem" + @click="isActive=false" + > + Return to Admin + + +
- - } + }
for _, item := range items {