big ole refactor

This commit is contained in:
2026-02-14 19:48:59 +11:00
parent 0fc3bb0c94
commit 4a2396bca8
66 changed files with 989 additions and 1114 deletions

View File

@@ -1,48 +0,0 @@
package main
import (
"database/sql"
"fmt"
"time"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)
func setupBun(cfg *config.Config) (conn *bun.DB, close func() error) {
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL)
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
sqldb.SetMaxOpenConns(25)
sqldb.SetMaxIdleConns(10)
sqldb.SetConnMaxLifetime(5 * time.Minute)
sqldb.SetConnMaxIdleTime(5 * time.Minute)
conn = bun.NewDB(sqldb, pgdialect.New())
registerDBModels(conn)
close = sqldb.Close
return conn, close
}
func registerDBModels(conn *bun.DB) []any {
models := []any{
(*db.RolePermission)(nil),
(*db.UserRole)(nil),
(*db.SeasonLeague)(nil),
(*db.TeamParticipation)(nil),
(*db.User)(nil),
(*db.DiscordToken)(nil),
(*db.Season)(nil),
(*db.League)(nil),
(*db.Team)(nil),
(*db.Role)(nil),
(*db.Permission)(nil),
(*db.AuditLog)(nil),
}
conn.RegisterModel(models...)
return models
}

View File

@@ -7,6 +7,7 @@ import (
"git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db/migrate"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -48,7 +49,7 @@ func main() {
// Handle migration file creation (doesn't need DB connection) // Handle migration file creation (doesn't need DB connection)
if flags.MigrateCreate != "" { if flags.MigrateCreate != "" {
if err := createMigration(flags.MigrateCreate); err != nil { if err := migrate.CreateMigration(flags.MigrateCreate); err != nil {
logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "createMigration"))).Msg("Error creating migration") logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "createMigration"))).Msg("Error creating migration")
} }
return return
@@ -59,17 +60,21 @@ func main() {
flags.MigrateStatus || flags.MigrateDryRun || flags.MigrateStatus || flags.MigrateDryRun ||
flags.ResetDB { flags.ResetDB {
var command, countStr string
// Route to appropriate command // Route to appropriate command
if flags.MigrateUp != "" { if flags.MigrateUp != "" {
err = runMigrations(ctx, cfg, "up", flags.MigrateUp) command = "up"
countStr = flags.MigrateUp
} else if flags.MigrateRollback != "" { } else if flags.MigrateRollback != "" {
err = runMigrations(ctx, cfg, "rollback", flags.MigrateRollback) command = "rollback"
countStr = flags.MigrateRollback
} else if flags.MigrateStatus { } else if flags.MigrateStatus {
err = runMigrations(ctx, cfg, "status", "") command = "status"
} else if flags.MigrateDryRun { }
err = runMigrations(ctx, cfg, "dry-run", "") if flags.ResetDB {
} else if flags.ResetDB { err = migrate.ResetDatabase(ctx, cfg)
err = resetDatabase(ctx, cfg) } else {
err = migrate.RunMigrations(ctx, cfg, command, countStr)
} }
if err != nil { if err != nil {

View File

@@ -12,8 +12,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/embedfs" "git.haelnorr.com/h/oslstats/internal/embedfs"
"git.haelnorr.com/h/oslstats/internal/server"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
) )
@@ -25,8 +27,7 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
// Setup the database connection // Setup the database connection
logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database") logger.Debug().Msg("Connecting to database")
bun, closedb := setupBun(cfg) conn := db.NewDB(cfg.DB)
// registerDBModels(bun)
// Setup embedded files // Setup embedded files
logger.Debug().Msg("Getting embedded files") logger.Debug().Msg("Getting embedded files")
@@ -47,7 +48,7 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
} }
logger.Debug().Msg("Setting up HTTP server") logger.Debug().Msg("Setting up HTTP server")
httpServer, err := setupHTTPServer(&staticFS, cfg, logger, bun, store, discordAPI) httpServer, err := server.Setup(staticFS, cfg, logger, conn, store, discordAPI)
if err != nil { if err != nil {
return errors.Wrap(err, "setupHttpServer") return errors.Wrap(err, "setupHttpServer")
} }
@@ -71,7 +72,7 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
if err != nil { if err != nil {
logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "httpServer.Shutdown"))).Msg("Error during HTTP server shutdown") logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "httpServer.Shutdown"))).Msg("Error during HTTP server shutdown")
} }
err = closedb() err = conn.Close()
if err != nil { if err != nil {
logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "closedb"))).Msg("Error during database close") logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "closedb"))).Msg("Error during database close")
} }

View File

@@ -1,187 +0,0 @@
// Package auditlog provides a system for logging events that require permissions to the audit log
package auditlog
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Logger struct {
conn *bun.DB
}
func NewLogger(conn *bun.DB) *Logger {
return &Logger{conn: conn}
}
// LogSuccess logs a successful permission-protected action
func (l *Logger) LogSuccess(
ctx context.Context,
tx bun.Tx,
user *db.User,
action string,
resourceType string,
resourceID any, // Can be int, string, or nil
details map[string]any,
r *http.Request,
) error {
return l.log(ctx, tx, user, action, resourceType, resourceID, details, "success", nil, r)
}
// LogError logs a failed action due to an error
func (l *Logger) LogError(
ctx context.Context,
tx bun.Tx,
user *db.User,
action string,
resourceType string,
resourceID any,
err error,
r *http.Request,
) error {
errMsg := err.Error()
return l.log(ctx, tx, user, action, resourceType, resourceID, nil, "error", &errMsg, r)
}
func (l *Logger) log(
ctx context.Context,
tx bun.Tx,
user *db.User,
action string,
resourceType string,
resourceID any,
details map[string]any,
result string,
errorMessage *string,
r *http.Request,
) error {
if user == nil {
return errors.New("user cannot be nil for audit logging")
}
// Convert resourceID to string
var resourceIDStr *string
if resourceID != nil {
idStr := fmt.Sprintf("%v", resourceID)
resourceIDStr = &idStr
}
// Marshal details to JSON
var detailsJSON json.RawMessage
if details != nil {
jsonBytes, err := json.Marshal(details)
if err != nil {
return errors.Wrap(err, "json.Marshal details")
}
detailsJSON = jsonBytes
}
// Extract IP and User-Agent from request
ipAddress := r.RemoteAddr
userAgent := r.UserAgent()
log := &db.AuditLog{
UserID: user.ID,
Action: action,
ResourceType: resourceType,
ResourceID: resourceIDStr,
Details: detailsJSON,
IPAddress: ipAddress,
UserAgent: userAgent,
Result: result,
ErrorMessage: errorMessage,
CreatedAt: time.Now().Unix(),
}
return db.CreateAuditLog(ctx, tx, log)
}
// GetRecentLogs retrieves recent audit logs with pagination
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) {
var logs *db.List[db.AuditLog]
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
var err error
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil)
if err != nil {
return errors.Wrap(err, "db.GetAuditLogs")
}
return nil
}); err != nil {
return nil, errors.Wrap(err, "db.WithTxFailSilently")
}
return logs, nil
}
// GetLogsByUser retrieves audit logs for a specific user
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) {
var logs *db.List[db.AuditLog]
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
var err error
logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
if err != nil {
return errors.Wrap(err, "db.GetAuditLogsByUser")
}
return nil
}); err != nil {
return nil, errors.Wrap(err, "db.WithTxFailSilently")
}
return logs, nil
}
// CleanupOldLogs deletes audit logs older than the specified number of days
func (l *Logger) CleanupOldLogs(ctx context.Context, daysToKeep int) (int, error) {
if daysToKeep <= 0 {
return 0, errors.New("daysToKeep must be positive")
}
cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix()
var count int
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
var err error
count, err = db.CleanupOldAuditLogs(ctx, tx, cutoffTime)
if err != nil {
return errors.Wrap(err, "db.CleanupOldAuditLogs")
}
return nil
}); err != nil {
return 0, errors.Wrap(err, "db.WithTxFailSilently")
}
return count, nil
}
// Callback returns a db.AuditCallback that logs to this Logger
// This is used with the generic database helpers (Insert, Update, Delete)
//
// Usage:
//
// audit := auditlog.NewLogger(conn)
// err := db.Insert(tx, season).
// WithAudit(r, audit.Callback()).
// Exec(ctx)
func (l *Logger) Callback() db.AuditCallback {
return func(ctx context.Context, tx bun.Tx, info *db.AuditInfo, r *http.Request) error {
user := db.CurrentUser(ctx)
if user == nil {
return errors.New("no user in context for audit logging")
}
return l.LogSuccess(
ctx,
tx,
user,
info.Action,
info.ResourceType,
info.ResourceID,
info.Details,
r,
)
}
}

View File

@@ -1,14 +1,23 @@
package db package db
import ( import (
"context"
"net/http" "net/http"
"reflect" "reflect"
"strings" "strings"
"github.com/uptrace/bun"
) )
type AuditMeta struct {
r *http.Request
u *User
}
func NewAudit(r *http.Request, u *User) *AuditMeta {
if u == nil {
u = CurrentUser(r.Context())
}
return &AuditMeta{r, u}
}
// AuditInfo contains metadata for audit logging // AuditInfo contains metadata for audit logging
type AuditInfo struct { type AuditInfo struct {
Action string // e.g., "seasons.create", "users.update" Action string // e.g., "seasons.create", "users.update"
@@ -17,9 +26,6 @@ type AuditInfo struct {
Details map[string]any // Changed fields or additional metadata Details map[string]any // Changed fields or additional metadata
} }
// AuditCallback is called after successful database operations to log changes
type AuditCallback func(ctx context.Context, tx bun.Tx, info *AuditInfo, r *http.Request) error
// extractTableName gets the bun table name from a model type using reflection // extractTableName gets the bun table name from a model type using reflection
// Example: Season with `bun:"table:seasons,alias:s"` returns "seasons" // Example: Season with `bun:"table:seasons,alias:s"` returns "seasons"
func extractTableName[T any]() string { func extractTableName[T any]() string {
@@ -27,7 +33,7 @@ func extractTableName[T any]() string {
t := reflect.TypeOf(model) t := reflect.TypeOf(model)
// Handle pointer types // Handle pointer types
if t.Kind() == reflect.Ptr { if t.Kind() == reflect.Pointer {
t = t.Elem() t = t.Elem()
} }
@@ -38,11 +44,9 @@ func extractTableName[T any]() string {
bunTag := field.Tag.Get("bun") bunTag := field.Tag.Get("bun")
if bunTag != "" { if bunTag != "" {
// Parse tag: "table:seasons,alias:s" -> "seasons" // Parse tag: "table:seasons,alias:s" -> "seasons"
parts := strings.Split(bunTag, ",") for part := range strings.SplitSeq(bunTag, ",") {
for _, part := range parts { part, _ := strings.CutPrefix(part, "table:")
if strings.HasPrefix(part, "table:") { return part
return strings.TrimPrefix(part, "table:")
}
} }
} }
} }
@@ -81,7 +85,7 @@ func extractPrimaryKey[T any](model *T) any {
} }
v := reflect.ValueOf(model) v := reflect.ValueOf(model)
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Pointer {
v = v.Elem() v = v.Elem()
} }
@@ -110,7 +114,7 @@ func extractChangedFields[T any](model *T, columns []string) map[string]any {
result := make(map[string]any) result := make(map[string]any)
v := reflect.ValueOf(model) v := reflect.ValueOf(model)
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Pointer {
v = v.Elem() v = v.Elem()
} }
@@ -142,5 +146,3 @@ func extractChangedFields[T any](model *T, columns []string) map[string]any {
return result return result
} }
// Note: We don't need getTxFromQuery since we store the tx directly in our helper structs

View File

@@ -0,0 +1,91 @@
package db
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// LogSuccess logs a successful permission-protected action
func LogSuccess(
ctx context.Context,
tx bun.Tx,
meta *AuditMeta,
info *AuditInfo,
) error {
return log(ctx, tx, meta, info, "success", nil)
}
// LogError logs a failed action due to an error
func LogError(
ctx context.Context,
tx bun.Tx,
meta *AuditMeta,
info *AuditInfo,
err error,
) error {
errMsg := err.Error()
return log(ctx, tx, meta, info, "error", &errMsg)
}
func log(
ctx context.Context,
tx bun.Tx,
meta *AuditMeta,
info *AuditInfo,
result string,
errorMessage *string,
) error {
if meta == nil {
return errors.New("audit meta cannot be nil for audit logging")
}
if info == nil {
return errors.New("audit info cannot be nil for audit logging")
}
if meta.u == nil {
return errors.New("user cannot be nil for audit logging")
}
if meta.r == nil {
return errors.New("request cannot be nil for audit logging")
}
// Convert resourceID to string
var resourceIDStr *string
if info.ResourceID != nil {
idStr := fmt.Sprintf("%v", info.ResourceID)
resourceIDStr = &idStr
}
// Marshal details to JSON
var detailsJSON json.RawMessage
if info.Details != nil {
jsonBytes, err := json.Marshal(info.Details)
if err != nil {
return errors.Wrap(err, "json.Marshal details")
}
detailsJSON = jsonBytes
}
// Extract IP and User-Agent from request
ipAddress := meta.r.RemoteAddr
userAgent := meta.r.UserAgent()
log := &AuditLog{
UserID: meta.u.ID,
Action: info.Action,
ResourceType: info.ResourceType,
ResourceID: resourceIDStr,
Details: detailsJSON,
IPAddress: ipAddress,
UserAgent: userAgent,
Result: result,
ErrorMessage: errorMessage,
CreatedAt: time.Now().Unix(),
}
return CreateAuditLog(ctx, tx, log)
}

View File

@@ -1,4 +1,4 @@
package backup package db
import ( import (
"context" "context"
@@ -9,14 +9,13 @@ import (
"sort" "sort"
"time" "time"
"git.haelnorr.com/h/oslstats/internal/config"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// CreateBackup creates a compressed PostgreSQL dump before migrations // CreateBackup creates a compressed PostgreSQL dump before migrations
// Returns backup filename and error // Returns backup filename and error
// If pg_dump is not available, returns nil error with warning // If pg_dump is not available, returns nil error with warning
func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (string, error) { func CreateBackup(ctx context.Context, cfg *Config, operation string) (string, error) {
// Check if pg_dump is available // Check if pg_dump is available
if _, err := exec.LookPath("pg_dump"); err != nil { if _, err := exec.LookPath("pg_dump"); err != nil {
fmt.Println("[WARN] pg_dump not found - skipping backup") fmt.Println("[WARN] pg_dump not found - skipping backup")
@@ -28,13 +27,13 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
} }
// Ensure backup directory exists // Ensure backup directory exists
if err := os.MkdirAll(cfg.DB.BackupDir, 0755); err != nil { if err := os.MkdirAll(cfg.BackupDir, 0o755); err != nil {
return "", errors.Wrap(err, "failed to create backup directory") return "", errors.Wrap(err, "failed to create backup directory")
} }
// Generate filename: YYYYMMDD_HHmmss_pre_{operation}.sql.gz // Generate filename: YYYYMMDD_HHmmss_pre_{operation}.sql.gz
timestamp := time.Now().Format("20060102_150405") timestamp := time.Now().Format("20060102_150405")
filename := filepath.Join(cfg.DB.BackupDir, filename := filepath.Join(cfg.BackupDir,
fmt.Sprintf("%s_pre_%s.sql.gz", timestamp, operation)) fmt.Sprintf("%s_pre_%s.sql.gz", timestamp, operation))
// Check if gzip is available // Check if gzip is available
@@ -42,7 +41,7 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
if _, err := exec.LookPath("gzip"); err != nil { if _, err := exec.LookPath("gzip"); err != nil {
fmt.Println("[WARN] gzip not found - using uncompressed backup") fmt.Println("[WARN] gzip not found - using uncompressed backup")
useGzip = false useGzip = false
filename = filepath.Join(cfg.DB.BackupDir, filename = filepath.Join(cfg.BackupDir,
fmt.Sprintf("%s_pre_%s.sql", timestamp, operation)) fmt.Sprintf("%s_pre_%s.sql", timestamp, operation))
} }
@@ -52,19 +51,19 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
// Use shell to pipe pg_dump through gzip // Use shell to pipe pg_dump through gzip
pgDumpCmd := fmt.Sprintf( pgDumpCmd := fmt.Sprintf(
"pg_dump -h %s -p %d -U %s -d %s --no-owner --no-acl --clean --if-exists | gzip > %s", "pg_dump -h %s -p %d -U %s -d %s --no-owner --no-acl --clean --if-exists | gzip > %s",
cfg.DB.Host, cfg.Host,
cfg.DB.Port, cfg.Port,
cfg.DB.User, cfg.User,
cfg.DB.DB, cfg.DB,
filename, filename,
) )
cmd = exec.CommandContext(ctx, "sh", "-c", pgDumpCmd) cmd = exec.CommandContext(ctx, "sh", "-c", pgDumpCmd)
} else { } else {
cmd = exec.CommandContext(ctx, "pg_dump", cmd = exec.CommandContext(ctx, "pg_dump",
"-h", cfg.DB.Host, "-h", cfg.Host,
"-p", fmt.Sprint(cfg.DB.Port), "-p", fmt.Sprint(cfg.Port),
"-U", cfg.DB.User, "-U", cfg.User,
"-d", cfg.DB.DB, "-d", cfg.DB,
"-f", filename, "-f", filename,
"--no-owner", "--no-owner",
"--no-acl", "--no-acl",
@@ -75,7 +74,7 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
// Set password via environment variable // Set password via environment variable
cmd.Env = append(os.Environ(), cmd.Env = append(os.Environ(),
fmt.Sprintf("PGPASSWORD=%s", cfg.DB.Password)) fmt.Sprintf("PGPASSWORD=%s", cfg.Password))
// Run backup // Run backup
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
@@ -95,14 +94,14 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
} }
// CleanOldBackups keeps only the N most recent backups // CleanOldBackups keeps only the N most recent backups
func CleanOldBackups(cfg *config.Config, keepCount int) error { func CleanOldBackups(cfg *Config, keepCount int) error {
// Get all backup files (both .sql and .sql.gz) // Get all backup files (both .sql and .sql.gz)
sqlFiles, err := filepath.Glob(filepath.Join(cfg.DB.BackupDir, "*.sql")) sqlFiles, err := filepath.Glob(filepath.Join(cfg.BackupDir, "*.sql"))
if err != nil { if err != nil {
return errors.Wrap(err, "failed to list .sql backups") return errors.Wrap(err, "failed to list .sql backups")
} }
gzFiles, err := filepath.Glob(filepath.Join(cfg.DB.BackupDir, "*.sql.gz")) gzFiles, err := filepath.Glob(filepath.Join(cfg.BackupDir, "*.sql.gz"))
if err != nil { if err != nil {
return errors.Wrap(err, "failed to list .sql.gz backups") return errors.Wrap(err, "failed to list .sql.gz backups")
} }

View File

@@ -3,18 +3,17 @@ package db
import ( import (
"context" "context"
"database/sql" "database/sql"
"net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type deleter[T any] struct { type deleter[T any] struct {
tx bun.Tx tx bun.Tx
q *bun.DeleteQuery q *bun.DeleteQuery
resourceID any // Store ID before deletion for audit resourceID any // Store ID before deletion for audit
auditCallback AuditCallback audit *AuditMeta
auditRequest *http.Request auditInfo *AuditInfo
} }
type systemType interface { type systemType interface {
@@ -39,11 +38,10 @@ func (d *deleter[T]) Where(query string, args ...any) *deleter[T] {
} }
// WithAudit enables audit logging for this delete operation // WithAudit enables audit logging for this delete operation
// The callback will be invoked after successful deletion with auto-generated audit info // If the provided *AuditInfo is nil, will use reflection to automatically work out the details
// If the callback returns an error, the transaction will be rolled back func (d *deleter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *deleter[T] {
func (d *deleter[T]) WithAudit(r *http.Request, callback AuditCallback) *deleter[T] { d.audit = meta
d.auditRequest = r d.auditInfo = info
d.auditCallback = callback
return d return d
} }
@@ -57,21 +55,23 @@ func (d *deleter[T]) Delete(ctx context.Context) error {
} }
// Handle audit logging if enabled // Handle audit logging if enabled
if d.auditCallback != nil && d.auditRequest != nil { if d.audit != nil {
tableName := extractTableName[T]() if d.auditInfo == nil {
resourceType := extractResourceType(tableName) tableName := extractTableName[T]()
action := buildAction(resourceType, "delete") resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "delete")
info := &AuditInfo{ d.auditInfo = &AuditInfo{
Action: action, Action: action,
ResourceType: resourceType, ResourceType: resourceType,
ResourceID: d.resourceID, ResourceID: d.resourceID,
Details: nil, // Delete doesn't need details Details: nil, // Delete doesn't need details
}
} }
// Call audit callback - if it fails, return error to trigger rollback err = LogSuccess(ctx, d.tx, d.audit, d.auditInfo)
if err := d.auditCallback(ctx, d.tx, info, d.auditRequest); err != nil { if err != nil {
return errors.Wrap(err, "audit.callback") return errors.Wrap(err, "LogSuccess")
} }
} }
@@ -82,7 +82,7 @@ func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] {
return DeleteItem[T](tx).Where("id = ?", id) return DeleteItem[T](tx).Where("id = ?", id)
} }
func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) error { func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error {
deleter := DeleteByID[T](tx, id) deleter := DeleteByID[T](tx, id)
item, err := GetByID[T](tx, id).Get(ctx) item, err := GetByID[T](tx, id).Get(ctx)
if err != nil { if err != nil {
@@ -94,5 +94,8 @@ func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int)
if (*item).isSystem() { if (*item).isSystem() {
return errors.New("record is system protected") return errors.New("record is system protected")
} }
if audit != nil {
deleter = deleter.WithAudit(audit, nil)
}
return deleter.Delete(ctx) return deleter.Delete(ctx)
} }

View File

@@ -3,7 +3,6 @@ package db
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -11,13 +10,13 @@ import (
) )
type inserter[T any] struct { type inserter[T any] struct {
tx bun.Tx tx bun.Tx
q *bun.InsertQuery q *bun.InsertQuery
model *T model *T
models []*T models []*T
isBulk bool isBulk bool
auditCallback AuditCallback audit *AuditMeta
auditRequest *http.Request auditInfo *AuditInfo
} }
// Insert creates an inserter for a single model // Insert creates an inserter for a single model
@@ -76,11 +75,10 @@ func (i *inserter[T]) Returning(columns ...string) *inserter[T] {
} }
// WithAudit enables audit logging for this insert operation // WithAudit enables audit logging for this insert operation
// The callback will be invoked after successful insert with auto-generated audit info // If the provided *AuditInfo is nil, will use reflection to automatically work out the details
// If the callback returns an error, the transaction will be rolled back func (i *inserter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *inserter[T] {
func (i *inserter[T]) WithAudit(r *http.Request, callback AuditCallback) *inserter[T] { i.audit = meta
i.auditRequest = r i.auditInfo = info
i.auditCallback = callback
return i return i
} }
@@ -94,35 +92,29 @@ func (i *inserter[T]) Exec(ctx context.Context) error {
} }
// Handle audit logging if enabled // Handle audit logging if enabled
if i.auditCallback != nil && i.auditRequest != nil { if i.audit != nil {
tableName := extractTableName[T]() if i.auditInfo == nil {
resourceType := extractResourceType(tableName) tableName := extractTableName[T]()
action := buildAction(resourceType, "create") resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "create")
var info *AuditInfo i.auditInfo = &AuditInfo{
if i.isBulk {
// For bulk inserts, log once with count in details
info = &AuditInfo{
Action: action, Action: action,
ResourceType: resourceType, ResourceType: resourceType,
ResourceID: nil, ResourceID: nil,
Details: map[string]any{
"count": len(i.models),
},
}
} else {
// For single insert, log with resource ID
info = &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: extractPrimaryKey(i.model),
Details: nil, Details: nil,
} }
if i.isBulk {
i.auditInfo.Details = map[string]any{
"count": len(i.models),
}
} else {
i.auditInfo.ResourceID = extractPrimaryKey(i.model)
}
} }
// Call audit callback - if it fails, return error to trigger rollback err = LogSuccess(ctx, i.tx, i.audit, i.auditInfo)
if err := i.auditCallback(ctx, i.tx, info, i.auditRequest); err != nil { if err != nil {
return errors.Wrap(err, "audit.callback") return errors.Wrap(err, "LogSuccess")
} }
} }

View File

@@ -19,13 +19,6 @@ type League struct {
Teams []Team `bun:"m2m:team_participations,join:League=Team"` Teams []Team `bun:"m2m:team_participations,join:League=Team"`
} }
type SeasonLeague struct {
SeasonID int `bun:",pk"`
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
LeagueID int `bun:",pk"`
League *League `bun:"rel:belongs-to,join:league_id=id"`
}
func GetLeagues(ctx context.Context, tx bun.Tx) ([]*League, error) { func GetLeagues(ctx context.Context, tx bun.Tx) ([]*League, error) {
return GetList[League](tx).Relation("Seasons").GetAll(ctx) return GetList[League](tx).Relation("Seasons").GetAll(ctx)
} }
@@ -37,41 +30,16 @@ func GetLeague(ctx context.Context, tx bun.Tx, shortname string) (*League, error
return GetByField[League](tx, "short_name", shortname).Relation("Seasons").Get(ctx) return GetByField[League](tx, "short_name", shortname).Relation("Seasons").Get(ctx)
} }
// GetSeasonLeague retrieves a specific season-league combination with teams func NewLeague(ctx context.Context, tx bun.Tx, name, shortname, description string, audit *AuditMeta) (*League, error) {
func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string) (*Season, *League, []*Team, error) { league := &League{
if seasonShortName == "" { Name: name,
return nil, nil, nil, errors.New("season short_name cannot be empty") ShortName: shortname,
Description: description,
} }
if leagueShortName == "" { err := Insert(tx, league).
return nil, nil, nil, errors.New("league short_name cannot be empty") WithAudit(audit, nil).Exec(ctx)
}
// Get the season
season, err := GetSeason(ctx, tx, seasonShortName)
if err != nil { if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetSeason") return nil, errors.Wrap(err, "db.Insert")
} }
return league, nil
// Get the league
league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetLeague")
}
if season == nil || league == nil || !season.HasLeague(league.ID) {
return nil, nil, nil, nil
}
// Get all teams participating in this season+league
var teams []*Team
err = tx.NewSelect().
Model(&teams).
Join("INNER JOIN team_participations AS tp ON tp.team_id = t.id").
Where("tp.season_id = ? AND tp.league_id = ?", season.ID, league.ID).
Order("t.name ASC").
Scan(ctx)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "tx.Select teams")
}
return season, league, teams, nil
} }

View File

@@ -1,4 +1,5 @@
package main // Package migrate provides functions for managing database migrations
package migrate
import ( import (
"bufio" "bufio"
@@ -11,20 +12,19 @@ import (
"text/tabwriter" "text/tabwriter"
"time" "time"
"git.haelnorr.com/h/oslstats/cmd/oslstats/migrations"
"git.haelnorr.com/h/oslstats/internal/backup"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/db/migrations"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate" "github.com/uptrace/bun/migrate"
) )
// runMigrations executes database migrations // RunMigrations executes database migrations
func runMigrations(ctx context.Context, cfg *config.Config, command string, countStr string) error { func RunMigrations(ctx context.Context, cfg *config.Config, command string, countStr string) error {
conn, close := setupBun(cfg) conn := db.NewDB(cfg.DB)
defer func() { _ = close() }() defer func() { _ = conn.Close() }()
migrator := migrate.NewMigrator(conn, migrations.Migrations) migrator := migrate.NewMigrator(conn.DB, migrations.Migrations)
// Initialize migration tables // Initialize migration tables
if err := migrator.Init(ctx); err != nil { if err := migrator.Init(ctx); err != nil {
@@ -47,15 +47,13 @@ func runMigrations(ctx context.Context, cfg *config.Config, command string, coun
return migrateRollback(ctx, migrator, conn, cfg, countStr) return migrateRollback(ctx, migrator, conn, cfg, countStr)
case "status": case "status":
return migrateStatus(ctx, migrator) return migrateStatus(ctx, migrator)
case "dry-run":
return migrateDryRun(ctx, migrator)
default: default:
return fmt.Errorf("unknown migration command: %s", command) return fmt.Errorf("unknown migration command: %s", command)
} }
} }
// migrateUp runs pending migrations // migrateUp runs pending migrations
func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cfg *config.Config, countStr string) error { func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error {
// Parse count parameter // Parse count parameter
count, all, err := parseMigrationCount(countStr) count, all, err := parseMigrationCount(countStr)
if err != nil { if err != nil {
@@ -101,13 +99,13 @@ func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cf
// Create backup unless --no-backup flag is set // Create backup unless --no-backup flag is set
if !cfg.Flags.MigrateNoBackup { if !cfg.Flags.MigrateNoBackup {
fmt.Println("[INFO] Step 3/5: Creating backup...") fmt.Println("[INFO] Step 3/5: Creating backup...")
_, err := backup.CreateBackup(ctx, cfg, "migration") _, err := db.CreateBackup(ctx, cfg.DB, "migration")
if err != nil { if err != nil {
return errors.Wrap(err, "create backup") return errors.Wrap(err, "create backup")
} }
// Clean old backups // Clean old backups
if err := backup.CleanOldBackups(cfg, cfg.DB.BackupRetention); err != nil { if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil {
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err) fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
} }
} else { } else {
@@ -143,7 +141,7 @@ func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cf
} }
// migrateRollback rolls back migrations // migrateRollback rolls back migrations
func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cfg *config.Config, countStr string) error { func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error {
// Parse count parameter // Parse count parameter
count, all, err := parseMigrationCount(countStr) count, all, err := parseMigrationCount(countStr)
if err != nil { if err != nil {
@@ -182,13 +180,13 @@ func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *bun.
// Create backup unless --no-backup flag is set // Create backup unless --no-backup flag is set
if !cfg.Flags.MigrateNoBackup { if !cfg.Flags.MigrateNoBackup {
fmt.Println("[INFO] Creating backup before rollback...") fmt.Println("[INFO] Creating backup before rollback...")
_, err := backup.CreateBackup(ctx, cfg, "rollback") _, err := db.CreateBackup(ctx, cfg.DB, "rollback")
if err != nil { if err != nil {
return errors.Wrap(err, "create backup") return errors.Wrap(err, "create backup")
} }
// Clean old backups // Clean old backups
if err := backup.CleanOldBackups(cfg, cfg.DB.BackupRetention); err != nil { if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil {
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err) fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
} }
} else { } else {
@@ -255,27 +253,6 @@ func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error {
return nil return nil
} }
// migrateDryRun shows what migrations would run without applying them
func migrateDryRun(ctx context.Context, migrator *migrate.Migrator) error {
group, err := migrator.Migrate(ctx, migrate.WithNopMigration())
if err != nil {
return errors.Wrap(err, "dry-run")
}
if group.IsZero() {
fmt.Println("[INFO] No pending migrations")
return nil
}
fmt.Println("[INFO] Pending migrations (dry-run):")
for _, migration := range group.Migrations {
fmt.Printf(" 📋 %s\n", migration.Name)
}
fmt.Printf("[INFO] Would migrate to group %d\n", group.ID)
return nil
}
// validateMigrations ensures migrations compile before running // validateMigrations ensures migrations compile before running
func validateMigrations(ctx context.Context) error { func validateMigrations(ctx context.Context) error {
cmd := exec.CommandContext(ctx, "go", "build", cmd := exec.CommandContext(ctx, "go", "build",
@@ -292,7 +269,7 @@ func validateMigrations(ctx context.Context) error {
} }
// acquireMigrationLock prevents concurrent migrations using PostgreSQL advisory lock // acquireMigrationLock prevents concurrent migrations using PostgreSQL advisory lock
func acquireMigrationLock(ctx context.Context, conn *bun.DB) error { func acquireMigrationLock(ctx context.Context, conn *db.DB) error {
const lockID = 1234567890 // Arbitrary unique ID for migration lock const lockID = 1234567890 // Arbitrary unique ID for migration lock
const timeoutSeconds = 300 // 5 minutes const timeoutSeconds = 300 // 5 minutes
@@ -318,7 +295,7 @@ func acquireMigrationLock(ctx context.Context, conn *bun.DB) error {
} }
// releaseMigrationLock releases the migration lock // releaseMigrationLock releases the migration lock
func releaseMigrationLock(ctx context.Context, conn *bun.DB) { func releaseMigrationLock(ctx context.Context, conn *db.DB) {
const lockID = 1234567890 const lockID = 1234567890
_, err := conn.NewRaw("SELECT pg_advisory_unlock(?)", lockID).Exec(ctx) _, err := conn.NewRaw("SELECT pg_advisory_unlock(?)", lockID).Exec(ctx)
@@ -329,8 +306,8 @@ func releaseMigrationLock(ctx context.Context, conn *bun.DB) {
} }
} }
// createMigration generates a new migration file // CreateMigration generates a new migration file
func createMigration(name string) error { func CreateMigration(name string) error {
if name == "" { if name == "" {
return errors.New("migration name cannot be empty") return errors.New("migration name cannot be empty")
} }
@@ -340,7 +317,7 @@ func createMigration(name string) error {
// Generate timestamp // Generate timestamp
timestamp := time.Now().Format("20060102150405") timestamp := time.Now().Format("20060102150405")
filename := fmt.Sprintf("cmd/oslstats/migrations/%s_%s.go", timestamp, name) filename := fmt.Sprintf("internal/db/migrations/%s_%s.go", timestamp, name)
// Template // Template
template := `package migrations template := `package migrations
@@ -502,8 +479,8 @@ func executeDownMigrations(ctx context.Context, migrator *migrate.Migrator, migr
return rolledBack, nil return rolledBack, nil
} }
// resetDatabase drops and recreates all tables (destructive) // ResetDatabase drops and recreates all tables (destructive)
func resetDatabase(ctx context.Context, cfg *config.Config) error { func ResetDatabase(ctx context.Context, cfg *config.Config) error {
fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!") fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!")
fmt.Print("Type 'yes' to continue: ") fmt.Print("Type 'yes' to continue: ")
@@ -518,10 +495,10 @@ func resetDatabase(ctx context.Context, cfg *config.Config) error {
fmt.Println("❌ Reset cancelled") fmt.Println("❌ Reset cancelled")
return nil return nil
} }
conn, close := setupBun(cfg) conn := db.NewDB(cfg.DB)
defer func() { _ = close() }() defer func() { _ = conn.Close() }()
models := registerDBModels(conn) models := conn.RegisterModels()
for _, model := range models { for _, model := range models {
if err := conn.ResetModel(ctx, model); err != nil { if err := conn.ResetModel(ctx, model); err != nil {

View File

@@ -10,9 +10,9 @@ import (
func init() { func init() {
Migrations.MustRegister( Migrations.MustRegister(
// UP: Create initial tables (users, discord_tokens) // UP: Create initial tables (users, discord_tokens)
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, conn *bun.DB) error {
// Create users table // Create users table
_, err := dbConn.NewCreateTable(). _, err := conn.NewCreateTable().
Model((*db.User)(nil)). Model((*db.User)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -20,15 +20,15 @@ func init() {
} }
// Create discord_tokens table // Create discord_tokens table
_, err = dbConn.NewCreateTable(). _, err = conn.NewCreateTable().
Model((*db.DiscordToken)(nil)). Model((*db.DiscordToken)(nil)).
Exec(ctx) Exec(ctx)
return err return err
}, },
// DOWN: Drop tables in reverse order // DOWN: Drop tables in reverse order
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, conn *bun.DB) error {
// Drop discord_tokens first (has foreign key to users) // Drop discord_tokens first (has foreign key to users)
_, err := dbConn.NewDropTable(). _, err := conn.NewDropTable().
Model((*db.DiscordToken)(nil)). Model((*db.DiscordToken)(nil)).
IfExists(). IfExists().
Exec(ctx) Exec(ctx)
@@ -37,7 +37,7 @@ func init() {
} }
// Drop users table // Drop users table
_, err = dbConn.NewDropTable(). _, err = conn.NewDropTable().
Model((*db.User)(nil)). Model((*db.User)(nil)).
IfExists(). IfExists().
Exec(ctx) Exec(ctx)

View File

@@ -10,8 +10,8 @@ import (
func init() { func init() {
Migrations.MustRegister( Migrations.MustRegister(
// UP migration // UP migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, conn *bun.DB) error {
_, err := dbConn.NewCreateTable(). _, err := conn.NewCreateTable().
Model((*db.Season)(nil)). Model((*db.Season)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -20,8 +20,8 @@ func init() {
return nil return nil
}, },
// DOWN migration // DOWN migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, conn *bun.DB) error {
_, err := dbConn.NewDropTable(). _, err := conn.NewDropTable().
Model((*db.Season)(nil)). Model((*db.Season)(nil)).
IfExists(). IfExists().
Exec(ctx) Exec(ctx)

View File

@@ -12,10 +12,10 @@ import (
func init() { func init() {
Migrations.MustRegister( Migrations.MustRegister(
// UP migration // UP migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, conn *bun.DB) error {
dbConn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil)) conn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil))
// Create permissions table // Create permissions table
_, err := dbConn.NewCreateTable(). _, err := conn.NewCreateTable().
Model((*db.Role)(nil)). Model((*db.Role)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -23,7 +23,7 @@ func init() {
} }
// Create permissions table // Create permissions table
_, err = dbConn.NewCreateTable(). _, err = conn.NewCreateTable().
Model((*db.Permission)(nil)). Model((*db.Permission)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -31,7 +31,7 @@ func init() {
} }
// Create indexes for permissions // Create indexes for permissions
_, err = dbConn.NewCreateIndex(). _, err = conn.NewCreateIndex().
Model((*db.Permission)(nil)). Model((*db.Permission)(nil)).
Index("idx_permissions_resource"). Index("idx_permissions_resource").
Column("resource"). Column("resource").
@@ -40,7 +40,7 @@ func init() {
return err return err
} }
_, err = dbConn.NewCreateIndex(). _, err = conn.NewCreateIndex().
Model((*db.Permission)(nil)). Model((*db.Permission)(nil)).
Index("idx_permissions_action"). Index("idx_permissions_action").
Column("action"). Column("action").
@@ -49,21 +49,21 @@ func init() {
return err return err
} }
_, err = dbConn.NewCreateTable(). _, err = conn.NewCreateTable().
Model((*db.RolePermission)(nil)). Model((*db.RolePermission)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return err return err
} }
_, err = dbConn.ExecContext(ctx, ` _, err = conn.ExecContext(ctx, `
CREATE INDEX idx_role_permissions_role ON role_permissions(role_id) CREATE INDEX idx_role_permissions_role ON role_permissions(role_id)
`) `)
if err != nil { if err != nil {
return err return err
} }
_, err = dbConn.ExecContext(ctx, ` _, err = conn.ExecContext(ctx, `
CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id) CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id)
`) `)
if err != nil { if err != nil {
@@ -71,7 +71,7 @@ func init() {
} }
// Create user_roles table // Create user_roles table
_, err = dbConn.NewCreateTable(). _, err = conn.NewCreateTable().
Model((*db.UserRole)(nil)). Model((*db.UserRole)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -79,7 +79,7 @@ func init() {
} }
// Create indexes for user_roles // Create indexes for user_roles
_, err = dbConn.NewCreateIndex(). _, err = conn.NewCreateIndex().
Model((*db.UserRole)(nil)). Model((*db.UserRole)(nil)).
Index("idx_user_roles_user"). Index("idx_user_roles_user").
Column("user_id"). Column("user_id").
@@ -88,7 +88,7 @@ func init() {
return err return err
} }
_, err = dbConn.NewCreateIndex(). _, err = conn.NewCreateIndex().
Model((*db.UserRole)(nil)). Model((*db.UserRole)(nil)).
Index("idx_user_roles_role"). Index("idx_user_roles_role").
Column("role_id"). Column("role_id").
@@ -98,7 +98,7 @@ func init() {
} }
// Create audit_log table // Create audit_log table
_, err = dbConn.NewCreateTable(). _, err = conn.NewCreateTable().
Model((*db.AuditLog)(nil)). Model((*db.AuditLog)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -106,7 +106,7 @@ func init() {
} }
// Create indexes for audit_log // Create indexes for audit_log
_, err = dbConn.NewCreateIndex(). _, err = conn.NewCreateIndex().
Model((*db.AuditLog)(nil)). Model((*db.AuditLog)(nil)).
Index("idx_audit_log_user"). Index("idx_audit_log_user").
Column("user_id"). Column("user_id").
@@ -115,7 +115,7 @@ func init() {
return err return err
} }
_, err = dbConn.NewCreateIndex(). _, err = conn.NewCreateIndex().
Model((*db.AuditLog)(nil)). Model((*db.AuditLog)(nil)).
Index("idx_audit_log_action"). Index("idx_audit_log_action").
Column("action"). Column("action").
@@ -124,7 +124,7 @@ func init() {
return err return err
} }
_, err = dbConn.NewCreateIndex(). _, err = conn.NewCreateIndex().
Model((*db.AuditLog)(nil)). Model((*db.AuditLog)(nil)).
Index("idx_audit_log_resource"). Index("idx_audit_log_resource").
Column("resource_type", "resource_id"). Column("resource_type", "resource_id").
@@ -133,7 +133,7 @@ func init() {
return err return err
} }
_, err = dbConn.NewCreateIndex(). _, err = conn.NewCreateIndex().
Model((*db.AuditLog)(nil)). Model((*db.AuditLog)(nil)).
Index("idx_audit_log_created"). Index("idx_audit_log_created").
Column("created_at"). Column("created_at").
@@ -142,7 +142,7 @@ func init() {
return err return err
} }
err = seedSystemRBAC(ctx, dbConn) err = seedSystemRBAC(ctx, conn)
if err != nil { if err != nil {
return err return err
} }
@@ -173,7 +173,7 @@ func init() {
) )
} }
func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error { func seedSystemRBAC(ctx context.Context, conn *bun.DB) error {
// Seed system roles // Seed system roles
now := time.Now().Unix() now := time.Now().Unix()
@@ -185,7 +185,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
CreatedAt: now, CreatedAt: now,
} }
_, err := dbConn.NewInsert(). _, err := conn.NewInsert().
Model(adminRole). Model(adminRole).
Returning("id"). Returning("id").
Exec(ctx) Exec(ctx)
@@ -201,7 +201,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
CreatedAt: now, CreatedAt: now,
} }
_, err = dbConn.NewInsert(). _, err = conn.NewInsert().
Model(userRole). Model(userRole).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -219,7 +219,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
{Name: "users.manage_roles", DisplayName: "Manage User Roles", Description: "Assign and revoke user roles", Resource: "users", Action: "manage_roles", IsSystem: true, CreatedAt: now}, {Name: "users.manage_roles", DisplayName: "Manage User Roles", Description: "Assign and revoke user roles", Resource: "users", Action: "manage_roles", IsSystem: true, CreatedAt: now},
} }
_, err = dbConn.NewInsert(). _, err = conn.NewInsert().
Model(&permissionsData). Model(&permissionsData).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -229,7 +229,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
// Grant wildcard permission to admin role using Bun // Grant wildcard permission to admin role using Bun
// First, get the IDs // First, get the IDs
var wildcardPerm db.Permission var wildcardPerm db.Permission
err = dbConn.NewSelect(). err = conn.NewSelect().
Model(&wildcardPerm). Model(&wildcardPerm).
Where("name = ?", "*"). Where("name = ?", "*").
Scan(ctx) Scan(ctx)
@@ -242,7 +242,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
RoleID: adminRole.ID, RoleID: adminRole.ID,
PermissionID: wildcardPerm.ID, PermissionID: wildcardPerm.ID,
} }
_, err = dbConn.NewInsert(). _, err = conn.NewInsert().
Model(adminRolePerms). Model(adminRolePerms).
On("CONFLICT (role_id, permission_id) DO NOTHING"). On("CONFLICT (role_id, permission_id) DO NOTHING").
Exec(ctx) Exec(ctx)

View File

@@ -11,9 +11,9 @@ import (
func init() { func init() {
Migrations.MustRegister( Migrations.MustRegister(
// UP migration // UP migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, conn *bun.DB) error {
// Add slap_version column to seasons table // Add slap_version column to seasons table
_, err := dbConn.NewAddColumn(). _, err := conn.NewAddColumn().
Model((*db.Season)(nil)). Model((*db.Season)(nil)).
ColumnExpr("slap_version VARCHAR NOT NULL DEFAULT 'rebound'"). ColumnExpr("slap_version VARCHAR NOT NULL DEFAULT 'rebound'").
IfNotExists(). IfNotExists().
@@ -23,7 +23,7 @@ func init() {
} }
// Create leagues table // Create leagues table
_, err = dbConn.NewCreateTable(). _, err = conn.NewCreateTable().
Model((*db.League)(nil)). Model((*db.League)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -31,15 +31,15 @@ func init() {
} }
// Create season_leagues join table // Create season_leagues join table
_, err = dbConn.NewCreateTable(). _, err = conn.NewCreateTable().
Model((*db.SeasonLeague)(nil)). Model((*db.SeasonLeague)(nil)).
Exec(ctx) Exec(ctx)
return err return err
}, },
// DOWN migration // DOWN migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, conn *bun.DB) error {
// Drop season_leagues join table first // Drop season_leagues join table first
_, err := dbConn.NewDropTable(). _, err := conn.NewDropTable().
Model((*db.SeasonLeague)(nil)). Model((*db.SeasonLeague)(nil)).
IfExists(). IfExists().
Exec(ctx) Exec(ctx)
@@ -48,7 +48,7 @@ func init() {
} }
// Drop leagues table // Drop leagues table
_, err = dbConn.NewDropTable(). _, err = conn.NewDropTable().
Model((*db.League)(nil)). Model((*db.League)(nil)).
IfExists(). IfExists().
Exec(ctx) Exec(ctx)
@@ -57,7 +57,7 @@ func init() {
} }
// Remove slap_version column from seasons table // Remove slap_version column from seasons table
_, err = dbConn.NewDropColumn(). _, err = conn.NewDropColumn().
Model((*db.Season)(nil)). Model((*db.Season)(nil)).
ColumnExpr("slap_version"). ColumnExpr("slap_version").
Exec(ctx) Exec(ctx)

View File

@@ -10,15 +10,15 @@ import (
func init() { func init() {
Migrations.MustRegister( Migrations.MustRegister(
// UP migration // UP migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, conn *bun.DB) error {
// Add your migration code here // Add your migration code here
_, err := dbConn.NewCreateTable(). _, err := conn.NewCreateTable().
Model((*db.Team)(nil)). Model((*db.Team)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return err return err
} }
_, err = dbConn.NewCreateTable(). _, err = conn.NewCreateTable().
Model((*db.TeamParticipation)(nil)). Model((*db.TeamParticipation)(nil)).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@@ -27,16 +27,16 @@ func init() {
return nil return nil
}, },
// DOWN migration // DOWN migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, conn *bun.DB) error {
// Add your rollback code here // Add your rollback code here
_, err := dbConn.NewDropTable(). _, err := conn.NewDropTable().
Model((*db.TeamParticipation)(nil)). Model((*db.TeamParticipation)(nil)).
IfExists(). IfExists().
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return err return err
} }
_, err = dbConn.NewDropTable(). _, err = conn.NewDropTable().
Model((*db.Team)(nil)). Model((*db.Team)(nil)).
IfExists(). IfExists().
Exec(ctx) Exec(ctx)

View File

@@ -1,8 +1,11 @@
package db package db
import ( import (
"net/http"
"strings" "strings"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/validation"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@@ -19,6 +22,41 @@ type OrderOpts struct {
Label string Label string
} }
func GetPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request) (*PageOpts, bool) {
var getter validation.Getter
switch r.Method {
case "GET":
getter = validation.NewQueryGetter(r)
case "POST":
var ok bool
getter, ok = validation.ParseFormOrError(s, w, r)
if !ok {
return nil, false
}
default:
return nil, false
}
return getPageOpts(s, w, r, getter), true
}
func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *PageOpts {
page := g.Int("page").Optional().Min(1).Value
perPage := g.Int("per_page").Optional().Min(1).Max(100).Value
order := g.String("order").TrimSpace().ToUpper().Optional().AllowedValues([]string{"ASC", "DESC"}).Value
orderBy := g.String("order_by").TrimSpace().Optional().ToLower().Value
valid := g.ValidateAndError(s, w, r)
if !valid {
return nil
}
pageOpts := &PageOpts{
Page: page,
PerPage: perPage,
Order: bun.Order(order),
OrderBy: orderBy,
}
return pageOpts
}
func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.SelectQuery, *PageOpts) { func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.SelectQuery, *PageOpts) {
if p == nil { if p == nil {
p = new(PageOpts) p = new(PageOpts)

View File

@@ -92,5 +92,5 @@ func DeletePermission(ctx context.Context, tx bun.Tx, id int) error {
if id <= 0 { if id <= 0 {
return errors.New("id must be positive") return errors.New("id must be positive")
} }
return DeleteWithProtection[Permission](ctx, tx, id) return DeleteWithProtection[Permission](ctx, tx, id, nil)
} }

View File

@@ -25,13 +25,6 @@ type Role struct {
Permissions []Permission `bun:"m2m:role_permissions,join:Role=Permission"` Permissions []Permission `bun:"m2m:role_permissions,join:Role=Permission"`
} }
type RolePermission struct {
RoleID int `bun:",pk"`
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
PermissionID int `bun:",pk"`
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
}
func (r Role) isSystem() bool { func (r Role) isSystem() bool {
return r.IsSystem return r.IsSystem
} }
@@ -42,17 +35,12 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro
if name == "" { if name == "" {
return nil, errors.New("name cannot be empty") return nil, errors.New("name cannot be empty")
} }
return GetByField[Role](tx, "name", name).Get(ctx) return GetByField[Role](tx, "name", name).Relation("Permissions").Get(ctx)
} }
// GetRoleByID queries the database for a role matching the given ID // GetRoleByID queries the database for a role matching the given ID
// Returns nil, nil if no role is found // Returns nil, nil if no role is found
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) { func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
return GetByID[Role](tx, id).Get(ctx)
}
// GetRoleWithPermissions loads a role and all its permissions
func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
return GetByID[Role](tx, id).Relation("Permissions").Get(ctx) return GetByID[Role](tx, id).Relation("Permissions").Get(ctx)
} }
@@ -73,7 +61,7 @@ func GetRoles(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Role],
} }
// CreateRole creates a new role // CreateRole creates a new role
func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error { func CreateRole(ctx context.Context, tx bun.Tx, role *Role, audit *AuditMeta) error {
if role == nil { if role == nil {
return errors.New("role cannot be nil") return errors.New("role cannot be nil")
} }
@@ -81,6 +69,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
err := Insert(tx, role). err := Insert(tx, role).
Returning("id"). Returning("id").
WithAudit(audit, nil).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "db.Insert") return errors.Wrap(err, "db.Insert")
@@ -90,7 +79,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
} }
// UpdateRole updates an existing role // UpdateRole updates an existing role
func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error { func UpdateRole(ctx context.Context, tx bun.Tx, role *Role, audit *AuditMeta) error {
if role == nil { if role == nil {
return errors.New("role cannot be nil") return errors.New("role cannot be nil")
} }
@@ -100,6 +89,7 @@ func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error {
err := Update(tx, role). err := Update(tx, role).
WherePK(). WherePK().
WithAudit(audit, nil).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "db.Update") return errors.Wrap(err, "db.Update")
@@ -110,7 +100,7 @@ func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error {
// DeleteRole deletes a role (checks IsSystem protection) // DeleteRole deletes a role (checks IsSystem protection)
// Also cleans up join table entries in role_permissions and user_roles // Also cleans up join table entries in role_permissions and user_roles
func DeleteRole(ctx context.Context, tx bun.Tx, id int) error { func DeleteRole(ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error {
if id <= 0 { if id <= 0 {
return errors.New("id must be positive") return errors.New("id must be positive")
} }
@@ -146,47 +136,5 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int) error {
} }
// Finally delete the role // Finally delete the role
return DeleteWithProtection[Role](ctx, tx, id) return DeleteWithProtection[Role](ctx, tx, id, audit)
}
// AddPermissionToRole grants a permission to a role
func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error {
if roleID <= 0 {
return errors.New("roleID must be positive")
}
if permissionID <= 0 {
return errors.New("permissionID must be positive")
}
rolePerm := &RolePermission{
RoleID: roleID,
PermissionID: permissionID,
}
err := Insert(tx, rolePerm).
ConflictNothing("role_id", "permission_id").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
return nil
}
// RemovePermissionFromRole revokes a permission from a role
func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error {
if roleID <= 0 {
return errors.New("roleID must be positive")
}
if permissionID <= 0 {
return errors.New("permissionID must be positive")
}
err := DeleteItem[RolePermission](tx).
Where("role_id = ?", roleID).
Where("permission_id = ?", permissionID).
Delete(ctx)
if err != nil {
return errors.Wrap(err, "DeleteItem")
}
return nil
} }

View File

@@ -0,0 +1,99 @@
package db
import (
"context"
"slices"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type RolePermission struct {
RoleID int `bun:",pk"`
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
PermissionID int `bun:",pk"`
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
}
func (r *Role) UpdatePermissions(ctx context.Context, tx bun.Tx, newPermissionsIDs []int, audit *AuditMeta) error {
addPerms, removePerms, err := detectChangedPermissions(ctx, tx, r, newPermissionsIDs)
if err != nil {
return errors.Wrap(err, "detectChangedPermissions")
}
addedPerms := []string{}
removedPerms := []string{}
for _, perm := range addPerms {
rolePerm := &RolePermission{
RoleID: r.ID,
PermissionID: perm.ID,
}
err := Insert(tx, rolePerm).
ConflictNothing("role_id", "permission_id").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
addedPerms = append(addedPerms, perm.Name.String())
}
for _, perm := range removePerms {
err := DeleteItem[RolePermission](tx).
Where("role_id = ?", r.ID).
Where("permission_id = ?", perm.ID).
Delete(ctx)
if err != nil {
return errors.Wrap(err, "DeleteItem")
}
removedPerms = append(removedPerms, perm.Name.String())
}
// Log the permission changes
if len(addedPerms) > 0 || len(removedPerms) > 0 {
details := map[string]any{
"role_name": string(r.Name),
}
if len(addedPerms) > 0 {
details["added_permissions"] = addedPerms
}
if len(removedPerms) > 0 {
details["removed_permissions"] = removedPerms
}
info := &AuditInfo{
"roles.update_permissions",
"role",
r.ID,
details,
}
err = LogSuccess(ctx, tx, audit, info)
if err != nil {
return errors.Wrap(err, "LogSuccess")
}
}
return nil
}
func detectChangedPermissions(ctx context.Context, tx bun.Tx, role *Role, permissionIDs []int) ([]*Permission, []*Permission, error) {
allPermissions, err := ListAllPermissions(ctx, tx)
if err != nil {
return nil, nil, errors.Wrap(err, "ListAllPermissions")
}
// Build map of current permissions
currentPermIDs := make(map[int]bool)
for _, perm := range role.Permissions {
currentPermIDs[perm.ID] = true
}
var addedPerms []*Permission
var removedPerms []*Permission
// Determine what to add and remove
for _, perm := range allPermissions {
hasNow := currentPermIDs[perm.ID]
shouldHave := slices.Contains(permissionIDs, perm.ID)
if shouldHave && !hasNow {
addedPerms = append(addedPerms, perm)
} else if !shouldHave && hasNow {
removedPerms = append(removedPerms, perm)
}
}
return addedPerms, removedPerms, nil
}

View File

@@ -25,15 +25,22 @@ type Season struct {
Teams []Team `bun:"m2m:team_participations,join:Season=Team"` Teams []Team `bun:"m2m:team_participations,join:Season=Team"`
} }
// NewSeason returns a new season. It does not add it to the database // NewSeason creats a new season
func NewSeason(name, version, shortname string, start time.Time) *Season { func NewSeason(ctx context.Context, tx bun.Tx, name, version, shortname string,
start time.Time, audit *AuditMeta,
) (*Season, error) {
season := &Season{ season := &Season{
Name: name, Name: name,
ShortName: strings.ToUpper(shortname), ShortName: strings.ToUpper(shortname),
StartDate: start.Truncate(time.Hour * 24), StartDate: start.Truncate(time.Hour * 24),
SlapVersion: version, SlapVersion: version,
} }
return season err := Insert(tx, season).
WithAudit(audit, nil).Exec(ctx)
if err != nil {
return nil, errors.WithMessage(err, "db.Insert")
}
return season, nil
} }
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) { func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) {
@@ -54,7 +61,9 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error
} }
// Update updates the season struct. It does not insert to the database // Update updates the season struct. It does not insert to the database
func (s *Season) Update(version string, start, end, finalsStart, finalsEnd time.Time) { func (s *Season) Update(ctx context.Context, tx bun.Tx, version string,
start, end, finalsStart, finalsEnd time.Time, audit *AuditMeta,
) error {
s.SlapVersion = version s.SlapVersion = version
s.StartDate = start.Truncate(time.Hour * 24) s.StartDate = start.Truncate(time.Hour * 24)
if !end.IsZero() { if !end.IsZero() {
@@ -66,6 +75,9 @@ func (s *Season) Update(version string, start, end, finalsStart, finalsEnd time.
if !finalsEnd.IsZero() { if !finalsEnd.IsZero() {
s.FinalsEndDate.Time = finalsEnd.Truncate(time.Hour * 24) s.FinalsEndDate.Time = finalsEnd.Truncate(time.Hour * 24)
} }
return Update(tx, s).WherePK().
Column("slap_version", "start_date", "end_date", "finals_start_date", "finals_end_date").
WithAudit(audit, nil).Exec(ctx)
} }
func (s *Season) MapTeamsToLeagues(ctx context.Context, tx bun.Tx) ([]LeagueWithTeams, error) { func (s *Season) MapTeamsToLeagues(ctx context.Context, tx bun.Tx) ([]LeagueWithTeams, error) {

118
internal/db/seasonleague.go Normal file
View File

@@ -0,0 +1,118 @@
package db
import (
"context"
"git.haelnorr.com/h/oslstats/internal/permissions"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type SeasonLeague struct {
SeasonID int `bun:",pk"`
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
LeagueID int `bun:",pk"`
League *League `bun:"rel:belongs-to,join:league_id=id"`
}
// GetSeasonLeague retrieves a specific season-league combination with teams
func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string) (*Season, *League, []*Team, error) {
if seasonShortName == "" {
return nil, nil, nil, errors.New("season short_name cannot be empty")
}
if leagueShortName == "" {
return nil, nil, nil, errors.New("league short_name cannot be empty")
}
// Get the season
season, err := GetSeason(ctx, tx, seasonShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetSeason")
}
// Get the league
league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetLeague")
}
if season == nil || league == nil || !season.HasLeague(league.ID) {
return nil, nil, nil, nil
}
// Get all teams participating in this season+league
var teams []*Team
err = tx.NewSelect().
Model(&teams).
Join("INNER JOIN team_participations AS tp ON tp.team_id = t.id").
Where("tp.season_id = ? AND tp.league_id = ?", season.ID, league.ID).
Order("t.name ASC").
Scan(ctx)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "tx.Select teams")
}
return season, league, teams, nil
}
func NewSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string, audit *AuditMeta) error {
season, err := GetSeason(ctx, tx, seasonShortName)
if err != nil {
return errors.Wrap(err, "GetSeason")
}
if season == nil {
return errors.New("season not found")
}
league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil {
return errors.Wrap(err, "GetLeague")
}
if league == nil {
return errors.New("league not found")
}
if season.HasLeague(league.ID) {
return errors.New("league already added to season")
}
seasonLeague := &SeasonLeague{
SeasonID: season.ID,
LeagueID: league.ID,
}
info := &AuditInfo{
string(permissions.SeasonsAddLeague),
"season",
season.ID,
map[string]any{"league_id": league.ID},
}
err = Insert(tx, seasonLeague).WithAudit(audit, info).Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
return nil
}
func (s *Season) RemoveLeague(ctx context.Context, tx bun.Tx, leagueShortName string, audit *AuditMeta) error {
league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil {
return errors.Wrap(err, "GetLeague")
}
if league == nil {
return errors.New("league not found")
}
if !s.HasLeague(league.ID) {
return errors.New("league not in season")
}
info := &AuditInfo{
string(permissions.SeasonsRemoveLeague),
"season",
s.ID,
map[string]any{"league_id": league.ID},
}
err = DeleteItem[SeasonLeague](tx).
Where("season_id = ?", s.ID).
Where("league_id = ?", league.ID).
WithAudit(audit, info).
Delete(ctx)
if err != nil {
return errors.Wrap(err, "db.DeleteItem")
}
return nil
}

55
internal/db/setup.go Normal file
View File

@@ -0,0 +1,55 @@
package db
import (
"database/sql"
"fmt"
"time"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)
type DB struct {
*bun.DB
}
func (db *DB) Close() error {
return db.DB.Close()
}
func (db *DB) RegisterModels() []any {
models := []any{
(*RolePermission)(nil),
(*UserRole)(nil),
(*SeasonLeague)(nil),
(*TeamParticipation)(nil),
(*User)(nil),
(*DiscordToken)(nil),
(*Season)(nil),
(*League)(nil),
(*Team)(nil),
(*Role)(nil),
(*Permission)(nil),
(*AuditLog)(nil),
}
db.RegisterModel(models...)
return models
}
func NewDB(cfg *Config) *DB {
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DB, cfg.SSL)
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
sqldb.SetMaxOpenConns(25)
sqldb.SetMaxIdleConns(10)
sqldb.SetConnMaxLifetime(5 * time.Minute)
sqldb.SetConnMaxIdleTime(5 * time.Minute)
db := &DB{
bun.NewDB(sqldb, pgdialect.New()),
}
db.RegisterModels()
return db
}

View File

@@ -19,13 +19,19 @@ type Team struct {
Leagues []League `bun:"m2m:team_participations,join:Team=League"` Leagues []League `bun:"m2m:team_participations,join:Team=League"`
} }
type TeamParticipation struct { func NewTeam(ctx context.Context, tx bun.Tx, name, shortName, altShortName, color string, audit *AuditMeta) (*Team, error) {
SeasonID int `bun:",pk,unique:season_team"` team := &Team{
Season *Season `bun:"rel:belongs-to,join:season_id=id"` Name: name,
LeagueID int `bun:",pk"` ShortName: shortName,
League *League `bun:"rel:belongs-to,join:league_id=id"` AltShortName: altShortName,
TeamID int `bun:",pk,unique:season_team"` Color: color,
Team *Team `bun:"rel:belongs-to,join:team_id=id"` }
err := Insert(tx, team).
WithAudit(audit, nil).Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "db.Insert")
}
return team, nil
} }
func ListTeams(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Team], error) { func ListTeams(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Team], error) {
@@ -45,14 +51,23 @@ func GetTeam(ctx context.Context, tx bun.Tx, id int) (*Team, error) {
return GetByID[Team](tx, id).Relation("Seasons").Relation("Leagues").Get(ctx) return GetByID[Team](tx, id).Relation("Seasons").Relation("Leagues").Get(ctx)
} }
func TeamShortNamesUnique(ctx context.Context, tx bun.Tx, shortname, altshortname string) (bool, error) { func TeamShortNamesUnique(ctx context.Context, tx bun.Tx, shortName, altShortName string) (bool, error) {
// Check if this combination of short_name and alt_short_name exists // Check if this combination of short_name and alt_short_name exists
count, err := tx.NewSelect(). count, err := tx.NewSelect().
Model((*Team)(nil)). Model((*Team)(nil)).
Where("short_name = ? AND alt_short_name = ?", shortname, altshortname). Where("short_name = ? AND alt_short_name = ?", shortName, altShortName).
Count(ctx) Count(ctx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "tx.Select") return false, errors.Wrap(err, "tx.Select")
} }
return count == 0, nil return count == 0, nil
} }
func (t *Team) InSeason(seasonID int) bool {
for _, season := range t.Seasons {
if season.ID == seasonID {
return true
}
}
return false
}

View File

@@ -0,0 +1,67 @@
package db
import (
"context"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type TeamParticipation struct {
SeasonID int `bun:",pk,unique:season_team"`
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
LeagueID int `bun:",pk"`
League *League `bun:"rel:belongs-to,join:league_id=id"`
TeamID int `bun:",pk,unique:season_team"`
Team *Team `bun:"rel:belongs-to,join:team_id=id"`
}
func NewTeamParticipation(ctx context.Context, tx bun.Tx,
seasonShortName, leagueShortName string, teamID int, audit *AuditMeta,
) (*Team, *Season, *League, error) {
season, err := GetSeason(ctx, tx, seasonShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetSeason")
}
if season == nil {
return nil, nil, nil, errors.New("season not found")
}
league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetLeague")
}
if league == nil {
return nil, nil, nil, errors.New("league not found")
}
if !season.HasLeague(league.ID) {
return nil, nil, nil, errors.New("league is not assigned to the season")
}
team, err := GetTeam(ctx, tx, teamID)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetTeam")
}
if team == nil {
return nil, nil, nil, errors.New("team not found")
}
if team.InSeason(season.ID) {
return nil, nil, nil, errors.New("team already in season")
}
participation := &TeamParticipation{
SeasonID: season.ID,
LeagueID: league.ID,
TeamID: team.ID,
}
info := &AuditInfo{
"teams.join_season",
"team",
teamID,
map[string]any{"season_id": season.ID, "league_id": league.ID},
}
err = Insert(tx, participation).
WithAudit(audit, info).Exec(ctx)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "db.Insert")
}
return team, season, league, nil
}

View File

@@ -22,16 +22,15 @@ var timeout = 15 * time.Second
// WithReadTx executes a read-only transaction with automatic rollback // WithReadTx executes a read-only transaction with automatic rollback
// Returns true if successful, false if error was thrown to client // Returns true if successful, false if error was thrown to client
func WithReadTx( func (db *DB) WithReadTx(
s *hws.Server, s *hws.Server,
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
conn *bun.DB,
fn TxFunc, fn TxFunc,
) bool { ) bool {
ctx, cancel := context.WithTimeout(r.Context(), timeout) ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel() defer cancel()
ok, err := withTx(ctx, conn, fn, false) ok, err := db.withTx(ctx, fn, false)
if err != nil { if err != nil {
throw.InternalServiceError(s, w, r, "Database error", err) throw.InternalServiceError(s, w, r, "Database error", err)
} }
@@ -41,31 +40,29 @@ func WithReadTx(
// WithTxFailSilently executes a transaction with automatic rollback // WithTxFailSilently executes a transaction with automatic rollback
// Returns true if successful, false if error occured. // Returns true if successful, false if error occured.
// Does not throw any errors to the client. // Does not throw any errors to the client.
func WithTxFailSilently( func (db *DB) WithTxFailSilently(
ctx context.Context, ctx context.Context,
conn *bun.DB,
fn TxFuncSilent, fn TxFuncSilent,
) error { ) error {
fnc := func(ctx context.Context, tx bun.Tx) (bool, error) { fnc := func(ctx context.Context, tx bun.Tx) (bool, error) {
err := fn(ctx, tx) err := fn(ctx, tx)
return err == nil, err return err == nil, err
} }
_, err := withTx(ctx, conn, fnc, true) _, err := db.withTx(ctx, fnc, true)
return err return err
} }
// WithWriteTx executes a write transaction with automatic rollback on error // WithWriteTx executes a write transaction with automatic rollback on error
// Commits only if fn returns nil. Returns true if successful. // Commits only if fn returns nil. Returns true if successful.
func WithWriteTx( func (db *DB) WithWriteTx(
s *hws.Server, s *hws.Server,
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
conn *bun.DB,
fn TxFunc, fn TxFunc,
) bool { ) bool {
ctx, cancel := context.WithTimeout(r.Context(), timeout) ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel() defer cancel()
ok, err := withTx(ctx, conn, fn, true) ok, err := db.withTx(ctx, fn, true)
if err != nil { if err != nil {
throw.InternalServiceError(s, w, r, "Database error", err) throw.InternalServiceError(s, w, r, "Database error", err)
} }
@@ -74,16 +71,15 @@ func WithWriteTx(
// WithNotifyTx executes a transaction with notification-based error handling // WithNotifyTx executes a transaction with notification-based error handling
// Uses notifyInternalServiceError instead of throwInternalServiceError // Uses notifyInternalServiceError instead of throwInternalServiceError
func WithNotifyTx( func (db *DB) WithNotifyTx(
s *hws.Server, s *hws.Server,
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
conn *bun.DB,
fn TxFunc, fn TxFunc,
) bool { ) bool {
ctx, cancel := context.WithTimeout(r.Context(), timeout) ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel() defer cancel()
ok, err := withTx(ctx, conn, fn, true) ok, err := db.withTx(ctx, fn, true)
if err != nil { if err != nil {
notify.InternalServiceError(s, w, r, "Database error", err) notify.InternalServiceError(s, w, r, "Database error", err)
} }
@@ -91,13 +87,12 @@ func WithNotifyTx(
} }
// withTx executes a transaction with automatic rollback on error // withTx executes a transaction with automatic rollback on error
func withTx( func (db *DB) withTx(
ctx context.Context, ctx context.Context,
conn *bun.DB,
fn TxFunc, fn TxFunc,
write bool, write bool,
) (bool, error) { ) (bool, error) {
tx, err := conn.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return false, errors.Wrap(err, "conn.BeginTx") return false, errors.Wrap(err, "conn.BeginTx")
} }

View File

@@ -2,19 +2,18 @@ package db
import ( import (
"context" "context"
"net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type updater[T any] struct { type updater[T any] struct {
tx bun.Tx tx bun.Tx
q *bun.UpdateQuery q *bun.UpdateQuery
model *T model *T
columns []string columns []string
auditCallback AuditCallback audit *AuditMeta
auditRequest *http.Request auditInfo *AuditInfo
} }
// Update creates an updater for a model // Update creates an updater for a model
@@ -69,11 +68,10 @@ func (u *updater[T]) Set(query string, args ...any) *updater[T] {
} }
// WithAudit enables audit logging for this update operation // WithAudit enables audit logging for this update operation
// The callback will be invoked after successful update with auto-generated audit info // If the provided *AuditInfo is nil, will use reflection to automatically work out the details
// If the callback returns an error, the transaction will be rolled back func (u *updater[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *updater[T] {
func (u *updater[T]) WithAudit(r *http.Request, callback AuditCallback) *updater[T] { u.audit = meta
u.auditRequest = r u.auditInfo = info
u.auditCallback = callback
return u return u
} }
@@ -82,7 +80,7 @@ func (u *updater[T]) WithAudit(r *http.Request, callback AuditCallback) *updater
func (u *updater[T]) Exec(ctx context.Context) error { func (u *updater[T]) Exec(ctx context.Context) error {
// Build audit details BEFORE update (captures changed fields) // Build audit details BEFORE update (captures changed fields)
var details map[string]any var details map[string]any
if u.auditCallback != nil && len(u.columns) > 0 { if u.audit != nil && len(u.columns) > 0 {
details = extractChangedFields(u.model, u.columns) details = extractChangedFields(u.model, u.columns)
} }
@@ -93,21 +91,22 @@ func (u *updater[T]) Exec(ctx context.Context) error {
} }
// Handle audit logging if enabled // Handle audit logging if enabled
if u.auditCallback != nil && u.auditRequest != nil { if u.audit != nil {
tableName := extractTableName[T]() if u.auditInfo == nil {
resourceType := extractResourceType(tableName) tableName := extractTableName[T]()
action := buildAction(resourceType, "update") resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "update")
info := &AuditInfo{ u.auditInfo = &AuditInfo{
Action: action, Action: action,
ResourceType: resourceType, ResourceType: resourceType,
ResourceID: extractPrimaryKey(u.model), ResourceID: extractPrimaryKey(u.model),
Details: details, // Changed fields only Details: details, // Changed fields only
}
} }
err = LogSuccess(ctx, u.tx, u.audit, u.auditInfo)
// Call audit callback - if it fails, return error to trigger rollback if err != nil {
if err := u.auditCallback(ctx, u.tx, info, u.auditRequest); err != nil { return errors.Wrap(err, "LogSuccess")
return errors.Wrap(err, "audit.callback")
} }
} }

View File

@@ -12,8 +12,6 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
var CurrentUser hwsauth.ContextLoader[*User]
type User struct { type User struct {
bun.BaseModel `bun:"table:users,alias:u"` bun.BaseModel `bun:"table:users,alias:u"`
@@ -29,8 +27,10 @@ func (u *User) GetID() int {
return u.ID return u.ID
} }
var CurrentUser hwsauth.ContextLoader[*User]
// CreateUser creates a new user with the given username and password // CreateUser creates a new user with the given username and password
func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) { func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User, audit *AuditMeta) (*User, error) {
if discorduser == nil { if discorduser == nil {
return nil, errors.New("user cannot be nil") return nil, errors.New("user cannot be nil")
} }
@@ -39,8 +39,10 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
DiscordID: discorduser.ID, DiscordID: discorduser.ID,
} }
audit.u = user
err := Insert(tx, user). err := Insert(tx, user).
WithAudit(audit, nil).
Returning("id"). Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {

View File

@@ -3,6 +3,7 @@ package db
import ( import (
"context" "context"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles" "git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -16,7 +17,7 @@ type UserRole struct {
} }
// AssignRole grants a role to a user // AssignRole grants a role to a user
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, audit *AuditMeta) error {
if userID <= 0 { if userID <= 0 {
return errors.New("userID must be positive") return errors.New("userID must be positive")
} }
@@ -28,8 +29,20 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
UserID: userID, UserID: userID,
RoleID: roleID, RoleID: roleID,
} }
details := map[string]any{
"action": "grant",
"role_id": roleID,
}
info := &AuditInfo{
string(permissions.UsersManageRoles),
"user",
userID,
details,
}
err := Insert(tx, userRole). err := Insert(tx, userRole).
ConflictNothing("user_id", "role_id").Exec(ctx) ConflictNothing("user_id", "role_id").
WithAudit(audit, info).
Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "db.Insert") return errors.Wrap(err, "db.Insert")
} }
@@ -38,7 +51,7 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
} }
// RevokeRole removes a role from a user // RevokeRole removes a role from a user
func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int, audit *AuditMeta) error {
if userID <= 0 { if userID <= 0 {
return errors.New("userID must be positive") return errors.New("userID must be positive")
} }
@@ -46,9 +59,20 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
return errors.New("roleID must be positive") return errors.New("roleID must be positive")
} }
details := map[string]any{
"action": "revoke",
"role_id": roleID,
}
info := &AuditInfo{
string(permissions.UsersManageRoles),
"user",
userID,
details,
}
err := DeleteItem[UserRole](tx). err := DeleteItem[UserRole](tx).
Where("user_id = ?", userID). Where("user_id = ?", userID).
Where("role_id = ?", roleID). Where("role_id = ?", roleID).
WithAudit(audit, info).
Delete(ctx) Delete(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "DeleteItem") return errors.Wrap(err, "DeleteItem")

View File

@@ -12,10 +12,10 @@ import (
var embeddedFiles embed.FS var embeddedFiles embed.FS
// GetEmbeddedFS gets the embedded files // GetEmbeddedFS gets the embedded files
func GetEmbeddedFS() (fs.FS, error) { func GetEmbeddedFS() (*fs.FS, error) {
subFS, err := fs.Sub(embeddedFiles, "web") subFS, err := fs.Sub(embeddedFiles, "web")
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fs.Sub") return nil, errors.Wrap(err, "fs.Sub")
} }
return subFS, nil return &subFS, nil
} }

View File

@@ -17,10 +17,10 @@ import (
) )
// AdminAuditLogsPage renders the full admin dashboard page with audit logs section (GET request) // AdminAuditLogsPage renders the full admin dashboard page with audit logs section (GET request)
func AdminAuditLogsPage(s *hws.Server, conn *bun.DB) http.Handler { func AdminAuditLogsPage(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts := pageOptsFromQuery(s, w, r) pageOpts, ok := db.GetPageOpts(s, w, r)
if pageOpts == nil { if !ok {
return return
} }
@@ -29,7 +29,7 @@ func AdminAuditLogsPage(s *hws.Server, conn *bun.DB) http.Handler {
var actions []string var actions []string
var resourceTypes []string var resourceTypes []string
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
// Get filters from query // Get filters from query
@@ -73,10 +73,10 @@ func AdminAuditLogsPage(s *hws.Server, conn *bun.DB) http.Handler {
} }
// AdminAuditLogsList shows the full audit logs list with filters (POST request for HTMX) // AdminAuditLogsList shows the full audit logs list with filters (POST request for HTMX)
func AdminAuditLogsList(s *hws.Server, conn *bun.DB) http.Handler { func AdminAuditLogsList(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts := pageOptsFromForm(s, w, r) pageOpts, ok := db.GetPageOpts(s, w, r)
if pageOpts == nil { if !ok {
return return
} }
@@ -85,7 +85,7 @@ func AdminAuditLogsList(s *hws.Server, conn *bun.DB) http.Handler {
var actions []string var actions []string
var resourceTypes []string var resourceTypes []string
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
// Get filters from form // Get filters from form
@@ -129,16 +129,16 @@ func AdminAuditLogsList(s *hws.Server, conn *bun.DB) http.Handler {
} }
// AdminAuditLogsFilter returns only the results container (table + pagination) for HTMX updates // AdminAuditLogsFilter returns only the results container (table + pagination) for HTMX updates
func AdminAuditLogsFilter(s *hws.Server, conn *bun.DB) http.Handler { func AdminAuditLogsFilter(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts := pageOptsFromForm(s, w, r) pageOpts, ok := db.GetPageOpts(s, w, r)
if pageOpts == nil { if !ok {
return return
} }
var logs *db.List[db.AuditLog] var logs *db.List[db.AuditLog]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
// Get filters from form // Get filters from form
@@ -164,7 +164,7 @@ func AdminAuditLogsFilter(s *hws.Server, conn *bun.DB) http.Handler {
} }
// AdminAuditLogDetail shows details for a single audit log entry // AdminAuditLogDetail shows details for a single audit log entry
func AdminAuditLogDetail(s *hws.Server, conn *bun.DB) http.Handler { func AdminAuditLogDetail(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get ID from path // Get ID from path
idStr := r.PathValue("id") idStr := r.PathValue("id")
@@ -181,7 +181,7 @@ func AdminAuditLogDetail(s *hws.Server, conn *bun.DB) http.Handler {
var log *db.AuditLog var log *db.AuditLog
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
log, err = db.GetAuditLogByID(ctx, tx, id) log, err = db.GetAuditLogByID(ctx, tx, id)
if err != nil { if err != nil {

View File

@@ -12,10 +12,10 @@ import (
) )
// AdminDashboard renders the full admin dashboard page (defaults to users section) // AdminDashboard renders the full admin dashboard page (defaults to users section)
func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler { func AdminDashboard(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var users *db.List[db.User] var users *db.List[db.User]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
users, err = db.GetUsersWithRoles(ctx, tx, nil) users, err = db.GetUsersWithRoles(ctx, tx, nil)
if err != nil { if err != nil {

View File

@@ -1,25 +0,0 @@
package handlers
import (
"net/http"
"git.haelnorr.com/h/golib/hws"
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
"github.com/uptrace/bun"
)
// AdminPermissionsPage renders the full admin dashboard page with permissions section
func AdminPermissionsPage(s *hws.Server, conn *bun.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// TODO: Load permissions from database
renderSafely(adminview.PermissionsPage(), s, r, w)
})
}
// AdminPermissionsList shows all permissions (HTMX content replacement)
func AdminPermissionsList(s *hws.Server, conn *bun.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// TODO: Load permissions from database
renderSafely(adminview.PermissionsList(), s, r, w)
})
}

View File

@@ -6,7 +6,6 @@ import (
"strconv" "strconv"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/roles" "git.haelnorr.com/h/oslstats/internal/roles"
@@ -16,7 +15,7 @@ import (
) )
// AdminPreviewRoleStart starts preview mode for a specific role // AdminPreviewRoleStart starts preview mode for a specific role
func AdminPreviewRoleStart(s *hws.Server, conn *bun.DB, cfg *config.Config) http.Handler { func AdminPreviewRoleStart(s *hws.Server, conn *db.DB, ssl bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get role ID from URL // Get role ID from URL
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
@@ -28,7 +27,7 @@ func AdminPreviewRoleStart(s *hws.Server, conn *bun.DB, cfg *config.Config) http
// Verify role exists and is not admin // Verify role exists and is not admin
var role *db.Role var role *db.Role
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
role, err = db.GetRoleByID(ctx, tx, roleID) role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
@@ -49,7 +48,7 @@ func AdminPreviewRoleStart(s *hws.Server, conn *bun.DB, cfg *config.Config) http
} }
// Set preview role cookie // Set preview role cookie
rbac.SetPreviewRoleCookie(w, roleID, cfg.HWSAuth.SSL) rbac.SetPreviewRoleCookie(w, roleID, ssl)
// Redirect to home page // Redirect to home page
http.Redirect(w, r, "/", http.StatusSeeOther) http.Redirect(w, r, "/", http.StatusSeeOther)

View File

@@ -8,10 +8,8 @@ import (
"time" "time"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/roles" "git.haelnorr.com/h/oslstats/internal/roles"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview" adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -19,20 +17,15 @@ import (
) )
// AdminRoles renders the full admin dashboard page with roles section // AdminRoles renders the full admin dashboard page with roles section
func AdminRoles(s *hws.Server, conn *bun.DB) http.Handler { func AdminRoles(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var pageOpts *db.PageOpts pageOpts, ok := db.GetPageOpts(s, w, r)
if r.Method == "GET" { if !ok {
pageOpts = pageOptsFromQuery(s, w, r)
} else {
pageOpts = pageOptsFromForm(s, w, r)
}
if pageOpts == nil {
return return
} }
var rolesList *db.List[db.Role] var rolesList *db.List[db.Role]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
rolesList, err = db.GetRoles(ctx, tx, pageOpts) rolesList, err = db.GetRoles(ctx, tx, pageOpts)
if err != nil { if err != nil {
@@ -59,7 +52,7 @@ func AdminRoleCreateForm(s *hws.Server) http.Handler {
} }
// AdminRoleCreate creates a new role // AdminRoleCreate creates a new role
func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.Handler { func AdminRoleCreate(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, ok := validation.ParseFormOrNotify(s, w, r) getter, ok := validation.ParseFormOrNotify(s, w, r)
if !ok { if !ok {
@@ -74,14 +67,14 @@ func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
return return
} }
pageOpts := pageOptsFromForm(s, w, r) pageOpts, ok := db.GetPageOpts(s, w, r)
if pageOpts == nil { if !ok {
return return
} }
var rolesList *db.List[db.Role] var rolesList *db.List[db.Role]
var newRole *db.Role var newRole *db.Role
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
newRole = &db.Role{ newRole = &db.Role{
Name: roles.Role(name), Name: roles.Role(name),
DisplayName: displayName, DisplayName: displayName,
@@ -90,9 +83,9 @@ func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
} }
err := db.Insert(tx, newRole).WithAudit(r, audit.Callback()).Exec(ctx) err := db.CreateRole(ctx, tx, newRole, db.NewAudit(r, nil))
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.Insert") return false, errors.Wrap(err, "db.CreateRole")
} }
rolesList, err = db.GetRoles(ctx, tx, pageOpts) rolesList, err = db.GetRoles(ctx, tx, pageOpts)
@@ -110,7 +103,7 @@ func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
} }
// AdminRoleManage shows the role management modal with details and actions // AdminRoleManage shows the role management modal with details and actions
func AdminRoleManage(s *hws.Server, conn *bun.DB) http.Handler { func AdminRoleManage(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr) roleID, err := strconv.Atoi(roleIDStr)
@@ -120,7 +113,7 @@ func AdminRoleManage(s *hws.Server, conn *bun.DB) http.Handler {
} }
var role *db.Role var role *db.Role
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
role, err = db.GetRoleByID(ctx, tx, roleID) role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
@@ -139,7 +132,7 @@ func AdminRoleManage(s *hws.Server, conn *bun.DB) http.Handler {
} }
// AdminRoleDeleteConfirm shows the delete confirmation dialog // AdminRoleDeleteConfirm shows the delete confirmation dialog
func AdminRoleDeleteConfirm(s *hws.Server, conn *bun.DB) http.Handler { func AdminRoleDeleteConfirm(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr) roleID, err := strconv.Atoi(roleIDStr)
@@ -149,7 +142,7 @@ func AdminRoleDeleteConfirm(s *hws.Server, conn *bun.DB) http.Handler {
} }
var role *db.Role var role *db.Role
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
role, err = db.GetRoleByID(ctx, tx, roleID) role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
@@ -168,7 +161,7 @@ func AdminRoleDeleteConfirm(s *hws.Server, conn *bun.DB) http.Handler {
} }
// AdminRoleDelete deletes a role // AdminRoleDelete deletes a role
func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.Handler { func AdminRoleDelete(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr) roleID, err := strconv.Atoi(roleIDStr)
@@ -177,13 +170,13 @@ func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
return return
} }
pageOpts := pageOptsFromForm(s, w, r) pageOpts, ok := db.GetPageOpts(s, w, r)
if pageOpts == nil { if !ok {
return return
} }
var rolesList *db.List[db.Role] var rolesList *db.List[db.Role]
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
// First check if role exists and get its details // First check if role exists and get its details
role, err := db.GetRoleByID(ctx, tx, roleID) role, err := db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
@@ -199,9 +192,9 @@ func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
} }
// Delete the role with audit logging // Delete the role with audit logging
err = db.DeleteByID[db.Role](tx, roleID).WithAudit(r, audit.Callback()).Delete(ctx) err = db.DeleteRole(ctx, tx, roleID, db.NewAudit(r, nil))
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.DeleteByID") return false, errors.Wrap(err, "db.DeleteRole")
} }
// Reload roles // Reload roles
@@ -220,7 +213,7 @@ func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
} }
// AdminRolePermissionsModal shows the permissions management modal for a role // AdminRolePermissionsModal shows the permissions management modal for a role
func AdminRolePermissionsModal(s *hws.Server, conn *bun.DB) http.Handler { func AdminRolePermissionsModal(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr) roleID, err := strconv.Atoi(roleIDStr)
@@ -234,12 +227,12 @@ func AdminRolePermissionsModal(s *hws.Server, conn *bun.DB) http.Handler {
var groupedPerms []adminview.PermissionsByResource var groupedPerms []adminview.PermissionsByResource
var rolePermIDs map[int]bool var rolePermIDs map[int]bool
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
// Load role with permissions // Load role with permissions
var err error var err error
role, err = db.GetRoleWithPermissions(ctx, tx, roleID) role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.GetRoleWithPermissions") return false, errors.Wrap(err, "db.GetRoleByID")
} }
if role == nil { if role == nil {
return false, errors.New("role not found") return false, errors.New("role not found")
@@ -283,7 +276,7 @@ func AdminRolePermissionsModal(s *hws.Server, conn *bun.DB) http.Handler {
} }
// AdminRolePermissionsUpdate updates the permissions for a role // AdminRolePermissionsUpdate updates the permissions for a role
func AdminRolePermissionsUpdate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.Handler { func AdminRolePermissionsUpdate(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr) roleID, err := strconv.Atoi(roleIDStr)
@@ -291,7 +284,6 @@ func AdminRolePermissionsUpdate(s *hws.Server, conn *bun.DB, audit *auditlog.Log
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
} }
user := db.CurrentUser(r.Context())
getter, ok := validation.ParseFormOrNotify(s, w, r) getter, ok := validation.ParseFormOrNotify(s, w, r)
if !ok { if !ok {
@@ -304,80 +296,24 @@ func AdminRolePermissionsUpdate(s *hws.Server, conn *bun.DB, audit *auditlog.Log
return return
} }
selectedPermIDs := make(map[int]bool) pageOpts, ok := db.GetPageOpts(s, w, r)
for _, id := range permissionIDs { if !ok {
selectedPermIDs[id] = true
}
pageOpts := pageOptsFromForm(s, w, r)
if pageOpts == nil {
return return
} }
var rolesList *db.List[db.Role] var rolesList *db.List[db.Role]
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
// Get role with current permissions role, err := db.GetRoleByID(ctx, tx, roleID)
role, err := db.GetRoleWithPermissions(ctx, tx, roleID)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.GetRoleWithPermissions") return false, errors.Wrap(err, "db.GetRoleByID")
} }
if role == nil { if role == nil {
throw.NotFound(s, w, r, "Role not found") w.WriteHeader(http.StatusBadRequest)
return false, nil return false, nil
} }
err = role.UpdatePermissions(ctx, tx, permissionIDs, db.NewAudit(r, nil))
// Get all permissions to know what exists
allPermissions, err := db.ListAllPermissions(ctx, tx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.ListAllPermissions") return false, errors.Wrap(err, "role.UpdatePermissions")
}
// Build map of current permissions
currentPermIDs := make(map[int]bool)
for _, perm := range role.Permissions {
currentPermIDs[perm.ID] = true
}
var addedPerms []string
var removedPerms []string
// Determine what to add and remove
for _, perm := range allPermissions {
hasNow := currentPermIDs[perm.ID]
shouldHave := selectedPermIDs[perm.ID]
if shouldHave && !hasNow {
// Add permission
err := db.AddPermissionToRole(ctx, tx, roleID, perm.ID)
if err != nil {
return false, errors.Wrap(err, "db.AddPermissionToRole")
}
addedPerms = append(addedPerms, string(perm.Name))
} else if !shouldHave && hasNow {
// Remove permission
err := db.RemovePermissionFromRole(ctx, tx, roleID, perm.ID)
if err != nil {
return false, errors.Wrap(err, "db.RemovePermissionFromRole")
}
removedPerms = append(removedPerms, string(perm.Name))
}
}
// Log the permission changes
if len(addedPerms) > 0 || len(removedPerms) > 0 {
details := map[string]any{
"role_name": string(role.Name),
}
if len(addedPerms) > 0 {
details["added_permissions"] = addedPerms
}
if len(removedPerms) > 0 {
details["removed_permissions"] = removedPerms
}
err = audit.LogSuccess(ctx, tx, user, "update", "role_permissions", roleID, details, r)
if err != nil {
return false, errors.Wrap(err, "audit.LogSuccess")
}
} }
// Reload roles // Reload roles

View File

@@ -12,20 +12,15 @@ import (
) )
// AdminUsersPage renders the full admin dashboard page with users section // AdminUsersPage renders the full admin dashboard page with users section
func AdminUsersPage(s *hws.Server, conn *bun.DB) http.Handler { func AdminUsersPage(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var pageOpts *db.PageOpts pageOpts, ok := db.GetPageOpts(s, w, r)
if r.Method == "GET" { if !ok {
pageOpts = pageOptsFromQuery(s, w, r)
} else {
pageOpts = pageOptsFromForm(s, w, r)
}
if pageOpts == nil {
return return
} }
var users *db.List[db.User] var users *db.List[db.User]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
users, err = db.GetUsersWithRoles(ctx, tx, pageOpts) users, err = db.GetUsersWithRoles(ctx, tx, pageOpts)
if err != nil { if err != nil {

View File

@@ -47,7 +47,7 @@ func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error
} }
// Grant admin role // Grant admin role
err = db.AssignRole(ctx, tx, user.ID, adminRole.ID) err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil)
if err != nil { if err != nil {
return errors.Wrap(err, "db.AssignRole") return errors.Wrap(err, "db.AssignRole")
} }

View File

@@ -22,7 +22,7 @@ import (
func Callback( func Callback(
s *hws.Server, s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB, conn *db.DB,
cfg *config.Config, cfg *config.Config,
store *store.Store, store *store.Store,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
@@ -70,7 +70,7 @@ func Callback(
switch data { switch data {
case "login": case "login":
var redirect func() var redirect func()
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI) redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
if err != nil { if err != nil {
throw.InternalServiceError(s, w, r, "OAuth login failed", err) throw.InternalServiceError(s, w, r, "OAuth login failed", err)

View File

@@ -15,7 +15,7 @@ import (
// Returns 200 OK if unique, 409 Conflict if not unique // Returns 200 OK if unique, 409 Conflict if not unique
func IsUnique( func IsUnique(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
model any, model any,
field string, field string,
) http.Handler { ) http.Handler {
@@ -31,7 +31,7 @@ func IsUnique(
return return
} }
unique := false unique := false
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
unique, err = db.IsUnique(ctx, tx, model, field, value) unique, err = db.IsUnique(ctx, tx, model, field, value)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.IsUnique") return false, errors.Wrap(err, "db.IsUnique")

View File

@@ -14,11 +14,11 @@ import (
func LeaguesList( func LeaguesList(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var leagues []*db.League var leagues []*db.League
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
leagues, err = db.GetLeagues(ctx, tx) leagues, err = db.GetLeagues(ctx, tx)
if err != nil { if err != nil {

View File

@@ -9,7 +9,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
@@ -18,20 +17,15 @@ import (
func NewLeague( func NewLeague(
s *hws.Server, s *hws.Server,
conn *bun.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" { renderSafely(leaguesview.NewPage(), s, r, w)
renderSafely(leaguesview.NewPage(), s, r, w)
return
}
}) })
} }
func NewLeagueSubmit( func NewLeagueSubmit(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
audit *auditlog.Logger,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, ok := validation.ParseFormOrNotify(s, w, r) getter, ok := validation.ParseFormOrNotify(s, w, r)
@@ -53,7 +47,7 @@ func NewLeagueSubmit(
nameUnique := false nameUnique := false
shortNameUnique := false shortNameUnique := false
var league *db.League var league *db.League
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
nameUnique, err = db.IsUnique(ctx, tx, (*db.League)(nil), "name", name) nameUnique, err = db.IsUnique(ctx, tx, (*db.League)(nil), "name", name)
if err != nil { if err != nil {
@@ -66,14 +60,9 @@ func NewLeagueSubmit(
if !nameUnique || !shortNameUnique { if !nameUnique || !shortNameUnique {
return true, nil return true, nil
} }
league = &db.League{ league, err = db.NewLeague(ctx, tx, name, shortname, description, db.NewAudit(r, nil))
Name: name,
ShortName: shortname,
Description: description,
}
err = db.Insert(tx, league).WithAudit(r, audit.Callback()).Exec(ctx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.Insert") return false, errors.Wrap(err, "db.NewLeague")
} }
return true, nil return true, nil
}); !ok { }); !ok {

View File

@@ -7,9 +7,9 @@ import (
"git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
@@ -19,7 +19,7 @@ import (
func Login( func Login(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
cfg *config.Config, cfg *config.Config,
st *store.Store, st *store.Store,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,

View File

@@ -16,7 +16,7 @@ import (
func Logout( func Logout(
s *hws.Server, s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB, conn *db.DB,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
@@ -27,7 +27,7 @@ func Logout(
w.Header().Set("HX-Redirect", "/") w.Header().Set("HX-Redirect", "/")
return return
} }
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
token, err := user.DeleteDiscordTokens(ctx, tx) token, err := user.DeleteDiscordTokens(ctx, tx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "user.DeleteDiscordTokens") return false, errors.Wrap(err, "user.DeleteDiscordTokens")

View File

@@ -1,45 +0,0 @@
package handlers
import (
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/validation"
"github.com/uptrace/bun"
)
// pageOptsFromForm calls r.ParseForm and gets the pageOpts from the formdata.
// It renders a Bad Request error page on fail
// PageOpts will be nil on fail
func pageOptsFromForm(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
getter, ok := validation.ParseFormOrError(s, w, r)
if !ok {
return nil
}
return getPageOpts(s, w, r, getter)
}
// pageOptsFromQuery gets the pageOpts from the request query and renders a Bad Request error page on fail
// PageOpts will be nil on fail
func pageOptsFromQuery(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
return getPageOpts(s, w, r, validation.NewQueryGetter(r))
}
func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *db.PageOpts {
page := g.Int("page").Optional().Min(1).Value
perPage := g.Int("per_page").Optional().Min(1).Max(100).Value
order := g.String("order").TrimSpace().ToUpper().Optional().AllowedValues([]string{"ASC", "DESC"}).Value
orderBy := g.String("order_by").TrimSpace().Optional().ToLower().Value
valid := g.ValidateAndError(s, w, r)
if !valid {
return nil
}
pageOpts := &db.PageOpts{
Page: page,
PerPage: perPage,
Order: bun.Order(order),
OrderBy: orderBy,
}
return pageOpts
}

View File

@@ -20,7 +20,7 @@ import (
func Register( func Register(
s *hws.Server, s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB, conn *db.DB,
cfg *config.Config, cfg *config.Config,
store *store.Store, store *store.Store,
) http.Handler { ) http.Handler {
@@ -55,7 +55,7 @@ func Register(
username := r.FormValue("username") username := r.FormValue("username")
unique := false unique := false
var user *db.User var user *db.User
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username) unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.IsUsernameUnique") return false, errors.Wrap(err, "db.IsUsernameUnique")
@@ -63,7 +63,7 @@ func Register(
if !unique { if !unique {
return true, nil return true, nil
} }
user, err = db.CreateUser(ctx, tx, username, details.DiscordUser) user, err = db.CreateUser(ctx, tx, username, details.DiscordUser, db.NewAudit(r, nil))
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.CreateUser") return false, errors.Wrap(err, "db.CreateUser")
} }

View File

@@ -14,14 +14,14 @@ import (
func SeasonPage( func SeasonPage(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seasonStr := r.PathValue("season_short_name") seasonStr := r.PathValue("season_short_name")
var season *db.Season var season *db.Season
var leaguesWithTeams []db.LeagueWithTeams var leaguesWithTeams []db.LeagueWithTeams
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
season, err = db.GetSeason(ctx, tx, seasonStr) season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {

View File

@@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/throw"
@@ -19,13 +18,13 @@ import (
func SeasonEditPage( func SeasonEditPage(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seasonStr := r.PathValue("season_short_name") seasonStr := r.PathValue("season_short_name")
var season *db.Season var season *db.Season
var allLeagues []*db.League var allLeagues []*db.League
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
season, err = db.GetSeason(ctx, tx, seasonStr) season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {
@@ -49,8 +48,7 @@ func SeasonEditPage(
func SeasonEditSubmit( func SeasonEditSubmit(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
audit *auditlog.Logger,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seasonStr := r.PathValue("season_short_name") seasonStr := r.PathValue("season_short_name")
@@ -77,7 +75,7 @@ func SeasonEditSubmit(
} }
var season *db.Season var season *db.Season
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
season, err = db.GetSeason(ctx, tx, seasonStr) season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {
@@ -86,12 +84,9 @@ func SeasonEditSubmit(
if season == nil { if season == nil {
return false, errors.New("season does not exist") return false, errors.New("season does not exist")
} }
season.Update(version, start, end, finalsStart, finalsEnd) err = season.Update(ctx, tx, version, start, end, finalsStart, finalsEnd, db.NewAudit(r, nil))
err = db.Update(tx, season).WherePK().
Column("slap_version", "start_date", "end_date", "finals_start_date", "finals_end_date").
WithAudit(r, audit.Callback()).Exec(ctx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.Update") return false, errors.Wrap(err, "season.Update")
} }
return true, nil return true, nil
}); !ok { }); !ok {

View File

@@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
@@ -16,8 +15,7 @@ import (
func SeasonLeagueAddTeam( func SeasonLeagueAddTeam(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
audit *auditlog.Logger,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seasonStr := r.PathValue("season_short_name") seasonStr := r.PathValue("season_short_name")
@@ -36,73 +34,12 @@ func SeasonLeagueAddTeam(
var league *db.League var league *db.League
var team *db.Team var team *db.Team
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
team, season, league, err = db.NewTeamParticipation(ctx, tx, seasonStr, leagueStr, teamID, db.NewAudit(r, nil))
// Get season
season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.GetSeason") return false, errors.Wrap(err, "db.NewTeamParticipation")
} }
if season == nil {
notify.Warn(s, w, r, "Not Found", "Season not found.", nil)
return false, nil
}
// Get league
league, err = db.GetLeague(ctx, tx, leagueStr)
if err != nil {
return false, errors.Wrap(err, "db.GetLeague")
}
if league == nil {
notify.Warn(s, w, r, "Not Found", "League not found.", nil)
return false, nil
}
if !season.HasLeague(league.ID) {
notify.Warn(s, w, r, "Invalid League", "This league is not associated with this season.", nil)
return false, nil
}
// Get team
team, err = db.GetTeam(ctx, tx, teamID)
if err != nil {
return false, errors.Wrap(err, "db.GetTeam")
}
if team == nil {
notify.Warn(s, w, r, "Not Found", "Team not found.", nil)
return false, nil
}
// Check if team is already in this season (in any league)
var tpCount int
tpCount, err = tx.NewSelect().
Model((*db.TeamParticipation)(nil)).
Where("season_id = ? AND team_id = ?", season.ID, team.ID).
Count(ctx)
if err != nil {
return false, errors.Wrap(err, "tx.NewSelect")
}
if tpCount > 0 {
notify.Warn(s, w, r, "Already In Season", fmt.Sprintf(
"Team '%s' is already participating in this season.",
team.Name,
), nil)
return false, nil
}
// Add team to league
participation := &db.TeamParticipation{
SeasonID: season.ID,
LeagueID: league.ID,
TeamID: team.ID,
}
err = db.Insert(tx, participation).WithAudit(r, audit.Callback()).Exec(ctx)
if err != nil {
return false, errors.Wrap(err, "db.Insert")
}
return true, nil return true, nil
}); !ok { }); !ok {
return return

View File

@@ -14,7 +14,7 @@ import (
func SeasonLeaguePage( func SeasonLeaguePage(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seasonStr := r.PathValue("season_short_name") seasonStr := r.PathValue("season_short_name")
@@ -25,7 +25,7 @@ func SeasonLeaguePage(
var teams []*db.Team var teams []*db.Team
var allTeams []*db.Team var allTeams []*db.Team
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
season, league, teams, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr) season, league, teams, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr)
if err != nil { if err != nil {

View File

@@ -8,7 +8,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/view/seasonsview" "git.haelnorr.com/h/oslstats/internal/view/seasonsview"
@@ -16,8 +15,7 @@ import (
func SeasonAddLeague( func SeasonAddLeague(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
audit *auditlog.Logger,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seasonStr := r.PathValue("season_short_name") seasonStr := r.PathValue("season_short_name")
@@ -25,32 +23,10 @@ func SeasonAddLeague(
var season *db.Season var season *db.Season
var allLeagues []*db.League var allLeagues []*db.League
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error err := db.NewSeasonLeague(ctx, tx, seasonStr, leagueStr, db.NewAudit(r, nil))
season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.GetSeason") return false, errors.Wrap(err, "db.NewSeasonLeague")
}
if season == nil {
return false, errors.New("season not found")
}
league, err := db.GetLeague(ctx, tx, leagueStr)
if err != nil {
return false, errors.Wrap(err, "db.GetLeague")
}
if league == nil {
return false, errors.New("league not found")
}
// Create the many-to-many relationship
seasonLeague := &db.SeasonLeague{
SeasonID: season.ID,
LeagueID: league.ID,
}
err = db.Insert(tx, seasonLeague).WithAudit(r, audit.Callback()).Exec(ctx)
if err != nil {
return false, errors.Wrap(err, "db.Insert")
} }
// Reload season with updated leagues // Reload season with updated leagues
@@ -76,8 +52,7 @@ func SeasonAddLeague(
func SeasonRemoveLeague( func SeasonRemoveLeague(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
audit *auditlog.Logger,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seasonStr := r.PathValue("season_short_name") seasonStr := r.PathValue("season_short_name")
@@ -85,7 +60,7 @@ func SeasonRemoveLeague(
var season *db.Season var season *db.Season
var allLeagues []*db.League var allLeagues []*db.League
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
season, err = db.GetSeason(ctx, tx, seasonStr) season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {
@@ -94,22 +69,9 @@ func SeasonRemoveLeague(
if season == nil { if season == nil {
return false, errors.New("season not found") return false, errors.New("season not found")
} }
err = season.RemoveLeague(ctx, tx, leagueStr, db.NewAudit(r, nil))
league, err := db.GetLeague(ctx, tx, leagueStr)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.GetLeague") return false, errors.Wrap(err, "season.RemoveLeague")
}
if league == nil {
return false, errors.New("league not found")
}
// Delete the many-to-many relationship
err = db.DeleteItem[db.SeasonLeague](tx).
Where("season_id = ? AND league_id = ?", season.ID, league.ID).
WithAudit(r, audit.Callback()).
Delete(ctx)
if err != nil {
return false, errors.Wrap(err, "db.DeleteItem")
} }
// Reload season with updated leagues // Reload season with updated leagues

View File

@@ -11,18 +11,18 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
// SeasonsPage renders the full page with the seasons list, for use with GET requests // SeasonsPage renders the season list. On GET it returns the full page, otherwise it just returns the list
func SeasonsPage( func SeasonsPage(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts := pageOptsFromQuery(s, w, r) pageOpts, ok := db.GetPageOpts(s, w, r)
if pageOpts == nil { if !ok {
return return
} }
var seasons *db.List[db.Season] var seasons *db.List[db.Season]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
seasons, err = db.ListSeasons(ctx, tx, pageOpts) seasons, err = db.ListSeasons(ctx, tx, pageOpts)
if err != nil { if err != nil {
@@ -32,31 +32,10 @@ func SeasonsPage(
}); !ok { }); !ok {
return return
} }
renderSafely(seasonsview.ListPage(seasons), s, r, w) if r.Method == "GET" {
}) renderSafely(seasonsview.ListPage(seasons), s, r, w)
} } else {
renderSafely(seasonsview.SeasonsList(seasons), s, r, w)
// SeasonsList renders just the seasons list, for use with POST requests and HTMX }
func SeasonsList(
s *hws.Server,
conn *bun.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts := pageOptsFromForm(s, w, r)
if pageOpts == nil {
return
}
var seasons *db.List[db.Season]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.ListSeasons")
}
return true, nil
}); !ok {
return
}
renderSafely(seasonsview.SeasonsList(seasons), s, r, w)
}) })
} }

View File

@@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
@@ -18,20 +17,15 @@ import (
func NewSeason( func NewSeason(
s *hws.Server, s *hws.Server,
conn *bun.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" { renderSafely(seasonsview.NewPage(), s, r, w)
renderSafely(seasonsview.NewPage(), s, r, w)
return
}
}) })
} }
func NewSeasonSubmit( func NewSeasonSubmit(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
audit *auditlog.Logger,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, ok := validation.ParseFormOrNotify(s, w, r) getter, ok := validation.ParseFormOrNotify(s, w, r)
@@ -58,7 +52,7 @@ func NewSeasonSubmit(
nameUnique := false nameUnique := false
shortNameUnique := false shortNameUnique := false
var season *db.Season var season *db.Season
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
nameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "name", name) nameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "name", name)
if err != nil { if err != nil {
@@ -71,10 +65,9 @@ func NewSeasonSubmit(
if !nameUnique || !shortNameUnique { if !nameUnique || !shortNameUnique {
return true, nil return true, nil
} }
season = db.NewSeason(name, version, shortname, start) season, err = db.NewSeason(ctx, tx, name, version, shortname, start, db.NewAudit(r, nil))
err = db.Insert(tx, season).WithAudit(r, audit.Callback()).Exec(ctx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.Insert") return false, errors.Wrap(err, "db.NewSeason")
} }
return true, nil return true, nil
}); !ok { }); !ok {

View File

@@ -15,7 +15,7 @@ import (
// and also validates that they are different from each other // and also validates that they are different from each other
func IsTeamShortNamesUnique( func IsTeamShortNamesUnique(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, err := validation.ParseForm(r) getter, err := validation.ParseForm(r)
@@ -38,7 +38,7 @@ func IsTeamShortNamesUnique(
} }
var isUnique bool var isUnique bool
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
isUnique, err = db.TeamShortNamesUnique(ctx, tx, shortName, altShortName) isUnique, err = db.TeamShortNamesUnique(ctx, tx, shortName, altShortName)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.TeamShortNamesUnique") return false, errors.Wrap(err, "db.TeamShortNamesUnique")

View File

@@ -14,15 +14,15 @@ import (
// TeamsPage renders the full page with the teams list, for use with GET requests // TeamsPage renders the full page with the teams list, for use with GET requests
func TeamsPage( func TeamsPage(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts := pageOptsFromQuery(s, w, r) pageOpts, ok := db.GetPageOpts(s, w, r)
if pageOpts == nil { if !ok {
return return
} }
var teams *db.List[db.Team] var teams *db.List[db.Team]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
teams, err = db.ListTeams(ctx, tx, pageOpts) teams, err = db.ListTeams(ctx, tx, pageOpts)
if err != nil { if err != nil {
@@ -39,15 +39,15 @@ func TeamsPage(
// TeamsList renders just the teams list, for use with POST requests and HTMX // TeamsList renders just the teams list, for use with POST requests and HTMX
func TeamsList( func TeamsList(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts := pageOptsFromForm(s, w, r) pageOpts, ok := db.GetPageOpts(s, w, r)
if pageOpts == nil { if !ok {
return return
} }
var teams *db.List[db.Team] var teams *db.List[db.Team]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
teams, err = db.ListTeams(ctx, tx, pageOpts) teams, err = db.ListTeams(ctx, tx, pageOpts)
if err != nil { if err != nil {

View File

@@ -9,7 +9,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
@@ -18,7 +17,6 @@ import (
func NewTeamPage( func NewTeamPage(
s *hws.Server, s *hws.Server,
conn *bun.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
renderSafely(teamsview.NewPage(), s, r, w) renderSafely(teamsview.NewPage(), s, r, w)
@@ -27,8 +25,7 @@ func NewTeamPage(
func NewTeamSubmit( func NewTeamSubmit(
s *hws.Server, s *hws.Server,
conn *bun.DB, conn *db.DB,
audit *auditlog.Logger,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, ok := validation.ParseFormOrNotify(s, w, r) getter, ok := validation.ParseFormOrNotify(s, w, r)
@@ -38,10 +35,10 @@ func NewTeamSubmit(
name := getter.String("name"). name := getter.String("name").
TrimSpace().Required(). TrimSpace().Required().
MaxLength(25).MinLength(3).Value MaxLength(25).MinLength(3).Value
shortname := getter.String("short_name"). shortName := getter.String("short_name").
TrimSpace().Required(). TrimSpace().Required().
MaxLength(3).MinLength(3).Value MaxLength(3).MinLength(3).Value
altShortname := getter.String("alt_short_name"). altShortName := getter.String("alt_short_name").
TrimSpace().Required(). TrimSpace().Required().
MaxLength(3).MinLength(3).Value MaxLength(3).MinLength(3).Value
color := getter.String("color"). color := getter.String("color").
@@ -51,22 +48,21 @@ func NewTeamSubmit(
} }
// Check that short names are different // Check that short names are different
if shortname == altShortname { if shortName == altShortName {
notify.Warn(s, w, r, "Invalid Short Names", "Short name and alternative short name must be different.", nil) notify.Warn(s, w, r, "Invalid Short Names", "Short name and alternative short name must be different.", nil)
return return
} }
nameUnique := false nameUnique := false
shortNameComboUnique := false shortNameComboUnique := false
var team *db.Team if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
nameUnique, err = db.IsUnique(ctx, tx, (*db.Team)(nil), "name", name) nameUnique, err = db.IsUnique(ctx, tx, (*db.Team)(nil), "name", name)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.IsTeamNameUnique") return false, errors.Wrap(err, "db.IsTeamNameUnique")
} }
shortNameComboUnique, err = db.TeamShortNamesUnique(ctx, tx, shortname, altShortname) shortNameComboUnique, err = db.TeamShortNamesUnique(ctx, tx, shortName, altShortName)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.TeamShortNamesUnique") return false, errors.Wrap(err, "db.TeamShortNamesUnique")
} }
@@ -74,15 +70,9 @@ func NewTeamSubmit(
if !nameUnique || !shortNameComboUnique { if !nameUnique || !shortNameComboUnique {
return true, nil return true, nil
} }
team = &db.Team{ _, err = db.NewTeam(ctx, tx, name, shortName, altShortName, color, db.NewAudit(r, nil))
Name: name,
ShortName: shortname,
AltShortName: altShortname,
Color: color,
}
err = db.Insert(tx, team).WithAudit(r, audit.Callback()).Exec(ctx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.Insert") return false, errors.Wrap(err, "db.NewTeam")
} }
return true, nil return true, nil
}); !ok { }); !ok {

View File

@@ -39,12 +39,12 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
var roles_ []*db.Role var roles_ []*db.Role
var perms []*db.Permission var perms []*db.Permission
if err := db.WithTxFailSilently(r.Context(), c.conn, func(ctx context.Context, tx bun.Tx) error { if err := c.conn.WithTxFailSilently(r.Context(), func(ctx context.Context, tx bun.Tx) error {
var err error var err error
if previewRole != nil { if previewRole != nil {
// In preview mode: use the preview role instead of user's roles // In preview mode: use the preview role instead of user's roles
role, err := db.GetRoleWithPermissions(ctx, tx, previewRole.ID) role, err := db.GetRoleByID(ctx, tx, previewRole.ID)
if err != nil { if err != nil {
return errors.Wrap(err, "db.GetRoleWithPermissions") return errors.Wrap(err, "db.GetRoleWithPermissions")
} }

View File

@@ -13,11 +13,11 @@ import (
) )
type Checker struct { type Checker struct {
conn *bun.DB conn *db.DB
s *hws.Server s *hws.Server
} }
func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) { func NewChecker(conn *db.DB, s *hws.Server) (*Checker, error) {
if conn == nil { if conn == nil {
return nil, errors.New("conn cannot be nil") return nil, errors.New("conn cannot be nil")
} }
@@ -56,7 +56,7 @@ func (c *Checker) UserHasPermission(ctx context.Context, user *db.User, permissi
// Not in preview mode: fallback to database for actual user permissions // Not in preview mode: fallback to database for actual user permissions
var has bool var has bool
if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error { if err := c.conn.WithTxFailSilently(ctx, func(ctx context.Context, tx bun.Tx) error {
var err error var err error
has, err = user.HasPermission(ctx, tx, permission) has, err = user.HasPermission(ctx, tx, permission)
if err != nil { if err != nil {
@@ -94,7 +94,7 @@ func (c *Checker) UserHasRole(ctx context.Context, user *db.User, role roles.Rol
// Not in preview mode: fallback to database for actual user roles // Not in preview mode: fallback to database for actual user roles
var has bool var has bool
if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error { if err := c.conn.WithTxFailSilently(ctx, func(ctx context.Context, tx bun.Tx) error {
var err error var err error
has, err = user.HasRole(ctx, tx, role) has, err = user.HasRole(ctx, tx, role)
if err != nil { if err != nil {

View File

@@ -15,7 +15,7 @@ import (
// LoadPreviewRoleMiddleware loads the preview role from the session cookie if present // LoadPreviewRoleMiddleware loads the preview role from the session cookie if present
// and adds it to the request context. This must run after authentication but before // and adds it to the request context. This must run after authentication but before
// the RBAC cache middleware. // the RBAC cache middleware.
func LoadPreviewRoleMiddleware(s *hws.Server, conn *bun.DB) func(http.Handler) http.Handler { func LoadPreviewRoleMiddleware(s *hws.Server, conn *db.DB) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if there's a preview role in the cookie // Check if there's a preview role in the cookie
@@ -26,10 +26,25 @@ func LoadPreviewRoleMiddleware(s *hws.Server, conn *bun.DB) func(http.Handler) h
return return
} }
user := db.CurrentUser(r.Context())
if user == nil {
// User not logged in,
ClearPreviewRoleCookie(w)
next.ServeHTTP(w, r)
return
}
// Load the preview role from the database // Load the preview role from the database
var previewRole *db.Role var previewRole *db.Role
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error isAdmin, err := user.IsAdmin(ctx, tx)
if err != nil {
return false, errors.Wrap(err, "user.IsAdmin")
}
if !isAdmin {
ClearPreviewRoleCookie(w)
return true, nil
}
previewRole, err = db.GetRoleByID(ctx, tx, roleID) previewRole, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.GetRoleByID") return false, errors.Wrap(err, "db.GetRoleByID")

View File

@@ -90,7 +90,7 @@ func (c *Checker) RequireActualAdmin(s *hws.Server) func(http.Handler) http.Hand
// Check user's ACTUAL role in database, bypassing preview mode // Check user's ACTUAL role in database, bypassing preview mode
var hasAdmin bool var hasAdmin bool
if ok := db.WithReadTx(s, w, r, c.conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := c.conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
hasAdmin, err = user.HasRole(ctx, tx, roles.Admin) hasAdmin, err = user.HasRole(ctx, tx, roles.Admin)
if err != nil { if err != nil {

View File

@@ -1,4 +1,4 @@
package main package server
import ( import (
"context" "context"
@@ -15,7 +15,7 @@ import (
func setupAuth( func setupAuth(
cfg *hwsauth.Config, cfg *hwsauth.Config,
logger *hlog.Logger, logger *hlog.Logger,
conn *bun.DB, conn *db.DB,
server *hws.Server, server *hws.Server,
ignoredPaths []string, ignoredPaths []string,
) (*hwsauth.Authenticator[*db.User, bun.Tx], error) { ) (*hwsauth.Authenticator[*db.User, bun.Tx], error) {
@@ -30,7 +30,7 @@ func setupAuth(
beginTx, beginTx,
logger, logger,
handlers.ErrorPage, handlers.ErrorPage,
conn.DB, conn.DB.DB,
) )
if err != nil { if err != nil {
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator") return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")

View File

@@ -1,4 +1,4 @@
package main package server
import ( import (
"context" "context"
@@ -27,7 +27,7 @@ func addMiddleware(
perms *rbac.Checker, perms *rbac.Checker,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
store *store.Store, store *store.Store,
conn *bun.DB, conn *db.DB,
) error { ) error {
err := server.AddMiddleware( err := server.AddMiddleware(
auth.Authenticate(tokenRefresh(auth, discordAPI, store)), auth.Authenticate(tokenRefresh(auth, discordAPI, store)),

View File

@@ -1,4 +1,4 @@
package main package server
import ( import (
"net/http" "net/http"
@@ -8,7 +8,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
@@ -22,12 +21,11 @@ func addRoutes(
s *hws.Server, s *hws.Server,
staticFS *http.FileSystem, staticFS *http.FileSystem,
cfg *config.Config, cfg *config.Config,
conn *bun.DB, conn *db.DB,
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
store *store.Store, store *store.Store,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
perms *rbac.Checker, perms *rbac.Checker,
audit *auditlog.Logger,
) error { ) error {
// Create the routes // Create the routes
baseRoutes := []hws.Route{ baseRoutes := []hws.Route{
@@ -69,23 +67,18 @@ func addRoutes(
seasonRoutes := []hws.Route{ seasonRoutes := []hws.Route{
{ {
Path: "/seasons", Path: "/seasons",
Method: hws.MethodGET, Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: handlers.SeasonsPage(s, conn), Handler: handlers.SeasonsPage(s, conn),
}, },
{
Path: "/seasons",
Method: hws.MethodPOST,
Handler: handlers.SeasonsList(s, conn),
},
{ {
Path: "/seasons/new", Path: "/seasons/new",
Method: hws.MethodGET, Method: hws.MethodGET,
Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeason(s, conn)), Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeason(s)),
}, },
{ {
Path: "/seasons/new", Path: "/seasons/new",
Method: hws.MethodPOST, Method: hws.MethodPOST,
Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeasonSubmit(s, conn, audit)), Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeasonSubmit(s, conn)),
}, },
{ {
Path: "/seasons/{season_short_name}", Path: "/seasons/{season_short_name}",
@@ -100,7 +93,7 @@ func addRoutes(
{ {
Path: "/seasons/{season_short_name}/edit", Path: "/seasons/{season_short_name}/edit",
Method: hws.MethodPOST, Method: hws.MethodPOST,
Handler: perms.RequirePermission(s, permissions.SeasonsUpdate)(handlers.SeasonEditSubmit(s, conn, audit)), Handler: perms.RequirePermission(s, permissions.SeasonsUpdate)(handlers.SeasonEditSubmit(s, conn)),
}, },
{ {
Path: "/seasons/{season_short_name}/leagues/{league_short_name}", Path: "/seasons/{season_short_name}/leagues/{league_short_name}",
@@ -110,17 +103,17 @@ func addRoutes(
{ {
Path: "/seasons/{season_short_name}/leagues/add/{league_short_name}", Path: "/seasons/{season_short_name}/leagues/add/{league_short_name}",
Method: hws.MethodPOST, Method: hws.MethodPOST,
Handler: perms.RequirePermission(s, permissions.SeasonsAddLeague)(handlers.SeasonAddLeague(s, conn, audit)), Handler: perms.RequirePermission(s, permissions.SeasonsAddLeague)(handlers.SeasonAddLeague(s, conn)),
}, },
{ {
Path: "/seasons/{season_short_name}/leagues/{league_short_name}", Path: "/seasons/{season_short_name}/leagues/{league_short_name}",
Method: hws.MethodDELETE, Method: hws.MethodDELETE,
Handler: perms.RequirePermission(s, permissions.SeasonsRemoveLeague)(handlers.SeasonRemoveLeague(s, conn, audit)), Handler: perms.RequirePermission(s, permissions.SeasonsRemoveLeague)(handlers.SeasonRemoveLeague(s, conn)),
}, },
{ {
Path: "/seasons/{season_short_name}/leagues/{league_short_name}/teams/add", Path: "/seasons/{season_short_name}/leagues/{league_short_name}/teams/add",
Method: hws.MethodPOST, Method: hws.MethodPOST,
Handler: perms.RequirePermission(s, permissions.TeamsAddToLeague)(handlers.SeasonLeagueAddTeam(s, conn, audit)), Handler: perms.RequirePermission(s, permissions.TeamsAddToLeague)(handlers.SeasonLeagueAddTeam(s, conn)),
}, },
} }
@@ -133,12 +126,12 @@ func addRoutes(
{ {
Path: "/leagues/new", Path: "/leagues/new",
Method: hws.MethodGET, Method: hws.MethodGET,
Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeague(s, conn)), Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeague(s)),
}, },
{ {
Path: "/leagues/new", Path: "/leagues/new",
Method: hws.MethodPOST, Method: hws.MethodPOST,
Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeagueSubmit(s, conn, audit)), Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeagueSubmit(s, conn)),
}, },
} }
@@ -156,12 +149,12 @@ func addRoutes(
{ {
Path: "/teams/new", Path: "/teams/new",
Method: hws.MethodGET, Method: hws.MethodGET,
Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamPage(s, conn)), Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamPage(s)),
}, },
{ {
Path: "/teams/new", Path: "/teams/new",
Method: hws.MethodPOST, Method: hws.MethodPOST,
Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamSubmit(s, conn, audit)), Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamSubmit(s, conn)),
}, },
} }
@@ -234,21 +227,11 @@ func addRoutes(
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: perms.RequireAdmin(s)(handlers.AdminRoles(s, conn)), Handler: perms.RequireAdmin(s)(handlers.AdminRoles(s, conn)),
}, },
{
Path: "/admin/permissions",
Method: hws.MethodGET,
Handler: perms.RequireAdmin(s)(handlers.AdminPermissionsPage(s, conn)),
},
{ {
Path: "/admin/audit", Path: "/admin/audit",
Method: hws.MethodGET, Method: hws.MethodGET,
Handler: perms.RequireAdmin(s)(handlers.AdminAuditLogsPage(s, conn)), Handler: perms.RequireAdmin(s)(handlers.AdminAuditLogsPage(s, conn)),
}, },
{
Path: "/admin/permissions",
Method: hws.MethodPOST,
Handler: perms.RequireAdmin(s)(handlers.AdminPermissionsList(s, conn)),
},
{ {
Path: "/admin/audit", Path: "/admin/audit",
Method: hws.MethodPOST, Method: hws.MethodPOST,
@@ -263,7 +246,7 @@ func addRoutes(
{ {
Path: "/admin/roles/create", Path: "/admin/roles/create",
Method: hws.MethodPOST, Method: hws.MethodPOST,
Handler: perms.RequireAdmin(s)(handlers.AdminRoleCreate(s, conn, audit)), Handler: perms.RequireAdmin(s)(handlers.AdminRoleCreate(s, conn)),
}, },
{ {
Path: "/admin/roles/{id}/manage", Path: "/admin/roles/{id}/manage",
@@ -273,7 +256,7 @@ func addRoutes(
{ {
Path: "/admin/roles/{id}", Path: "/admin/roles/{id}",
Method: hws.MethodDELETE, Method: hws.MethodDELETE,
Handler: perms.RequireAdmin(s)(handlers.AdminRoleDelete(s, conn, audit)), Handler: perms.RequireAdmin(s)(handlers.AdminRoleDelete(s, conn)),
}, },
{ {
Path: "/admin/roles/{id}/delete-confirm", Path: "/admin/roles/{id}/delete-confirm",
@@ -288,12 +271,12 @@ func addRoutes(
{ {
Path: "/admin/roles/{id}/permissions", Path: "/admin/roles/{id}/permissions",
Method: hws.MethodPOST, Method: hws.MethodPOST,
Handler: perms.RequireAdmin(s)(handlers.AdminRolePermissionsUpdate(s, conn, audit)), Handler: perms.RequireAdmin(s)(handlers.AdminRolePermissionsUpdate(s, conn)),
}, },
{ {
Path: "/admin/roles/{id}/preview-start", Path: "/admin/roles/{id}/preview-start",
Method: hws.MethodPOST, Method: hws.MethodPOST,
Handler: perms.RequireAdmin(s)(handlers.AdminPreviewRoleStart(s, conn, cfg)), Handler: perms.RequireAdmin(s)(handlers.AdminPreviewRoleStart(s, conn, cfg.HWSAuth.SSL)),
}, },
{ {
Path: "/admin/roles/preview-stop", Path: "/admin/roles/preview-stop",

View File

@@ -1,4 +1,5 @@
package main // Package server provides setup utilities for the HTTP server
package server
import ( import (
"io/fs" "io/fs"
@@ -7,21 +8,20 @@ import (
"git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/handlers" "git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
) )
func setupHTTPServer( func Setup(
staticFS *fs.FS, staticFS *fs.FS,
cfg *config.Config, cfg *config.Config,
logger *hlog.Logger, logger *hlog.Logger,
bun *bun.DB, conn *db.DB,
store *store.Store, store *store.Store,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
) (server *hws.Server, err error) { ) (server *hws.Server, err error) {
@@ -41,7 +41,7 @@ func setupHTTPServer(
} }
auth, err := setupAuth( auth, err := setupAuth(
cfg.HWSAuth, logger, bun, httpServer, ignoredPaths) cfg.HWSAuth, logger, conn, httpServer, ignoredPaths)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "setupAuth") return nil, errors.Wrap(err, "setupAuth")
} }
@@ -62,20 +62,17 @@ func setupHTTPServer(
} }
// Initialize permissions checker // Initialize permissions checker
perms, err := rbac.NewChecker(bun, httpServer) perms, err := rbac.NewChecker(conn, httpServer)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "rbac.NewChecker") return nil, errors.Wrap(err, "rbac.NewChecker")
} }
// Initialize audit logger err = addRoutes(httpServer, &fs, cfg, conn, auth, store, discordAPI, perms)
audit := auditlog.NewLogger(bun)
err = addRoutes(httpServer, &fs, cfg, bun, auth, store, discordAPI, perms, audit)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "addRoutes") return nil, errors.Wrap(err, "addRoutes")
} }
err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store, bun) err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store, conn)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "addMiddleware") return nil, errors.Wrap(err, "addMiddleware")
} }

View File

@@ -136,38 +136,38 @@ templ profileDropdown(user *db.User, items []ProfileItem) {
x-on:click.away="isActive = false" x-on:click.away="isActive = false"
x-on:keydown.escape.window="isActive = false" x-on:keydown.escape.window="isActive = false"
> >
<!-- Preview Mode Stop Buttons --> <!-- Preview Mode Stop Buttons -->
if previewRole != nil { if previewRole != nil {
<div class="p-2 bg-yellow/10 border-b border-yellow/30 space-y-2"> <div class="p-2 bg-yellow/10 border-b border-yellow/30 space-y-2">
<p class="text-xs text-yellow/80 px-2 font-semibold"> <p class="text-xs text-yellow/80 px-2 font-semibold">
Viewing as: { previewRole.DisplayName } Viewing as: { previewRole.DisplayName }
</p> </p>
<div class="flex gap-2"> <div class="flex gap-2">
<form method="POST" action="/admin/roles/preview-stop?stay=true" class="flex-1"> <form method="POST" action="/admin/roles/preview-stop?stay=true" class="flex-1">
<button <button
type="submit" type="submit"
class="w-full rounded-lg px-3 py-2 class="w-full rounded-lg px-3 py-2
text-sm text-mantle bg-green font-semibold hover:bg-teal hover:cursor-pointer transition" text-sm text-mantle bg-green font-semibold hover:bg-teal hover:cursor-pointer transition"
role="menuitem" role="menuitem"
@click="isActive=false" @click="isActive=false"
> >
Stop Preview Stop Preview
</button> </button>
</form> </form>
<form method="POST" action="/admin/roles/preview-stop" class="flex-1"> <form method="POST" action="/admin/roles/preview-stop" class="flex-1">
<button <button
type="submit" type="submit"
class="w-full rounded-lg px-3 py-2 class="w-full rounded-lg px-3 py-2
text-sm text-mantle bg-blue font-semibold hover:bg-sky hover:cursor-pointer transition" text-sm text-mantle bg-blue font-semibold hover:bg-sky hover:cursor-pointer transition"
role="menuitem" role="menuitem"
@click="isActive=false" @click="isActive=false"
> >
Return to Admin Return to Admin
</button> </button>
</form> </form>
</div>
</div> </div>
</div> }
}
<!-- Profile links --> <!-- Profile links -->
<div class="p-2"> <div class="p-2">
for _, item := range items { for _, item := range items {