big ole refactor
This commit is contained in:
@@ -1,48 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
|
||||||
"github.com/uptrace/bun"
|
|
||||||
"github.com/uptrace/bun/dialect/pgdialect"
|
|
||||||
"github.com/uptrace/bun/driver/pgdriver"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setupBun(cfg *config.Config) (conn *bun.DB, close func() error) {
|
|
||||||
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
|
|
||||||
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL)
|
|
||||||
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
|
|
||||||
|
|
||||||
sqldb.SetMaxOpenConns(25)
|
|
||||||
sqldb.SetMaxIdleConns(10)
|
|
||||||
sqldb.SetConnMaxLifetime(5 * time.Minute)
|
|
||||||
sqldb.SetConnMaxIdleTime(5 * time.Minute)
|
|
||||||
|
|
||||||
conn = bun.NewDB(sqldb, pgdialect.New())
|
|
||||||
registerDBModels(conn)
|
|
||||||
close = sqldb.Close
|
|
||||||
return conn, close
|
|
||||||
}
|
|
||||||
|
|
||||||
func registerDBModels(conn *bun.DB) []any {
|
|
||||||
models := []any{
|
|
||||||
(*db.RolePermission)(nil),
|
|
||||||
(*db.UserRole)(nil),
|
|
||||||
(*db.SeasonLeague)(nil),
|
|
||||||
(*db.TeamParticipation)(nil),
|
|
||||||
(*db.User)(nil),
|
|
||||||
(*db.DiscordToken)(nil),
|
|
||||||
(*db.Season)(nil),
|
|
||||||
(*db.League)(nil),
|
|
||||||
(*db.Team)(nil),
|
|
||||||
(*db.Role)(nil),
|
|
||||||
(*db.Permission)(nil),
|
|
||||||
(*db.AuditLog)(nil),
|
|
||||||
}
|
|
||||||
conn.RegisterModel(models...)
|
|
||||||
return models
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"git.haelnorr.com/h/golib/hlog"
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db/migrate"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,7 +49,7 @@ func main() {
|
|||||||
|
|
||||||
// Handle migration file creation (doesn't need DB connection)
|
// Handle migration file creation (doesn't need DB connection)
|
||||||
if flags.MigrateCreate != "" {
|
if flags.MigrateCreate != "" {
|
||||||
if err := createMigration(flags.MigrateCreate); err != nil {
|
if err := migrate.CreateMigration(flags.MigrateCreate); err != nil {
|
||||||
logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "createMigration"))).Msg("Error creating migration")
|
logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "createMigration"))).Msg("Error creating migration")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -59,17 +60,21 @@ func main() {
|
|||||||
flags.MigrateStatus || flags.MigrateDryRun ||
|
flags.MigrateStatus || flags.MigrateDryRun ||
|
||||||
flags.ResetDB {
|
flags.ResetDB {
|
||||||
|
|
||||||
|
var command, countStr string
|
||||||
// Route to appropriate command
|
// Route to appropriate command
|
||||||
if flags.MigrateUp != "" {
|
if flags.MigrateUp != "" {
|
||||||
err = runMigrations(ctx, cfg, "up", flags.MigrateUp)
|
command = "up"
|
||||||
|
countStr = flags.MigrateUp
|
||||||
} else if flags.MigrateRollback != "" {
|
} else if flags.MigrateRollback != "" {
|
||||||
err = runMigrations(ctx, cfg, "rollback", flags.MigrateRollback)
|
command = "rollback"
|
||||||
|
countStr = flags.MigrateRollback
|
||||||
} else if flags.MigrateStatus {
|
} else if flags.MigrateStatus {
|
||||||
err = runMigrations(ctx, cfg, "status", "")
|
command = "status"
|
||||||
} else if flags.MigrateDryRun {
|
}
|
||||||
err = runMigrations(ctx, cfg, "dry-run", "")
|
if flags.ResetDB {
|
||||||
} else if flags.ResetDB {
|
err = migrate.ResetDatabase(ctx, cfg)
|
||||||
err = resetDatabase(ctx, cfg)
|
} else {
|
||||||
|
err = migrate.RunMigrations(ctx, cfg, command, countStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -12,8 +12,10 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
"git.haelnorr.com/h/oslstats/internal/embedfs"
|
"git.haelnorr.com/h/oslstats/internal/embedfs"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/server"
|
||||||
"git.haelnorr.com/h/oslstats/internal/store"
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,8 +27,7 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
|
|||||||
// Setup the database connection
|
// Setup the database connection
|
||||||
logger.Debug().Msg("Config loaded and logger started")
|
logger.Debug().Msg("Config loaded and logger started")
|
||||||
logger.Debug().Msg("Connecting to database")
|
logger.Debug().Msg("Connecting to database")
|
||||||
bun, closedb := setupBun(cfg)
|
conn := db.NewDB(cfg.DB)
|
||||||
// registerDBModels(bun)
|
|
||||||
|
|
||||||
// Setup embedded files
|
// Setup embedded files
|
||||||
logger.Debug().Msg("Getting embedded files")
|
logger.Debug().Msg("Getting embedded files")
|
||||||
@@ -47,7 +48,7 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug().Msg("Setting up HTTP server")
|
logger.Debug().Msg("Setting up HTTP server")
|
||||||
httpServer, err := setupHTTPServer(&staticFS, cfg, logger, bun, store, discordAPI)
|
httpServer, err := server.Setup(staticFS, cfg, logger, conn, store, discordAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "setupHttpServer")
|
return errors.Wrap(err, "setupHttpServer")
|
||||||
}
|
}
|
||||||
@@ -71,7 +72,7 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "httpServer.Shutdown"))).Msg("Error during HTTP server shutdown")
|
logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "httpServer.Shutdown"))).Msg("Error during HTTP server shutdown")
|
||||||
}
|
}
|
||||||
err = closedb()
|
err = conn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "closedb"))).Msg("Error during database close")
|
logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "closedb"))).Msg("Error during database close")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,187 +0,0 @@
|
|||||||
// Package auditlog provides a system for logging events that require permissions to the audit log
|
|
||||||
package auditlog
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
conn *bun.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLogger(conn *bun.DB) *Logger {
|
|
||||||
return &Logger{conn: conn}
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogSuccess logs a successful permission-protected action
|
|
||||||
func (l *Logger) LogSuccess(
|
|
||||||
ctx context.Context,
|
|
||||||
tx bun.Tx,
|
|
||||||
user *db.User,
|
|
||||||
action string,
|
|
||||||
resourceType string,
|
|
||||||
resourceID any, // Can be int, string, or nil
|
|
||||||
details map[string]any,
|
|
||||||
r *http.Request,
|
|
||||||
) error {
|
|
||||||
return l.log(ctx, tx, user, action, resourceType, resourceID, details, "success", nil, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogError logs a failed action due to an error
|
|
||||||
func (l *Logger) LogError(
|
|
||||||
ctx context.Context,
|
|
||||||
tx bun.Tx,
|
|
||||||
user *db.User,
|
|
||||||
action string,
|
|
||||||
resourceType string,
|
|
||||||
resourceID any,
|
|
||||||
err error,
|
|
||||||
r *http.Request,
|
|
||||||
) error {
|
|
||||||
errMsg := err.Error()
|
|
||||||
return l.log(ctx, tx, user, action, resourceType, resourceID, nil, "error", &errMsg, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) log(
|
|
||||||
ctx context.Context,
|
|
||||||
tx bun.Tx,
|
|
||||||
user *db.User,
|
|
||||||
action string,
|
|
||||||
resourceType string,
|
|
||||||
resourceID any,
|
|
||||||
details map[string]any,
|
|
||||||
result string,
|
|
||||||
errorMessage *string,
|
|
||||||
r *http.Request,
|
|
||||||
) error {
|
|
||||||
if user == nil {
|
|
||||||
return errors.New("user cannot be nil for audit logging")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert resourceID to string
|
|
||||||
var resourceIDStr *string
|
|
||||||
if resourceID != nil {
|
|
||||||
idStr := fmt.Sprintf("%v", resourceID)
|
|
||||||
resourceIDStr = &idStr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal details to JSON
|
|
||||||
var detailsJSON json.RawMessage
|
|
||||||
if details != nil {
|
|
||||||
jsonBytes, err := json.Marshal(details)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "json.Marshal details")
|
|
||||||
}
|
|
||||||
detailsJSON = jsonBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract IP and User-Agent from request
|
|
||||||
ipAddress := r.RemoteAddr
|
|
||||||
userAgent := r.UserAgent()
|
|
||||||
|
|
||||||
log := &db.AuditLog{
|
|
||||||
UserID: user.ID,
|
|
||||||
Action: action,
|
|
||||||
ResourceType: resourceType,
|
|
||||||
ResourceID: resourceIDStr,
|
|
||||||
Details: detailsJSON,
|
|
||||||
IPAddress: ipAddress,
|
|
||||||
UserAgent: userAgent,
|
|
||||||
Result: result,
|
|
||||||
ErrorMessage: errorMessage,
|
|
||||||
CreatedAt: time.Now().Unix(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return db.CreateAuditLog(ctx, tx, log)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRecentLogs retrieves recent audit logs with pagination
|
|
||||||
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) {
|
|
||||||
var logs *db.List[db.AuditLog]
|
|
||||||
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
|
||||||
var err error
|
|
||||||
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "db.GetAuditLogs")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "db.WithTxFailSilently")
|
|
||||||
}
|
|
||||||
return logs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLogsByUser retrieves audit logs for a specific user
|
|
||||||
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) {
|
|
||||||
var logs *db.List[db.AuditLog]
|
|
||||||
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
|
||||||
var err error
|
|
||||||
logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "db.GetAuditLogsByUser")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "db.WithTxFailSilently")
|
|
||||||
}
|
|
||||||
return logs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CleanupOldLogs deletes audit logs older than the specified number of days
|
|
||||||
func (l *Logger) CleanupOldLogs(ctx context.Context, daysToKeep int) (int, error) {
|
|
||||||
if daysToKeep <= 0 {
|
|
||||||
return 0, errors.New("daysToKeep must be positive")
|
|
||||||
}
|
|
||||||
|
|
||||||
cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix()
|
|
||||||
|
|
||||||
var count int
|
|
||||||
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
|
||||||
var err error
|
|
||||||
count, err = db.CleanupOldAuditLogs(ctx, tx, cutoffTime)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "db.CleanupOldAuditLogs")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return 0, errors.Wrap(err, "db.WithTxFailSilently")
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Callback returns a db.AuditCallback that logs to this Logger
|
|
||||||
// This is used with the generic database helpers (Insert, Update, Delete)
|
|
||||||
//
|
|
||||||
// Usage:
|
|
||||||
//
|
|
||||||
// audit := auditlog.NewLogger(conn)
|
|
||||||
// err := db.Insert(tx, season).
|
|
||||||
// WithAudit(r, audit.Callback()).
|
|
||||||
// Exec(ctx)
|
|
||||||
func (l *Logger) Callback() db.AuditCallback {
|
|
||||||
return func(ctx context.Context, tx bun.Tx, info *db.AuditInfo, r *http.Request) error {
|
|
||||||
user := db.CurrentUser(ctx)
|
|
||||||
if user == nil {
|
|
||||||
return errors.New("no user in context for audit logging")
|
|
||||||
}
|
|
||||||
|
|
||||||
return l.LogSuccess(
|
|
||||||
ctx,
|
|
||||||
tx,
|
|
||||||
user,
|
|
||||||
info.Action,
|
|
||||||
info.ResourceType,
|
|
||||||
info.ResourceID,
|
|
||||||
info.Details,
|
|
||||||
r,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,14 +1,23 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type AuditMeta struct {
|
||||||
|
r *http.Request
|
||||||
|
u *User
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAudit(r *http.Request, u *User) *AuditMeta {
|
||||||
|
if u == nil {
|
||||||
|
u = CurrentUser(r.Context())
|
||||||
|
}
|
||||||
|
return &AuditMeta{r, u}
|
||||||
|
}
|
||||||
|
|
||||||
// AuditInfo contains metadata for audit logging
|
// AuditInfo contains metadata for audit logging
|
||||||
type AuditInfo struct {
|
type AuditInfo struct {
|
||||||
Action string // e.g., "seasons.create", "users.update"
|
Action string // e.g., "seasons.create", "users.update"
|
||||||
@@ -17,9 +26,6 @@ type AuditInfo struct {
|
|||||||
Details map[string]any // Changed fields or additional metadata
|
Details map[string]any // Changed fields or additional metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuditCallback is called after successful database operations to log changes
|
|
||||||
type AuditCallback func(ctx context.Context, tx bun.Tx, info *AuditInfo, r *http.Request) error
|
|
||||||
|
|
||||||
// extractTableName gets the bun table name from a model type using reflection
|
// extractTableName gets the bun table name from a model type using reflection
|
||||||
// Example: Season with `bun:"table:seasons,alias:s"` returns "seasons"
|
// Example: Season with `bun:"table:seasons,alias:s"` returns "seasons"
|
||||||
func extractTableName[T any]() string {
|
func extractTableName[T any]() string {
|
||||||
@@ -27,7 +33,7 @@ func extractTableName[T any]() string {
|
|||||||
t := reflect.TypeOf(model)
|
t := reflect.TypeOf(model)
|
||||||
|
|
||||||
// Handle pointer types
|
// Handle pointer types
|
||||||
if t.Kind() == reflect.Ptr {
|
if t.Kind() == reflect.Pointer {
|
||||||
t = t.Elem()
|
t = t.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,11 +44,9 @@ func extractTableName[T any]() string {
|
|||||||
bunTag := field.Tag.Get("bun")
|
bunTag := field.Tag.Get("bun")
|
||||||
if bunTag != "" {
|
if bunTag != "" {
|
||||||
// Parse tag: "table:seasons,alias:s" -> "seasons"
|
// Parse tag: "table:seasons,alias:s" -> "seasons"
|
||||||
parts := strings.Split(bunTag, ",")
|
for part := range strings.SplitSeq(bunTag, ",") {
|
||||||
for _, part := range parts {
|
part, _ := strings.CutPrefix(part, "table:")
|
||||||
if strings.HasPrefix(part, "table:") {
|
return part
|
||||||
return strings.TrimPrefix(part, "table:")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,7 +85,7 @@ func extractPrimaryKey[T any](model *T) any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
v := reflect.ValueOf(model)
|
v := reflect.ValueOf(model)
|
||||||
if v.Kind() == reflect.Ptr {
|
if v.Kind() == reflect.Pointer {
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +114,7 @@ func extractChangedFields[T any](model *T, columns []string) map[string]any {
|
|||||||
|
|
||||||
result := make(map[string]any)
|
result := make(map[string]any)
|
||||||
v := reflect.ValueOf(model)
|
v := reflect.ValueOf(model)
|
||||||
if v.Kind() == reflect.Ptr {
|
if v.Kind() == reflect.Pointer {
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,5 +146,3 @@ func extractChangedFields[T any](model *T, columns []string) map[string]any {
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: We don't need getTxFromQuery since we store the tx directly in our helper structs
|
|
||||||
|
|||||||
91
internal/db/auditlogger.go
Normal file
91
internal/db/auditlogger.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogSuccess logs a successful permission-protected action
|
||||||
|
func LogSuccess(
|
||||||
|
ctx context.Context,
|
||||||
|
tx bun.Tx,
|
||||||
|
meta *AuditMeta,
|
||||||
|
info *AuditInfo,
|
||||||
|
) error {
|
||||||
|
return log(ctx, tx, meta, info, "success", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogError logs a failed action due to an error
|
||||||
|
func LogError(
|
||||||
|
ctx context.Context,
|
||||||
|
tx bun.Tx,
|
||||||
|
meta *AuditMeta,
|
||||||
|
info *AuditInfo,
|
||||||
|
err error,
|
||||||
|
) error {
|
||||||
|
errMsg := err.Error()
|
||||||
|
return log(ctx, tx, meta, info, "error", &errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func log(
|
||||||
|
ctx context.Context,
|
||||||
|
tx bun.Tx,
|
||||||
|
meta *AuditMeta,
|
||||||
|
info *AuditInfo,
|
||||||
|
result string,
|
||||||
|
errorMessage *string,
|
||||||
|
) error {
|
||||||
|
if meta == nil {
|
||||||
|
return errors.New("audit meta cannot be nil for audit logging")
|
||||||
|
}
|
||||||
|
if info == nil {
|
||||||
|
return errors.New("audit info cannot be nil for audit logging")
|
||||||
|
}
|
||||||
|
if meta.u == nil {
|
||||||
|
return errors.New("user cannot be nil for audit logging")
|
||||||
|
}
|
||||||
|
if meta.r == nil {
|
||||||
|
return errors.New("request cannot be nil for audit logging")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert resourceID to string
|
||||||
|
var resourceIDStr *string
|
||||||
|
if info.ResourceID != nil {
|
||||||
|
idStr := fmt.Sprintf("%v", info.ResourceID)
|
||||||
|
resourceIDStr = &idStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal details to JSON
|
||||||
|
var detailsJSON json.RawMessage
|
||||||
|
if info.Details != nil {
|
||||||
|
jsonBytes, err := json.Marshal(info.Details)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "json.Marshal details")
|
||||||
|
}
|
||||||
|
detailsJSON = jsonBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract IP and User-Agent from request
|
||||||
|
ipAddress := meta.r.RemoteAddr
|
||||||
|
userAgent := meta.r.UserAgent()
|
||||||
|
|
||||||
|
log := &AuditLog{
|
||||||
|
UserID: meta.u.ID,
|
||||||
|
Action: info.Action,
|
||||||
|
ResourceType: info.ResourceType,
|
||||||
|
ResourceID: resourceIDStr,
|
||||||
|
Details: detailsJSON,
|
||||||
|
IPAddress: ipAddress,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
Result: result,
|
||||||
|
ErrorMessage: errorMessage,
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return CreateAuditLog(ctx, tx, log)
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package backup
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,14 +9,13 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateBackup creates a compressed PostgreSQL dump before migrations
|
// CreateBackup creates a compressed PostgreSQL dump before migrations
|
||||||
// Returns backup filename and error
|
// Returns backup filename and error
|
||||||
// If pg_dump is not available, returns nil error with warning
|
// If pg_dump is not available, returns nil error with warning
|
||||||
func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (string, error) {
|
func CreateBackup(ctx context.Context, cfg *Config, operation string) (string, error) {
|
||||||
// Check if pg_dump is available
|
// Check if pg_dump is available
|
||||||
if _, err := exec.LookPath("pg_dump"); err != nil {
|
if _, err := exec.LookPath("pg_dump"); err != nil {
|
||||||
fmt.Println("[WARN] pg_dump not found - skipping backup")
|
fmt.Println("[WARN] pg_dump not found - skipping backup")
|
||||||
@@ -28,13 +27,13 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ensure backup directory exists
|
// Ensure backup directory exists
|
||||||
if err := os.MkdirAll(cfg.DB.BackupDir, 0755); err != nil {
|
if err := os.MkdirAll(cfg.BackupDir, 0o755); err != nil {
|
||||||
return "", errors.Wrap(err, "failed to create backup directory")
|
return "", errors.Wrap(err, "failed to create backup directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate filename: YYYYMMDD_HHmmss_pre_{operation}.sql.gz
|
// Generate filename: YYYYMMDD_HHmmss_pre_{operation}.sql.gz
|
||||||
timestamp := time.Now().Format("20060102_150405")
|
timestamp := time.Now().Format("20060102_150405")
|
||||||
filename := filepath.Join(cfg.DB.BackupDir,
|
filename := filepath.Join(cfg.BackupDir,
|
||||||
fmt.Sprintf("%s_pre_%s.sql.gz", timestamp, operation))
|
fmt.Sprintf("%s_pre_%s.sql.gz", timestamp, operation))
|
||||||
|
|
||||||
// Check if gzip is available
|
// Check if gzip is available
|
||||||
@@ -42,7 +41,7 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
|
|||||||
if _, err := exec.LookPath("gzip"); err != nil {
|
if _, err := exec.LookPath("gzip"); err != nil {
|
||||||
fmt.Println("[WARN] gzip not found - using uncompressed backup")
|
fmt.Println("[WARN] gzip not found - using uncompressed backup")
|
||||||
useGzip = false
|
useGzip = false
|
||||||
filename = filepath.Join(cfg.DB.BackupDir,
|
filename = filepath.Join(cfg.BackupDir,
|
||||||
fmt.Sprintf("%s_pre_%s.sql", timestamp, operation))
|
fmt.Sprintf("%s_pre_%s.sql", timestamp, operation))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,19 +51,19 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
|
|||||||
// Use shell to pipe pg_dump through gzip
|
// Use shell to pipe pg_dump through gzip
|
||||||
pgDumpCmd := fmt.Sprintf(
|
pgDumpCmd := fmt.Sprintf(
|
||||||
"pg_dump -h %s -p %d -U %s -d %s --no-owner --no-acl --clean --if-exists | gzip > %s",
|
"pg_dump -h %s -p %d -U %s -d %s --no-owner --no-acl --clean --if-exists | gzip > %s",
|
||||||
cfg.DB.Host,
|
cfg.Host,
|
||||||
cfg.DB.Port,
|
cfg.Port,
|
||||||
cfg.DB.User,
|
cfg.User,
|
||||||
cfg.DB.DB,
|
cfg.DB,
|
||||||
filename,
|
filename,
|
||||||
)
|
)
|
||||||
cmd = exec.CommandContext(ctx, "sh", "-c", pgDumpCmd)
|
cmd = exec.CommandContext(ctx, "sh", "-c", pgDumpCmd)
|
||||||
} else {
|
} else {
|
||||||
cmd = exec.CommandContext(ctx, "pg_dump",
|
cmd = exec.CommandContext(ctx, "pg_dump",
|
||||||
"-h", cfg.DB.Host,
|
"-h", cfg.Host,
|
||||||
"-p", fmt.Sprint(cfg.DB.Port),
|
"-p", fmt.Sprint(cfg.Port),
|
||||||
"-U", cfg.DB.User,
|
"-U", cfg.User,
|
||||||
"-d", cfg.DB.DB,
|
"-d", cfg.DB,
|
||||||
"-f", filename,
|
"-f", filename,
|
||||||
"--no-owner",
|
"--no-owner",
|
||||||
"--no-acl",
|
"--no-acl",
|
||||||
@@ -75,7 +74,7 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
|
|||||||
|
|
||||||
// Set password via environment variable
|
// Set password via environment variable
|
||||||
cmd.Env = append(os.Environ(),
|
cmd.Env = append(os.Environ(),
|
||||||
fmt.Sprintf("PGPASSWORD=%s", cfg.DB.Password))
|
fmt.Sprintf("PGPASSWORD=%s", cfg.Password))
|
||||||
|
|
||||||
// Run backup
|
// Run backup
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
@@ -95,14 +94,14 @@ func CreateBackup(ctx context.Context, cfg *config.Config, operation string) (st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CleanOldBackups keeps only the N most recent backups
|
// CleanOldBackups keeps only the N most recent backups
|
||||||
func CleanOldBackups(cfg *config.Config, keepCount int) error {
|
func CleanOldBackups(cfg *Config, keepCount int) error {
|
||||||
// Get all backup files (both .sql and .sql.gz)
|
// Get all backup files (both .sql and .sql.gz)
|
||||||
sqlFiles, err := filepath.Glob(filepath.Join(cfg.DB.BackupDir, "*.sql"))
|
sqlFiles, err := filepath.Glob(filepath.Join(cfg.BackupDir, "*.sql"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to list .sql backups")
|
return errors.Wrap(err, "failed to list .sql backups")
|
||||||
}
|
}
|
||||||
|
|
||||||
gzFiles, err := filepath.Glob(filepath.Join(cfg.DB.BackupDir, "*.sql.gz"))
|
gzFiles, err := filepath.Glob(filepath.Join(cfg.BackupDir, "*.sql.gz"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to list .sql.gz backups")
|
return errors.Wrap(err, "failed to list .sql.gz backups")
|
||||||
}
|
}
|
||||||
@@ -3,18 +3,17 @@ package db
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type deleter[T any] struct {
|
type deleter[T any] struct {
|
||||||
tx bun.Tx
|
tx bun.Tx
|
||||||
q *bun.DeleteQuery
|
q *bun.DeleteQuery
|
||||||
resourceID any // Store ID before deletion for audit
|
resourceID any // Store ID before deletion for audit
|
||||||
auditCallback AuditCallback
|
audit *AuditMeta
|
||||||
auditRequest *http.Request
|
auditInfo *AuditInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
type systemType interface {
|
type systemType interface {
|
||||||
@@ -39,11 +38,10 @@ func (d *deleter[T]) Where(query string, args ...any) *deleter[T] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WithAudit enables audit logging for this delete operation
|
// WithAudit enables audit logging for this delete operation
|
||||||
// The callback will be invoked after successful deletion with auto-generated audit info
|
// If the provided *AuditInfo is nil, will use reflection to automatically work out the details
|
||||||
// If the callback returns an error, the transaction will be rolled back
|
func (d *deleter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *deleter[T] {
|
||||||
func (d *deleter[T]) WithAudit(r *http.Request, callback AuditCallback) *deleter[T] {
|
d.audit = meta
|
||||||
d.auditRequest = r
|
d.auditInfo = info
|
||||||
d.auditCallback = callback
|
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,21 +55,23 @@ func (d *deleter[T]) Delete(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle audit logging if enabled
|
// Handle audit logging if enabled
|
||||||
if d.auditCallback != nil && d.auditRequest != nil {
|
if d.audit != nil {
|
||||||
tableName := extractTableName[T]()
|
if d.auditInfo == nil {
|
||||||
resourceType := extractResourceType(tableName)
|
tableName := extractTableName[T]()
|
||||||
action := buildAction(resourceType, "delete")
|
resourceType := extractResourceType(tableName)
|
||||||
|
action := buildAction(resourceType, "delete")
|
||||||
|
|
||||||
info := &AuditInfo{
|
d.auditInfo = &AuditInfo{
|
||||||
Action: action,
|
Action: action,
|
||||||
ResourceType: resourceType,
|
ResourceType: resourceType,
|
||||||
ResourceID: d.resourceID,
|
ResourceID: d.resourceID,
|
||||||
Details: nil, // Delete doesn't need details
|
Details: nil, // Delete doesn't need details
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call audit callback - if it fails, return error to trigger rollback
|
err = LogSuccess(ctx, d.tx, d.audit, d.auditInfo)
|
||||||
if err := d.auditCallback(ctx, d.tx, info, d.auditRequest); err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "audit.callback")
|
return errors.Wrap(err, "LogSuccess")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] {
|
|||||||
return DeleteItem[T](tx).Where("id = ?", id)
|
return DeleteItem[T](tx).Where("id = ?", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) error {
|
func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error {
|
||||||
deleter := DeleteByID[T](tx, id)
|
deleter := DeleteByID[T](tx, id)
|
||||||
item, err := GetByID[T](tx, id).Get(ctx)
|
item, err := GetByID[T](tx, id).Get(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -94,5 +94,8 @@ func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int)
|
|||||||
if (*item).isSystem() {
|
if (*item).isSystem() {
|
||||||
return errors.New("record is system protected")
|
return errors.New("record is system protected")
|
||||||
}
|
}
|
||||||
|
if audit != nil {
|
||||||
|
deleter = deleter.WithAudit(audit, nil)
|
||||||
|
}
|
||||||
return deleter.Delete(ctx)
|
return deleter.Delete(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package db
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -11,13 +10,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type inserter[T any] struct {
|
type inserter[T any] struct {
|
||||||
tx bun.Tx
|
tx bun.Tx
|
||||||
q *bun.InsertQuery
|
q *bun.InsertQuery
|
||||||
model *T
|
model *T
|
||||||
models []*T
|
models []*T
|
||||||
isBulk bool
|
isBulk bool
|
||||||
auditCallback AuditCallback
|
audit *AuditMeta
|
||||||
auditRequest *http.Request
|
auditInfo *AuditInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert creates an inserter for a single model
|
// Insert creates an inserter for a single model
|
||||||
@@ -76,11 +75,10 @@ func (i *inserter[T]) Returning(columns ...string) *inserter[T] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WithAudit enables audit logging for this insert operation
|
// WithAudit enables audit logging for this insert operation
|
||||||
// The callback will be invoked after successful insert with auto-generated audit info
|
// If the provided *AuditInfo is nil, will use reflection to automatically work out the details
|
||||||
// If the callback returns an error, the transaction will be rolled back
|
func (i *inserter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *inserter[T] {
|
||||||
func (i *inserter[T]) WithAudit(r *http.Request, callback AuditCallback) *inserter[T] {
|
i.audit = meta
|
||||||
i.auditRequest = r
|
i.auditInfo = info
|
||||||
i.auditCallback = callback
|
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,35 +92,29 @@ func (i *inserter[T]) Exec(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle audit logging if enabled
|
// Handle audit logging if enabled
|
||||||
if i.auditCallback != nil && i.auditRequest != nil {
|
if i.audit != nil {
|
||||||
tableName := extractTableName[T]()
|
if i.auditInfo == nil {
|
||||||
resourceType := extractResourceType(tableName)
|
tableName := extractTableName[T]()
|
||||||
action := buildAction(resourceType, "create")
|
resourceType := extractResourceType(tableName)
|
||||||
|
action := buildAction(resourceType, "create")
|
||||||
var info *AuditInfo
|
i.auditInfo = &AuditInfo{
|
||||||
if i.isBulk {
|
|
||||||
// For bulk inserts, log once with count in details
|
|
||||||
info = &AuditInfo{
|
|
||||||
Action: action,
|
Action: action,
|
||||||
ResourceType: resourceType,
|
ResourceType: resourceType,
|
||||||
ResourceID: nil,
|
ResourceID: nil,
|
||||||
Details: map[string]any{
|
|
||||||
"count": len(i.models),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// For single insert, log with resource ID
|
|
||||||
info = &AuditInfo{
|
|
||||||
Action: action,
|
|
||||||
ResourceType: resourceType,
|
|
||||||
ResourceID: extractPrimaryKey(i.model),
|
|
||||||
Details: nil,
|
Details: nil,
|
||||||
}
|
}
|
||||||
|
if i.isBulk {
|
||||||
|
i.auditInfo.Details = map[string]any{
|
||||||
|
"count": len(i.models),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
i.auditInfo.ResourceID = extractPrimaryKey(i.model)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call audit callback - if it fails, return error to trigger rollback
|
err = LogSuccess(ctx, i.tx, i.audit, i.auditInfo)
|
||||||
if err := i.auditCallback(ctx, i.tx, info, i.auditRequest); err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "audit.callback")
|
return errors.Wrap(err, "LogSuccess")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,13 +19,6 @@ type League struct {
|
|||||||
Teams []Team `bun:"m2m:team_participations,join:League=Team"`
|
Teams []Team `bun:"m2m:team_participations,join:League=Team"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SeasonLeague struct {
|
|
||||||
SeasonID int `bun:",pk"`
|
|
||||||
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
|
|
||||||
LeagueID int `bun:",pk"`
|
|
||||||
League *League `bun:"rel:belongs-to,join:league_id=id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetLeagues(ctx context.Context, tx bun.Tx) ([]*League, error) {
|
func GetLeagues(ctx context.Context, tx bun.Tx) ([]*League, error) {
|
||||||
return GetList[League](tx).Relation("Seasons").GetAll(ctx)
|
return GetList[League](tx).Relation("Seasons").GetAll(ctx)
|
||||||
}
|
}
|
||||||
@@ -37,41 +30,16 @@ func GetLeague(ctx context.Context, tx bun.Tx, shortname string) (*League, error
|
|||||||
return GetByField[League](tx, "short_name", shortname).Relation("Seasons").Get(ctx)
|
return GetByField[League](tx, "short_name", shortname).Relation("Seasons").Get(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSeasonLeague retrieves a specific season-league combination with teams
|
func NewLeague(ctx context.Context, tx bun.Tx, name, shortname, description string, audit *AuditMeta) (*League, error) {
|
||||||
func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string) (*Season, *League, []*Team, error) {
|
league := &League{
|
||||||
if seasonShortName == "" {
|
Name: name,
|
||||||
return nil, nil, nil, errors.New("season short_name cannot be empty")
|
ShortName: shortname,
|
||||||
|
Description: description,
|
||||||
}
|
}
|
||||||
if leagueShortName == "" {
|
err := Insert(tx, league).
|
||||||
return nil, nil, nil, errors.New("league short_name cannot be empty")
|
WithAudit(audit, nil).Exec(ctx)
|
||||||
}
|
|
||||||
|
|
||||||
// Get the season
|
|
||||||
season, err := GetSeason(ctx, tx, seasonShortName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, errors.Wrap(err, "GetSeason")
|
return nil, errors.Wrap(err, "db.Insert")
|
||||||
}
|
}
|
||||||
|
return league, nil
|
||||||
// Get the league
|
|
||||||
league, err := GetLeague(ctx, tx, leagueShortName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, errors.Wrap(err, "GetLeague")
|
|
||||||
}
|
|
||||||
if season == nil || league == nil || !season.HasLeague(league.ID) {
|
|
||||||
return nil, nil, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get all teams participating in this season+league
|
|
||||||
var teams []*Team
|
|
||||||
err = tx.NewSelect().
|
|
||||||
Model(&teams).
|
|
||||||
Join("INNER JOIN team_participations AS tp ON tp.team_id = t.id").
|
|
||||||
Where("tp.season_id = ? AND tp.league_id = ?", season.ID, league.ID).
|
|
||||||
Order("t.name ASC").
|
|
||||||
Scan(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, errors.Wrap(err, "tx.Select teams")
|
|
||||||
}
|
|
||||||
|
|
||||||
return season, league, teams, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
package main
|
// Package migrate provides functions for managing database migrations
|
||||||
|
package migrate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@@ -11,20 +12,19 @@ import (
|
|||||||
"text/tabwriter"
|
"text/tabwriter"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/cmd/oslstats/migrations"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/backup"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db/migrations"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
|
||||||
"github.com/uptrace/bun/migrate"
|
"github.com/uptrace/bun/migrate"
|
||||||
)
|
)
|
||||||
|
|
||||||
// runMigrations executes database migrations
|
// RunMigrations executes database migrations
|
||||||
func runMigrations(ctx context.Context, cfg *config.Config, command string, countStr string) error {
|
func RunMigrations(ctx context.Context, cfg *config.Config, command string, countStr string) error {
|
||||||
conn, close := setupBun(cfg)
|
conn := db.NewDB(cfg.DB)
|
||||||
defer func() { _ = close() }()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
migrator := migrate.NewMigrator(conn, migrations.Migrations)
|
migrator := migrate.NewMigrator(conn.DB, migrations.Migrations)
|
||||||
|
|
||||||
// Initialize migration tables
|
// Initialize migration tables
|
||||||
if err := migrator.Init(ctx); err != nil {
|
if err := migrator.Init(ctx); err != nil {
|
||||||
@@ -47,15 +47,13 @@ func runMigrations(ctx context.Context, cfg *config.Config, command string, coun
|
|||||||
return migrateRollback(ctx, migrator, conn, cfg, countStr)
|
return migrateRollback(ctx, migrator, conn, cfg, countStr)
|
||||||
case "status":
|
case "status":
|
||||||
return migrateStatus(ctx, migrator)
|
return migrateStatus(ctx, migrator)
|
||||||
case "dry-run":
|
|
||||||
return migrateDryRun(ctx, migrator)
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown migration command: %s", command)
|
return fmt.Errorf("unknown migration command: %s", command)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrateUp runs pending migrations
|
// migrateUp runs pending migrations
|
||||||
func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cfg *config.Config, countStr string) error {
|
func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error {
|
||||||
// Parse count parameter
|
// Parse count parameter
|
||||||
count, all, err := parseMigrationCount(countStr)
|
count, all, err := parseMigrationCount(countStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -101,13 +99,13 @@ func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cf
|
|||||||
// Create backup unless --no-backup flag is set
|
// Create backup unless --no-backup flag is set
|
||||||
if !cfg.Flags.MigrateNoBackup {
|
if !cfg.Flags.MigrateNoBackup {
|
||||||
fmt.Println("[INFO] Step 3/5: Creating backup...")
|
fmt.Println("[INFO] Step 3/5: Creating backup...")
|
||||||
_, err := backup.CreateBackup(ctx, cfg, "migration")
|
_, err := db.CreateBackup(ctx, cfg.DB, "migration")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "create backup")
|
return errors.Wrap(err, "create backup")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean old backups
|
// Clean old backups
|
||||||
if err := backup.CleanOldBackups(cfg, cfg.DB.BackupRetention); err != nil {
|
if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil {
|
||||||
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
|
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -143,7 +141,7 @@ func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cf
|
|||||||
}
|
}
|
||||||
|
|
||||||
// migrateRollback rolls back migrations
|
// migrateRollback rolls back migrations
|
||||||
func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cfg *config.Config, countStr string) error {
|
func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error {
|
||||||
// Parse count parameter
|
// Parse count parameter
|
||||||
count, all, err := parseMigrationCount(countStr)
|
count, all, err := parseMigrationCount(countStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -182,13 +180,13 @@ func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *bun.
|
|||||||
// Create backup unless --no-backup flag is set
|
// Create backup unless --no-backup flag is set
|
||||||
if !cfg.Flags.MigrateNoBackup {
|
if !cfg.Flags.MigrateNoBackup {
|
||||||
fmt.Println("[INFO] Creating backup before rollback...")
|
fmt.Println("[INFO] Creating backup before rollback...")
|
||||||
_, err := backup.CreateBackup(ctx, cfg, "rollback")
|
_, err := db.CreateBackup(ctx, cfg.DB, "rollback")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "create backup")
|
return errors.Wrap(err, "create backup")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean old backups
|
// Clean old backups
|
||||||
if err := backup.CleanOldBackups(cfg, cfg.DB.BackupRetention); err != nil {
|
if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil {
|
||||||
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
|
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -255,27 +253,6 @@ func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrateDryRun shows what migrations would run without applying them
|
|
||||||
func migrateDryRun(ctx context.Context, migrator *migrate.Migrator) error {
|
|
||||||
group, err := migrator.Migrate(ctx, migrate.WithNopMigration())
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "dry-run")
|
|
||||||
}
|
|
||||||
|
|
||||||
if group.IsZero() {
|
|
||||||
fmt.Println("[INFO] No pending migrations")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("[INFO] Pending migrations (dry-run):")
|
|
||||||
for _, migration := range group.Migrations {
|
|
||||||
fmt.Printf(" 📋 %s\n", migration.Name)
|
|
||||||
}
|
|
||||||
fmt.Printf("[INFO] Would migrate to group %d\n", group.ID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateMigrations ensures migrations compile before running
|
// validateMigrations ensures migrations compile before running
|
||||||
func validateMigrations(ctx context.Context) error {
|
func validateMigrations(ctx context.Context) error {
|
||||||
cmd := exec.CommandContext(ctx, "go", "build",
|
cmd := exec.CommandContext(ctx, "go", "build",
|
||||||
@@ -292,7 +269,7 @@ func validateMigrations(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// acquireMigrationLock prevents concurrent migrations using PostgreSQL advisory lock
|
// acquireMigrationLock prevents concurrent migrations using PostgreSQL advisory lock
|
||||||
func acquireMigrationLock(ctx context.Context, conn *bun.DB) error {
|
func acquireMigrationLock(ctx context.Context, conn *db.DB) error {
|
||||||
const lockID = 1234567890 // Arbitrary unique ID for migration lock
|
const lockID = 1234567890 // Arbitrary unique ID for migration lock
|
||||||
const timeoutSeconds = 300 // 5 minutes
|
const timeoutSeconds = 300 // 5 minutes
|
||||||
|
|
||||||
@@ -318,7 +295,7 @@ func acquireMigrationLock(ctx context.Context, conn *bun.DB) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// releaseMigrationLock releases the migration lock
|
// releaseMigrationLock releases the migration lock
|
||||||
func releaseMigrationLock(ctx context.Context, conn *bun.DB) {
|
func releaseMigrationLock(ctx context.Context, conn *db.DB) {
|
||||||
const lockID = 1234567890
|
const lockID = 1234567890
|
||||||
|
|
||||||
_, err := conn.NewRaw("SELECT pg_advisory_unlock(?)", lockID).Exec(ctx)
|
_, err := conn.NewRaw("SELECT pg_advisory_unlock(?)", lockID).Exec(ctx)
|
||||||
@@ -329,8 +306,8 @@ func releaseMigrationLock(ctx context.Context, conn *bun.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// createMigration generates a new migration file
|
// CreateMigration generates a new migration file
|
||||||
func createMigration(name string) error {
|
func CreateMigration(name string) error {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return errors.New("migration name cannot be empty")
|
return errors.New("migration name cannot be empty")
|
||||||
}
|
}
|
||||||
@@ -340,7 +317,7 @@ func createMigration(name string) error {
|
|||||||
|
|
||||||
// Generate timestamp
|
// Generate timestamp
|
||||||
timestamp := time.Now().Format("20060102150405")
|
timestamp := time.Now().Format("20060102150405")
|
||||||
filename := fmt.Sprintf("cmd/oslstats/migrations/%s_%s.go", timestamp, name)
|
filename := fmt.Sprintf("internal/db/migrations/%s_%s.go", timestamp, name)
|
||||||
|
|
||||||
// Template
|
// Template
|
||||||
template := `package migrations
|
template := `package migrations
|
||||||
@@ -502,8 +479,8 @@ func executeDownMigrations(ctx context.Context, migrator *migrate.Migrator, migr
|
|||||||
return rolledBack, nil
|
return rolledBack, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// resetDatabase drops and recreates all tables (destructive)
|
// ResetDatabase drops and recreates all tables (destructive)
|
||||||
func resetDatabase(ctx context.Context, cfg *config.Config) error {
|
func ResetDatabase(ctx context.Context, cfg *config.Config) error {
|
||||||
fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!")
|
fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!")
|
||||||
fmt.Print("Type 'yes' to continue: ")
|
fmt.Print("Type 'yes' to continue: ")
|
||||||
|
|
||||||
@@ -518,10 +495,10 @@ func resetDatabase(ctx context.Context, cfg *config.Config) error {
|
|||||||
fmt.Println("❌ Reset cancelled")
|
fmt.Println("❌ Reset cancelled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
conn, close := setupBun(cfg)
|
conn := db.NewDB(cfg.DB)
|
||||||
defer func() { _ = close() }()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
models := registerDBModels(conn)
|
models := conn.RegisterModels()
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
if err := conn.ResetModel(ctx, model); err != nil {
|
if err := conn.ResetModel(ctx, model); err != nil {
|
||||||
@@ -10,9 +10,9 @@ import (
|
|||||||
func init() {
|
func init() {
|
||||||
Migrations.MustRegister(
|
Migrations.MustRegister(
|
||||||
// UP: Create initial tables (users, discord_tokens)
|
// UP: Create initial tables (users, discord_tokens)
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, conn *bun.DB) error {
|
||||||
// Create users table
|
// Create users table
|
||||||
_, err := dbConn.NewCreateTable().
|
_, err := conn.NewCreateTable().
|
||||||
Model((*db.User)(nil)).
|
Model((*db.User)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -20,15 +20,15 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create discord_tokens table
|
// Create discord_tokens table
|
||||||
_, err = dbConn.NewCreateTable().
|
_, err = conn.NewCreateTable().
|
||||||
Model((*db.DiscordToken)(nil)).
|
Model((*db.DiscordToken)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
return err
|
return err
|
||||||
},
|
},
|
||||||
// DOWN: Drop tables in reverse order
|
// DOWN: Drop tables in reverse order
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, conn *bun.DB) error {
|
||||||
// Drop discord_tokens first (has foreign key to users)
|
// Drop discord_tokens first (has foreign key to users)
|
||||||
_, err := dbConn.NewDropTable().
|
_, err := conn.NewDropTable().
|
||||||
Model((*db.DiscordToken)(nil)).
|
Model((*db.DiscordToken)(nil)).
|
||||||
IfExists().
|
IfExists().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -37,7 +37,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Drop users table
|
// Drop users table
|
||||||
_, err = dbConn.NewDropTable().
|
_, err = conn.NewDropTable().
|
||||||
Model((*db.User)(nil)).
|
Model((*db.User)(nil)).
|
||||||
IfExists().
|
IfExists().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
func init() {
|
func init() {
|
||||||
Migrations.MustRegister(
|
Migrations.MustRegister(
|
||||||
// UP migration
|
// UP migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, conn *bun.DB) error {
|
||||||
_, err := dbConn.NewCreateTable().
|
_, err := conn.NewCreateTable().
|
||||||
Model((*db.Season)(nil)).
|
Model((*db.Season)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -20,8 +20,8 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
// DOWN migration
|
// DOWN migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, conn *bun.DB) error {
|
||||||
_, err := dbConn.NewDropTable().
|
_, err := conn.NewDropTable().
|
||||||
Model((*db.Season)(nil)).
|
Model((*db.Season)(nil)).
|
||||||
IfExists().
|
IfExists().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -12,10 +12,10 @@ import (
|
|||||||
func init() {
|
func init() {
|
||||||
Migrations.MustRegister(
|
Migrations.MustRegister(
|
||||||
// UP migration
|
// UP migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, conn *bun.DB) error {
|
||||||
dbConn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil))
|
conn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil))
|
||||||
// Create permissions table
|
// Create permissions table
|
||||||
_, err := dbConn.NewCreateTable().
|
_, err := conn.NewCreateTable().
|
||||||
Model((*db.Role)(nil)).
|
Model((*db.Role)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -23,7 +23,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create permissions table
|
// Create permissions table
|
||||||
_, err = dbConn.NewCreateTable().
|
_, err = conn.NewCreateTable().
|
||||||
Model((*db.Permission)(nil)).
|
Model((*db.Permission)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -31,7 +31,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create indexes for permissions
|
// Create indexes for permissions
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = conn.NewCreateIndex().
|
||||||
Model((*db.Permission)(nil)).
|
Model((*db.Permission)(nil)).
|
||||||
Index("idx_permissions_resource").
|
Index("idx_permissions_resource").
|
||||||
Column("resource").
|
Column("resource").
|
||||||
@@ -40,7 +40,7 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = conn.NewCreateIndex().
|
||||||
Model((*db.Permission)(nil)).
|
Model((*db.Permission)(nil)).
|
||||||
Index("idx_permissions_action").
|
Index("idx_permissions_action").
|
||||||
Column("action").
|
Column("action").
|
||||||
@@ -49,21 +49,21 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewCreateTable().
|
_, err = conn.NewCreateTable().
|
||||||
Model((*db.RolePermission)(nil)).
|
Model((*db.RolePermission)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.ExecContext(ctx, `
|
_, err = conn.ExecContext(ctx, `
|
||||||
CREATE INDEX idx_role_permissions_role ON role_permissions(role_id)
|
CREATE INDEX idx_role_permissions_role ON role_permissions(role_id)
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.ExecContext(ctx, `
|
_, err = conn.ExecContext(ctx, `
|
||||||
CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id)
|
CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id)
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -71,7 +71,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create user_roles table
|
// Create user_roles table
|
||||||
_, err = dbConn.NewCreateTable().
|
_, err = conn.NewCreateTable().
|
||||||
Model((*db.UserRole)(nil)).
|
Model((*db.UserRole)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -79,7 +79,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create indexes for user_roles
|
// Create indexes for user_roles
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = conn.NewCreateIndex().
|
||||||
Model((*db.UserRole)(nil)).
|
Model((*db.UserRole)(nil)).
|
||||||
Index("idx_user_roles_user").
|
Index("idx_user_roles_user").
|
||||||
Column("user_id").
|
Column("user_id").
|
||||||
@@ -88,7 +88,7 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = conn.NewCreateIndex().
|
||||||
Model((*db.UserRole)(nil)).
|
Model((*db.UserRole)(nil)).
|
||||||
Index("idx_user_roles_role").
|
Index("idx_user_roles_role").
|
||||||
Column("role_id").
|
Column("role_id").
|
||||||
@@ -98,7 +98,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create audit_log table
|
// Create audit_log table
|
||||||
_, err = dbConn.NewCreateTable().
|
_, err = conn.NewCreateTable().
|
||||||
Model((*db.AuditLog)(nil)).
|
Model((*db.AuditLog)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -106,7 +106,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create indexes for audit_log
|
// Create indexes for audit_log
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = conn.NewCreateIndex().
|
||||||
Model((*db.AuditLog)(nil)).
|
Model((*db.AuditLog)(nil)).
|
||||||
Index("idx_audit_log_user").
|
Index("idx_audit_log_user").
|
||||||
Column("user_id").
|
Column("user_id").
|
||||||
@@ -115,7 +115,7 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = conn.NewCreateIndex().
|
||||||
Model((*db.AuditLog)(nil)).
|
Model((*db.AuditLog)(nil)).
|
||||||
Index("idx_audit_log_action").
|
Index("idx_audit_log_action").
|
||||||
Column("action").
|
Column("action").
|
||||||
@@ -124,7 +124,7 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = conn.NewCreateIndex().
|
||||||
Model((*db.AuditLog)(nil)).
|
Model((*db.AuditLog)(nil)).
|
||||||
Index("idx_audit_log_resource").
|
Index("idx_audit_log_resource").
|
||||||
Column("resource_type", "resource_id").
|
Column("resource_type", "resource_id").
|
||||||
@@ -133,7 +133,7 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = conn.NewCreateIndex().
|
||||||
Model((*db.AuditLog)(nil)).
|
Model((*db.AuditLog)(nil)).
|
||||||
Index("idx_audit_log_created").
|
Index("idx_audit_log_created").
|
||||||
Column("created_at").
|
Column("created_at").
|
||||||
@@ -142,7 +142,7 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = seedSystemRBAC(ctx, dbConn)
|
err = seedSystemRBAC(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -173,7 +173,7 @@ func init() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
|
func seedSystemRBAC(ctx context.Context, conn *bun.DB) error {
|
||||||
// Seed system roles
|
// Seed system roles
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
|
|||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := dbConn.NewInsert().
|
_, err := conn.NewInsert().
|
||||||
Model(adminRole).
|
Model(adminRole).
|
||||||
Returning("id").
|
Returning("id").
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -201,7 +201,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
|
|||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewInsert().
|
_, err = conn.NewInsert().
|
||||||
Model(userRole).
|
Model(userRole).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -219,7 +219,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
|
|||||||
{Name: "users.manage_roles", DisplayName: "Manage User Roles", Description: "Assign and revoke user roles", Resource: "users", Action: "manage_roles", IsSystem: true, CreatedAt: now},
|
{Name: "users.manage_roles", DisplayName: "Manage User Roles", Description: "Assign and revoke user roles", Resource: "users", Action: "manage_roles", IsSystem: true, CreatedAt: now},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewInsert().
|
_, err = conn.NewInsert().
|
||||||
Model(&permissionsData).
|
Model(&permissionsData).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -229,7 +229,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
|
|||||||
// Grant wildcard permission to admin role using Bun
|
// Grant wildcard permission to admin role using Bun
|
||||||
// First, get the IDs
|
// First, get the IDs
|
||||||
var wildcardPerm db.Permission
|
var wildcardPerm db.Permission
|
||||||
err = dbConn.NewSelect().
|
err = conn.NewSelect().
|
||||||
Model(&wildcardPerm).
|
Model(&wildcardPerm).
|
||||||
Where("name = ?", "*").
|
Where("name = ?", "*").
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
@@ -242,7 +242,7 @@ func seedSystemRBAC(ctx context.Context, dbConn *bun.DB) error {
|
|||||||
RoleID: adminRole.ID,
|
RoleID: adminRole.ID,
|
||||||
PermissionID: wildcardPerm.ID,
|
PermissionID: wildcardPerm.ID,
|
||||||
}
|
}
|
||||||
_, err = dbConn.NewInsert().
|
_, err = conn.NewInsert().
|
||||||
Model(adminRolePerms).
|
Model(adminRolePerms).
|
||||||
On("CONFLICT (role_id, permission_id) DO NOTHING").
|
On("CONFLICT (role_id, permission_id) DO NOTHING").
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
func init() {
|
func init() {
|
||||||
Migrations.MustRegister(
|
Migrations.MustRegister(
|
||||||
// UP migration
|
// UP migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, conn *bun.DB) error {
|
||||||
// Add slap_version column to seasons table
|
// Add slap_version column to seasons table
|
||||||
_, err := dbConn.NewAddColumn().
|
_, err := conn.NewAddColumn().
|
||||||
Model((*db.Season)(nil)).
|
Model((*db.Season)(nil)).
|
||||||
ColumnExpr("slap_version VARCHAR NOT NULL DEFAULT 'rebound'").
|
ColumnExpr("slap_version VARCHAR NOT NULL DEFAULT 'rebound'").
|
||||||
IfNotExists().
|
IfNotExists().
|
||||||
@@ -23,7 +23,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create leagues table
|
// Create leagues table
|
||||||
_, err = dbConn.NewCreateTable().
|
_, err = conn.NewCreateTable().
|
||||||
Model((*db.League)(nil)).
|
Model((*db.League)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -31,15 +31,15 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create season_leagues join table
|
// Create season_leagues join table
|
||||||
_, err = dbConn.NewCreateTable().
|
_, err = conn.NewCreateTable().
|
||||||
Model((*db.SeasonLeague)(nil)).
|
Model((*db.SeasonLeague)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
return err
|
return err
|
||||||
},
|
},
|
||||||
// DOWN migration
|
// DOWN migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, conn *bun.DB) error {
|
||||||
// Drop season_leagues join table first
|
// Drop season_leagues join table first
|
||||||
_, err := dbConn.NewDropTable().
|
_, err := conn.NewDropTable().
|
||||||
Model((*db.SeasonLeague)(nil)).
|
Model((*db.SeasonLeague)(nil)).
|
||||||
IfExists().
|
IfExists().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -48,7 +48,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Drop leagues table
|
// Drop leagues table
|
||||||
_, err = dbConn.NewDropTable().
|
_, err = conn.NewDropTable().
|
||||||
Model((*db.League)(nil)).
|
Model((*db.League)(nil)).
|
||||||
IfExists().
|
IfExists().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -57,7 +57,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove slap_version column from seasons table
|
// Remove slap_version column from seasons table
|
||||||
_, err = dbConn.NewDropColumn().
|
_, err = conn.NewDropColumn().
|
||||||
Model((*db.Season)(nil)).
|
Model((*db.Season)(nil)).
|
||||||
ColumnExpr("slap_version").
|
ColumnExpr("slap_version").
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -10,15 +10,15 @@ import (
|
|||||||
func init() {
|
func init() {
|
||||||
Migrations.MustRegister(
|
Migrations.MustRegister(
|
||||||
// UP migration
|
// UP migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, conn *bun.DB) error {
|
||||||
// Add your migration code here
|
// Add your migration code here
|
||||||
_, err := dbConn.NewCreateTable().
|
_, err := conn.NewCreateTable().
|
||||||
Model((*db.Team)(nil)).
|
Model((*db.Team)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = dbConn.NewCreateTable().
|
_, err = conn.NewCreateTable().
|
||||||
Model((*db.TeamParticipation)(nil)).
|
Model((*db.TeamParticipation)(nil)).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -27,16 +27,16 @@ func init() {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
// DOWN migration
|
// DOWN migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, conn *bun.DB) error {
|
||||||
// Add your rollback code here
|
// Add your rollback code here
|
||||||
_, err := dbConn.NewDropTable().
|
_, err := conn.NewDropTable().
|
||||||
Model((*db.TeamParticipation)(nil)).
|
Model((*db.TeamParticipation)(nil)).
|
||||||
IfExists().
|
IfExists().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = dbConn.NewDropTable().
|
_, err = conn.NewDropTable().
|
||||||
Model((*db.Team)(nil)).
|
Model((*db.Team)(nil)).
|
||||||
IfExists().
|
IfExists().
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
@@ -1,8 +1,11 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,6 +22,41 @@ type OrderOpts struct {
|
|||||||
Label string
|
Label string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request) (*PageOpts, bool) {
|
||||||
|
var getter validation.Getter
|
||||||
|
switch r.Method {
|
||||||
|
case "GET":
|
||||||
|
getter = validation.NewQueryGetter(r)
|
||||||
|
case "POST":
|
||||||
|
var ok bool
|
||||||
|
getter, ok = validation.ParseFormOrError(s, w, r)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return getPageOpts(s, w, r, getter), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *PageOpts {
|
||||||
|
page := g.Int("page").Optional().Min(1).Value
|
||||||
|
perPage := g.Int("per_page").Optional().Min(1).Max(100).Value
|
||||||
|
order := g.String("order").TrimSpace().ToUpper().Optional().AllowedValues([]string{"ASC", "DESC"}).Value
|
||||||
|
orderBy := g.String("order_by").TrimSpace().Optional().ToLower().Value
|
||||||
|
valid := g.ValidateAndError(s, w, r)
|
||||||
|
if !valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pageOpts := &PageOpts{
|
||||||
|
Page: page,
|
||||||
|
PerPage: perPage,
|
||||||
|
Order: bun.Order(order),
|
||||||
|
OrderBy: orderBy,
|
||||||
|
}
|
||||||
|
return pageOpts
|
||||||
|
}
|
||||||
|
|
||||||
func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.SelectQuery, *PageOpts) {
|
func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.SelectQuery, *PageOpts) {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
p = new(PageOpts)
|
p = new(PageOpts)
|
||||||
|
|||||||
@@ -92,5 +92,5 @@ func DeletePermission(ctx context.Context, tx bun.Tx, id int) error {
|
|||||||
if id <= 0 {
|
if id <= 0 {
|
||||||
return errors.New("id must be positive")
|
return errors.New("id must be positive")
|
||||||
}
|
}
|
||||||
return DeleteWithProtection[Permission](ctx, tx, id)
|
return DeleteWithProtection[Permission](ctx, tx, id, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,13 +25,6 @@ type Role struct {
|
|||||||
Permissions []Permission `bun:"m2m:role_permissions,join:Role=Permission"`
|
Permissions []Permission `bun:"m2m:role_permissions,join:Role=Permission"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RolePermission struct {
|
|
||||||
RoleID int `bun:",pk"`
|
|
||||||
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
|
|
||||||
PermissionID int `bun:",pk"`
|
|
||||||
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Role) isSystem() bool {
|
func (r Role) isSystem() bool {
|
||||||
return r.IsSystem
|
return r.IsSystem
|
||||||
}
|
}
|
||||||
@@ -42,17 +35,12 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro
|
|||||||
if name == "" {
|
if name == "" {
|
||||||
return nil, errors.New("name cannot be empty")
|
return nil, errors.New("name cannot be empty")
|
||||||
}
|
}
|
||||||
return GetByField[Role](tx, "name", name).Get(ctx)
|
return GetByField[Role](tx, "name", name).Relation("Permissions").Get(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoleByID queries the database for a role matching the given ID
|
// GetRoleByID queries the database for a role matching the given ID
|
||||||
// Returns nil, nil if no role is found
|
// Returns nil, nil if no role is found
|
||||||
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
|
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
|
||||||
return GetByID[Role](tx, id).Get(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRoleWithPermissions loads a role and all its permissions
|
|
||||||
func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
|
|
||||||
return GetByID[Role](tx, id).Relation("Permissions").Get(ctx)
|
return GetByID[Role](tx, id).Relation("Permissions").Get(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,7 +61,7 @@ func GetRoles(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Role],
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateRole creates a new role
|
// CreateRole creates a new role
|
||||||
func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
|
func CreateRole(ctx context.Context, tx bun.Tx, role *Role, audit *AuditMeta) error {
|
||||||
if role == nil {
|
if role == nil {
|
||||||
return errors.New("role cannot be nil")
|
return errors.New("role cannot be nil")
|
||||||
}
|
}
|
||||||
@@ -81,6 +69,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
|
|||||||
|
|
||||||
err := Insert(tx, role).
|
err := Insert(tx, role).
|
||||||
Returning("id").
|
Returning("id").
|
||||||
|
WithAudit(audit, nil).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "db.Insert")
|
return errors.Wrap(err, "db.Insert")
|
||||||
@@ -90,7 +79,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRole updates an existing role
|
// UpdateRole updates an existing role
|
||||||
func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error {
|
func UpdateRole(ctx context.Context, tx bun.Tx, role *Role, audit *AuditMeta) error {
|
||||||
if role == nil {
|
if role == nil {
|
||||||
return errors.New("role cannot be nil")
|
return errors.New("role cannot be nil")
|
||||||
}
|
}
|
||||||
@@ -100,6 +89,7 @@ func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error {
|
|||||||
|
|
||||||
err := Update(tx, role).
|
err := Update(tx, role).
|
||||||
WherePK().
|
WherePK().
|
||||||
|
WithAudit(audit, nil).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "db.Update")
|
return errors.Wrap(err, "db.Update")
|
||||||
@@ -110,7 +100,7 @@ func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error {
|
|||||||
|
|
||||||
// DeleteRole deletes a role (checks IsSystem protection)
|
// DeleteRole deletes a role (checks IsSystem protection)
|
||||||
// Also cleans up join table entries in role_permissions and user_roles
|
// Also cleans up join table entries in role_permissions and user_roles
|
||||||
func DeleteRole(ctx context.Context, tx bun.Tx, id int) error {
|
func DeleteRole(ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error {
|
||||||
if id <= 0 {
|
if id <= 0 {
|
||||||
return errors.New("id must be positive")
|
return errors.New("id must be positive")
|
||||||
}
|
}
|
||||||
@@ -146,47 +136,5 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Finally delete the role
|
// Finally delete the role
|
||||||
return DeleteWithProtection[Role](ctx, tx, id)
|
return DeleteWithProtection[Role](ctx, tx, id, audit)
|
||||||
}
|
|
||||||
|
|
||||||
// AddPermissionToRole grants a permission to a role
|
|
||||||
func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error {
|
|
||||||
if roleID <= 0 {
|
|
||||||
return errors.New("roleID must be positive")
|
|
||||||
}
|
|
||||||
if permissionID <= 0 {
|
|
||||||
return errors.New("permissionID must be positive")
|
|
||||||
}
|
|
||||||
rolePerm := &RolePermission{
|
|
||||||
RoleID: roleID,
|
|
||||||
PermissionID: permissionID,
|
|
||||||
}
|
|
||||||
err := Insert(tx, rolePerm).
|
|
||||||
ConflictNothing("role_id", "permission_id").
|
|
||||||
Exec(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "db.Insert")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePermissionFromRole revokes a permission from a role
|
|
||||||
func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error {
|
|
||||||
if roleID <= 0 {
|
|
||||||
return errors.New("roleID must be positive")
|
|
||||||
}
|
|
||||||
if permissionID <= 0 {
|
|
||||||
return errors.New("permissionID must be positive")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := DeleteItem[RolePermission](tx).
|
|
||||||
Where("role_id = ?", roleID).
|
|
||||||
Where("permission_id = ?", permissionID).
|
|
||||||
Delete(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "DeleteItem")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
99
internal/db/rolepermission.go
Normal file
99
internal/db/rolepermission.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RolePermission struct {
|
||||||
|
RoleID int `bun:",pk"`
|
||||||
|
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
|
||||||
|
PermissionID int `bun:",pk"`
|
||||||
|
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Role) UpdatePermissions(ctx context.Context, tx bun.Tx, newPermissionsIDs []int, audit *AuditMeta) error {
|
||||||
|
addPerms, removePerms, err := detectChangedPermissions(ctx, tx, r, newPermissionsIDs)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "detectChangedPermissions")
|
||||||
|
}
|
||||||
|
addedPerms := []string{}
|
||||||
|
removedPerms := []string{}
|
||||||
|
for _, perm := range addPerms {
|
||||||
|
rolePerm := &RolePermission{
|
||||||
|
RoleID: r.ID,
|
||||||
|
PermissionID: perm.ID,
|
||||||
|
}
|
||||||
|
err := Insert(tx, rolePerm).
|
||||||
|
ConflictNothing("role_id", "permission_id").
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "db.Insert")
|
||||||
|
}
|
||||||
|
addedPerms = append(addedPerms, perm.Name.String())
|
||||||
|
}
|
||||||
|
for _, perm := range removePerms {
|
||||||
|
err := DeleteItem[RolePermission](tx).
|
||||||
|
Where("role_id = ?", r.ID).
|
||||||
|
Where("permission_id = ?", perm.ID).
|
||||||
|
Delete(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "DeleteItem")
|
||||||
|
}
|
||||||
|
removedPerms = append(removedPerms, perm.Name.String())
|
||||||
|
}
|
||||||
|
// Log the permission changes
|
||||||
|
if len(addedPerms) > 0 || len(removedPerms) > 0 {
|
||||||
|
details := map[string]any{
|
||||||
|
"role_name": string(r.Name),
|
||||||
|
}
|
||||||
|
if len(addedPerms) > 0 {
|
||||||
|
details["added_permissions"] = addedPerms
|
||||||
|
}
|
||||||
|
if len(removedPerms) > 0 {
|
||||||
|
details["removed_permissions"] = removedPerms
|
||||||
|
}
|
||||||
|
info := &AuditInfo{
|
||||||
|
"roles.update_permissions",
|
||||||
|
"role",
|
||||||
|
r.ID,
|
||||||
|
details,
|
||||||
|
}
|
||||||
|
err = LogSuccess(ctx, tx, audit, info)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "LogSuccess")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectChangedPermissions(ctx context.Context, tx bun.Tx, role *Role, permissionIDs []int) ([]*Permission, []*Permission, error) {
|
||||||
|
allPermissions, err := ListAllPermissions(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "ListAllPermissions")
|
||||||
|
}
|
||||||
|
// Build map of current permissions
|
||||||
|
currentPermIDs := make(map[int]bool)
|
||||||
|
for _, perm := range role.Permissions {
|
||||||
|
currentPermIDs[perm.ID] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var addedPerms []*Permission
|
||||||
|
var removedPerms []*Permission
|
||||||
|
|
||||||
|
// Determine what to add and remove
|
||||||
|
for _, perm := range allPermissions {
|
||||||
|
hasNow := currentPermIDs[perm.ID]
|
||||||
|
shouldHave := slices.Contains(permissionIDs, perm.ID)
|
||||||
|
|
||||||
|
if shouldHave && !hasNow {
|
||||||
|
addedPerms = append(addedPerms, perm)
|
||||||
|
} else if !shouldHave && hasNow {
|
||||||
|
removedPerms = append(removedPerms, perm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return addedPerms, removedPerms, nil
|
||||||
|
}
|
||||||
@@ -25,15 +25,22 @@ type Season struct {
|
|||||||
Teams []Team `bun:"m2m:team_participations,join:Season=Team"`
|
Teams []Team `bun:"m2m:team_participations,join:Season=Team"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSeason returns a new season. It does not add it to the database
|
// NewSeason creats a new season
|
||||||
func NewSeason(name, version, shortname string, start time.Time) *Season {
|
func NewSeason(ctx context.Context, tx bun.Tx, name, version, shortname string,
|
||||||
|
start time.Time, audit *AuditMeta,
|
||||||
|
) (*Season, error) {
|
||||||
season := &Season{
|
season := &Season{
|
||||||
Name: name,
|
Name: name,
|
||||||
ShortName: strings.ToUpper(shortname),
|
ShortName: strings.ToUpper(shortname),
|
||||||
StartDate: start.Truncate(time.Hour * 24),
|
StartDate: start.Truncate(time.Hour * 24),
|
||||||
SlapVersion: version,
|
SlapVersion: version,
|
||||||
}
|
}
|
||||||
return season
|
err := Insert(tx, season).
|
||||||
|
WithAudit(audit, nil).Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithMessage(err, "db.Insert")
|
||||||
|
}
|
||||||
|
return season, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) {
|
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) {
|
||||||
@@ -54,7 +61,9 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update updates the season struct. It does not insert to the database
|
// Update updates the season struct. It does not insert to the database
|
||||||
func (s *Season) Update(version string, start, end, finalsStart, finalsEnd time.Time) {
|
func (s *Season) Update(ctx context.Context, tx bun.Tx, version string,
|
||||||
|
start, end, finalsStart, finalsEnd time.Time, audit *AuditMeta,
|
||||||
|
) error {
|
||||||
s.SlapVersion = version
|
s.SlapVersion = version
|
||||||
s.StartDate = start.Truncate(time.Hour * 24)
|
s.StartDate = start.Truncate(time.Hour * 24)
|
||||||
if !end.IsZero() {
|
if !end.IsZero() {
|
||||||
@@ -66,6 +75,9 @@ func (s *Season) Update(version string, start, end, finalsStart, finalsEnd time.
|
|||||||
if !finalsEnd.IsZero() {
|
if !finalsEnd.IsZero() {
|
||||||
s.FinalsEndDate.Time = finalsEnd.Truncate(time.Hour * 24)
|
s.FinalsEndDate.Time = finalsEnd.Truncate(time.Hour * 24)
|
||||||
}
|
}
|
||||||
|
return Update(tx, s).WherePK().
|
||||||
|
Column("slap_version", "start_date", "end_date", "finals_start_date", "finals_end_date").
|
||||||
|
WithAudit(audit, nil).Exec(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Season) MapTeamsToLeagues(ctx context.Context, tx bun.Tx) ([]LeagueWithTeams, error) {
|
func (s *Season) MapTeamsToLeagues(ctx context.Context, tx bun.Tx) ([]LeagueWithTeams, error) {
|
||||||
|
|||||||
118
internal/db/seasonleague.go
Normal file
118
internal/db/seasonleague.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/permissions"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SeasonLeague struct {
|
||||||
|
SeasonID int `bun:",pk"`
|
||||||
|
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
|
||||||
|
LeagueID int `bun:",pk"`
|
||||||
|
League *League `bun:"rel:belongs-to,join:league_id=id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSeasonLeague retrieves a specific season-league combination with teams
|
||||||
|
func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string) (*Season, *League, []*Team, error) {
|
||||||
|
if seasonShortName == "" {
|
||||||
|
return nil, nil, nil, errors.New("season short_name cannot be empty")
|
||||||
|
}
|
||||||
|
if leagueShortName == "" {
|
||||||
|
return nil, nil, nil, errors.New("league short_name cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the season
|
||||||
|
season, err := GetSeason(ctx, tx, seasonShortName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, errors.Wrap(err, "GetSeason")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the league
|
||||||
|
league, err := GetLeague(ctx, tx, leagueShortName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, errors.Wrap(err, "GetLeague")
|
||||||
|
}
|
||||||
|
if season == nil || league == nil || !season.HasLeague(league.ID) {
|
||||||
|
return nil, nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all teams participating in this season+league
|
||||||
|
var teams []*Team
|
||||||
|
err = tx.NewSelect().
|
||||||
|
Model(&teams).
|
||||||
|
Join("INNER JOIN team_participations AS tp ON tp.team_id = t.id").
|
||||||
|
Where("tp.season_id = ? AND tp.league_id = ?", season.ID, league.ID).
|
||||||
|
Order("t.name ASC").
|
||||||
|
Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, errors.Wrap(err, "tx.Select teams")
|
||||||
|
}
|
||||||
|
|
||||||
|
return season, league, teams, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string, audit *AuditMeta) error {
|
||||||
|
season, err := GetSeason(ctx, tx, seasonShortName)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "GetSeason")
|
||||||
|
}
|
||||||
|
if season == nil {
|
||||||
|
return errors.New("season not found")
|
||||||
|
}
|
||||||
|
league, err := GetLeague(ctx, tx, leagueShortName)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "GetLeague")
|
||||||
|
}
|
||||||
|
if league == nil {
|
||||||
|
return errors.New("league not found")
|
||||||
|
}
|
||||||
|
if season.HasLeague(league.ID) {
|
||||||
|
return errors.New("league already added to season")
|
||||||
|
}
|
||||||
|
seasonLeague := &SeasonLeague{
|
||||||
|
SeasonID: season.ID,
|
||||||
|
LeagueID: league.ID,
|
||||||
|
}
|
||||||
|
info := &AuditInfo{
|
||||||
|
string(permissions.SeasonsAddLeague),
|
||||||
|
"season",
|
||||||
|
season.ID,
|
||||||
|
map[string]any{"league_id": league.ID},
|
||||||
|
}
|
||||||
|
err = Insert(tx, seasonLeague).WithAudit(audit, info).Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "db.Insert")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Season) RemoveLeague(ctx context.Context, tx bun.Tx, leagueShortName string, audit *AuditMeta) error {
|
||||||
|
league, err := GetLeague(ctx, tx, leagueShortName)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "GetLeague")
|
||||||
|
}
|
||||||
|
if league == nil {
|
||||||
|
return errors.New("league not found")
|
||||||
|
}
|
||||||
|
if !s.HasLeague(league.ID) {
|
||||||
|
return errors.New("league not in season")
|
||||||
|
}
|
||||||
|
info := &AuditInfo{
|
||||||
|
string(permissions.SeasonsRemoveLeague),
|
||||||
|
"season",
|
||||||
|
s.ID,
|
||||||
|
map[string]any{"league_id": league.ID},
|
||||||
|
}
|
||||||
|
err = DeleteItem[SeasonLeague](tx).
|
||||||
|
Where("season_id = ?", s.ID).
|
||||||
|
Where("league_id = ?", league.ID).
|
||||||
|
WithAudit(audit, info).
|
||||||
|
Delete(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "db.DeleteItem")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
55
internal/db/setup.go
Normal file
55
internal/db/setup.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
"github.com/uptrace/bun/dialect/pgdialect"
|
||||||
|
"github.com/uptrace/bun/driver/pgdriver"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DB struct {
|
||||||
|
*bun.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Close() error {
|
||||||
|
return db.DB.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) RegisterModels() []any {
|
||||||
|
models := []any{
|
||||||
|
(*RolePermission)(nil),
|
||||||
|
(*UserRole)(nil),
|
||||||
|
(*SeasonLeague)(nil),
|
||||||
|
(*TeamParticipation)(nil),
|
||||||
|
(*User)(nil),
|
||||||
|
(*DiscordToken)(nil),
|
||||||
|
(*Season)(nil),
|
||||||
|
(*League)(nil),
|
||||||
|
(*Team)(nil),
|
||||||
|
(*Role)(nil),
|
||||||
|
(*Permission)(nil),
|
||||||
|
(*AuditLog)(nil),
|
||||||
|
}
|
||||||
|
db.RegisterModel(models...)
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDB(cfg *Config) *DB {
|
||||||
|
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
|
||||||
|
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DB, cfg.SSL)
|
||||||
|
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
|
||||||
|
|
||||||
|
sqldb.SetMaxOpenConns(25)
|
||||||
|
sqldb.SetMaxIdleConns(10)
|
||||||
|
sqldb.SetConnMaxLifetime(5 * time.Minute)
|
||||||
|
sqldb.SetConnMaxIdleTime(5 * time.Minute)
|
||||||
|
|
||||||
|
db := &DB{
|
||||||
|
bun.NewDB(sqldb, pgdialect.New()),
|
||||||
|
}
|
||||||
|
db.RegisterModels()
|
||||||
|
return db
|
||||||
|
}
|
||||||
@@ -19,13 +19,19 @@ type Team struct {
|
|||||||
Leagues []League `bun:"m2m:team_participations,join:Team=League"`
|
Leagues []League `bun:"m2m:team_participations,join:Team=League"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TeamParticipation struct {
|
func NewTeam(ctx context.Context, tx bun.Tx, name, shortName, altShortName, color string, audit *AuditMeta) (*Team, error) {
|
||||||
SeasonID int `bun:",pk,unique:season_team"`
|
team := &Team{
|
||||||
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
|
Name: name,
|
||||||
LeagueID int `bun:",pk"`
|
ShortName: shortName,
|
||||||
League *League `bun:"rel:belongs-to,join:league_id=id"`
|
AltShortName: altShortName,
|
||||||
TeamID int `bun:",pk,unique:season_team"`
|
Color: color,
|
||||||
Team *Team `bun:"rel:belongs-to,join:team_id=id"`
|
}
|
||||||
|
err := Insert(tx, team).
|
||||||
|
WithAudit(audit, nil).Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "db.Insert")
|
||||||
|
}
|
||||||
|
return team, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListTeams(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Team], error) {
|
func ListTeams(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Team], error) {
|
||||||
@@ -45,14 +51,23 @@ func GetTeam(ctx context.Context, tx bun.Tx, id int) (*Team, error) {
|
|||||||
return GetByID[Team](tx, id).Relation("Seasons").Relation("Leagues").Get(ctx)
|
return GetByID[Team](tx, id).Relation("Seasons").Relation("Leagues").Get(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TeamShortNamesUnique(ctx context.Context, tx bun.Tx, shortname, altshortname string) (bool, error) {
|
func TeamShortNamesUnique(ctx context.Context, tx bun.Tx, shortName, altShortName string) (bool, error) {
|
||||||
// Check if this combination of short_name and alt_short_name exists
|
// Check if this combination of short_name and alt_short_name exists
|
||||||
count, err := tx.NewSelect().
|
count, err := tx.NewSelect().
|
||||||
Model((*Team)(nil)).
|
Model((*Team)(nil)).
|
||||||
Where("short_name = ? AND alt_short_name = ?", shortname, altshortname).
|
Where("short_name = ? AND alt_short_name = ?", shortName, altShortName).
|
||||||
Count(ctx)
|
Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "tx.Select")
|
return false, errors.Wrap(err, "tx.Select")
|
||||||
}
|
}
|
||||||
return count == 0, nil
|
return count == 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Team) InSeason(seasonID int) bool {
|
||||||
|
for _, season := range t.Seasons {
|
||||||
|
if season.ID == seasonID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
67
internal/db/teamparticipation.go
Normal file
67
internal/db/teamparticipation.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TeamParticipation struct {
|
||||||
|
SeasonID int `bun:",pk,unique:season_team"`
|
||||||
|
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
|
||||||
|
LeagueID int `bun:",pk"`
|
||||||
|
League *League `bun:"rel:belongs-to,join:league_id=id"`
|
||||||
|
TeamID int `bun:",pk,unique:season_team"`
|
||||||
|
Team *Team `bun:"rel:belongs-to,join:team_id=id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTeamParticipation(ctx context.Context, tx bun.Tx,
|
||||||
|
seasonShortName, leagueShortName string, teamID int, audit *AuditMeta,
|
||||||
|
) (*Team, *Season, *League, error) {
|
||||||
|
season, err := GetSeason(ctx, tx, seasonShortName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, errors.Wrap(err, "GetSeason")
|
||||||
|
}
|
||||||
|
if season == nil {
|
||||||
|
return nil, nil, nil, errors.New("season not found")
|
||||||
|
}
|
||||||
|
league, err := GetLeague(ctx, tx, leagueShortName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, errors.Wrap(err, "GetLeague")
|
||||||
|
}
|
||||||
|
if league == nil {
|
||||||
|
return nil, nil, nil, errors.New("league not found")
|
||||||
|
}
|
||||||
|
if !season.HasLeague(league.ID) {
|
||||||
|
return nil, nil, nil, errors.New("league is not assigned to the season")
|
||||||
|
}
|
||||||
|
team, err := GetTeam(ctx, tx, teamID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, errors.Wrap(err, "GetTeam")
|
||||||
|
}
|
||||||
|
if team == nil {
|
||||||
|
return nil, nil, nil, errors.New("team not found")
|
||||||
|
}
|
||||||
|
if team.InSeason(season.ID) {
|
||||||
|
return nil, nil, nil, errors.New("team already in season")
|
||||||
|
}
|
||||||
|
participation := &TeamParticipation{
|
||||||
|
SeasonID: season.ID,
|
||||||
|
LeagueID: league.ID,
|
||||||
|
TeamID: team.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
info := &AuditInfo{
|
||||||
|
"teams.join_season",
|
||||||
|
"team",
|
||||||
|
teamID,
|
||||||
|
map[string]any{"season_id": season.ID, "league_id": league.ID},
|
||||||
|
}
|
||||||
|
err = Insert(tx, participation).
|
||||||
|
WithAudit(audit, info).Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, errors.Wrap(err, "db.Insert")
|
||||||
|
}
|
||||||
|
return team, season, league, nil
|
||||||
|
}
|
||||||
@@ -22,16 +22,15 @@ var timeout = 15 * time.Second
|
|||||||
|
|
||||||
// WithReadTx executes a read-only transaction with automatic rollback
|
// WithReadTx executes a read-only transaction with automatic rollback
|
||||||
// Returns true if successful, false if error was thrown to client
|
// Returns true if successful, false if error was thrown to client
|
||||||
func WithReadTx(
|
func (db *DB) WithReadTx(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
conn *bun.DB,
|
|
||||||
fn TxFunc,
|
fn TxFunc,
|
||||||
) bool {
|
) bool {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ok, err := withTx(ctx, conn, fn, false)
|
ok, err := db.withTx(ctx, fn, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throw.InternalServiceError(s, w, r, "Database error", err)
|
throw.InternalServiceError(s, w, r, "Database error", err)
|
||||||
}
|
}
|
||||||
@@ -41,31 +40,29 @@ func WithReadTx(
|
|||||||
// WithTxFailSilently executes a transaction with automatic rollback
|
// WithTxFailSilently executes a transaction with automatic rollback
|
||||||
// Returns true if successful, false if error occured.
|
// Returns true if successful, false if error occured.
|
||||||
// Does not throw any errors to the client.
|
// Does not throw any errors to the client.
|
||||||
func WithTxFailSilently(
|
func (db *DB) WithTxFailSilently(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
conn *bun.DB,
|
|
||||||
fn TxFuncSilent,
|
fn TxFuncSilent,
|
||||||
) error {
|
) error {
|
||||||
fnc := func(ctx context.Context, tx bun.Tx) (bool, error) {
|
fnc := func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
err := fn(ctx, tx)
|
err := fn(ctx, tx)
|
||||||
return err == nil, err
|
return err == nil, err
|
||||||
}
|
}
|
||||||
_, err := withTx(ctx, conn, fnc, true)
|
_, err := db.withTx(ctx, fnc, true)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithWriteTx executes a write transaction with automatic rollback on error
|
// WithWriteTx executes a write transaction with automatic rollback on error
|
||||||
// Commits only if fn returns nil. Returns true if successful.
|
// Commits only if fn returns nil. Returns true if successful.
|
||||||
func WithWriteTx(
|
func (db *DB) WithWriteTx(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
conn *bun.DB,
|
|
||||||
fn TxFunc,
|
fn TxFunc,
|
||||||
) bool {
|
) bool {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ok, err := withTx(ctx, conn, fn, true)
|
ok, err := db.withTx(ctx, fn, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throw.InternalServiceError(s, w, r, "Database error", err)
|
throw.InternalServiceError(s, w, r, "Database error", err)
|
||||||
}
|
}
|
||||||
@@ -74,16 +71,15 @@ func WithWriteTx(
|
|||||||
|
|
||||||
// WithNotifyTx executes a transaction with notification-based error handling
|
// WithNotifyTx executes a transaction with notification-based error handling
|
||||||
// Uses notifyInternalServiceError instead of throwInternalServiceError
|
// Uses notifyInternalServiceError instead of throwInternalServiceError
|
||||||
func WithNotifyTx(
|
func (db *DB) WithNotifyTx(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
conn *bun.DB,
|
|
||||||
fn TxFunc,
|
fn TxFunc,
|
||||||
) bool {
|
) bool {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ok, err := withTx(ctx, conn, fn, true)
|
ok, err := db.withTx(ctx, fn, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
notify.InternalServiceError(s, w, r, "Database error", err)
|
notify.InternalServiceError(s, w, r, "Database error", err)
|
||||||
}
|
}
|
||||||
@@ -91,13 +87,12 @@ func WithNotifyTx(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// withTx executes a transaction with automatic rollback on error
|
// withTx executes a transaction with automatic rollback on error
|
||||||
func withTx(
|
func (db *DB) withTx(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
conn *bun.DB,
|
|
||||||
fn TxFunc,
|
fn TxFunc,
|
||||||
write bool,
|
write bool,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "conn.BeginTx")
|
return false, errors.Wrap(err, "conn.BeginTx")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,19 +2,18 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
type updater[T any] struct {
|
type updater[T any] struct {
|
||||||
tx bun.Tx
|
tx bun.Tx
|
||||||
q *bun.UpdateQuery
|
q *bun.UpdateQuery
|
||||||
model *T
|
model *T
|
||||||
columns []string
|
columns []string
|
||||||
auditCallback AuditCallback
|
audit *AuditMeta
|
||||||
auditRequest *http.Request
|
auditInfo *AuditInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update creates an updater for a model
|
// Update creates an updater for a model
|
||||||
@@ -69,11 +68,10 @@ func (u *updater[T]) Set(query string, args ...any) *updater[T] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WithAudit enables audit logging for this update operation
|
// WithAudit enables audit logging for this update operation
|
||||||
// The callback will be invoked after successful update with auto-generated audit info
|
// If the provided *AuditInfo is nil, will use reflection to automatically work out the details
|
||||||
// If the callback returns an error, the transaction will be rolled back
|
func (u *updater[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *updater[T] {
|
||||||
func (u *updater[T]) WithAudit(r *http.Request, callback AuditCallback) *updater[T] {
|
u.audit = meta
|
||||||
u.auditRequest = r
|
u.auditInfo = info
|
||||||
u.auditCallback = callback
|
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,7 +80,7 @@ func (u *updater[T]) WithAudit(r *http.Request, callback AuditCallback) *updater
|
|||||||
func (u *updater[T]) Exec(ctx context.Context) error {
|
func (u *updater[T]) Exec(ctx context.Context) error {
|
||||||
// Build audit details BEFORE update (captures changed fields)
|
// Build audit details BEFORE update (captures changed fields)
|
||||||
var details map[string]any
|
var details map[string]any
|
||||||
if u.auditCallback != nil && len(u.columns) > 0 {
|
if u.audit != nil && len(u.columns) > 0 {
|
||||||
details = extractChangedFields(u.model, u.columns)
|
details = extractChangedFields(u.model, u.columns)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,21 +91,22 @@ func (u *updater[T]) Exec(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle audit logging if enabled
|
// Handle audit logging if enabled
|
||||||
if u.auditCallback != nil && u.auditRequest != nil {
|
if u.audit != nil {
|
||||||
tableName := extractTableName[T]()
|
if u.auditInfo == nil {
|
||||||
resourceType := extractResourceType(tableName)
|
tableName := extractTableName[T]()
|
||||||
action := buildAction(resourceType, "update")
|
resourceType := extractResourceType(tableName)
|
||||||
|
action := buildAction(resourceType, "update")
|
||||||
|
|
||||||
info := &AuditInfo{
|
u.auditInfo = &AuditInfo{
|
||||||
Action: action,
|
Action: action,
|
||||||
ResourceType: resourceType,
|
ResourceType: resourceType,
|
||||||
ResourceID: extractPrimaryKey(u.model),
|
ResourceID: extractPrimaryKey(u.model),
|
||||||
Details: details, // Changed fields only
|
Details: details, // Changed fields only
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
err = LogSuccess(ctx, u.tx, u.audit, u.auditInfo)
|
||||||
// Call audit callback - if it fails, return error to trigger rollback
|
if err != nil {
|
||||||
if err := u.auditCallback(ctx, u.tx, info, u.auditRequest); err != nil {
|
return errors.Wrap(err, "LogSuccess")
|
||||||
return errors.Wrap(err, "audit.callback")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import (
|
|||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
var CurrentUser hwsauth.ContextLoader[*User]
|
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
bun.BaseModel `bun:"table:users,alias:u"`
|
bun.BaseModel `bun:"table:users,alias:u"`
|
||||||
|
|
||||||
@@ -29,8 +27,10 @@ func (u *User) GetID() int {
|
|||||||
return u.ID
|
return u.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var CurrentUser hwsauth.ContextLoader[*User]
|
||||||
|
|
||||||
// CreateUser creates a new user with the given username and password
|
// CreateUser creates a new user with the given username and password
|
||||||
func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) {
|
func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User, audit *AuditMeta) (*User, error) {
|
||||||
if discorduser == nil {
|
if discorduser == nil {
|
||||||
return nil, errors.New("user cannot be nil")
|
return nil, errors.New("user cannot be nil")
|
||||||
}
|
}
|
||||||
@@ -39,8 +39,10 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
|
|||||||
CreatedAt: time.Now().Unix(),
|
CreatedAt: time.Now().Unix(),
|
||||||
DiscordID: discorduser.ID,
|
DiscordID: discorduser.ID,
|
||||||
}
|
}
|
||||||
|
audit.u = user
|
||||||
|
|
||||||
err := Insert(tx, user).
|
err := Insert(tx, user).
|
||||||
|
WithAudit(audit, nil).
|
||||||
Returning("id").
|
Returning("id").
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package db
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/permissions"
|
||||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -16,7 +17,7 @@ type UserRole struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AssignRole grants a role to a user
|
// AssignRole grants a role to a user
|
||||||
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
|
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, audit *AuditMeta) error {
|
||||||
if userID <= 0 {
|
if userID <= 0 {
|
||||||
return errors.New("userID must be positive")
|
return errors.New("userID must be positive")
|
||||||
}
|
}
|
||||||
@@ -28,8 +29,20 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
|
|||||||
UserID: userID,
|
UserID: userID,
|
||||||
RoleID: roleID,
|
RoleID: roleID,
|
||||||
}
|
}
|
||||||
|
details := map[string]any{
|
||||||
|
"action": "grant",
|
||||||
|
"role_id": roleID,
|
||||||
|
}
|
||||||
|
info := &AuditInfo{
|
||||||
|
string(permissions.UsersManageRoles),
|
||||||
|
"user",
|
||||||
|
userID,
|
||||||
|
details,
|
||||||
|
}
|
||||||
err := Insert(tx, userRole).
|
err := Insert(tx, userRole).
|
||||||
ConflictNothing("user_id", "role_id").Exec(ctx)
|
ConflictNothing("user_id", "role_id").
|
||||||
|
WithAudit(audit, info).
|
||||||
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "db.Insert")
|
return errors.Wrap(err, "db.Insert")
|
||||||
}
|
}
|
||||||
@@ -38,7 +51,7 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RevokeRole removes a role from a user
|
// RevokeRole removes a role from a user
|
||||||
func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
|
func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int, audit *AuditMeta) error {
|
||||||
if userID <= 0 {
|
if userID <= 0 {
|
||||||
return errors.New("userID must be positive")
|
return errors.New("userID must be positive")
|
||||||
}
|
}
|
||||||
@@ -46,9 +59,20 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
|
|||||||
return errors.New("roleID must be positive")
|
return errors.New("roleID must be positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
details := map[string]any{
|
||||||
|
"action": "revoke",
|
||||||
|
"role_id": roleID,
|
||||||
|
}
|
||||||
|
info := &AuditInfo{
|
||||||
|
string(permissions.UsersManageRoles),
|
||||||
|
"user",
|
||||||
|
userID,
|
||||||
|
details,
|
||||||
|
}
|
||||||
err := DeleteItem[UserRole](tx).
|
err := DeleteItem[UserRole](tx).
|
||||||
Where("user_id = ?", userID).
|
Where("user_id = ?", userID).
|
||||||
Where("role_id = ?", roleID).
|
Where("role_id = ?", roleID).
|
||||||
|
WithAudit(audit, info).
|
||||||
Delete(ctx)
|
Delete(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "DeleteItem")
|
return errors.Wrap(err, "DeleteItem")
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ import (
|
|||||||
var embeddedFiles embed.FS
|
var embeddedFiles embed.FS
|
||||||
|
|
||||||
// GetEmbeddedFS gets the embedded files
|
// GetEmbeddedFS gets the embedded files
|
||||||
func GetEmbeddedFS() (fs.FS, error) {
|
func GetEmbeddedFS() (*fs.FS, error) {
|
||||||
subFS, err := fs.Sub(embeddedFiles, "web")
|
subFS, err := fs.Sub(embeddedFiles, "web")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "fs.Sub")
|
return nil, errors.Wrap(err, "fs.Sub")
|
||||||
}
|
}
|
||||||
return subFS, nil
|
return &subFS, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AdminAuditLogsPage renders the full admin dashboard page with audit logs section (GET request)
|
// AdminAuditLogsPage renders the full admin dashboard page with audit logs section (GET request)
|
||||||
func AdminAuditLogsPage(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminAuditLogsPage(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
pageOpts := pageOptsFromQuery(s, w, r)
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if pageOpts == nil {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ func AdminAuditLogsPage(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
var actions []string
|
var actions []string
|
||||||
var resourceTypes []string
|
var resourceTypes []string
|
||||||
|
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Get filters from query
|
// Get filters from query
|
||||||
@@ -73,10 +73,10 @@ func AdminAuditLogsPage(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AdminAuditLogsList shows the full audit logs list with filters (POST request for HTMX)
|
// AdminAuditLogsList shows the full audit logs list with filters (POST request for HTMX)
|
||||||
func AdminAuditLogsList(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminAuditLogsList(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
pageOpts := pageOptsFromForm(s, w, r)
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if pageOpts == nil {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ func AdminAuditLogsList(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
var actions []string
|
var actions []string
|
||||||
var resourceTypes []string
|
var resourceTypes []string
|
||||||
|
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Get filters from form
|
// Get filters from form
|
||||||
@@ -129,16 +129,16 @@ func AdminAuditLogsList(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AdminAuditLogsFilter returns only the results container (table + pagination) for HTMX updates
|
// AdminAuditLogsFilter returns only the results container (table + pagination) for HTMX updates
|
||||||
func AdminAuditLogsFilter(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminAuditLogsFilter(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
pageOpts := pageOptsFromForm(s, w, r)
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if pageOpts == nil {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var logs *db.List[db.AuditLog]
|
var logs *db.List[db.AuditLog]
|
||||||
|
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Get filters from form
|
// Get filters from form
|
||||||
@@ -164,7 +164,7 @@ func AdminAuditLogsFilter(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AdminAuditLogDetail shows details for a single audit log entry
|
// AdminAuditLogDetail shows details for a single audit log entry
|
||||||
func AdminAuditLogDetail(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminAuditLogDetail(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Get ID from path
|
// Get ID from path
|
||||||
idStr := r.PathValue("id")
|
idStr := r.PathValue("id")
|
||||||
@@ -181,7 +181,7 @@ func AdminAuditLogDetail(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
|
|
||||||
var log *db.AuditLog
|
var log *db.AuditLog
|
||||||
|
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
log, err = db.GetAuditLogByID(ctx, tx, id)
|
log, err = db.GetAuditLogByID(ctx, tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AdminDashboard renders the full admin dashboard page (defaults to users section)
|
// AdminDashboard renders the full admin dashboard page (defaults to users section)
|
||||||
func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminDashboard(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var users *db.List[db.User]
|
var users *db.List[db.User]
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
users, err = db.GetUsersWithRoles(ctx, tx, nil)
|
users, err = db.GetUsersWithRoles(ctx, tx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
|
||||||
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
|
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AdminPermissionsPage renders the full admin dashboard page with permissions section
|
|
||||||
func AdminPermissionsPage(s *hws.Server, conn *bun.DB) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// TODO: Load permissions from database
|
|
||||||
renderSafely(adminview.PermissionsPage(), s, r, w)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AdminPermissionsList shows all permissions (HTMX content replacement)
|
|
||||||
func AdminPermissionsList(s *hws.Server, conn *bun.DB) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// TODO: Load permissions from database
|
|
||||||
renderSafely(adminview.PermissionsList(), s, r, w)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/rbac"
|
"git.haelnorr.com/h/oslstats/internal/rbac"
|
||||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
@@ -16,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AdminPreviewRoleStart starts preview mode for a specific role
|
// AdminPreviewRoleStart starts preview mode for a specific role
|
||||||
func AdminPreviewRoleStart(s *hws.Server, conn *bun.DB, cfg *config.Config) http.Handler {
|
func AdminPreviewRoleStart(s *hws.Server, conn *db.DB, ssl bool) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Get role ID from URL
|
// Get role ID from URL
|
||||||
roleIDStr := r.PathValue("id")
|
roleIDStr := r.PathValue("id")
|
||||||
@@ -28,7 +27,7 @@ func AdminPreviewRoleStart(s *hws.Server, conn *bun.DB, cfg *config.Config) http
|
|||||||
|
|
||||||
// Verify role exists and is not admin
|
// Verify role exists and is not admin
|
||||||
var role *db.Role
|
var role *db.Role
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
role, err = db.GetRoleByID(ctx, tx, roleID)
|
role, err = db.GetRoleByID(ctx, tx, roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -49,7 +48,7 @@ func AdminPreviewRoleStart(s *hws.Server, conn *bun.DB, cfg *config.Config) http
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set preview role cookie
|
// Set preview role cookie
|
||||||
rbac.SetPreviewRoleCookie(w, roleID, cfg.HWSAuth.SSL)
|
rbac.SetPreviewRoleCookie(w, roleID, ssl)
|
||||||
|
|
||||||
// Redirect to home page
|
// Redirect to home page
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||||
|
|||||||
@@ -8,10 +8,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/auditlog"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/validation"
|
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||||
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
|
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -19,20 +17,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AdminRoles renders the full admin dashboard page with roles section
|
// AdminRoles renders the full admin dashboard page with roles section
|
||||||
func AdminRoles(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminRoles(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var pageOpts *db.PageOpts
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if r.Method == "GET" {
|
if !ok {
|
||||||
pageOpts = pageOptsFromQuery(s, w, r)
|
|
||||||
} else {
|
|
||||||
pageOpts = pageOptsFromForm(s, w, r)
|
|
||||||
}
|
|
||||||
if pageOpts == nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var rolesList *db.List[db.Role]
|
var rolesList *db.List[db.Role]
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
rolesList, err = db.GetRoles(ctx, tx, pageOpts)
|
rolesList, err = db.GetRoles(ctx, tx, pageOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -59,7 +52,7 @@ func AdminRoleCreateForm(s *hws.Server) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AdminRoleCreate creates a new role
|
// AdminRoleCreate creates a new role
|
||||||
func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.Handler {
|
func AdminRoleCreate(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -74,14 +67,14 @@ func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pageOpts := pageOptsFromForm(s, w, r)
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if pageOpts == nil {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var rolesList *db.List[db.Role]
|
var rolesList *db.List[db.Role]
|
||||||
var newRole *db.Role
|
var newRole *db.Role
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
newRole = &db.Role{
|
newRole = &db.Role{
|
||||||
Name: roles.Role(name),
|
Name: roles.Role(name),
|
||||||
DisplayName: displayName,
|
DisplayName: displayName,
|
||||||
@@ -90,9 +83,9 @@ func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
|
|||||||
CreatedAt: time.Now().Unix(),
|
CreatedAt: time.Now().Unix(),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := db.Insert(tx, newRole).WithAudit(r, audit.Callback()).Exec(ctx)
|
err := db.CreateRole(ctx, tx, newRole, db.NewAudit(r, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.Insert")
|
return false, errors.Wrap(err, "db.CreateRole")
|
||||||
}
|
}
|
||||||
|
|
||||||
rolesList, err = db.GetRoles(ctx, tx, pageOpts)
|
rolesList, err = db.GetRoles(ctx, tx, pageOpts)
|
||||||
@@ -110,7 +103,7 @@ func AdminRoleCreate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AdminRoleManage shows the role management modal with details and actions
|
// AdminRoleManage shows the role management modal with details and actions
|
||||||
func AdminRoleManage(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminRoleManage(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
roleIDStr := r.PathValue("id")
|
roleIDStr := r.PathValue("id")
|
||||||
roleID, err := strconv.Atoi(roleIDStr)
|
roleID, err := strconv.Atoi(roleIDStr)
|
||||||
@@ -120,7 +113,7 @@ func AdminRoleManage(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var role *db.Role
|
var role *db.Role
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
role, err = db.GetRoleByID(ctx, tx, roleID)
|
role, err = db.GetRoleByID(ctx, tx, roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -139,7 +132,7 @@ func AdminRoleManage(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AdminRoleDeleteConfirm shows the delete confirmation dialog
|
// AdminRoleDeleteConfirm shows the delete confirmation dialog
|
||||||
func AdminRoleDeleteConfirm(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminRoleDeleteConfirm(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
roleIDStr := r.PathValue("id")
|
roleIDStr := r.PathValue("id")
|
||||||
roleID, err := strconv.Atoi(roleIDStr)
|
roleID, err := strconv.Atoi(roleIDStr)
|
||||||
@@ -149,7 +142,7 @@ func AdminRoleDeleteConfirm(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var role *db.Role
|
var role *db.Role
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
role, err = db.GetRoleByID(ctx, tx, roleID)
|
role, err = db.GetRoleByID(ctx, tx, roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -168,7 +161,7 @@ func AdminRoleDeleteConfirm(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AdminRoleDelete deletes a role
|
// AdminRoleDelete deletes a role
|
||||||
func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.Handler {
|
func AdminRoleDelete(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
roleIDStr := r.PathValue("id")
|
roleIDStr := r.PathValue("id")
|
||||||
roleID, err := strconv.Atoi(roleIDStr)
|
roleID, err := strconv.Atoi(roleIDStr)
|
||||||
@@ -177,13 +170,13 @@ func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pageOpts := pageOptsFromForm(s, w, r)
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if pageOpts == nil {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var rolesList *db.List[db.Role]
|
var rolesList *db.List[db.Role]
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
// First check if role exists and get its details
|
// First check if role exists and get its details
|
||||||
role, err := db.GetRoleByID(ctx, tx, roleID)
|
role, err := db.GetRoleByID(ctx, tx, roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -199,9 +192,9 @@ func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Delete the role with audit logging
|
// Delete the role with audit logging
|
||||||
err = db.DeleteByID[db.Role](tx, roleID).WithAudit(r, audit.Callback()).Delete(ctx)
|
err = db.DeleteRole(ctx, tx, roleID, db.NewAudit(r, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.DeleteByID")
|
return false, errors.Wrap(err, "db.DeleteRole")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload roles
|
// Reload roles
|
||||||
@@ -220,7 +213,7 @@ func AdminRoleDelete(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.H
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AdminRolePermissionsModal shows the permissions management modal for a role
|
// AdminRolePermissionsModal shows the permissions management modal for a role
|
||||||
func AdminRolePermissionsModal(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminRolePermissionsModal(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
roleIDStr := r.PathValue("id")
|
roleIDStr := r.PathValue("id")
|
||||||
roleID, err := strconv.Atoi(roleIDStr)
|
roleID, err := strconv.Atoi(roleIDStr)
|
||||||
@@ -234,12 +227,12 @@ func AdminRolePermissionsModal(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
var groupedPerms []adminview.PermissionsByResource
|
var groupedPerms []adminview.PermissionsByResource
|
||||||
var rolePermIDs map[int]bool
|
var rolePermIDs map[int]bool
|
||||||
|
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
// Load role with permissions
|
// Load role with permissions
|
||||||
var err error
|
var err error
|
||||||
role, err = db.GetRoleWithPermissions(ctx, tx, roleID)
|
role, err = db.GetRoleByID(ctx, tx, roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.GetRoleWithPermissions")
|
return false, errors.Wrap(err, "db.GetRoleByID")
|
||||||
}
|
}
|
||||||
if role == nil {
|
if role == nil {
|
||||||
return false, errors.New("role not found")
|
return false, errors.New("role not found")
|
||||||
@@ -283,7 +276,7 @@ func AdminRolePermissionsModal(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AdminRolePermissionsUpdate updates the permissions for a role
|
// AdminRolePermissionsUpdate updates the permissions for a role
|
||||||
func AdminRolePermissionsUpdate(s *hws.Server, conn *bun.DB, audit *auditlog.Logger) http.Handler {
|
func AdminRolePermissionsUpdate(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
roleIDStr := r.PathValue("id")
|
roleIDStr := r.PathValue("id")
|
||||||
roleID, err := strconv.Atoi(roleIDStr)
|
roleID, err := strconv.Atoi(roleIDStr)
|
||||||
@@ -291,7 +284,6 @@ func AdminRolePermissionsUpdate(s *hws.Server, conn *bun.DB, audit *auditlog.Log
|
|||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := db.CurrentUser(r.Context())
|
|
||||||
|
|
||||||
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -304,80 +296,24 @@ func AdminRolePermissionsUpdate(s *hws.Server, conn *bun.DB, audit *auditlog.Log
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
selectedPermIDs := make(map[int]bool)
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
for _, id := range permissionIDs {
|
if !ok {
|
||||||
selectedPermIDs[id] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
pageOpts := pageOptsFromForm(s, w, r)
|
|
||||||
if pageOpts == nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var rolesList *db.List[db.Role]
|
var rolesList *db.List[db.Role]
|
||||||
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
// Get role with current permissions
|
role, err := db.GetRoleByID(ctx, tx, roleID)
|
||||||
role, err := db.GetRoleWithPermissions(ctx, tx, roleID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.GetRoleWithPermissions")
|
return false, errors.Wrap(err, "db.GetRoleByID")
|
||||||
}
|
}
|
||||||
if role == nil {
|
if role == nil {
|
||||||
throw.NotFound(s, w, r, "Role not found")
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
err = role.UpdatePermissions(ctx, tx, permissionIDs, db.NewAudit(r, nil))
|
||||||
// Get all permissions to know what exists
|
|
||||||
allPermissions, err := db.ListAllPermissions(ctx, tx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.ListAllPermissions")
|
return false, errors.Wrap(err, "role.UpdatePermissions")
|
||||||
}
|
|
||||||
|
|
||||||
// Build map of current permissions
|
|
||||||
currentPermIDs := make(map[int]bool)
|
|
||||||
for _, perm := range role.Permissions {
|
|
||||||
currentPermIDs[perm.ID] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
var addedPerms []string
|
|
||||||
var removedPerms []string
|
|
||||||
|
|
||||||
// Determine what to add and remove
|
|
||||||
for _, perm := range allPermissions {
|
|
||||||
hasNow := currentPermIDs[perm.ID]
|
|
||||||
shouldHave := selectedPermIDs[perm.ID]
|
|
||||||
|
|
||||||
if shouldHave && !hasNow {
|
|
||||||
// Add permission
|
|
||||||
err := db.AddPermissionToRole(ctx, tx, roleID, perm.ID)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.AddPermissionToRole")
|
|
||||||
}
|
|
||||||
addedPerms = append(addedPerms, string(perm.Name))
|
|
||||||
} else if !shouldHave && hasNow {
|
|
||||||
// Remove permission
|
|
||||||
err := db.RemovePermissionFromRole(ctx, tx, roleID, perm.ID)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.RemovePermissionFromRole")
|
|
||||||
}
|
|
||||||
removedPerms = append(removedPerms, string(perm.Name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log the permission changes
|
|
||||||
if len(addedPerms) > 0 || len(removedPerms) > 0 {
|
|
||||||
details := map[string]any{
|
|
||||||
"role_name": string(role.Name),
|
|
||||||
}
|
|
||||||
if len(addedPerms) > 0 {
|
|
||||||
details["added_permissions"] = addedPerms
|
|
||||||
}
|
|
||||||
if len(removedPerms) > 0 {
|
|
||||||
details["removed_permissions"] = removedPerms
|
|
||||||
}
|
|
||||||
err = audit.LogSuccess(ctx, tx, user, "update", "role_permissions", roleID, details, r)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "audit.LogSuccess")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload roles
|
// Reload roles
|
||||||
|
|||||||
@@ -12,20 +12,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AdminUsersPage renders the full admin dashboard page with users section
|
// AdminUsersPage renders the full admin dashboard page with users section
|
||||||
func AdminUsersPage(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminUsersPage(s *hws.Server, conn *db.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var pageOpts *db.PageOpts
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if r.Method == "GET" {
|
if !ok {
|
||||||
pageOpts = pageOptsFromQuery(s, w, r)
|
|
||||||
} else {
|
|
||||||
pageOpts = pageOptsFromForm(s, w, r)
|
|
||||||
}
|
|
||||||
if pageOpts == nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var users *db.List[db.User]
|
var users *db.List[db.User]
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
users, err = db.GetUsersWithRoles(ctx, tx, pageOpts)
|
users, err = db.GetUsersWithRoles(ctx, tx, pageOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Grant admin role
|
// Grant admin role
|
||||||
err = db.AssignRole(ctx, tx, user.ID, adminRole.ID)
|
err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "db.AssignRole")
|
return errors.Wrap(err, "db.AssignRole")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
func Callback(
|
func Callback(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
store *store.Store,
|
store *store.Store,
|
||||||
discordAPI *discord.APIClient,
|
discordAPI *discord.APIClient,
|
||||||
@@ -70,7 +70,7 @@ func Callback(
|
|||||||
switch data {
|
switch data {
|
||||||
case "login":
|
case "login":
|
||||||
var redirect func()
|
var redirect func()
|
||||||
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
|
redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throw.InternalServiceError(s, w, r, "OAuth login failed", err)
|
throw.InternalServiceError(s, w, r, "OAuth login failed", err)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
// Returns 200 OK if unique, 409 Conflict if not unique
|
// Returns 200 OK if unique, 409 Conflict if not unique
|
||||||
func IsUnique(
|
func IsUnique(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
model any,
|
model any,
|
||||||
field string,
|
field string,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
@@ -31,7 +31,7 @@ func IsUnique(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
unique := false
|
unique := false
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
unique, err = db.IsUnique(ctx, tx, model, field, value)
|
unique, err = db.IsUnique(ctx, tx, model, field, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.IsUnique")
|
return false, errors.Wrap(err, "db.IsUnique")
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ import (
|
|||||||
|
|
||||||
func LeaguesList(
|
func LeaguesList(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var leagues []*db.League
|
var leagues []*db.League
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
leagues, err = db.GetLeagues(ctx, tx)
|
leagues, err = db.GetLeagues(ctx, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/auditlog"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/validation"
|
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||||
@@ -18,20 +17,15 @@ import (
|
|||||||
|
|
||||||
func NewLeague(
|
func NewLeague(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == "GET" {
|
renderSafely(leaguesview.NewPage(), s, r, w)
|
||||||
renderSafely(leaguesview.NewPage(), s, r, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLeagueSubmit(
|
func NewLeagueSubmit(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
audit *auditlog.Logger,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
||||||
@@ -53,7 +47,7 @@ func NewLeagueSubmit(
|
|||||||
nameUnique := false
|
nameUnique := false
|
||||||
shortNameUnique := false
|
shortNameUnique := false
|
||||||
var league *db.League
|
var league *db.League
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
nameUnique, err = db.IsUnique(ctx, tx, (*db.League)(nil), "name", name)
|
nameUnique, err = db.IsUnique(ctx, tx, (*db.League)(nil), "name", name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -66,14 +60,9 @@ func NewLeagueSubmit(
|
|||||||
if !nameUnique || !shortNameUnique {
|
if !nameUnique || !shortNameUnique {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
league = &db.League{
|
league, err = db.NewLeague(ctx, tx, name, shortname, description, db.NewAudit(r, nil))
|
||||||
Name: name,
|
|
||||||
ShortName: shortname,
|
|
||||||
Description: description,
|
|
||||||
}
|
|
||||||
err = db.Insert(tx, league).WithAudit(r, audit.Callback()).Exec(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.Insert")
|
return false, errors.Wrap(err, "db.NewLeague")
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
}); !ok {
|
}); !ok {
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"git.haelnorr.com/h/golib/cookies"
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/store"
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
|
|
||||||
func Login(
|
func Login(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
st *store.Store,
|
st *store.Store,
|
||||||
discordAPI *discord.APIClient,
|
discordAPI *discord.APIClient,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
func Logout(
|
func Logout(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
discordAPI *discord.APIClient,
|
discordAPI *discord.APIClient,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
@@ -27,7 +27,7 @@ func Logout(
|
|||||||
w.Header().Set("HX-Redirect", "/")
|
w.Header().Set("HX-Redirect", "/")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
token, err := user.DeleteDiscordTokens(ctx, tx)
|
token, err := user.DeleteDiscordTokens(ctx, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "user.DeleteDiscordTokens")
|
return false, errors.Wrap(err, "user.DeleteDiscordTokens")
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/validation"
|
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
|
||||||
|
|
||||||
// pageOptsFromForm calls r.ParseForm and gets the pageOpts from the formdata.
|
|
||||||
// It renders a Bad Request error page on fail
|
|
||||||
// PageOpts will be nil on fail
|
|
||||||
func pageOptsFromForm(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
|
|
||||||
getter, ok := validation.ParseFormOrError(s, w, r)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return getPageOpts(s, w, r, getter)
|
|
||||||
}
|
|
||||||
|
|
||||||
// pageOptsFromQuery gets the pageOpts from the request query and renders a Bad Request error page on fail
|
|
||||||
// PageOpts will be nil on fail
|
|
||||||
func pageOptsFromQuery(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
|
|
||||||
return getPageOpts(s, w, r, validation.NewQueryGetter(r))
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *db.PageOpts {
|
|
||||||
page := g.Int("page").Optional().Min(1).Value
|
|
||||||
perPage := g.Int("per_page").Optional().Min(1).Max(100).Value
|
|
||||||
order := g.String("order").TrimSpace().ToUpper().Optional().AllowedValues([]string{"ASC", "DESC"}).Value
|
|
||||||
orderBy := g.String("order_by").TrimSpace().Optional().ToLower().Value
|
|
||||||
valid := g.ValidateAndError(s, w, r)
|
|
||||||
if !valid {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
pageOpts := &db.PageOpts{
|
|
||||||
Page: page,
|
|
||||||
PerPage: perPage,
|
|
||||||
Order: bun.Order(order),
|
|
||||||
OrderBy: orderBy,
|
|
||||||
}
|
|
||||||
return pageOpts
|
|
||||||
}
|
|
||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
func Register(
|
func Register(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
store *store.Store,
|
store *store.Store,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
@@ -55,7 +55,7 @@ func Register(
|
|||||||
username := r.FormValue("username")
|
username := r.FormValue("username")
|
||||||
unique := false
|
unique := false
|
||||||
var user *db.User
|
var user *db.User
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username)
|
unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.IsUsernameUnique")
|
return false, errors.Wrap(err, "db.IsUsernameUnique")
|
||||||
@@ -63,7 +63,7 @@ func Register(
|
|||||||
if !unique {
|
if !unique {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
user, err = db.CreateUser(ctx, tx, username, details.DiscordUser)
|
user, err = db.CreateUser(ctx, tx, username, details.DiscordUser, db.NewAudit(r, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.CreateUser")
|
return false, errors.Wrap(err, "db.CreateUser")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,14 +14,14 @@ import (
|
|||||||
|
|
||||||
func SeasonPage(
|
func SeasonPage(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seasonStr := r.PathValue("season_short_name")
|
seasonStr := r.PathValue("season_short_name")
|
||||||
var season *db.Season
|
var season *db.Season
|
||||||
var leaguesWithTeams []db.LeagueWithTeams
|
var leaguesWithTeams []db.LeagueWithTeams
|
||||||
|
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
season, err = db.GetSeason(ctx, tx, seasonStr)
|
season, err = db.GetSeason(ctx, tx, seasonStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/auditlog"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
@@ -19,13 +18,13 @@ import (
|
|||||||
|
|
||||||
func SeasonEditPage(
|
func SeasonEditPage(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seasonStr := r.PathValue("season_short_name")
|
seasonStr := r.PathValue("season_short_name")
|
||||||
var season *db.Season
|
var season *db.Season
|
||||||
var allLeagues []*db.League
|
var allLeagues []*db.League
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
season, err = db.GetSeason(ctx, tx, seasonStr)
|
season, err = db.GetSeason(ctx, tx, seasonStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -49,8 +48,7 @@ func SeasonEditPage(
|
|||||||
|
|
||||||
func SeasonEditSubmit(
|
func SeasonEditSubmit(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
audit *auditlog.Logger,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seasonStr := r.PathValue("season_short_name")
|
seasonStr := r.PathValue("season_short_name")
|
||||||
@@ -77,7 +75,7 @@ func SeasonEditSubmit(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var season *db.Season
|
var season *db.Season
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
season, err = db.GetSeason(ctx, tx, seasonStr)
|
season, err = db.GetSeason(ctx, tx, seasonStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -86,12 +84,9 @@ func SeasonEditSubmit(
|
|||||||
if season == nil {
|
if season == nil {
|
||||||
return false, errors.New("season does not exist")
|
return false, errors.New("season does not exist")
|
||||||
}
|
}
|
||||||
season.Update(version, start, end, finalsStart, finalsEnd)
|
err = season.Update(ctx, tx, version, start, end, finalsStart, finalsEnd, db.NewAudit(r, nil))
|
||||||
err = db.Update(tx, season).WherePK().
|
|
||||||
Column("slap_version", "start_date", "end_date", "finals_start_date", "finals_end_date").
|
|
||||||
WithAudit(r, audit.Callback()).Exec(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.Update")
|
return false, errors.Wrap(err, "season.Update")
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
}); !ok {
|
}); !ok {
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/auditlog"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/validation"
|
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||||
@@ -16,8 +15,7 @@ import (
|
|||||||
|
|
||||||
func SeasonLeagueAddTeam(
|
func SeasonLeagueAddTeam(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
audit *auditlog.Logger,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seasonStr := r.PathValue("season_short_name")
|
seasonStr := r.PathValue("season_short_name")
|
||||||
@@ -36,73 +34,12 @@ func SeasonLeagueAddTeam(
|
|||||||
var league *db.League
|
var league *db.League
|
||||||
var team *db.Team
|
var team *db.Team
|
||||||
|
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
|
team, season, league, err = db.NewTeamParticipation(ctx, tx, seasonStr, leagueStr, teamID, db.NewAudit(r, nil))
|
||||||
// Get season
|
|
||||||
season, err = db.GetSeason(ctx, tx, seasonStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.GetSeason")
|
return false, errors.Wrap(err, "db.NewTeamParticipation")
|
||||||
}
|
}
|
||||||
if season == nil {
|
|
||||||
notify.Warn(s, w, r, "Not Found", "Season not found.", nil)
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get league
|
|
||||||
league, err = db.GetLeague(ctx, tx, leagueStr)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.GetLeague")
|
|
||||||
}
|
|
||||||
if league == nil {
|
|
||||||
notify.Warn(s, w, r, "Not Found", "League not found.", nil)
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !season.HasLeague(league.ID) {
|
|
||||||
notify.Warn(s, w, r, "Invalid League", "This league is not associated with this season.", nil)
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get team
|
|
||||||
team, err = db.GetTeam(ctx, tx, teamID)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.GetTeam")
|
|
||||||
}
|
|
||||||
if team == nil {
|
|
||||||
notify.Warn(s, w, r, "Not Found", "Team not found.", nil)
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if team is already in this season (in any league)
|
|
||||||
var tpCount int
|
|
||||||
tpCount, err = tx.NewSelect().
|
|
||||||
Model((*db.TeamParticipation)(nil)).
|
|
||||||
Where("season_id = ? AND team_id = ?", season.ID, team.ID).
|
|
||||||
Count(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "tx.NewSelect")
|
|
||||||
}
|
|
||||||
if tpCount > 0 {
|
|
||||||
notify.Warn(s, w, r, "Already In Season", fmt.Sprintf(
|
|
||||||
"Team '%s' is already participating in this season.",
|
|
||||||
team.Name,
|
|
||||||
), nil)
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add team to league
|
|
||||||
participation := &db.TeamParticipation{
|
|
||||||
SeasonID: season.ID,
|
|
||||||
LeagueID: league.ID,
|
|
||||||
TeamID: team.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = db.Insert(tx, participation).WithAudit(r, audit.Callback()).Exec(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.Insert")
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}); !ok {
|
}); !ok {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
func SeasonLeaguePage(
|
func SeasonLeaguePage(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seasonStr := r.PathValue("season_short_name")
|
seasonStr := r.PathValue("season_short_name")
|
||||||
@@ -25,7 +25,7 @@ func SeasonLeaguePage(
|
|||||||
var teams []*db.Team
|
var teams []*db.Team
|
||||||
var allTeams []*db.Team
|
var allTeams []*db.Team
|
||||||
|
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
season, league, teams, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr)
|
season, league, teams, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/auditlog"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/view/seasonsview"
|
"git.haelnorr.com/h/oslstats/internal/view/seasonsview"
|
||||||
@@ -16,8 +15,7 @@ import (
|
|||||||
|
|
||||||
func SeasonAddLeague(
|
func SeasonAddLeague(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
audit *auditlog.Logger,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seasonStr := r.PathValue("season_short_name")
|
seasonStr := r.PathValue("season_short_name")
|
||||||
@@ -25,32 +23,10 @@ func SeasonAddLeague(
|
|||||||
|
|
||||||
var season *db.Season
|
var season *db.Season
|
||||||
var allLeagues []*db.League
|
var allLeagues []*db.League
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
err := db.NewSeasonLeague(ctx, tx, seasonStr, leagueStr, db.NewAudit(r, nil))
|
||||||
season, err = db.GetSeason(ctx, tx, seasonStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.GetSeason")
|
return false, errors.Wrap(err, "db.NewSeasonLeague")
|
||||||
}
|
|
||||||
if season == nil {
|
|
||||||
return false, errors.New("season not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
league, err := db.GetLeague(ctx, tx, leagueStr)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.GetLeague")
|
|
||||||
}
|
|
||||||
if league == nil {
|
|
||||||
return false, errors.New("league not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the many-to-many relationship
|
|
||||||
seasonLeague := &db.SeasonLeague{
|
|
||||||
SeasonID: season.ID,
|
|
||||||
LeagueID: league.ID,
|
|
||||||
}
|
|
||||||
err = db.Insert(tx, seasonLeague).WithAudit(r, audit.Callback()).Exec(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.Insert")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload season with updated leagues
|
// Reload season with updated leagues
|
||||||
@@ -76,8 +52,7 @@ func SeasonAddLeague(
|
|||||||
|
|
||||||
func SeasonRemoveLeague(
|
func SeasonRemoveLeague(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
audit *auditlog.Logger,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
seasonStr := r.PathValue("season_short_name")
|
seasonStr := r.PathValue("season_short_name")
|
||||||
@@ -85,7 +60,7 @@ func SeasonRemoveLeague(
|
|||||||
|
|
||||||
var season *db.Season
|
var season *db.Season
|
||||||
var allLeagues []*db.League
|
var allLeagues []*db.League
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
season, err = db.GetSeason(ctx, tx, seasonStr)
|
season, err = db.GetSeason(ctx, tx, seasonStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -94,22 +69,9 @@ func SeasonRemoveLeague(
|
|||||||
if season == nil {
|
if season == nil {
|
||||||
return false, errors.New("season not found")
|
return false, errors.New("season not found")
|
||||||
}
|
}
|
||||||
|
err = season.RemoveLeague(ctx, tx, leagueStr, db.NewAudit(r, nil))
|
||||||
league, err := db.GetLeague(ctx, tx, leagueStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.GetLeague")
|
return false, errors.Wrap(err, "season.RemoveLeague")
|
||||||
}
|
|
||||||
if league == nil {
|
|
||||||
return false, errors.New("league not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the many-to-many relationship
|
|
||||||
err = db.DeleteItem[db.SeasonLeague](tx).
|
|
||||||
Where("season_id = ? AND league_id = ?", season.ID, league.ID).
|
|
||||||
WithAudit(r, audit.Callback()).
|
|
||||||
Delete(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.DeleteItem")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload season with updated leagues
|
// Reload season with updated leagues
|
||||||
|
|||||||
@@ -11,18 +11,18 @@ import (
|
|||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SeasonsPage renders the full page with the seasons list, for use with GET requests
|
// SeasonsPage renders the season list. On GET it returns the full page, otherwise it just returns the list
|
||||||
func SeasonsPage(
|
func SeasonsPage(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
pageOpts := pageOptsFromQuery(s, w, r)
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if pageOpts == nil {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var seasons *db.List[db.Season]
|
var seasons *db.List[db.Season]
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
|
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -32,31 +32,10 @@ func SeasonsPage(
|
|||||||
}); !ok {
|
}); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
renderSafely(seasonsview.ListPage(seasons), s, r, w)
|
if r.Method == "GET" {
|
||||||
})
|
renderSafely(seasonsview.ListPage(seasons), s, r, w)
|
||||||
}
|
} else {
|
||||||
|
renderSafely(seasonsview.SeasonsList(seasons), s, r, w)
|
||||||
// SeasonsList renders just the seasons list, for use with POST requests and HTMX
|
}
|
||||||
func SeasonsList(
|
|
||||||
s *hws.Server,
|
|
||||||
conn *bun.DB,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
pageOpts := pageOptsFromForm(s, w, r)
|
|
||||||
if pageOpts == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var seasons *db.List[db.Season]
|
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
|
||||||
var err error
|
|
||||||
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.ListSeasons")
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}); !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
renderSafely(seasonsview.SeasonsList(seasons), s, r, w)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/auditlog"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/validation"
|
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||||
@@ -18,20 +17,15 @@ import (
|
|||||||
|
|
||||||
func NewSeason(
|
func NewSeason(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == "GET" {
|
renderSafely(seasonsview.NewPage(), s, r, w)
|
||||||
renderSafely(seasonsview.NewPage(), s, r, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSeasonSubmit(
|
func NewSeasonSubmit(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
audit *auditlog.Logger,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
||||||
@@ -58,7 +52,7 @@ func NewSeasonSubmit(
|
|||||||
nameUnique := false
|
nameUnique := false
|
||||||
shortNameUnique := false
|
shortNameUnique := false
|
||||||
var season *db.Season
|
var season *db.Season
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
nameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "name", name)
|
nameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "name", name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -71,10 +65,9 @@ func NewSeasonSubmit(
|
|||||||
if !nameUnique || !shortNameUnique {
|
if !nameUnique || !shortNameUnique {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
season = db.NewSeason(name, version, shortname, start)
|
season, err = db.NewSeason(ctx, tx, name, version, shortname, start, db.NewAudit(r, nil))
|
||||||
err = db.Insert(tx, season).WithAudit(r, audit.Callback()).Exec(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.Insert")
|
return false, errors.Wrap(err, "db.NewSeason")
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
}); !ok {
|
}); !ok {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
// and also validates that they are different from each other
|
// and also validates that they are different from each other
|
||||||
func IsTeamShortNamesUnique(
|
func IsTeamShortNamesUnique(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
getter, err := validation.ParseForm(r)
|
getter, err := validation.ParseForm(r)
|
||||||
@@ -38,7 +38,7 @@ func IsTeamShortNamesUnique(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var isUnique bool
|
var isUnique bool
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
isUnique, err = db.TeamShortNamesUnique(ctx, tx, shortName, altShortName)
|
isUnique, err = db.TeamShortNamesUnique(ctx, tx, shortName, altShortName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.TeamShortNamesUnique")
|
return false, errors.Wrap(err, "db.TeamShortNamesUnique")
|
||||||
|
|||||||
@@ -14,15 +14,15 @@ import (
|
|||||||
// TeamsPage renders the full page with the teams list, for use with GET requests
|
// TeamsPage renders the full page with the teams list, for use with GET requests
|
||||||
func TeamsPage(
|
func TeamsPage(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
pageOpts := pageOptsFromQuery(s, w, r)
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if pageOpts == nil {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var teams *db.List[db.Team]
|
var teams *db.List[db.Team]
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
teams, err = db.ListTeams(ctx, tx, pageOpts)
|
teams, err = db.ListTeams(ctx, tx, pageOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -39,15 +39,15 @@ func TeamsPage(
|
|||||||
// TeamsList renders just the teams list, for use with POST requests and HTMX
|
// TeamsList renders just the teams list, for use with POST requests and HTMX
|
||||||
func TeamsList(
|
func TeamsList(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
pageOpts := pageOptsFromForm(s, w, r)
|
pageOpts, ok := db.GetPageOpts(s, w, r)
|
||||||
if pageOpts == nil {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var teams *db.List[db.Team]
|
var teams *db.List[db.Team]
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
teams, err = db.ListTeams(ctx, tx, pageOpts)
|
teams, err = db.ListTeams(ctx, tx, pageOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/auditlog"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/validation"
|
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||||
@@ -18,7 +17,6 @@ import (
|
|||||||
|
|
||||||
func NewTeamPage(
|
func NewTeamPage(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
renderSafely(teamsview.NewPage(), s, r, w)
|
renderSafely(teamsview.NewPage(), s, r, w)
|
||||||
@@ -27,8 +25,7 @@ func NewTeamPage(
|
|||||||
|
|
||||||
func NewTeamSubmit(
|
func NewTeamSubmit(
|
||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
audit *auditlog.Logger,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
||||||
@@ -38,10 +35,10 @@ func NewTeamSubmit(
|
|||||||
name := getter.String("name").
|
name := getter.String("name").
|
||||||
TrimSpace().Required().
|
TrimSpace().Required().
|
||||||
MaxLength(25).MinLength(3).Value
|
MaxLength(25).MinLength(3).Value
|
||||||
shortname := getter.String("short_name").
|
shortName := getter.String("short_name").
|
||||||
TrimSpace().Required().
|
TrimSpace().Required().
|
||||||
MaxLength(3).MinLength(3).Value
|
MaxLength(3).MinLength(3).Value
|
||||||
altShortname := getter.String("alt_short_name").
|
altShortName := getter.String("alt_short_name").
|
||||||
TrimSpace().Required().
|
TrimSpace().Required().
|
||||||
MaxLength(3).MinLength(3).Value
|
MaxLength(3).MinLength(3).Value
|
||||||
color := getter.String("color").
|
color := getter.String("color").
|
||||||
@@ -51,22 +48,21 @@ func NewTeamSubmit(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check that short names are different
|
// Check that short names are different
|
||||||
if shortname == altShortname {
|
if shortName == altShortName {
|
||||||
notify.Warn(s, w, r, "Invalid Short Names", "Short name and alternative short name must be different.", nil)
|
notify.Warn(s, w, r, "Invalid Short Names", "Short name and alternative short name must be different.", nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nameUnique := false
|
nameUnique := false
|
||||||
shortNameComboUnique := false
|
shortNameComboUnique := false
|
||||||
var team *db.Team
|
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
|
||||||
var err error
|
var err error
|
||||||
nameUnique, err = db.IsUnique(ctx, tx, (*db.Team)(nil), "name", name)
|
nameUnique, err = db.IsUnique(ctx, tx, (*db.Team)(nil), "name", name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.IsTeamNameUnique")
|
return false, errors.Wrap(err, "db.IsTeamNameUnique")
|
||||||
}
|
}
|
||||||
|
|
||||||
shortNameComboUnique, err = db.TeamShortNamesUnique(ctx, tx, shortname, altShortname)
|
shortNameComboUnique, err = db.TeamShortNamesUnique(ctx, tx, shortName, altShortName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.TeamShortNamesUnique")
|
return false, errors.Wrap(err, "db.TeamShortNamesUnique")
|
||||||
}
|
}
|
||||||
@@ -74,15 +70,9 @@ func NewTeamSubmit(
|
|||||||
if !nameUnique || !shortNameComboUnique {
|
if !nameUnique || !shortNameComboUnique {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
team = &db.Team{
|
_, err = db.NewTeam(ctx, tx, name, shortName, altShortName, color, db.NewAudit(r, nil))
|
||||||
Name: name,
|
|
||||||
ShortName: shortname,
|
|
||||||
AltShortName: altShortname,
|
|
||||||
Color: color,
|
|
||||||
}
|
|
||||||
err = db.Insert(tx, team).WithAudit(r, audit.Callback()).Exec(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.Insert")
|
return false, errors.Wrap(err, "db.NewTeam")
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
}); !ok {
|
}); !ok {
|
||||||
|
|||||||
@@ -39,12 +39,12 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
|
|||||||
|
|
||||||
var roles_ []*db.Role
|
var roles_ []*db.Role
|
||||||
var perms []*db.Permission
|
var perms []*db.Permission
|
||||||
if err := db.WithTxFailSilently(r.Context(), c.conn, func(ctx context.Context, tx bun.Tx) error {
|
if err := c.conn.WithTxFailSilently(r.Context(), func(ctx context.Context, tx bun.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if previewRole != nil {
|
if previewRole != nil {
|
||||||
// In preview mode: use the preview role instead of user's roles
|
// In preview mode: use the preview role instead of user's roles
|
||||||
role, err := db.GetRoleWithPermissions(ctx, tx, previewRole.ID)
|
role, err := db.GetRoleByID(ctx, tx, previewRole.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "db.GetRoleWithPermissions")
|
return errors.Wrap(err, "db.GetRoleWithPermissions")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,11 +13,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Checker struct {
|
type Checker struct {
|
||||||
conn *bun.DB
|
conn *db.DB
|
||||||
s *hws.Server
|
s *hws.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) {
|
func NewChecker(conn *db.DB, s *hws.Server) (*Checker, error) {
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return nil, errors.New("conn cannot be nil")
|
return nil, errors.New("conn cannot be nil")
|
||||||
}
|
}
|
||||||
@@ -56,7 +56,7 @@ func (c *Checker) UserHasPermission(ctx context.Context, user *db.User, permissi
|
|||||||
|
|
||||||
// Not in preview mode: fallback to database for actual user permissions
|
// Not in preview mode: fallback to database for actual user permissions
|
||||||
var has bool
|
var has bool
|
||||||
if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error {
|
if err := c.conn.WithTxFailSilently(ctx, func(ctx context.Context, tx bun.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
has, err = user.HasPermission(ctx, tx, permission)
|
has, err = user.HasPermission(ctx, tx, permission)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -94,7 +94,7 @@ func (c *Checker) UserHasRole(ctx context.Context, user *db.User, role roles.Rol
|
|||||||
|
|
||||||
// Not in preview mode: fallback to database for actual user roles
|
// Not in preview mode: fallback to database for actual user roles
|
||||||
var has bool
|
var has bool
|
||||||
if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error {
|
if err := c.conn.WithTxFailSilently(ctx, func(ctx context.Context, tx bun.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
has, err = user.HasRole(ctx, tx, role)
|
has, err = user.HasRole(ctx, tx, role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
// LoadPreviewRoleMiddleware loads the preview role from the session cookie if present
|
// LoadPreviewRoleMiddleware loads the preview role from the session cookie if present
|
||||||
// and adds it to the request context. This must run after authentication but before
|
// and adds it to the request context. This must run after authentication but before
|
||||||
// the RBAC cache middleware.
|
// the RBAC cache middleware.
|
||||||
func LoadPreviewRoleMiddleware(s *hws.Server, conn *bun.DB) func(http.Handler) http.Handler {
|
func LoadPreviewRoleMiddleware(s *hws.Server, conn *db.DB) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Check if there's a preview role in the cookie
|
// Check if there's a preview role in the cookie
|
||||||
@@ -26,10 +26,25 @@ func LoadPreviewRoleMiddleware(s *hws.Server, conn *bun.DB) func(http.Handler) h
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user := db.CurrentUser(r.Context())
|
||||||
|
if user == nil {
|
||||||
|
// User not logged in,
|
||||||
|
ClearPreviewRoleCookie(w)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Load the preview role from the database
|
// Load the preview role from the database
|
||||||
var previewRole *db.Role
|
var previewRole *db.Role
|
||||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
isAdmin, err := user.IsAdmin(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "user.IsAdmin")
|
||||||
|
}
|
||||||
|
if !isAdmin {
|
||||||
|
ClearPreviewRoleCookie(w)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
previewRole, err = db.GetRoleByID(ctx, tx, roleID)
|
previewRole, err = db.GetRoleByID(ctx, tx, roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "db.GetRoleByID")
|
return false, errors.Wrap(err, "db.GetRoleByID")
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ func (c *Checker) RequireActualAdmin(s *hws.Server) func(http.Handler) http.Hand
|
|||||||
|
|
||||||
// Check user's ACTUAL role in database, bypassing preview mode
|
// Check user's ACTUAL role in database, bypassing preview mode
|
||||||
var hasAdmin bool
|
var hasAdmin bool
|
||||||
if ok := db.WithReadTx(s, w, r, c.conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
if ok := c.conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
var err error
|
var err error
|
||||||
hasAdmin, err = user.HasRole(ctx, tx, roles.Admin)
|
hasAdmin, err = user.HasRole(ctx, tx, roles.Admin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
func setupAuth(
|
func setupAuth(
|
||||||
cfg *hwsauth.Config,
|
cfg *hwsauth.Config,
|
||||||
logger *hlog.Logger,
|
logger *hlog.Logger,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
server *hws.Server,
|
server *hws.Server,
|
||||||
ignoredPaths []string,
|
ignoredPaths []string,
|
||||||
) (*hwsauth.Authenticator[*db.User, bun.Tx], error) {
|
) (*hwsauth.Authenticator[*db.User, bun.Tx], error) {
|
||||||
@@ -30,7 +30,7 @@ func setupAuth(
|
|||||||
beginTx,
|
beginTx,
|
||||||
logger,
|
logger,
|
||||||
handlers.ErrorPage,
|
handlers.ErrorPage,
|
||||||
conn.DB,
|
conn.DB.DB,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
|
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -27,7 +27,7 @@ func addMiddleware(
|
|||||||
perms *rbac.Checker,
|
perms *rbac.Checker,
|
||||||
discordAPI *discord.APIClient,
|
discordAPI *discord.APIClient,
|
||||||
store *store.Store,
|
store *store.Store,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
) error {
|
) error {
|
||||||
err := server.AddMiddleware(
|
err := server.AddMiddleware(
|
||||||
auth.Authenticate(tokenRefresh(auth, discordAPI, store)),
|
auth.Authenticate(tokenRefresh(auth, discordAPI, store)),
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/auditlog"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
@@ -22,12 +21,11 @@ func addRoutes(
|
|||||||
s *hws.Server,
|
s *hws.Server,
|
||||||
staticFS *http.FileSystem,
|
staticFS *http.FileSystem,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
conn *bun.DB,
|
conn *db.DB,
|
||||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
store *store.Store,
|
store *store.Store,
|
||||||
discordAPI *discord.APIClient,
|
discordAPI *discord.APIClient,
|
||||||
perms *rbac.Checker,
|
perms *rbac.Checker,
|
||||||
audit *auditlog.Logger,
|
|
||||||
) error {
|
) error {
|
||||||
// Create the routes
|
// Create the routes
|
||||||
baseRoutes := []hws.Route{
|
baseRoutes := []hws.Route{
|
||||||
@@ -69,23 +67,18 @@ func addRoutes(
|
|||||||
seasonRoutes := []hws.Route{
|
seasonRoutes := []hws.Route{
|
||||||
{
|
{
|
||||||
Path: "/seasons",
|
Path: "/seasons",
|
||||||
Method: hws.MethodGET,
|
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
|
||||||
Handler: handlers.SeasonsPage(s, conn),
|
Handler: handlers.SeasonsPage(s, conn),
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Path: "/seasons",
|
|
||||||
Method: hws.MethodPOST,
|
|
||||||
Handler: handlers.SeasonsList(s, conn),
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Path: "/seasons/new",
|
Path: "/seasons/new",
|
||||||
Method: hws.MethodGET,
|
Method: hws.MethodGET,
|
||||||
Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeason(s, conn)),
|
Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeason(s)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/seasons/new",
|
Path: "/seasons/new",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeasonSubmit(s, conn, audit)),
|
Handler: perms.RequirePermission(s, permissions.SeasonsCreate)(handlers.NewSeasonSubmit(s, conn)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/seasons/{season_short_name}",
|
Path: "/seasons/{season_short_name}",
|
||||||
@@ -100,7 +93,7 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/seasons/{season_short_name}/edit",
|
Path: "/seasons/{season_short_name}/edit",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: perms.RequirePermission(s, permissions.SeasonsUpdate)(handlers.SeasonEditSubmit(s, conn, audit)),
|
Handler: perms.RequirePermission(s, permissions.SeasonsUpdate)(handlers.SeasonEditSubmit(s, conn)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/seasons/{season_short_name}/leagues/{league_short_name}",
|
Path: "/seasons/{season_short_name}/leagues/{league_short_name}",
|
||||||
@@ -110,17 +103,17 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/seasons/{season_short_name}/leagues/add/{league_short_name}",
|
Path: "/seasons/{season_short_name}/leagues/add/{league_short_name}",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: perms.RequirePermission(s, permissions.SeasonsAddLeague)(handlers.SeasonAddLeague(s, conn, audit)),
|
Handler: perms.RequirePermission(s, permissions.SeasonsAddLeague)(handlers.SeasonAddLeague(s, conn)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/seasons/{season_short_name}/leagues/{league_short_name}",
|
Path: "/seasons/{season_short_name}/leagues/{league_short_name}",
|
||||||
Method: hws.MethodDELETE,
|
Method: hws.MethodDELETE,
|
||||||
Handler: perms.RequirePermission(s, permissions.SeasonsRemoveLeague)(handlers.SeasonRemoveLeague(s, conn, audit)),
|
Handler: perms.RequirePermission(s, permissions.SeasonsRemoveLeague)(handlers.SeasonRemoveLeague(s, conn)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/seasons/{season_short_name}/leagues/{league_short_name}/teams/add",
|
Path: "/seasons/{season_short_name}/leagues/{league_short_name}/teams/add",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: perms.RequirePermission(s, permissions.TeamsAddToLeague)(handlers.SeasonLeagueAddTeam(s, conn, audit)),
|
Handler: perms.RequirePermission(s, permissions.TeamsAddToLeague)(handlers.SeasonLeagueAddTeam(s, conn)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,12 +126,12 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/leagues/new",
|
Path: "/leagues/new",
|
||||||
Method: hws.MethodGET,
|
Method: hws.MethodGET,
|
||||||
Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeague(s, conn)),
|
Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeague(s)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/leagues/new",
|
Path: "/leagues/new",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeagueSubmit(s, conn, audit)),
|
Handler: perms.RequirePermission(s, permissions.LeaguesCreate)(handlers.NewLeagueSubmit(s, conn)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,12 +149,12 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/teams/new",
|
Path: "/teams/new",
|
||||||
Method: hws.MethodGET,
|
Method: hws.MethodGET,
|
||||||
Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamPage(s, conn)),
|
Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamPage(s)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/teams/new",
|
Path: "/teams/new",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamSubmit(s, conn, audit)),
|
Handler: perms.RequirePermission(s, permissions.TeamsCreate)(handlers.NewTeamSubmit(s, conn)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,21 +227,11 @@ func addRoutes(
|
|||||||
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
|
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
|
||||||
Handler: perms.RequireAdmin(s)(handlers.AdminRoles(s, conn)),
|
Handler: perms.RequireAdmin(s)(handlers.AdminRoles(s, conn)),
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Path: "/admin/permissions",
|
|
||||||
Method: hws.MethodGET,
|
|
||||||
Handler: perms.RequireAdmin(s)(handlers.AdminPermissionsPage(s, conn)),
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Path: "/admin/audit",
|
Path: "/admin/audit",
|
||||||
Method: hws.MethodGET,
|
Method: hws.MethodGET,
|
||||||
Handler: perms.RequireAdmin(s)(handlers.AdminAuditLogsPage(s, conn)),
|
Handler: perms.RequireAdmin(s)(handlers.AdminAuditLogsPage(s, conn)),
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Path: "/admin/permissions",
|
|
||||||
Method: hws.MethodPOST,
|
|
||||||
Handler: perms.RequireAdmin(s)(handlers.AdminPermissionsList(s, conn)),
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Path: "/admin/audit",
|
Path: "/admin/audit",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
@@ -263,7 +246,7 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/admin/roles/create",
|
Path: "/admin/roles/create",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: perms.RequireAdmin(s)(handlers.AdminRoleCreate(s, conn, audit)),
|
Handler: perms.RequireAdmin(s)(handlers.AdminRoleCreate(s, conn)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/admin/roles/{id}/manage",
|
Path: "/admin/roles/{id}/manage",
|
||||||
@@ -273,7 +256,7 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/admin/roles/{id}",
|
Path: "/admin/roles/{id}",
|
||||||
Method: hws.MethodDELETE,
|
Method: hws.MethodDELETE,
|
||||||
Handler: perms.RequireAdmin(s)(handlers.AdminRoleDelete(s, conn, audit)),
|
Handler: perms.RequireAdmin(s)(handlers.AdminRoleDelete(s, conn)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/admin/roles/{id}/delete-confirm",
|
Path: "/admin/roles/{id}/delete-confirm",
|
||||||
@@ -288,12 +271,12 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/admin/roles/{id}/permissions",
|
Path: "/admin/roles/{id}/permissions",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: perms.RequireAdmin(s)(handlers.AdminRolePermissionsUpdate(s, conn, audit)),
|
Handler: perms.RequireAdmin(s)(handlers.AdminRolePermissionsUpdate(s, conn)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/admin/roles/{id}/preview-start",
|
Path: "/admin/roles/{id}/preview-start",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: perms.RequireAdmin(s)(handlers.AdminPreviewRoleStart(s, conn, cfg)),
|
Handler: perms.RequireAdmin(s)(handlers.AdminPreviewRoleStart(s, conn, cfg.HWSAuth.SSL)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/admin/roles/preview-stop",
|
Path: "/admin/roles/preview-stop",
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
package main
|
// Package server provides setup utilities for the HTTP server
|
||||||
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
"io/fs"
|
||||||
@@ -7,21 +8,20 @@ import (
|
|||||||
"git.haelnorr.com/h/golib/hlog"
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/auditlog"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
"git.haelnorr.com/h/oslstats/internal/handlers"
|
"git.haelnorr.com/h/oslstats/internal/handlers"
|
||||||
"git.haelnorr.com/h/oslstats/internal/rbac"
|
"git.haelnorr.com/h/oslstats/internal/rbac"
|
||||||
"git.haelnorr.com/h/oslstats/internal/store"
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupHTTPServer(
|
func Setup(
|
||||||
staticFS *fs.FS,
|
staticFS *fs.FS,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
logger *hlog.Logger,
|
logger *hlog.Logger,
|
||||||
bun *bun.DB,
|
conn *db.DB,
|
||||||
store *store.Store,
|
store *store.Store,
|
||||||
discordAPI *discord.APIClient,
|
discordAPI *discord.APIClient,
|
||||||
) (server *hws.Server, err error) {
|
) (server *hws.Server, err error) {
|
||||||
@@ -41,7 +41,7 @@ func setupHTTPServer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auth, err := setupAuth(
|
auth, err := setupAuth(
|
||||||
cfg.HWSAuth, logger, bun, httpServer, ignoredPaths)
|
cfg.HWSAuth, logger, conn, httpServer, ignoredPaths)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "setupAuth")
|
return nil, errors.Wrap(err, "setupAuth")
|
||||||
}
|
}
|
||||||
@@ -62,20 +62,17 @@ func setupHTTPServer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize permissions checker
|
// Initialize permissions checker
|
||||||
perms, err := rbac.NewChecker(bun, httpServer)
|
perms, err := rbac.NewChecker(conn, httpServer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "rbac.NewChecker")
|
return nil, errors.Wrap(err, "rbac.NewChecker")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize audit logger
|
err = addRoutes(httpServer, &fs, cfg, conn, auth, store, discordAPI, perms)
|
||||||
audit := auditlog.NewLogger(bun)
|
|
||||||
|
|
||||||
err = addRoutes(httpServer, &fs, cfg, bun, auth, store, discordAPI, perms, audit)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "addRoutes")
|
return nil, errors.Wrap(err, "addRoutes")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store, bun)
|
err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "addMiddleware")
|
return nil, errors.Wrap(err, "addMiddleware")
|
||||||
}
|
}
|
||||||
@@ -136,38 +136,38 @@ templ profileDropdown(user *db.User, items []ProfileItem) {
|
|||||||
x-on:click.away="isActive = false"
|
x-on:click.away="isActive = false"
|
||||||
x-on:keydown.escape.window="isActive = false"
|
x-on:keydown.escape.window="isActive = false"
|
||||||
>
|
>
|
||||||
<!-- Preview Mode Stop Buttons -->
|
<!-- Preview Mode Stop Buttons -->
|
||||||
if previewRole != nil {
|
if previewRole != nil {
|
||||||
<div class="p-2 bg-yellow/10 border-b border-yellow/30 space-y-2">
|
<div class="p-2 bg-yellow/10 border-b border-yellow/30 space-y-2">
|
||||||
<p class="text-xs text-yellow/80 px-2 font-semibold">
|
<p class="text-xs text-yellow/80 px-2 font-semibold">
|
||||||
Viewing as: { previewRole.DisplayName }
|
Viewing as: { previewRole.DisplayName }
|
||||||
</p>
|
</p>
|
||||||
<div class="flex gap-2">
|
<div class="flex gap-2">
|
||||||
<form method="POST" action="/admin/roles/preview-stop?stay=true" class="flex-1">
|
<form method="POST" action="/admin/roles/preview-stop?stay=true" class="flex-1">
|
||||||
<button
|
<button
|
||||||
type="submit"
|
type="submit"
|
||||||
class="w-full rounded-lg px-3 py-2
|
class="w-full rounded-lg px-3 py-2
|
||||||
text-sm text-mantle bg-green font-semibold hover:bg-teal hover:cursor-pointer transition"
|
text-sm text-mantle bg-green font-semibold hover:bg-teal hover:cursor-pointer transition"
|
||||||
role="menuitem"
|
role="menuitem"
|
||||||
@click="isActive=false"
|
@click="isActive=false"
|
||||||
>
|
>
|
||||||
Stop Preview
|
Stop Preview
|
||||||
</button>
|
</button>
|
||||||
</form>
|
</form>
|
||||||
<form method="POST" action="/admin/roles/preview-stop" class="flex-1">
|
<form method="POST" action="/admin/roles/preview-stop" class="flex-1">
|
||||||
<button
|
<button
|
||||||
type="submit"
|
type="submit"
|
||||||
class="w-full rounded-lg px-3 py-2
|
class="w-full rounded-lg px-3 py-2
|
||||||
text-sm text-mantle bg-blue font-semibold hover:bg-sky hover:cursor-pointer transition"
|
text-sm text-mantle bg-blue font-semibold hover:bg-sky hover:cursor-pointer transition"
|
||||||
role="menuitem"
|
role="menuitem"
|
||||||
@click="isActive=false"
|
@click="isActive=false"
|
||||||
>
|
>
|
||||||
Return to Admin
|
Return to Admin
|
||||||
</button>
|
</button>
|
||||||
</form>
|
</form>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
}
|
||||||
}
|
|
||||||
<!-- Profile links -->
|
<!-- Profile links -->
|
||||||
<div class="p-2">
|
<div class="p-2">
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
|
|||||||
Reference in New Issue
Block a user