From d2b1a252eaf2372f84d26097e967d1641693adc9 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 3 Feb 2026 21:37:06 +1100 Subject: [PATCH] rbac system first stage --- cmd/oslstats/httpserver.go | 15 +- cmd/oslstats/main.go | 43 ++- cmd/oslstats/middleware.go | 4 +- cmd/oslstats/migrate.go | 17 +- .../20260202231414_add_rbac_system.go | 272 ++++++++++++++++++ cmd/oslstats/migrations/migrations.go | 1 + cmd/oslstats/routes.go | 20 ++ go.mod | 4 +- go.sum | 8 +- internal/auditlog/logger.go | 166 +++++++++++ internal/config/config.go | 17 +- internal/db/auditlog.go | 143 +++++++++ internal/db/discordtokens.go | 15 +- internal/db/paginate.go | 2 +- internal/db/permission.go | 154 ++++++++++ internal/db/role.go | 204 +++++++++++++ internal/db/season.go | 2 +- internal/db/user.go | 120 +++++++- internal/db/userrole.go | 114 ++++++++ internal/handlers/admin_dashboard.go | 36 +++ internal/handlers/admin_users.go | 44 +++ internal/handlers/auth_helpers.go | 56 ++++ internal/handlers/callback.go | 9 + internal/handlers/errors.go | 5 +- internal/handlers/page_opt_helpers.go | 68 +++++ internal/handlers/seasons.go | 63 +--- internal/permissions/constants.go | 23 ++ internal/rbac/cache_middleware.go | 95 ++++++ internal/rbac/checker.go | 108 +++++++ internal/rbac/config.go | 22 ++ internal/rbac/ezconf.go | 41 +++ internal/rbac/protection_middleware.go | 101 +++++++ internal/roles/constants.go | 13 + internal/view/component/admin/user_list.templ | 6 + internal/view/component/nav/navbarright.templ | 23 +- internal/view/layout/admin_dashboard.templ | 8 + internal/view/page/admin_dashboard.templ | 10 + pkg/contexts/keys.go | 28 +- 38 files changed, 1966 insertions(+), 114 deletions(-) create mode 100644 cmd/oslstats/migrations/20260202231414_add_rbac_system.go create mode 100644 internal/auditlog/logger.go create mode 100644 internal/db/auditlog.go create mode 100644 internal/db/permission.go create mode 100644 internal/db/role.go create mode 100644 internal/db/userrole.go create mode 100644 internal/handlers/admin_dashboard.go create mode 100644 internal/handlers/admin_users.go create mode 100644 internal/handlers/auth_helpers.go create mode 100644 internal/handlers/page_opt_helpers.go create mode 100644 internal/permissions/constants.go create mode 100644 internal/rbac/cache_middleware.go create mode 100644 internal/rbac/checker.go create mode 100644 internal/rbac/config.go create mode 100644 internal/rbac/ezconf.go create mode 100644 internal/rbac/protection_middleware.go create mode 100644 internal/roles/constants.go create mode 100644 internal/view/component/admin/user_list.templ create mode 100644 internal/view/layout/admin_dashboard.templ create mode 100644 internal/view/page/admin_dashboard.templ diff --git a/cmd/oslstats/httpserver.go b/cmd/oslstats/httpserver.go index d1dbf24..7bfea84 100644 --- a/cmd/oslstats/httpserver.go +++ b/cmd/oslstats/httpserver.go @@ -9,9 +9,11 @@ import ( "github.com/pkg/errors" "github.com/uptrace/bun" + "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/store" ) @@ -58,12 +60,21 @@ func setupHTTPServer( return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths") } - err = addRoutes(httpServer, &fs, cfg, bun, auth, store, discordAPI) + // Initialize permissions checker + perms, err := rbac.NewChecker(bun, server) + if err != nil { + return nil, errors.Wrap(err, "rbac.NewChecker") + } + + // Initialize audit logger + audit := auditlog.NewLogger(bun) + + err = addRoutes(httpServer, &fs, cfg, bun, auth, store, discordAPI, perms, audit) if err != nil { return nil, errors.Wrap(err, "addRoutes") } - err = addMiddleware(httpServer, auth, cfg) + err = addMiddleware(httpServer, auth, cfg, perms) if err != nil { return nil, errors.Wrap(err, "addMiddleware") } diff --git a/cmd/oslstats/main.go b/cmd/oslstats/main.go index 1cdb585..cdef7fa 100644 --- a/cmd/oslstats/main.go +++ b/cmd/oslstats/main.go @@ -24,6 +24,21 @@ func main() { fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config")) os.Exit(1) } + // Handle utility flags + if flags.EnvDoc || flags.ShowEnv { + if err = loader.PrintEnvVarsStdout(flags.ShowEnv); err != nil { + fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to print env doc")) + } + return + } + + if flags.GenEnv != "" { + if err = loader.GenerateEnvFile(flags.GenEnv, true); err != nil { + fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to generate env file")) + } + return + } + // // Setup the logger logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout) if err != nil { @@ -31,17 +46,6 @@ func main() { os.Exit(1) } - // Handle utility flags - if flags.EnvDoc || flags.ShowEnv { - loader.PrintEnvVarsStdout(flags.ShowEnv) - return - } - - if flags.GenEnv != "" { - loader.GenerateEnvFile(flags.GenEnv, true) - return - } - // Handle migration file creation (doesn't need DB connection) if flags.MigrateCreate != "" { if err := createMigration(flags.MigrateCreate); err != nil { @@ -55,24 +59,17 @@ func main() { flags.MigrateStatus || flags.MigrateDryRun || flags.ResetDB { - // Setup database connection - conn, close, err := setupBun(ctx, cfg) - if err != nil { - logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "setupBun"))).Msg("Error setting up database") - } - defer close() - // Route to appropriate command if flags.MigrateUp { - err = runMigrations(ctx, conn, cfg, "up") + err = runMigrations(ctx, cfg, "up") } else if flags.MigrateRollback { - err = runMigrations(ctx, conn, cfg, "rollback") + err = runMigrations(ctx, cfg, "rollback") } else if flags.MigrateStatus { - err = runMigrations(ctx, conn, cfg, "status") + err = runMigrations(ctx, cfg, "status") } else if flags.MigrateDryRun { - err = runMigrations(ctx, conn, cfg, "dry-run") + err = runMigrations(ctx, cfg, "dry-run") } else if flags.ResetDB { - err = resetDatabase(ctx, conn) + err = resetDatabase(ctx, cfg) } if err != nil { diff --git a/cmd/oslstats/middleware.go b/cmd/oslstats/middleware.go index 8dd5cd8..81d92c6 100644 --- a/cmd/oslstats/middleware.go +++ b/cmd/oslstats/middleware.go @@ -9,6 +9,7 @@ import ( "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/pkg/contexts" "github.com/pkg/errors" @@ -19,10 +20,11 @@ func addMiddleware( server *hws.Server, auth *hwsauth.Authenticator[*db.User, bun.Tx], cfg *config.Config, + perms *rbac.Checker, ) error { - err := server.AddMiddleware( auth.Authenticate(), + perms.LoadPermissionsMiddleware(), devMode(cfg), ) if err != nil { diff --git a/cmd/oslstats/migrate.go b/cmd/oslstats/migrate.go index 3cc05ce..81c8a3d 100644 --- a/cmd/oslstats/migrate.go +++ b/cmd/oslstats/migrate.go @@ -20,7 +20,13 @@ import ( ) // runMigrations executes database migrations -func runMigrations(ctx context.Context, conn *bun.DB, cfg *config.Config, command string) error { +func runMigrations(ctx context.Context, cfg *config.Config, command string) error { + conn, close, err := setupBun(ctx, cfg) + if err != nil { + return errors.Wrap(err, "setupBun") + } + defer close() + migrator := migrate.NewMigrator(conn, migrations.Migrations) // Initialize migration tables @@ -306,7 +312,7 @@ func init() { ` // Write file - if err := os.WriteFile(filename, []byte(template), 0644); err != nil { + if err := os.WriteFile(filename, []byte(template), 0o644); err != nil { return errors.Wrap(err, "write migration file") } @@ -319,7 +325,7 @@ func init() { } // resetDatabase drops and recreates all tables (destructive) -func resetDatabase(ctx context.Context, conn *bun.DB) error { +func resetDatabase(ctx context.Context, cfg *config.Config) error { fmt.Println("⚠️ WARNING: This will DELETE ALL DATA in the database!") fmt.Print("Type 'yes' to continue: ") @@ -334,6 +340,11 @@ func resetDatabase(ctx context.Context, conn *bun.DB) error { fmt.Println("❌ Reset cancelled") return nil } + conn, close, err := setupBun(ctx, cfg) + if err != nil { + return errors.Wrap(err, "setupBun") + } + defer close() models := []any{ (*db.User)(nil), diff --git a/cmd/oslstats/migrations/20260202231414_add_rbac_system.go b/cmd/oslstats/migrations/20260202231414_add_rbac_system.go new file mode 100644 index 0000000..72c2bf5 --- /dev/null +++ b/cmd/oslstats/migrations/20260202231414_add_rbac_system.go @@ -0,0 +1,272 @@ +package migrations + +import ( + "context" + "time" + + "git.haelnorr.com/h/oslstats/internal/db" + "github.com/uptrace/bun" +) + +func init() { + Migrations.MustRegister( + // UP migration + func(ctx context.Context, dbConn *bun.DB) error { + // Create roles table using raw SQL to avoid m2m relationship issues + // Bun tries to resolve relationships when creating tables from models + // TODO: use proper m2m table instead of raw sql + _, err := dbConn.ExecContext(ctx, ` + CREATE TABLE roles ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) UNIQUE NOT NULL, + display_name VARCHAR(100) NOT NULL, + description TEXT, + is_system BOOLEAN DEFAULT FALSE, + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL + ) + `) + if err != nil { + return err + } + + // Create permissions table + _, err = dbConn.NewCreateTable(). + Model((*db.Permission)(nil)). + Exec(ctx) + if err != nil { + return err + } + + // Create indexes for permissions + // TODO: why do we need this? + _, err = dbConn.NewCreateIndex(). + Model((*db.Permission)(nil)). + Index("idx_permissions_resource"). + Column("resource"). + Exec(ctx) + if err != nil { + return err + } + + // TODO: why do we need this? + _, err = dbConn.NewCreateIndex(). + Model((*db.Permission)(nil)). + Index("idx_permissions_action"). + Column("action"). + Exec(ctx) + if err != nil { + return err + } + + // Create role_permissions join table (Bun doesn't auto-create m2m tables) + // TODO: use proper m2m table instead of raw sql + _, err = dbConn.ExecContext(ctx, ` + CREATE TABLE role_permissions ( + id SERIAL PRIMARY KEY, + role_id INTEGER NOT NULL REFERENCES roles(id) ON DELETE CASCADE, + permission_id INTEGER NOT NULL REFERENCES permissions(id) ON DELETE CASCADE, + created_at BIGINT NOT NULL, + UNIQUE(role_id, permission_id) + ) + `) + if err != nil { + return err + } + + // TODO: why do we need this? + _, err = dbConn.ExecContext(ctx, ` + CREATE INDEX idx_role_permissions_role ON role_permissions(role_id) + `) + if err != nil { + return err + } + + // TODO: why do we need this? + _, err = dbConn.ExecContext(ctx, ` + CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id) + `) + if err != nil { + return err + } + + // Create user_roles table + _, err = dbConn.NewCreateTable(). + Model((*db.UserRole)(nil)). + Exec(ctx) + if err != nil { + return err + } + + // Create indexes for user_roles + // TODO: why do we need this? + _, err = dbConn.NewCreateIndex(). + Model((*db.UserRole)(nil)). + Index("idx_user_roles_user"). + Column("user_id"). + Exec(ctx) + if err != nil { + return err + } + + // TODO: why do we need this? + _, err = dbConn.NewCreateIndex(). + Model((*db.UserRole)(nil)). + Index("idx_user_roles_role"). + Column("role_id"). + Exec(ctx) + if err != nil { + return err + } + + // Create audit_log table + _, err = dbConn.NewCreateTable(). + Model((*db.AuditLog)(nil)). + Exec(ctx) + if err != nil { + return err + } + + // Create indexes for audit_log + // TODO: why do we need this? + _, err = dbConn.NewCreateIndex(). + Model((*db.AuditLog)(nil)). + Index("idx_audit_log_user"). + Column("user_id"). + Exec(ctx) + if err != nil { + return err + } + + // TODO: why do we need this? + _, err = dbConn.NewCreateIndex(). + Model((*db.AuditLog)(nil)). + Index("idx_audit_log_action"). + Column("action"). + Exec(ctx) + if err != nil { + return err + } + + // TODO: why do we need this? + _, err = dbConn.NewCreateIndex(). + Model((*db.AuditLog)(nil)). + Index("idx_audit_log_resource"). + Column("resource_type", "resource_id"). + Exec(ctx) + if err != nil { + return err + } + + // TODO: why do we need this? + _, err = dbConn.NewCreateIndex(). + Model((*db.AuditLog)(nil)). + Index("idx_audit_log_created"). + Column("created_at"). + Exec(ctx) + if err != nil { + return err + } + + // Seed system roles + now := time.Now().Unix() + + adminRole := &db.Role{ + Name: "admin", + DisplayName: "Administrator", + Description: "Full system access with all permissions", + IsSystem: true, + CreatedAt: now, // TODO: this should be defaulted in table + UpdatedAt: now, // TODO: this should be defaulted in table + } + + _, err = dbConn.NewInsert(). + Model(adminRole). + Exec(ctx) + if err != nil { + return err + } + + userRole := &db.Role{ + Name: "user", + DisplayName: "User", + Description: "Standard user with basic permissions", + IsSystem: true, + CreatedAt: now, // TODO: this should be defaulted in table + UpdatedAt: now, // TODO: this should be defaulted in table + + } + + _, err = dbConn.NewInsert(). + Model(userRole). + Exec(ctx) + if err != nil { + return err + } + + // Seed system permissions + // TODO: timestamps for created should be defaulted in table + permissionsData := []*db.Permission{ + {Name: "*", DisplayName: "Wildcard (All Permissions)", Description: "Grants access to all permissions, past, present, and future", Resource: "*", Action: "*", IsSystem: true, CreatedAt: now}, + {Name: "seasons.create", DisplayName: "Create Seasons", Description: "Create new seasons", Resource: "seasons", Action: "create", IsSystem: true, CreatedAt: now}, + {Name: "seasons.update", DisplayName: "Update Seasons", Description: "Update existing seasons", Resource: "seasons", Action: "update", IsSystem: true, CreatedAt: now}, + {Name: "seasons.delete", DisplayName: "Delete Seasons", Description: "Delete seasons", Resource: "seasons", Action: "delete", IsSystem: true, CreatedAt: now}, + {Name: "users.update", DisplayName: "Update Users", Description: "Update user information", Resource: "users", Action: "update", IsSystem: true, CreatedAt: now}, + {Name: "users.ban", DisplayName: "Ban Users", Description: "Ban users from the system", Resource: "users", Action: "ban", IsSystem: true, CreatedAt: now}, + {Name: "users.manage_roles", DisplayName: "Manage User Roles", Description: "Assign and revoke user roles", Resource: "users", Action: "manage_roles", IsSystem: true, CreatedAt: now}, + } + + _, err = dbConn.NewInsert(). + Model(&permissionsData). + Exec(ctx) + if err != nil { + return err + } + + // Grant wildcard permission to admin role using Bun + // First, get the IDs + var wildcardPerm db.Permission + err = dbConn.NewSelect(). + Model(&wildcardPerm). + Where("name = ?", "*"). + Scan(ctx) + if err != nil { + return err + } + + // Insert role_permission mapping + // TODO: use proper m2m table, and default now in table settings + _, err = dbConn.ExecContext(ctx, ` + INSERT INTO role_permissions (role_id, permission_id, created_at) + VALUES ($1, $2, $3) + `, adminRole.ID, wildcardPerm.ID, now) + if err != nil { + return err + } + + return nil + }, + // DOWN migration + func(ctx context.Context, dbConn *bun.DB) error { + // Drop tables in reverse order + // Use raw SQL to avoid relationship resolution issues + // TODO: surely we can use proper bun methods? + tables := []string{ + "audit_log", + "user_roles", + "role_permissions", + "permissions", + "roles", + } + + for _, table := range tables { + _, err := dbConn.ExecContext(ctx, "DROP TABLE IF EXISTS "+table+" CASCADE") + if err != nil { + return err + } + } + + return nil + }, + ) +} diff --git a/cmd/oslstats/migrations/migrations.go b/cmd/oslstats/migrations/migrations.go index 940c08f..661fe58 100644 --- a/cmd/oslstats/migrations/migrations.go +++ b/cmd/oslstats/migrations/migrations.go @@ -1,3 +1,4 @@ +// Package migrations defines the database migrations to apply when using the migrate tags package migrations import ( diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index a37c837..b663d46 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -8,10 +8,12 @@ import ( "github.com/pkg/errors" "github.com/uptrace/bun" + "git.haelnorr.com/h/oslstats/internal/auditlog" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/store" ) @@ -23,6 +25,8 @@ func addRoutes( auth *hwsauth.Authenticator[*db.User, bun.Tx], store *store.Store, discordAPI *discord.APIClient, + perms *rbac.Checker, + audit *auditlog.Logger, ) error { // Create the routes pageroutes := []hws.Route{ @@ -115,8 +119,24 @@ func addRoutes( }, } + // Admin routes + adminRoutes := []hws.Route{ + { + // TODO: on page load, redirect to /admin/users + Path: "/admin", + Method: hws.MethodGET, + Handler: perms.RequireAdmin(s)(handlers.AdminDashboard(s, conn)), + }, + { + Path: "/admin/users", + Method: hws.MethodPOST, + Handler: perms.RequireAdmin(s)(handlers.AdminUsersList(s, conn)), + }, + } + routes := append(pageroutes, htmxRoutes...) routes = append(routes, wsRoutes...) + routes = append(routes, adminRoutes...) // Register the routes with the server err := s.AddRoutes(routes...) diff --git a/go.mod b/go.mod index ec498d5..91d57dd 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,8 @@ require ( git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/hlog v0.10.4 - git.haelnorr.com/h/golib/hws v0.4.4 - git.haelnorr.com/h/golib/hwsauth v0.5.4 + git.haelnorr.com/h/golib/hws v0.5.0 + git.haelnorr.com/h/golib/hwsauth v0.5.5 git.haelnorr.com/h/golib/notify v0.1.0 github.com/a-h/templ v0.3.977 github.com/coder/websocket v1.8.14 diff --git a/go.sum b/go.sum index cd54517..3721fe7 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,10 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= -git.haelnorr.com/h/golib/hws v0.4.4 h1:tV9UjZ4q96UlOdJKsC7b3kDV+bpQYqKVPQuaV1n3U3k= -git.haelnorr.com/h/golib/hws v0.4.4/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM= -git.haelnorr.com/h/golib/hwsauth v0.5.4 h1:nuaiVpJHHXgKVRPoQSE/v3CJHSkivViK5h3SVhEcbbM= -git.haelnorr.com/h/golib/hwsauth v0.5.4/go.mod h1:eIjRPeGycvxRWERkxCoRVMEEhHuUdiPDvjpzzZOhQ0w= +git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c= +git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM= +git.haelnorr.com/h/golib/hwsauth v0.5.5 h1:w1qssktq0zYo5cC/xa44h/ZE5G5r7rIsJ4QQWq2Jeoo= +git.haelnorr.com/h/golib/hwsauth v0.5.5/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10= diff --git a/internal/auditlog/logger.go b/internal/auditlog/logger.go new file mode 100644 index 0000000..e3e806f --- /dev/null +++ b/internal/auditlog/logger.go @@ -0,0 +1,166 @@ +// Package auditlog provides a system for logging events that require permissions to the audit log +package auditlog + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "git.haelnorr.com/h/oslstats/internal/db" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type Logger struct { + conn *bun.DB +} + +func NewLogger(conn *bun.DB) *Logger { + return &Logger{conn: conn} +} + +// LogSuccess logs a successful permission-protected action +func (l *Logger) LogSuccess( + ctx context.Context, + tx bun.Tx, + user *db.User, + action string, + resourceType string, + resourceID any, // Can be int, string, or nil + details map[string]any, + r *http.Request, +) error { + return l.log(ctx, tx, user, action, resourceType, resourceID, details, "success", nil, r) +} + +// LogError logs a failed action due to an error +func (l *Logger) LogError( + ctx context.Context, + tx bun.Tx, + user *db.User, + action string, + resourceType string, + resourceID any, + err error, + r *http.Request, +) error { + errMsg := err.Error() + return l.log(ctx, tx, user, action, resourceType, resourceID, nil, "error", &errMsg, r) +} + +func (l *Logger) log( + ctx context.Context, + tx bun.Tx, + user *db.User, + action string, + resourceType string, + resourceID any, + details map[string]any, + result string, + errorMessage *string, + r *http.Request, +) error { + if user == nil { + return errors.New("user cannot be nil for audit logging") + } + + // Convert resourceID to string + var resourceIDStr *string + if resourceID != nil { + idStr := fmt.Sprintf("%v", resourceID) + resourceIDStr = &idStr + } + + // Marshal details to JSON + var detailsJSON json.RawMessage + if details != nil { + jsonBytes, err := json.Marshal(details) + if err != nil { + return errors.Wrap(err, "json.Marshal details") + } + detailsJSON = jsonBytes + } + + // Extract IP and User-Agent from request + ipAddress := r.RemoteAddr + userAgent := r.UserAgent() + + log := &db.AuditLog{ + UserID: user.ID, + Action: action, + ResourceType: resourceType, + ResourceID: resourceIDStr, + Details: detailsJSON, + IPAddress: ipAddress, + UserAgent: userAgent, + Result: result, + ErrorMessage: errorMessage, + CreatedAt: time.Now().Unix(), + } + + return db.CreateAuditLog(ctx, tx, log) +} + +// GetRecentLogs retrieves recent audit logs with pagination +// TODO: change this to user db.PageOpts +func (l *Logger) GetRecentLogs(ctx context.Context, limit, offset int) ([]*db.AuditLog, int, error) { + tx, err := l.conn.BeginTx(ctx, nil) + if err != nil { + return nil, 0, errors.Wrap(err, "conn.BeginTx") + } + defer func() { _ = tx.Rollback() }() + + logs, total, err := db.GetAuditLogs(ctx, tx, limit, offset, nil) + if err != nil { + return nil, 0, err + } + + _ = tx.Commit() // read only transaction + return logs, total, nil +} + +// GetLogsByUser retrieves audit logs for a specific user +// TODO: change this to user db.PageOpts +func (l *Logger) GetLogsByUser(ctx context.Context, userID int, limit, offset int) ([]*db.AuditLog, int, error) { + tx, err := l.conn.BeginTx(ctx, nil) + if err != nil { + return nil, 0, errors.Wrap(err, "conn.BeginTx") + } + defer func() { _ = tx.Rollback() }() + + logs, total, err := db.GetAuditLogsByUser(ctx, tx, userID, limit, offset) + if err != nil { + return nil, 0, err + } + + _ = tx.Commit() // read only transaction + return logs, total, nil +} + +// CleanupOldLogs deletes audit logs older than the specified number of days +func (l *Logger) CleanupOldLogs(ctx context.Context, daysToKeep int) (int, error) { + if daysToKeep <= 0 { + return 0, errors.New("daysToKeep must be positive") + } + + cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix() + + tx, err := l.conn.BeginTx(ctx, nil) + if err != nil { + return 0, errors.Wrap(err, "conn.BeginTx") + } + defer func() { _ = tx.Rollback() }() + + count, err := db.CleanupOldAuditLogs(ctx, tx, cutoffTime) + if err != nil { + return 0, err + } + + err = tx.Commit() + if err != nil { + return 0, errors.Wrap(err, "tx.Commit") + } + return count, nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 112752a..3cd6094 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,3 +1,4 @@ +// Package config provides the environment based configuration for the program package config import ( @@ -7,6 +8,7 @@ import ( "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/pkg/oauth" "github.com/joho/godotenv" "github.com/pkg/errors" @@ -19,10 +21,11 @@ type Config struct { HLOG *hlog.Config Discord *discord.Config OAuth *oauth.Config + RBAC *rbac.Config Flags *Flags } -// Load the application configuration and get a pointer to the Config object +// GetConfig loads the application configuration and returns a pointer to the Config object // If doconly is specified, only the loader will be returned func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { err := godotenv.Load(flags.EnvFile) @@ -31,14 +34,18 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { } loader := ezconf.New() - loader.RegisterIntegrations( + err = loader.RegisterIntegrations( hlog.NewEZConfIntegration(), hws.NewEZConfIntegration(), hwsauth.NewEZConfIntegration(), db.NewEZConfIntegration(), discord.NewEZConfIntegration(), oauth.NewEZConfIntegration(), + rbac.NewEZConfIntegration(), ) + if err != nil { + return nil, nil, errors.Wrap(err, "loader.RegisterIntegrations") + } if err := loader.ParseEnvVars(); err != nil { return nil, nil, errors.Wrap(err, "loader.ParseEnvVars") } @@ -81,6 +88,11 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { return nil, nil, errors.New("OAuth Config not loaded") } + rbaccfg, ok := loader.GetConfig("rbac") + if !ok { + return nil, nil, errors.New("RBAC Config not loaded") + } + config := &Config{ DB: dbcfg.(*db.Config), HWS: hwscfg.(*hws.Config), @@ -88,6 +100,7 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { HLOG: hlogcfg.(*hlog.Config), Discord: discordcfg.(*discord.Config), OAuth: oauthcfg.(*oauth.Config), + RBAC: rbaccfg.(*rbac.Config), Flags: flags, } diff --git a/internal/db/auditlog.go b/internal/db/auditlog.go new file mode 100644 index 0000000..f5c0d83 --- /dev/null +++ b/internal/db/auditlog.go @@ -0,0 +1,143 @@ +package db + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type AuditLog struct { + bun.BaseModel `bun:"table:audit_log,alias:al"` + + ID int `bun:"id,pk,autoincrement"` + UserID int `bun:"user_id,notnull"` + Action string `bun:"action,notnull"` + ResourceType string `bun:"resource_type,notnull"` + ResourceID *string `bun:"resource_id"` + Details json.RawMessage `bun:"details,type:jsonb"` + IPAddress string `bun:"ip_address"` + UserAgent string `bun:"user_agent"` + Result string `bun:"result,notnull"` // success, denied, error + ErrorMessage *string `bun:"error_message"` + CreatedAt int64 `bun:"created_at,notnull"` + + // Relations + User *User `bun:"rel:belongs-to,join:user_id=id"` +} + +// TODO: add AuditLogs to match list style with PageOpts + +// CreateAuditLog creates a new audit log entry +func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error { + if log == nil { + return errors.New("log cannot be nil") + } + + _, err := tx.NewInsert(). + Model(log). + Exec(ctx) + if err != nil { + return errors.Wrap(err, "tx.NewInsert") + } + + return nil +} + +type AuditLogFilters struct { + UserID *int + Action *string + ResourceType *string + Result *string +} + +// GetAuditLogs retrieves audit logs with optional filters and pagination +// TODO: change this to use db.PageOpts +func GetAuditLogs(ctx context.Context, tx bun.Tx, limit, offset int, filters *AuditLogFilters) ([]*AuditLog, int, error) { + query := tx.NewSelect(). + Model((*AuditLog)(nil)). + Relation("User"). + Order("created_at DESC") + + // Apply filters if provided + if filters != nil { + if filters.UserID != nil { + query = query.Where("al.user_id = ?", *filters.UserID) + } + if filters.Action != nil { + query = query.Where("al.action = ?", *filters.Action) + } + if filters.ResourceType != nil { + query = query.Where("al.resource_type = ?", *filters.ResourceType) + } + if filters.Result != nil { + query = query.Where("al.result = ?", *filters.Result) + } + } + + // Get total count + total, err := query.Count(ctx) + if err != nil { + return nil, 0, errors.Wrap(err, "query.Count") + } + + // Get paginated results + var logs []*AuditLog + err = query. + Limit(limit). + Offset(offset). + Scan(ctx, &logs) + if err != nil && err != sql.ErrNoRows { + return nil, 0, errors.Wrap(err, "query.Scan") + } + + return logs, total, nil +} + +// GetAuditLogsByUser retrieves audit logs for a specific user +// TODO: change this to use db.PageOpts +func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, limit, offset int) ([]*AuditLog, int, error) { + if userID <= 0 { + return nil, 0, errors.New("userID must be positive") + } + + filters := &AuditLogFilters{ + UserID: &userID, + } + + return GetAuditLogs(ctx, tx, limit, offset, filters) +} + +// GetAuditLogsByAction retrieves audit logs for a specific action +// TODO: change this to use db.PageOpts +func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, limit, offset int) ([]*AuditLog, int, error) { + if action == "" { + return nil, 0, errors.New("action cannot be empty") + } + + filters := &AuditLogFilters{ + Action: &action, + } + + return GetAuditLogs(ctx, tx, limit, offset, filters) +} + +// CleanupOldAuditLogs deletes audit logs older than the specified timestamp +func CleanupOldAuditLogs(ctx context.Context, tx bun.Tx, olderThan int64) (int, error) { + result, err := tx.NewDelete(). + Model((*AuditLog)(nil)). + Where("created_at < ?", olderThan). + Exec(ctx) + if err != nil { + return 0, errors.Wrap(err, "tx.NewDelete") + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return 0, errors.Wrap(err, "result.RowsAffected") + } + + return int(rowsAffected), nil +} diff --git a/internal/db/discordtokens.go b/internal/db/discordtokens.go index 7cc09d9..bf70d5b 100644 --- a/internal/db/discordtokens.go +++ b/internal/db/discordtokens.go @@ -22,14 +22,14 @@ type DiscordToken struct { // UpdateDiscordToken adds the provided discord token to the database. // If the user already has a token stored, it will replace that token instead. -func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error { +func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error { if token == nil { return errors.New("token cannot be nil") } expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix() discordToken := &DiscordToken{ - DiscordID: user.DiscordID, + DiscordID: u.DiscordID, AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, ExpiresAt: expiresAt, @@ -44,7 +44,6 @@ func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *disc Set("refresh_token = EXCLUDED.refresh_token"). Set("expires_at = EXCLUDED.expires_at"). Exec(ctx) - if err != nil { return errors.Wrap(err, "tx.NewInsert") } @@ -53,14 +52,14 @@ func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *disc // DeleteDiscordTokens deletes a users discord OAuth tokens from the database. // It returns the DiscordToken so that it can be revoked via the discord API -func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { - token, err := user.GetDiscordToken(ctx, tx) +func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { + token, err := u.GetDiscordToken(ctx, tx) if err != nil { return nil, errors.Wrap(err, "user.GetDiscordToken") } _, err = tx.NewDelete(). Model((*DiscordToken)(nil)). - Where("discord_id = ?", user.DiscordID). + Where("discord_id = ?", u.DiscordID). Exec(ctx) if err != nil { return nil, errors.Wrap(err, "tx.NewDelete") @@ -69,11 +68,11 @@ func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordT } // GetDiscordToken retrieves the users discord token from the database -func (user *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { +func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { token := new(DiscordToken) err := tx.NewSelect(). Model(token). - Where("discord_id = ?", user.DiscordID). + Where("discord_id = ?", u.DiscordID). Limit(1). Scan(ctx) if err != nil { diff --git a/internal/db/paginate.go b/internal/db/paginate.go index e4ad389..366cb33 100644 --- a/internal/db/paginate.go +++ b/internal/db/paginate.go @@ -48,7 +48,7 @@ func (p *PageOpts) GetPageRange(total int, maxButtons int) []int { // If total pages is less than max buttons, show all pages if totalPages <= maxButtons { pages := make([]int, totalPages) - for i := 0; i < totalPages; i++ { + for i := range totalPages { pages[i] = i + 1 } return pages diff --git a/internal/db/permission.go b/internal/db/permission.go new file mode 100644 index 0000000..64aacef --- /dev/null +++ b/internal/db/permission.go @@ -0,0 +1,154 @@ +package db + +import ( + "context" + "database/sql" + + "git.haelnorr.com/h/oslstats/internal/permissions" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type Permission struct { + bun.BaseModel `bun:"table:permissions,alias:p"` + + ID int `bun:"id,pk,autoincrement"` + Name permissions.Permission `bun:"name,unique,notnull"` + DisplayName string `bun:"display_name,notnull"` + Description string `bun:"description"` + Resource string `bun:"resource,notnull"` + Action string `bun:"action,notnull"` + IsSystem bool `bun:"is_system,default:false"` + CreatedAt int64 `bun:"created_at,notnull"` +} + +// GetPermissionByName queries the database for a permission matching the given name +// Returns nil, nil if no permission is found +func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) { + if name == "" { + return nil, errors.New("name cannot be empty") + } + + perm := new(Permission) + err := tx.NewSelect(). + Model(perm). + Where("name = ?", name). + Limit(1). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return perm, nil +} + +// GetPermissionByID queries the database for a permission matching the given ID +// Returns nil, nil if no permission is found +func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, error) { + if id <= 0 { + return nil, errors.New("id must be positive") + } + + perm := new(Permission) + err := tx.NewSelect(). + Model(perm). + Where("id = ?", id). + Limit(1). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return perm, nil +} + +// GetPermissionsByResource queries for all permissions for a given resource +func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) ([]*Permission, error) { + if resource == "" { + return nil, errors.New("resource cannot be empty") + } + + var perms []*Permission + err := tx.NewSelect(). + Model(&perms). + Where("resource = ?", resource). + Order("action ASC"). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return perms, nil +} + +// GetPermissionsByIDs queries for permissions matching the given IDs +func GetPermissionsByIDs(ctx context.Context, tx bun.Tx, ids []int) ([]*Permission, error) { + if len(ids) == 0 { + return []*Permission{}, nil + } + + var perms []*Permission + err := tx.NewSelect(). + Model(&perms). + Where("id IN (?)", bun.In(ids)). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return perms, nil +} + +// ListAllPermissions returns all permissions +func ListAllPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) { + var perms []*Permission + err := tx.NewSelect(). + Model(&perms). + Order("resource ASC", "action ASC"). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return perms, nil +} + +// CreatePermission creates a new permission +func CreatePermission(ctx context.Context, tx bun.Tx, perm *Permission) error { + if perm == nil { + return errors.New("permission cannot be nil") + } + + _, err := tx.NewInsert(). + Model(perm). + Exec(ctx) + if err != nil { + return errors.Wrap(err, "tx.NewInsert") + } + + return nil +} + +// DeletePermission deletes a permission (checks IsSystem protection) +func DeletePermission(ctx context.Context, tx bun.Tx, id int) error { + if id <= 0 { + return errors.New("id must be positive") + } + + // Check if permission is system permission + perm, err := GetPermissionByID(ctx, tx, id) + if err != nil { + return errors.Wrap(err, "GetPermissionByID") + } + if perm == nil { + return errors.New("permission not found") + } + if perm.IsSystem { + return errors.New("cannot delete system permission") + } + + _, err = tx.NewDelete(). + Model((*Permission)(nil)). + Where("id = ?", id). + Exec(ctx) + if err != nil { + return errors.Wrap(err, "tx.NewDelete") + } + + return nil +} diff --git a/internal/db/role.go b/internal/db/role.go new file mode 100644 index 0000000..dbf87bd --- /dev/null +++ b/internal/db/role.go @@ -0,0 +1,204 @@ +package db + +import ( + "context" + "database/sql" + + "git.haelnorr.com/h/oslstats/internal/roles" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type Role struct { + bun.BaseModel `bun:"table:roles,alias:r"` + + ID int `bun:"id,pk,autoincrement"` + Name roles.Role `bun:"name,unique,notnull"` + DisplayName string `bun:"display_name,notnull"` + Description string `bun:"description"` + IsSystem bool `bun:"is_system,default:false"` + CreatedAt int64 `bun:"created_at,notnull"` + UpdatedAt int64 `bun:"updated_at,notnull"` + + // Relations (loaded on demand) + Permissions []*Permission `bun:"m2m:role_permissions,join:Role=Permission"` +} + +// GetRoleByName queries the database for a role matching the given name +// Returns nil, nil if no role is found +func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) { + if name == "" { + return nil, errors.New("name cannot be empty") + } + + role := new(Role) + err := tx.NewSelect(). + Model(role). + Where("name = ?", name). + Limit(1). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return role, nil +} + +// GetRoleByID queries the database for a role matching the given ID +// Returns nil, nil if no role is found +func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) { + if id <= 0 { + return nil, errors.New("id must be positive") + } + + role := new(Role) + err := tx.NewSelect(). + Model(role). + Where("id = ?", id). + Limit(1). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return role, nil +} + +// GetRoleWithPermissions loads a role and all its permissions +func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) { + if id <= 0 { + return nil, errors.New("id must be positive") + } + + role := new(Role) + err := tx.NewSelect(). + Model(role). + Where("id = ?", id). + Relation("Permissions"). + Limit(1). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return role, nil +} + +// ListAllRoles returns all roles +func ListAllRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) { + var roles []*Role + err := tx.NewSelect(). + Model(&roles). + Order("name ASC"). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return roles, nil +} + +// CreateRole creates a new role +func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error { + if role == nil { + return errors.New("role cannot be nil") + } + + _, err := tx.NewInsert(). + Model(role). + Exec(ctx) + if err != nil { + return errors.Wrap(err, "tx.NewInsert") + } + + return nil +} + +// UpdateRole updates an existing role +func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error { + if role == nil { + return errors.New("role cannot be nil") + } + if role.ID <= 0 { + return errors.New("role id must be positive") + } + + _, err := tx.NewUpdate(). + Model(role). + WherePK(). + Exec(ctx) + if err != nil { + return errors.Wrap(err, "tx.NewUpdate") + } + + return nil +} + +// DeleteRole deletes a role (checks IsSystem protection) +func DeleteRole(ctx context.Context, tx bun.Tx, id int) error { + if id <= 0 { + return errors.New("id must be positive") + } + + // Check if role is system role + role, err := GetRoleByID(ctx, tx, id) + if err != nil { + return errors.Wrap(err, "GetRoleByID") + } + if role == nil { + return errors.New("role not found") + } + if role.IsSystem { + return errors.New("cannot delete system role") + } + + _, err = tx.NewDelete(). + Model((*Role)(nil)). + Where("id = ?", id). + Exec(ctx) + if err != nil { + return errors.Wrap(err, "tx.NewDelete") + } + + return nil +} + +// AddPermissionToRole grants a permission to a role +func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int, createdAt int64) error { + if roleID <= 0 { + return errors.New("roleID must be positive") + } + if permissionID <= 0 { + return errors.New("permissionID must be positive") + } + + // TODO: use proper m2m table + // also make createdAt automatic in table so not required as input here + _, err := tx.ExecContext(ctx, ` + INSERT INTO role_permissions (role_id, permission_id, created_at) + VALUES ($1, $2, $3) + ON CONFLICT (role_id, permission_id) DO NOTHING + `, roleID, permissionID, createdAt) + if err != nil { + return errors.Wrap(err, "tx.ExecContext") + } + + return nil +} + +// RemovePermissionFromRole revokes a permission from a role +func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error { + if roleID <= 0 { + return errors.New("roleID must be positive") + } + if permissionID <= 0 { + return errors.New("permissionID must be positive") + } + + // TODO: use proper m2m table + _, err := tx.ExecContext(ctx, ` + DELETE FROM role_permissions + WHERE role_id = $1 AND permission_id = $2 + `, roleID, permissionID) + if err != nil { + return errors.Wrap(err, "tx.ExecContext") + } + + return nil +} diff --git a/internal/db/season.go b/internal/db/season.go index 261a68d..49021a4 100644 --- a/internal/db/season.go +++ b/internal/db/season.go @@ -78,7 +78,7 @@ func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonLis total, err := tx.NewSelect(). Model(&seasons). Count(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil { return nil, errors.Wrap(err, "tx.NewSelect") } sl := &SeasonList{ diff --git a/internal/db/user.go b/internal/db/user.go index 5af8092..bdcdeeb 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -2,10 +2,13 @@ package db import ( "context" + "database/sql" "fmt" "time" "git.haelnorr.com/h/golib/hwsauth" + "git.haelnorr.com/h/oslstats/internal/permissions" + "git.haelnorr.com/h/oslstats/internal/roles" "github.com/bwmarrin/discordgo" "github.com/pkg/errors" "github.com/uptrace/bun" @@ -22,8 +25,14 @@ type User struct { DiscordID string `bun:"discord_id,unique"` } -func (user *User) GetID() int { - return user.ID +type Users struct { + Users []*User + Total int + PageOpts PageOpts +} + +func (u *User) GetID() int { + return u.ID } // CreateUser creates a new user with the given username and password @@ -114,3 +123,110 @@ func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, er } return count == 0, nil } + +// GetRoles loads and returns all roles for this user +func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) { + if u == nil { + return nil, errors.New("user cannot be nil") + } + return GetUserRoles(ctx, tx, u.ID) +} + +// GetPermissions loads and returns all permissions for this user (via roles) +func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) { + if u == nil { + return nil, errors.New("user cannot be nil") + } + + // TODO: use proper m2m tables and relations instead of join + var permissions []*Permission + err := tx.NewSelect(). + Model(&permissions). + Join("JOIN role_permissions AS rp ON rp.permission_id = p.id"). + Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id"). + Where("ur.user_id = ?", u.ID). + Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return permissions, nil +} + +// HasPermission checks if user has a specific permission (including wildcard check) +func (u *User) HasPermission(ctx context.Context, tx bun.Tx, permissionName permissions.Permission) (bool, error) { + if u == nil { + return false, errors.New("user cannot be nil") + } + if permissionName == "" { + return false, errors.New("permissionName cannot be empty") + } + + perms, err := u.GetPermissions(ctx, tx) + if err != nil { + return false, err + } + + for _, p := range perms { + if p.Name == permissionName || p.Name == permissions.Wildcard { + return true, nil + } + } + return false, nil +} + +// HasRole checks if user has a specific role +func (u *User) HasRole(ctx context.Context, tx bun.Tx, roleName roles.Role) (bool, error) { + if u == nil { + return false, errors.New("user cannot be nil") + } + return HasRole(ctx, tx, u.ID, roleName) +} + +// IsAdmin is a convenience method to check if user has admin role +func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) { + if u == nil { + return false, errors.New("user cannot be nil") + } + return u.HasRole(ctx, tx, "admin") +} + +func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) { + if pageOpts == nil { + pageOpts = &PageOpts{} + } + if pageOpts.Page == 0 { + pageOpts.Page = 1 + } + if pageOpts.PerPage == 0 { + pageOpts.PerPage = 50 + } + if pageOpts.Order == "" { + pageOpts.Order = bun.OrderAsc + } + if pageOpts.OrderBy == "" { + pageOpts.OrderBy = "id" + } + users := []*User{} + err := tx.NewSelect(). + Model(users). + OrderBy(pageOpts.OrderBy, pageOpts.Order). + Limit(pageOpts.PerPage). + Offset(pageOpts.PerPage * (pageOpts.Page - 1)). + Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "tx.NewSelect") + } + total, err := tx.NewSelect(). + Model(users). + Count(ctx) + if err != nil { + return nil, errors.Wrap(err, "tx.NewSelect") + } + list := &Users{ + Users: users, + Total: total, + PageOpts: *pageOpts, + } + return list, nil +} diff --git a/internal/db/userrole.go b/internal/db/userrole.go new file mode 100644 index 0000000..ba2708f --- /dev/null +++ b/internal/db/userrole.go @@ -0,0 +1,114 @@ +package db + +import ( + "context" + "time" + + "git.haelnorr.com/h/oslstats/internal/roles" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type UserRole struct { + bun.BaseModel `bun:"table:user_roles,alias:ur"` + + ID int `bun:"id,pk,autoincrement"` + UserID int `bun:"user_id,notnull"` + RoleID int `bun:"role_id,notnull"` + GrantedBy *int `bun:"granted_by"` + GrantedAt int64 `bun:"granted_at,notnull"` // TODO: default now + ExpiresAt *int64 `bun:"expires_at"` + + // Relations + User *User `bun:"rel:belongs-to,join:user_id=id"` + Role *Role `bun:"rel:belongs-to,join:role_id=id"` +} + +// GetUserRoles loads all roles for a given user +func GetUserRoles(ctx context.Context, tx bun.Tx, userID int) ([]*Role, error) { + if userID <= 0 { + return nil, errors.New("userID must be positive") + } + + var roles []*Role + err := tx.NewSelect(). + Model(&roles). + // TODO: why are we joining? can we do relation? + Join("JOIN user_roles AS ur ON ur.role_id = r.id"). + Where("ur.user_id = ?", userID). + Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()). + Scan(ctx) + if err != nil { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return roles, nil +} + +// AssignRole grants a role to a user +func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, grantedBy *int) error { + if userID <= 0 { + return errors.New("userID must be positive") + } + if roleID <= 0 { + return errors.New("roleID must be positive") + } + + now := time.Now().Unix() + + // TODO: use proper m2m table instead of raw SQL + _, err := tx.ExecContext(ctx, ` + INSERT INTO user_roles (user_id, role_id, granted_by, granted_at) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, role_id) DO NOTHING + `, userID, roleID, grantedBy, now) + if err != nil { + return errors.Wrap(err, "tx.ExecContext") + } + + return nil +} + +// RevokeRole removes a role from a user +func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { + if userID <= 0 { + return errors.New("userID must be positive") + } + if roleID <= 0 { + return errors.New("roleID must be positive") + } + + // TODO: use proper m2m table instead of raw sql + _, err := tx.ExecContext(ctx, ` + DELETE FROM user_roles + WHERE user_id = $1 AND role_id = $2 + `, userID, roleID) + if err != nil { + return errors.Wrap(err, "tx.ExecContext") + } + + return nil +} + +// HasRole checks if a user has a specific role +func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (bool, error) { + if userID <= 0 { + return false, errors.New("userID must be positive") + } + if roleName == "" { + return false, errors.New("roleName cannot be empty") + } + + // TODO: use proper m2m table instead of TableExpr and Join? + count, err := tx.NewSelect(). + TableExpr("user_roles AS ur"). + Join("JOIN roles AS r ON r.id = ur.role_id"). + Where("ur.user_id = ?", userID). + Where("r.name = ?", roleName). + Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()). + Count(ctx) + if err != nil { + return false, errors.Wrap(err, "tx.NewSelect") + } + + return count > 0, nil +} diff --git a/internal/handlers/admin_dashboard.go b/internal/handlers/admin_dashboard.go new file mode 100644 index 0000000..2f16041 --- /dev/null +++ b/internal/handlers/admin_dashboard.go @@ -0,0 +1,36 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/view/page" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) + return + } + defer func() { _ = tx.Rollback() }() + + users, err := db.GetUsers(ctx, tx, nil) + if err != nil { + throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetUsers")) + return + } + _ = tx.Commit() + + renderSafely(page.AdminDashboard(users), s, r, w) + }) +} diff --git a/internal/handlers/admin_users.go b/internal/handlers/admin_users.go new file mode 100644 index 0000000..c15eb2f --- /dev/null +++ b/internal/handlers/admin_users.go @@ -0,0 +1,44 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/view/component/admin" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +// AdminUsersList shows all users +func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(s, w, r, "DB Transaction failed", errors.Wrap(err, "conn.BeginTx")) + return + } + defer func() { _ = tx.Rollback() }() + + // Get all users + pageOpts, err := pageOptsFromForm(r) + if err != nil { + throwBadRequest(s, w, r, "invalid form data", err) + return + } + users, err := db.GetUsers(ctx, tx, pageOpts) + if err != nil { + throwInternalServiceError(s, w, r, "Failed to load users", errors.Wrap(err, "db.GetUsers")) + return + } + + _ = tx.Commit() + + renderSafely(admin.UserList(users), s, r, w) + }) +} diff --git a/internal/handlers/auth_helpers.go b/internal/handlers/auth_helpers.go new file mode 100644 index 0000000..6a29242 --- /dev/null +++ b/internal/handlers/auth_helpers.go @@ -0,0 +1,56 @@ +package handlers + +import ( + "context" + + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/rbac" + "git.haelnorr.com/h/oslstats/internal/roles" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +// shouldGrantAdmin checks if user's Discord ID is in admin list +func shouldGrantAdmin(user *db.User, cfg *rbac.Config) bool { + if cfg == nil || user == nil { + return false + } + if user.DiscordID == cfg.AdminDiscordID { + return true + } + return false +} + +// ensureUserHasAdminRole grants admin role if not already granted +func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error { + if user == nil { + return errors.New("user cannot be nil") + } + + // Check if user already has admin role + hasAdmin, err := user.HasRole(ctx, tx, roles.Admin) + if err != nil { + return errors.Wrap(err, "user.HasRole") + } + + if hasAdmin { + return nil // Already admin + } + + // Get admin role + adminRole, err := db.GetRoleByName(ctx, tx, roles.Admin) + if err != nil { + return errors.Wrap(err, "db.GetRoleByName") + } + if adminRole == nil { + return errors.New("admin role not found in database") + } + + // Grant admin role (nil grantedBy = system granted) + err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil) + if err != nil { + return errors.Wrap(err, "db.AssignRole") + } + + return nil +} diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go index 12082a6..0e9b3d9 100644 --- a/internal/handlers/callback.go +++ b/internal/handlers/callback.go @@ -193,6 +193,15 @@ func login( if err != nil { return nil, errors.Wrap(err, "user.UpdateDiscordToken") } + + // Check if user should be granted admin role (environment-based) + if shouldGrantAdmin(user, cfg.RBAC) { + err := ensureUserHasAdminRole(ctx, tx, user) + if err != nil { + return nil, errors.Wrap(err, "ensureUserHasAdminRole") + } + } + err := auth.Login(w, r, user, true) if err != nil { return nil, errors.Wrap(err, "auth.Login") diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go index a647607..26765e6 100644 --- a/internal/handlers/errors.go +++ b/internal/handlers/errors.go @@ -20,16 +20,13 @@ func throwError( err error, level hws.ErrorLevel, ) { - err = s.ThrowError(w, r, hws.HWSError{ + s.ThrowError(w, r, hws.HWSError{ StatusCode: statusCode, Message: msg, Error: err, Level: level, RenderErrorPage: true, // throw* family always renders error pages }) - if err != nil { - s.ThrowFatal(w, err) - } } // throwInternalServiceError handles 500 errors (server failures) diff --git a/internal/handlers/page_opt_helpers.go b/internal/handlers/page_opt_helpers.go new file mode 100644 index 0000000..19ce03f --- /dev/null +++ b/internal/handlers/page_opt_helpers.go @@ -0,0 +1,68 @@ +package handlers + +import ( + "net/http" + "strconv" + + "git.haelnorr.com/h/oslstats/internal/db" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +func pageOptsFromForm(r *http.Request) (*db.PageOpts, error) { + var pageNum, perPage int + var order bun.Order + var orderBy string + var err error + + if pageStr := r.FormValue("page"); pageStr != "" { + pageNum, err = strconv.Atoi(pageStr) + if err != nil { + return nil, errors.Wrap(err, "invalid page number") + } + } + if perPageStr := r.FormValue("per_page"); perPageStr != "" { + perPage, err = strconv.Atoi(perPageStr) + if err != nil { + return nil, errors.Wrap(err, "invalid per_page number") + } + } + order = bun.Order(r.FormValue("order")) + orderBy = r.FormValue("order_by") + + pageOpts := &db.PageOpts{ + Page: pageNum, + PerPage: perPage, + Order: order, + OrderBy: orderBy, + } + return pageOpts, nil +} + +func pageOptsFromQuery(r *http.Request) (*db.PageOpts, error) { + var pageNum, perPage int + var order bun.Order + var orderBy string + var err error + if pageStr := r.URL.Query().Get("page"); pageStr != "" { + pageNum, err = strconv.Atoi(pageStr) + if err != nil { + return nil, errors.Wrap(err, "invalid page number") + } + } + if perPageStr := r.URL.Query().Get("per_page"); perPageStr != "" { + perPage, err = strconv.Atoi(perPageStr) + if err != nil { + return nil, errors.Wrap(err, "invalid per_page number") + } + } + order = bun.Order(r.URL.Query().Get("order")) + orderBy = r.URL.Query().Get("order_by") + pageOpts := &db.PageOpts{ + Page: pageNum, + PerPage: perPage, + Order: order, + OrderBy: orderBy, + } + return pageOpts, nil +} diff --git a/internal/handlers/seasons.go b/internal/handlers/seasons.go index 0f4fcdd..f572103 100644 --- a/internal/handlers/seasons.go +++ b/internal/handlers/seasons.go @@ -2,9 +2,7 @@ package handlers import ( "context" - "fmt" "net/http" - "strconv" "time" "git.haelnorr.com/h/golib/hws" @@ -27,30 +25,10 @@ func SeasonsPage( return } defer tx.Rollback() - var pageNum, perPage int - var order bun.Order - var orderBy string - if pageStr := r.URL.Query().Get("page"); pageStr != "" { - pageNum, err = strconv.Atoi(pageStr) - if err != nil { - throwBadRequest(s, w, r, "Invalid page number", err) - return - } - } - if perPageStr := r.URL.Query().Get("per_page"); perPageStr != "" { - perPage, err = strconv.Atoi(perPageStr) - if err != nil { - throwBadRequest(s, w, r, "Invalid per_page number", err) - return - } - } - order = bun.Order(r.URL.Query().Get("order")) - orderBy = r.URL.Query().Get("order_by") - pageOpts := &db.PageOpts{ - Page: pageNum, - PerPage: perPage, - Order: order, - OrderBy: orderBy, + pageOpts, err := pageOptsFromQuery(r) + if err != nil { + throwBadRequest(s, w, r, "invalid query", err) + return } seasons, err := db.ListSeasons(ctx, tx, pageOpts) if err != nil { @@ -76,36 +54,11 @@ func SeasonsList( return } - // Extract pagination/sort params from form - var pageNum, perPage int - var order bun.Order - var orderBy string - var err error - - if pageStr := r.FormValue("page"); pageStr != "" { - pageNum, err = strconv.Atoi(pageStr) - if err != nil { - throwBadRequest(s, w, r, "Invalid page number", err) - return - } + pageOpts, err := pageOptsFromForm(r) + if err != nil { + throwBadRequest(s, w, r, "invalid form data", err) + return } - if perPageStr := r.FormValue("per_page"); perPageStr != "" { - perPage, err = strconv.Atoi(perPageStr) - if err != nil { - throwBadRequest(s, w, r, "Invalid per_page number", err) - return - } - } - order = bun.Order(r.FormValue("order")) - orderBy = r.FormValue("order_by") - - pageOpts := &db.PageOpts{ - Page: pageNum, - PerPage: perPage, - Order: order, - OrderBy: orderBy, - } - fmt.Println(pageOpts) // Database query tx, err := conn.BeginTx(ctx, nil) diff --git a/internal/permissions/constants.go b/internal/permissions/constants.go new file mode 100644 index 0000000..8c474a0 --- /dev/null +++ b/internal/permissions/constants.go @@ -0,0 +1,23 @@ +// Package permissions provides constants for RBAC +package permissions + +type Permission string + +func (p Permission) String() string { + return string(p) +} + +const ( + // Wildcard - grants all permissions + Wildcard Permission = "*" + + // Seasons permissions + SeasonsCreate Permission = "seasons.create" + SeasonsUpdate Permission = "seasons.update" + SeasonsDelete Permission = "seasons.delete" + + // Users permissions + UsersUpdate Permission = "users.update" + UsersBan Permission = "users.ban" + UsersManageRoles Permission = "users.manage_roles" +) diff --git a/internal/rbac/cache_middleware.go b/internal/rbac/cache_middleware.go new file mode 100644 index 0000000..817f749 --- /dev/null +++ b/internal/rbac/cache_middleware.go @@ -0,0 +1,95 @@ +package rbac + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/permissions" + "git.haelnorr.com/h/oslstats/internal/roles" + "git.haelnorr.com/h/oslstats/pkg/contexts" + "github.com/pkg/errors" +) + +// LoadPermissionsMiddleware loads user permissions into context after authentication +// MUST run AFTER auth.Authenticate() middleware +func (c *Checker) LoadPermissionsMiddleware() hws.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := db.CurrentUser(r.Context()) + + if user == nil { + // No authenticated user - continue without permissions + next.ServeHTTP(w, r) + return + } + + // Start transaction for loading permissions + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + tx, err := c.conn.BeginTx(ctx, nil) + if err != nil { + // Log but don't block - permission checks will fail gracefully + c.s.LogError(hws.HWSError{ + Message: "Failed to start database transaction", + Error: errors.Wrap(err, "c.conn.BeginTx"), + Level: hws.ErrorERROR, + }) + next.ServeHTTP(w, r) + return + } + defer func() { _ = tx.Rollback() }() + + // Load user's roles_ and permissions + roles_, err := user.GetRoles(ctx, tx) + if err != nil { + c.s.LogError(hws.HWSError{ + Message: "Failed to get user roles", + Error: errors.Wrap(err, "user.GetRoles"), + Level: hws.ErrorERROR, + }) + next.ServeHTTP(w, r) + return + } + + perms, err := user.GetPermissions(ctx, tx) + if err != nil { + c.s.LogError(hws.HWSError{ + Message: "Failed to get user permissions", + Error: errors.Wrap(err, "user.GetPermissions"), + Level: hws.ErrorERROR, + }) + next.ServeHTTP(w, r) + return + } + + _ = tx.Commit() // read only transaction + + // Build permission cache + cache := &contexts.PermissionCache{ + Permissions: make(map[permissions.Permission]bool), + Roles: make(map[roles.Role]bool), + } + + // Check for wildcard permission + hasWildcard := false + for _, perm := range perms { + cache.Permissions[perm.Name] = true + if perm.Name == permissions.Wildcard { + hasWildcard = true + } + } + cache.HasWildcard = hasWildcard + + for _, role := range roles_ { + cache.Roles[role.Name] = true + } + + // Add cache to context (type-safe) + ctx = context.WithValue(ctx, contexts.PermissionCacheKey, cache) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/internal/rbac/checker.go b/internal/rbac/checker.go new file mode 100644 index 0000000..54f9632 --- /dev/null +++ b/internal/rbac/checker.go @@ -0,0 +1,108 @@ +package rbac + +import ( + "context" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/permissions" + "git.haelnorr.com/h/oslstats/internal/roles" + "git.haelnorr.com/h/oslstats/pkg/contexts" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type Checker struct { + conn *bun.DB + s *hws.Server +} + +func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) { + if conn == nil || s == nil { + return nil, errors.New("arguments cannot be nil") + } + return &Checker{conn: conn, s: s}, nil +} + +// UserHasPermission checks if user has a specific permission (uses cache) +func (c *Checker) UserHasPermission(ctx context.Context, user *db.User, permission permissions.Permission) (bool, error) { + if user == nil { + return false, nil + } + + // Try cache first + cache := contexts.Permissions(ctx) + if cache != nil { + if cache.HasWildcard { + return true, nil + } + if has, exists := cache.Permissions[permission]; exists { + return has, nil + } + } + + // Fallback to database + tx, err := c.conn.BeginTx(ctx, nil) + if err != nil { + return false, errors.Wrap(err, "conn.BeginTx") + } + defer func() { _ = tx.Rollback() }() + + has, err := user.HasPermission(ctx, tx, permission) + if err != nil { + return false, err + } + + return has, nil +} + +// UserHasRole checks if user has a specific role (uses cache) +func (c *Checker) UserHasRole(ctx context.Context, user *db.User, role roles.Role) (bool, error) { + if user == nil { + return false, nil + } + + cache := contexts.Permissions(ctx) + if cache != nil { + if has, exists := cache.Roles[role]; exists { + return has, nil + } + } + + // Fallback to database + tx, err := c.conn.BeginTx(ctx, nil) + if err != nil { + return false, errors.Wrap(err, "conn.BeginTx") + } + defer func() { _ = tx.Rollback() }() + + return user.HasRole(ctx, tx, role) +} + +// UserHasAnyPermission checks if user has ANY of the given permissions +func (c *Checker) UserHasAnyPermission(ctx context.Context, user *db.User, permissions ...permissions.Permission) (bool, error) { + for _, perm := range permissions { + has, err := c.UserHasPermission(ctx, user, perm) + if err != nil { + return false, err + } + if has { + return true, nil + } + } + return false, nil +} + +// UserHasAllPermissions checks if user has ALL of the given permissions +func (c *Checker) UserHasAllPermissions(ctx context.Context, user *db.User, permissions ...permissions.Permission) (bool, error) { + for _, perm := range permissions { + has, err := c.UserHasPermission(ctx, user, perm) + if err != nil { + return false, err + } + if !has { + return false, nil + } + } + return true, nil +} diff --git a/internal/rbac/config.go b/internal/rbac/config.go new file mode 100644 index 0000000..b30843a --- /dev/null +++ b/internal/rbac/config.go @@ -0,0 +1,22 @@ +// Package rbac provides Role-Based Access Control functionality +package rbac + +import ( + "errors" + + "git.haelnorr.com/h/golib/env" +) + +type Config struct { + AdminDiscordID string // ENV ADMIN_DISCORD_ID: Discord ID to grant admin role on first login (required) +} + +func ConfigFromEnv() (any, error) { + cfg := &Config{ + AdminDiscordID: env.String("ADMIN_DISCORD_ID", ""), + } + if cfg.AdminDiscordID == "" { + return nil, errors.New("env var not set: ADMIN_DISCORD_ID") + } + return cfg, nil +} diff --git a/internal/rbac/ezconf.go b/internal/rbac/ezconf.go new file mode 100644 index 0000000..d477a1b --- /dev/null +++ b/internal/rbac/ezconf.go @@ -0,0 +1,41 @@ +package rbac + +import ( + "runtime" + "strings" +) + +// EZConfIntegration provides integration with ezconf for automatic configuration +type EZConfIntegration struct { + configFunc func() (any, error) + name string +} + +// PackagePath returns the path to the config package for source parsing +func (e EZConfIntegration) PackagePath() string { + _, filename, _, _ := runtime.Caller(0) + // Return directory of this file + return filename[:len(filename)-len("/ezconf.go")] +} + +// ConfigFunc returns the ConfigFromEnv function for ezconf +func (e EZConfIntegration) ConfigFunc() func() (any, error) { + return func() (any, error) { + return e.configFunc() + } +} + +// Name returns the name to use when registering with ezconf +func (e EZConfIntegration) Name() string { + return strings.ToLower(e.name) +} + +// GroupName returns the display name for grouping environment variables +func (e EZConfIntegration) GroupName() string { + return e.name +} + +// NewEZConfIntegration creates a new EZConf integration helper +func NewEZConfIntegration() EZConfIntegration { + return EZConfIntegration{name: "RBAC", configFunc: ConfigFromEnv} +} diff --git a/internal/rbac/protection_middleware.go b/internal/rbac/protection_middleware.go new file mode 100644 index 0000000..b6eaa88 --- /dev/null +++ b/internal/rbac/protection_middleware.go @@ -0,0 +1,101 @@ +package rbac + +import ( + "net/http" + + "git.haelnorr.com/h/golib/cookies" + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/permissions" + "git.haelnorr.com/h/oslstats/internal/roles" +) + +// RequirePermission creates middleware that requires a specific permission +func (c *Checker) RequirePermission(server *hws.Server, permission permissions.Permission) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := db.CurrentUser(r.Context()) + if user == nil { + // Not logged in - redirect to login with page_from + cookies.SetPageFrom(w, r, r.URL.Path) + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + + has, err := c.UserHasPermission(r.Context(), user, permission) + if err != nil { + // Log error and return 500 + server.ThrowError(w, r, hws.HWSError{ + StatusCode: http.StatusInternalServerError, + Message: "Permission check failed", + Error: err, + Level: hws.ErrorERROR, + RenderErrorPage: true, + }) + return + } + + if !has { + // User lacks permission - return 403 + server.ThrowError(w, r, hws.HWSError{ + StatusCode: http.StatusForbidden, + Message: "You don't have permission to access this resource", + Error: nil, + Level: hws.ErrorDEBUG, + RenderErrorPage: true, + }) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// RequireRole creates middleware that requires a specific role +func (c *Checker) RequireRole(server *hws.Server, role roles.Role) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := db.CurrentUser(r.Context()) + if user == nil { + // Not logged in - redirect to login + cookies.SetPageFrom(w, r, r.URL.Path) + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + + has, err := c.UserHasRole(r.Context(), user, role) + if err != nil { + // Log error and return 500 + hwserr := hws.HWSError{ + StatusCode: http.StatusInternalServerError, + Message: "Role check failed", + Error: err, + Level: hws.ErrorERROR, + RenderErrorPage: true, + } + server.ThrowError(w, r, hwserr) + return + } + + if !has { + // User lacks role - return 403 + server.ThrowError(w, r, hws.HWSError{ + StatusCode: http.StatusForbidden, + Message: "You don't have the required role to access this resource", + Error: nil, + Level: hws.ErrorDEBUG, + RenderErrorPage: true, + }) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// RequireAdmin is a convenience middleware for admin-only routes +func (c *Checker) RequireAdmin(server *hws.Server) func(http.Handler) http.Handler { + return c.RequireRole(server, roles.Admin) +} diff --git a/internal/roles/constants.go b/internal/roles/constants.go new file mode 100644 index 0000000..5bcf8a6 --- /dev/null +++ b/internal/roles/constants.go @@ -0,0 +1,13 @@ +// Package roles provides constants for the RBAC +package roles + +type Role string + +func (r Role) String() string { + return string(r) +} + +const ( + Admin Role = "admin" + User Role = "user" +) diff --git a/internal/view/component/admin/user_list.templ b/internal/view/component/admin/user_list.templ new file mode 100644 index 0000000..bd3c63c --- /dev/null +++ b/internal/view/component/admin/user_list.templ @@ -0,0 +1,6 @@ +package admin + +import "git.haelnorr.com/h/oslstats/internal/db" + +templ UserList(users *db.Users) { +} diff --git a/internal/view/component/nav/navbarright.templ b/internal/view/component/nav/navbarright.templ index fa27c16..58d8e26 100644 --- a/internal/view/component/nav/navbarright.templ +++ b/internal/view/component/nav/navbarright.templ @@ -1,6 +1,10 @@ package nav -import "git.haelnorr.com/h/oslstats/internal/db" +import ( + "context" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/pkg/contexts" +) type ProfileItem struct { name string // Label to display @@ -8,8 +12,8 @@ type ProfileItem struct { } // Return the list of profile links -func getProfileItems() []ProfileItem { - return []ProfileItem{ +func getProfileItems(ctx context.Context) []ProfileItem { + items := []ProfileItem{ { name: "Profile", href: "/profile", @@ -19,12 +23,23 @@ func getProfileItems() []ProfileItem { href: "/account", }, } + + // Add admin link if user has admin role + cache := contexts.Permissions(ctx) + if cache != nil && cache.Roles["admin"] { + items = append(items, ProfileItem{ + name: "Admin Panel", + href: "/admin", + }) + } + + return items } // Returns the right portion of the navbar templ navRight() { {{ user := db.CurrentUser(ctx) }} - {{ items := getProfileItems() }} + {{ items := getProfileItems(ctx) }}
if user != nil { diff --git a/internal/view/layout/admin_dashboard.templ b/internal/view/layout/admin_dashboard.templ new file mode 100644 index 0000000..f0e54a8 --- /dev/null +++ b/internal/view/layout/admin_dashboard.templ @@ -0,0 +1,8 @@ +package layout + +templ AdminDashboard() { + @Global("Admin") +
+ { children... } +
+} diff --git a/internal/view/page/admin_dashboard.templ b/internal/view/page/admin_dashboard.templ new file mode 100644 index 0000000..68189ba --- /dev/null +++ b/internal/view/page/admin_dashboard.templ @@ -0,0 +1,10 @@ +package page + +import "git.haelnorr.com/h/oslstats/internal/view/layout" +import "git.haelnorr.com/h/oslstats/internal/view/component/admin" +import "git.haelnorr.com/h/oslstats/internal/db" + +templ AdminDashboard(users *db.Users) { + @layout.AdminDashboard() + @admin.UserList(users) +} diff --git a/pkg/contexts/keys.go b/pkg/contexts/keys.go index 35cbc63..4ec539f 100644 --- a/pkg/contexts/keys.go +++ b/pkg/contexts/keys.go @@ -1,6 +1,12 @@ +// Package contexts provides utilities for loading and extracting structs from contexts package contexts -import "context" +import ( + "context" + + "git.haelnorr.com/h/oslstats/internal/permissions" + "git.haelnorr.com/h/oslstats/internal/roles" +) type Key string @@ -8,7 +14,10 @@ func (c Key) String() string { return "oslstats context key " + string(c) } -var DevModeKey Key = Key("devmode") +var ( + DevModeKey Key = Key("devmode") + PermissionCacheKey Key = Key("permissions") +) func DevMode(ctx context.Context) DevInfo { devmode, ok := ctx.Value(DevModeKey).(DevInfo) @@ -22,3 +31,18 @@ type DevInfo struct { WebsocketBase string HTMXLog bool } + +// Permissions retrieves the permission cache from context (type-safe) +func Permissions(ctx context.Context) *PermissionCache { + cache, ok := ctx.Value(PermissionCacheKey).(*PermissionCache) + if !ok { + return nil + } + return cache +} + +type PermissionCache struct { + Permissions map[permissions.Permission]bool + Roles map[roles.Role]bool + HasWildcard bool +}