rbac system first stage

This commit is contained in:
2026-02-03 21:37:06 +11:00
parent 24bbc5337b
commit d2b1a252ea
38 changed files with 1966 additions and 114 deletions

View File

@@ -9,9 +9,11 @@ import (
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/store"
)
@@ -58,12 +60,21 @@ func setupHTTPServer(
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
}
err = addRoutes(httpServer, &fs, cfg, bun, auth, store, discordAPI)
// Initialize permissions checker
perms, err := rbac.NewChecker(bun, server)
if err != nil {
return nil, errors.Wrap(err, "rbac.NewChecker")
}
// Initialize audit logger
audit := auditlog.NewLogger(bun)
err = addRoutes(httpServer, &fs, cfg, bun, auth, store, discordAPI, perms, audit)
if err != nil {
return nil, errors.Wrap(err, "addRoutes")
}
err = addMiddleware(httpServer, auth, cfg)
err = addMiddleware(httpServer, auth, cfg, perms)
if err != nil {
return nil, errors.Wrap(err, "addMiddleware")
}

View File

@@ -24,6 +24,21 @@ func main() {
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config"))
os.Exit(1)
}
// Handle utility flags
if flags.EnvDoc || flags.ShowEnv {
if err = loader.PrintEnvVarsStdout(flags.ShowEnv); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to print env doc"))
}
return
}
if flags.GenEnv != "" {
if err = loader.GenerateEnvFile(flags.GenEnv, true); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to generate env file"))
}
return
}
//
// Setup the logger
logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout)
if err != nil {
@@ -31,17 +46,6 @@ func main() {
os.Exit(1)
}
// Handle utility flags
if flags.EnvDoc || flags.ShowEnv {
loader.PrintEnvVarsStdout(flags.ShowEnv)
return
}
if flags.GenEnv != "" {
loader.GenerateEnvFile(flags.GenEnv, true)
return
}
// Handle migration file creation (doesn't need DB connection)
if flags.MigrateCreate != "" {
if err := createMigration(flags.MigrateCreate); err != nil {
@@ -55,24 +59,17 @@ func main() {
flags.MigrateStatus || flags.MigrateDryRun ||
flags.ResetDB {
// Setup database connection
conn, close, err := setupBun(ctx, cfg)
if err != nil {
logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "setupBun"))).Msg("Error setting up database")
}
defer close()
// Route to appropriate command
if flags.MigrateUp {
err = runMigrations(ctx, conn, cfg, "up")
err = runMigrations(ctx, cfg, "up")
} else if flags.MigrateRollback {
err = runMigrations(ctx, conn, cfg, "rollback")
err = runMigrations(ctx, cfg, "rollback")
} else if flags.MigrateStatus {
err = runMigrations(ctx, conn, cfg, "status")
err = runMigrations(ctx, cfg, "status")
} else if flags.MigrateDryRun {
err = runMigrations(ctx, conn, cfg, "dry-run")
err = runMigrations(ctx, cfg, "dry-run")
} else if flags.ResetDB {
err = resetDatabase(ctx, conn)
err = resetDatabase(ctx, cfg)
}
if err != nil {

View File

@@ -9,6 +9,7 @@ import (
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/pkg/contexts"
"github.com/pkg/errors"
@@ -19,10 +20,11 @@ func addMiddleware(
server *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
cfg *config.Config,
perms *rbac.Checker,
) error {
err := server.AddMiddleware(
auth.Authenticate(),
perms.LoadPermissionsMiddleware(),
devMode(cfg),
)
if err != nil {

View File

@@ -20,7 +20,13 @@ import (
)
// runMigrations executes database migrations
func runMigrations(ctx context.Context, conn *bun.DB, cfg *config.Config, command string) error {
func runMigrations(ctx context.Context, cfg *config.Config, command string) error {
conn, close, err := setupBun(ctx, cfg)
if err != nil {
return errors.Wrap(err, "setupBun")
}
defer close()
migrator := migrate.NewMigrator(conn, migrations.Migrations)
// Initialize migration tables
@@ -306,7 +312,7 @@ func init() {
`
// Write file
if err := os.WriteFile(filename, []byte(template), 0644); err != nil {
if err := os.WriteFile(filename, []byte(template), 0o644); err != nil {
return errors.Wrap(err, "write migration file")
}
@@ -319,7 +325,7 @@ func init() {
}
// resetDatabase drops and recreates all tables (destructive)
func resetDatabase(ctx context.Context, conn *bun.DB) error {
func resetDatabase(ctx context.Context, cfg *config.Config) error {
fmt.Println("⚠️ WARNING: This will DELETE ALL DATA in the database!")
fmt.Print("Type 'yes' to continue: ")
@@ -334,6 +340,11 @@ func resetDatabase(ctx context.Context, conn *bun.DB) error {
fmt.Println("❌ Reset cancelled")
return nil
}
conn, close, err := setupBun(ctx, cfg)
if err != nil {
return errors.Wrap(err, "setupBun")
}
defer close()
models := []any{
(*db.User)(nil),

View File

@@ -0,0 +1,272 @@
package migrations
import (
"context"
"time"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP migration
func(ctx context.Context, dbConn *bun.DB) error {
// Create roles table using raw SQL to avoid m2m relationship issues
// Bun tries to resolve relationships when creating tables from models
// TODO: use proper m2m table instead of raw sql
_, err := dbConn.ExecContext(ctx, `
CREATE TABLE roles (
id SERIAL PRIMARY KEY,
name VARCHAR(50) UNIQUE NOT NULL,
display_name VARCHAR(100) NOT NULL,
description TEXT,
is_system BOOLEAN DEFAULT FALSE,
created_at BIGINT NOT NULL,
updated_at BIGINT NOT NULL
)
`)
if err != nil {
return err
}
// Create permissions table
_, err = dbConn.NewCreateTable().
Model((*db.Permission)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create indexes for permissions
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex().
Model((*db.Permission)(nil)).
Index("idx_permissions_resource").
Column("resource").
Exec(ctx)
if err != nil {
return err
}
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex().
Model((*db.Permission)(nil)).
Index("idx_permissions_action").
Column("action").
Exec(ctx)
if err != nil {
return err
}
// Create role_permissions join table (Bun doesn't auto-create m2m tables)
// TODO: use proper m2m table instead of raw sql
_, err = dbConn.ExecContext(ctx, `
CREATE TABLE role_permissions (
id SERIAL PRIMARY KEY,
role_id INTEGER NOT NULL REFERENCES roles(id) ON DELETE CASCADE,
permission_id INTEGER NOT NULL REFERENCES permissions(id) ON DELETE CASCADE,
created_at BIGINT NOT NULL,
UNIQUE(role_id, permission_id)
)
`)
if err != nil {
return err
}
// TODO: why do we need this?
_, err = dbConn.ExecContext(ctx, `
CREATE INDEX idx_role_permissions_role ON role_permissions(role_id)
`)
if err != nil {
return err
}
// TODO: why do we need this?
_, err = dbConn.ExecContext(ctx, `
CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id)
`)
if err != nil {
return err
}
// Create user_roles table
_, err = dbConn.NewCreateTable().
Model((*db.UserRole)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create indexes for user_roles
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex().
Model((*db.UserRole)(nil)).
Index("idx_user_roles_user").
Column("user_id").
Exec(ctx)
if err != nil {
return err
}
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex().
Model((*db.UserRole)(nil)).
Index("idx_user_roles_role").
Column("role_id").
Exec(ctx)
if err != nil {
return err
}
// Create audit_log table
_, err = dbConn.NewCreateTable().
Model((*db.AuditLog)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create indexes for audit_log
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex().
Model((*db.AuditLog)(nil)).
Index("idx_audit_log_user").
Column("user_id").
Exec(ctx)
if err != nil {
return err
}
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex().
Model((*db.AuditLog)(nil)).
Index("idx_audit_log_action").
Column("action").
Exec(ctx)
if err != nil {
return err
}
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex().
Model((*db.AuditLog)(nil)).
Index("idx_audit_log_resource").
Column("resource_type", "resource_id").
Exec(ctx)
if err != nil {
return err
}
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex().
Model((*db.AuditLog)(nil)).
Index("idx_audit_log_created").
Column("created_at").
Exec(ctx)
if err != nil {
return err
}
// Seed system roles
now := time.Now().Unix()
adminRole := &db.Role{
Name: "admin",
DisplayName: "Administrator",
Description: "Full system access with all permissions",
IsSystem: true,
CreatedAt: now, // TODO: this should be defaulted in table
UpdatedAt: now, // TODO: this should be defaulted in table
}
_, err = dbConn.NewInsert().
Model(adminRole).
Exec(ctx)
if err != nil {
return err
}
userRole := &db.Role{
Name: "user",
DisplayName: "User",
Description: "Standard user with basic permissions",
IsSystem: true,
CreatedAt: now, // TODO: this should be defaulted in table
UpdatedAt: now, // TODO: this should be defaulted in table
}
_, err = dbConn.NewInsert().
Model(userRole).
Exec(ctx)
if err != nil {
return err
}
// Seed system permissions
// TODO: timestamps for created should be defaulted in table
permissionsData := []*db.Permission{
{Name: "*", DisplayName: "Wildcard (All Permissions)", Description: "Grants access to all permissions, past, present, and future", Resource: "*", Action: "*", IsSystem: true, CreatedAt: now},
{Name: "seasons.create", DisplayName: "Create Seasons", Description: "Create new seasons", Resource: "seasons", Action: "create", IsSystem: true, CreatedAt: now},
{Name: "seasons.update", DisplayName: "Update Seasons", Description: "Update existing seasons", Resource: "seasons", Action: "update", IsSystem: true, CreatedAt: now},
{Name: "seasons.delete", DisplayName: "Delete Seasons", Description: "Delete seasons", Resource: "seasons", Action: "delete", IsSystem: true, CreatedAt: now},
{Name: "users.update", DisplayName: "Update Users", Description: "Update user information", Resource: "users", Action: "update", IsSystem: true, CreatedAt: now},
{Name: "users.ban", DisplayName: "Ban Users", Description: "Ban users from the system", Resource: "users", Action: "ban", IsSystem: true, CreatedAt: now},
{Name: "users.manage_roles", DisplayName: "Manage User Roles", Description: "Assign and revoke user roles", Resource: "users", Action: "manage_roles", IsSystem: true, CreatedAt: now},
}
_, err = dbConn.NewInsert().
Model(&permissionsData).
Exec(ctx)
if err != nil {
return err
}
// Grant wildcard permission to admin role using Bun
// First, get the IDs
var wildcardPerm db.Permission
err = dbConn.NewSelect().
Model(&wildcardPerm).
Where("name = ?", "*").
Scan(ctx)
if err != nil {
return err
}
// Insert role_permission mapping
// TODO: use proper m2m table, and default now in table settings
_, err = dbConn.ExecContext(ctx, `
INSERT INTO role_permissions (role_id, permission_id, created_at)
VALUES ($1, $2, $3)
`, adminRole.ID, wildcardPerm.ID, now)
if err != nil {
return err
}
return nil
},
// DOWN migration
func(ctx context.Context, dbConn *bun.DB) error {
// Drop tables in reverse order
// Use raw SQL to avoid relationship resolution issues
// TODO: surely we can use proper bun methods?
tables := []string{
"audit_log",
"user_roles",
"role_permissions",
"permissions",
"roles",
}
for _, table := range tables {
_, err := dbConn.ExecContext(ctx, "DROP TABLE IF EXISTS "+table+" CASCADE")
if err != nil {
return err
}
}
return nil
},
)
}

View File

@@ -1,3 +1,4 @@
// Package migrations defines the database migrations to apply when using the migrate tags
package migrations
import (

View File

@@ -8,10 +8,12 @@ import (
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/auditlog"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/store"
)
@@ -23,6 +25,8 @@ func addRoutes(
auth *hwsauth.Authenticator[*db.User, bun.Tx],
store *store.Store,
discordAPI *discord.APIClient,
perms *rbac.Checker,
audit *auditlog.Logger,
) error {
// Create the routes
pageroutes := []hws.Route{
@@ -115,8 +119,24 @@ func addRoutes(
},
}
// Admin routes
adminRoutes := []hws.Route{
{
// TODO: on page load, redirect to /admin/users
Path: "/admin",
Method: hws.MethodGET,
Handler: perms.RequireAdmin(s)(handlers.AdminDashboard(s, conn)),
},
{
Path: "/admin/users",
Method: hws.MethodPOST,
Handler: perms.RequireAdmin(s)(handlers.AdminUsersList(s, conn)),
},
}
routes := append(pageroutes, htmxRoutes...)
routes = append(routes, wsRoutes...)
routes = append(routes, adminRoutes...)
// Register the routes with the server
err := s.AddRoutes(routes...)

4
go.mod
View File

@@ -6,8 +6,8 @@ require (
git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/ezconf v0.1.1
git.haelnorr.com/h/golib/hlog v0.10.4
git.haelnorr.com/h/golib/hws v0.4.4
git.haelnorr.com/h/golib/hwsauth v0.5.4
git.haelnorr.com/h/golib/hws v0.5.0
git.haelnorr.com/h/golib/hwsauth v0.5.5
git.haelnorr.com/h/golib/notify v0.1.0
github.com/a-h/templ v0.3.977
github.com/coder/websocket v1.8.14

8
go.sum
View File

@@ -6,10 +6,10 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI
git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
git.haelnorr.com/h/golib/hws v0.4.4 h1:tV9UjZ4q96UlOdJKsC7b3kDV+bpQYqKVPQuaV1n3U3k=
git.haelnorr.com/h/golib/hws v0.4.4/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
git.haelnorr.com/h/golib/hwsauth v0.5.4 h1:nuaiVpJHHXgKVRPoQSE/v3CJHSkivViK5h3SVhEcbbM=
git.haelnorr.com/h/golib/hwsauth v0.5.4/go.mod h1:eIjRPeGycvxRWERkxCoRVMEEhHuUdiPDvjpzzZOhQ0w=
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
git.haelnorr.com/h/golib/hwsauth v0.5.5 h1:w1qssktq0zYo5cC/xa44h/ZE5G5r7rIsJ4QQWq2Jeoo=
git.haelnorr.com/h/golib/hwsauth v0.5.5/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=

166
internal/auditlog/logger.go Normal file
View File

@@ -0,0 +1,166 @@
// Package auditlog provides a system for logging events that require permissions to the audit log
package auditlog
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Logger struct {
conn *bun.DB
}
func NewLogger(conn *bun.DB) *Logger {
return &Logger{conn: conn}
}
// LogSuccess logs a successful permission-protected action
func (l *Logger) LogSuccess(
ctx context.Context,
tx bun.Tx,
user *db.User,
action string,
resourceType string,
resourceID any, // Can be int, string, or nil
details map[string]any,
r *http.Request,
) error {
return l.log(ctx, tx, user, action, resourceType, resourceID, details, "success", nil, r)
}
// LogError logs a failed action due to an error
func (l *Logger) LogError(
ctx context.Context,
tx bun.Tx,
user *db.User,
action string,
resourceType string,
resourceID any,
err error,
r *http.Request,
) error {
errMsg := err.Error()
return l.log(ctx, tx, user, action, resourceType, resourceID, nil, "error", &errMsg, r)
}
func (l *Logger) log(
ctx context.Context,
tx bun.Tx,
user *db.User,
action string,
resourceType string,
resourceID any,
details map[string]any,
result string,
errorMessage *string,
r *http.Request,
) error {
if user == nil {
return errors.New("user cannot be nil for audit logging")
}
// Convert resourceID to string
var resourceIDStr *string
if resourceID != nil {
idStr := fmt.Sprintf("%v", resourceID)
resourceIDStr = &idStr
}
// Marshal details to JSON
var detailsJSON json.RawMessage
if details != nil {
jsonBytes, err := json.Marshal(details)
if err != nil {
return errors.Wrap(err, "json.Marshal details")
}
detailsJSON = jsonBytes
}
// Extract IP and User-Agent from request
ipAddress := r.RemoteAddr
userAgent := r.UserAgent()
log := &db.AuditLog{
UserID: user.ID,
Action: action,
ResourceType: resourceType,
ResourceID: resourceIDStr,
Details: detailsJSON,
IPAddress: ipAddress,
UserAgent: userAgent,
Result: result,
ErrorMessage: errorMessage,
CreatedAt: time.Now().Unix(),
}
return db.CreateAuditLog(ctx, tx, log)
}
// GetRecentLogs retrieves recent audit logs with pagination
// TODO: change this to user db.PageOpts
func (l *Logger) GetRecentLogs(ctx context.Context, limit, offset int) ([]*db.AuditLog, int, error) {
tx, err := l.conn.BeginTx(ctx, nil)
if err != nil {
return nil, 0, errors.Wrap(err, "conn.BeginTx")
}
defer func() { _ = tx.Rollback() }()
logs, total, err := db.GetAuditLogs(ctx, tx, limit, offset, nil)
if err != nil {
return nil, 0, err
}
_ = tx.Commit() // read only transaction
return logs, total, nil
}
// GetLogsByUser retrieves audit logs for a specific user
// TODO: change this to user db.PageOpts
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, limit, offset int) ([]*db.AuditLog, int, error) {
tx, err := l.conn.BeginTx(ctx, nil)
if err != nil {
return nil, 0, errors.Wrap(err, "conn.BeginTx")
}
defer func() { _ = tx.Rollback() }()
logs, total, err := db.GetAuditLogsByUser(ctx, tx, userID, limit, offset)
if err != nil {
return nil, 0, err
}
_ = tx.Commit() // read only transaction
return logs, total, nil
}
// CleanupOldLogs deletes audit logs older than the specified number of days
func (l *Logger) CleanupOldLogs(ctx context.Context, daysToKeep int) (int, error) {
if daysToKeep <= 0 {
return 0, errors.New("daysToKeep must be positive")
}
cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix()
tx, err := l.conn.BeginTx(ctx, nil)
if err != nil {
return 0, errors.Wrap(err, "conn.BeginTx")
}
defer func() { _ = tx.Rollback() }()
count, err := db.CleanupOldAuditLogs(ctx, tx, cutoffTime)
if err != nil {
return 0, err
}
err = tx.Commit()
if err != nil {
return 0, errors.Wrap(err, "tx.Commit")
}
return count, nil
}

View File

@@ -1,3 +1,4 @@
// Package config provides the environment based configuration for the program
package config
import (
@@ -7,6 +8,7 @@ import (
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/pkg/oauth"
"github.com/joho/godotenv"
"github.com/pkg/errors"
@@ -19,10 +21,11 @@ type Config struct {
HLOG *hlog.Config
Discord *discord.Config
OAuth *oauth.Config
RBAC *rbac.Config
Flags *Flags
}
// Load the application configuration and get a pointer to the Config object
// GetConfig loads the application configuration and returns a pointer to the Config object
// If doconly is specified, only the loader will be returned
func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
err := godotenv.Load(flags.EnvFile)
@@ -31,14 +34,18 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
}
loader := ezconf.New()
loader.RegisterIntegrations(
err = loader.RegisterIntegrations(
hlog.NewEZConfIntegration(),
hws.NewEZConfIntegration(),
hwsauth.NewEZConfIntegration(),
db.NewEZConfIntegration(),
discord.NewEZConfIntegration(),
oauth.NewEZConfIntegration(),
rbac.NewEZConfIntegration(),
)
if err != nil {
return nil, nil, errors.Wrap(err, "loader.RegisterIntegrations")
}
if err := loader.ParseEnvVars(); err != nil {
return nil, nil, errors.Wrap(err, "loader.ParseEnvVars")
}
@@ -81,6 +88,11 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
return nil, nil, errors.New("OAuth Config not loaded")
}
rbaccfg, ok := loader.GetConfig("rbac")
if !ok {
return nil, nil, errors.New("RBAC Config not loaded")
}
config := &Config{
DB: dbcfg.(*db.Config),
HWS: hwscfg.(*hws.Config),
@@ -88,6 +100,7 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
HLOG: hlogcfg.(*hlog.Config),
Discord: discordcfg.(*discord.Config),
OAuth: oauthcfg.(*oauth.Config),
RBAC: rbaccfg.(*rbac.Config),
Flags: flags,
}

143
internal/db/auditlog.go Normal file
View File

@@ -0,0 +1,143 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type AuditLog struct {
bun.BaseModel `bun:"table:audit_log,alias:al"`
ID int `bun:"id,pk,autoincrement"`
UserID int `bun:"user_id,notnull"`
Action string `bun:"action,notnull"`
ResourceType string `bun:"resource_type,notnull"`
ResourceID *string `bun:"resource_id"`
Details json.RawMessage `bun:"details,type:jsonb"`
IPAddress string `bun:"ip_address"`
UserAgent string `bun:"user_agent"`
Result string `bun:"result,notnull"` // success, denied, error
ErrorMessage *string `bun:"error_message"`
CreatedAt int64 `bun:"created_at,notnull"`
// Relations
User *User `bun:"rel:belongs-to,join:user_id=id"`
}
// TODO: add AuditLogs to match list style with PageOpts
// CreateAuditLog creates a new audit log entry
func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
if log == nil {
return errors.New("log cannot be nil")
}
_, err := tx.NewInsert().
Model(log).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
}
return nil
}
type AuditLogFilters struct {
UserID *int
Action *string
ResourceType *string
Result *string
}
// GetAuditLogs retrieves audit logs with optional filters and pagination
// TODO: change this to use db.PageOpts
func GetAuditLogs(ctx context.Context, tx bun.Tx, limit, offset int, filters *AuditLogFilters) ([]*AuditLog, int, error) {
query := tx.NewSelect().
Model((*AuditLog)(nil)).
Relation("User").
Order("created_at DESC")
// Apply filters if provided
if filters != nil {
if filters.UserID != nil {
query = query.Where("al.user_id = ?", *filters.UserID)
}
if filters.Action != nil {
query = query.Where("al.action = ?", *filters.Action)
}
if filters.ResourceType != nil {
query = query.Where("al.resource_type = ?", *filters.ResourceType)
}
if filters.Result != nil {
query = query.Where("al.result = ?", *filters.Result)
}
}
// Get total count
total, err := query.Count(ctx)
if err != nil {
return nil, 0, errors.Wrap(err, "query.Count")
}
// Get paginated results
var logs []*AuditLog
err = query.
Limit(limit).
Offset(offset).
Scan(ctx, &logs)
if err != nil && err != sql.ErrNoRows {
return nil, 0, errors.Wrap(err, "query.Scan")
}
return logs, total, nil
}
// GetAuditLogsByUser retrieves audit logs for a specific user
// TODO: change this to use db.PageOpts
func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, limit, offset int) ([]*AuditLog, int, error) {
if userID <= 0 {
return nil, 0, errors.New("userID must be positive")
}
filters := &AuditLogFilters{
UserID: &userID,
}
return GetAuditLogs(ctx, tx, limit, offset, filters)
}
// GetAuditLogsByAction retrieves audit logs for a specific action
// TODO: change this to use db.PageOpts
func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, limit, offset int) ([]*AuditLog, int, error) {
if action == "" {
return nil, 0, errors.New("action cannot be empty")
}
filters := &AuditLogFilters{
Action: &action,
}
return GetAuditLogs(ctx, tx, limit, offset, filters)
}
// CleanupOldAuditLogs deletes audit logs older than the specified timestamp
func CleanupOldAuditLogs(ctx context.Context, tx bun.Tx, olderThan int64) (int, error) {
result, err := tx.NewDelete().
Model((*AuditLog)(nil)).
Where("created_at < ?", olderThan).
Exec(ctx)
if err != nil {
return 0, errors.Wrap(err, "tx.NewDelete")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, errors.Wrap(err, "result.RowsAffected")
}
return int(rowsAffected), nil
}

View File

@@ -22,14 +22,14 @@ type DiscordToken struct {
// UpdateDiscordToken adds the provided discord token to the database.
// If the user already has a token stored, it will replace that token instead.
func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error {
func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error {
if token == nil {
return errors.New("token cannot be nil")
}
expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix()
discordToken := &DiscordToken{
DiscordID: user.DiscordID,
DiscordID: u.DiscordID,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
ExpiresAt: expiresAt,
@@ -44,7 +44,6 @@ func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *disc
Set("refresh_token = EXCLUDED.refresh_token").
Set("expires_at = EXCLUDED.expires_at").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
}
@@ -53,14 +52,14 @@ func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *disc
// DeleteDiscordTokens deletes a users discord OAuth tokens from the database.
// It returns the DiscordToken so that it can be revoked via the discord API
func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token, err := user.GetDiscordToken(ctx, tx)
func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token, err := u.GetDiscordToken(ctx, tx)
if err != nil {
return nil, errors.Wrap(err, "user.GetDiscordToken")
}
_, err = tx.NewDelete().
Model((*DiscordToken)(nil)).
Where("discord_id = ?", user.DiscordID).
Where("discord_id = ?", u.DiscordID).
Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewDelete")
@@ -69,11 +68,11 @@ func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordT
}
// GetDiscordToken retrieves the users discord token from the database
func (user *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token := new(DiscordToken)
err := tx.NewSelect().
Model(token).
Where("discord_id = ?", user.DiscordID).
Where("discord_id = ?", u.DiscordID).
Limit(1).
Scan(ctx)
if err != nil {

View File

@@ -48,7 +48,7 @@ func (p *PageOpts) GetPageRange(total int, maxButtons int) []int {
// If total pages is less than max buttons, show all pages
if totalPages <= maxButtons {
pages := make([]int, totalPages)
for i := 0; i < totalPages; i++ {
for i := range totalPages {
pages[i] = i + 1
}
return pages

154
internal/db/permission.go Normal file
View File

@@ -0,0 +1,154 @@
package db
import (
"context"
"database/sql"
"git.haelnorr.com/h/oslstats/internal/permissions"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Permission struct {
bun.BaseModel `bun:"table:permissions,alias:p"`
ID int `bun:"id,pk,autoincrement"`
Name permissions.Permission `bun:"name,unique,notnull"`
DisplayName string `bun:"display_name,notnull"`
Description string `bun:"description"`
Resource string `bun:"resource,notnull"`
Action string `bun:"action,notnull"`
IsSystem bool `bun:"is_system,default:false"`
CreatedAt int64 `bun:"created_at,notnull"`
}
// GetPermissionByName queries the database for a permission matching the given name
// Returns nil, nil if no permission is found
func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) {
if name == "" {
return nil, errors.New("name cannot be empty")
}
perm := new(Permission)
err := tx.NewSelect().
Model(perm).
Where("name = ?", name).
Limit(1).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perm, nil
}
// GetPermissionByID queries the database for a permission matching the given ID
// Returns nil, nil if no permission is found
func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, error) {
if id <= 0 {
return nil, errors.New("id must be positive")
}
perm := new(Permission)
err := tx.NewSelect().
Model(perm).
Where("id = ?", id).
Limit(1).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perm, nil
}
// GetPermissionsByResource queries for all permissions for a given resource
func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) ([]*Permission, error) {
if resource == "" {
return nil, errors.New("resource cannot be empty")
}
var perms []*Permission
err := tx.NewSelect().
Model(&perms).
Where("resource = ?", resource).
Order("action ASC").
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perms, nil
}
// GetPermissionsByIDs queries for permissions matching the given IDs
func GetPermissionsByIDs(ctx context.Context, tx bun.Tx, ids []int) ([]*Permission, error) {
if len(ids) == 0 {
return []*Permission{}, nil
}
var perms []*Permission
err := tx.NewSelect().
Model(&perms).
Where("id IN (?)", bun.In(ids)).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perms, nil
}
// ListAllPermissions returns all permissions
func ListAllPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
var perms []*Permission
err := tx.NewSelect().
Model(&perms).
Order("resource ASC", "action ASC").
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perms, nil
}
// CreatePermission creates a new permission
func CreatePermission(ctx context.Context, tx bun.Tx, perm *Permission) error {
if perm == nil {
return errors.New("permission cannot be nil")
}
_, err := tx.NewInsert().
Model(perm).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
}
return nil
}
// DeletePermission deletes a permission (checks IsSystem protection)
func DeletePermission(ctx context.Context, tx bun.Tx, id int) error {
if id <= 0 {
return errors.New("id must be positive")
}
// Check if permission is system permission
perm, err := GetPermissionByID(ctx, tx, id)
if err != nil {
return errors.Wrap(err, "GetPermissionByID")
}
if perm == nil {
return errors.New("permission not found")
}
if perm.IsSystem {
return errors.New("cannot delete system permission")
}
_, err = tx.NewDelete().
Model((*Permission)(nil)).
Where("id = ?", id).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewDelete")
}
return nil
}

204
internal/db/role.go Normal file
View File

@@ -0,0 +1,204 @@
package db
import (
"context"
"database/sql"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Role struct {
bun.BaseModel `bun:"table:roles,alias:r"`
ID int `bun:"id,pk,autoincrement"`
Name roles.Role `bun:"name,unique,notnull"`
DisplayName string `bun:"display_name,notnull"`
Description string `bun:"description"`
IsSystem bool `bun:"is_system,default:false"`
CreatedAt int64 `bun:"created_at,notnull"`
UpdatedAt int64 `bun:"updated_at,notnull"`
// Relations (loaded on demand)
Permissions []*Permission `bun:"m2m:role_permissions,join:Role=Permission"`
}
// GetRoleByName queries the database for a role matching the given name
// Returns nil, nil if no role is found
func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) {
if name == "" {
return nil, errors.New("name cannot be empty")
}
role := new(Role)
err := tx.NewSelect().
Model(role).
Where("name = ?", name).
Limit(1).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return role, nil
}
// GetRoleByID queries the database for a role matching the given ID
// Returns nil, nil if no role is found
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
if id <= 0 {
return nil, errors.New("id must be positive")
}
role := new(Role)
err := tx.NewSelect().
Model(role).
Where("id = ?", id).
Limit(1).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return role, nil
}
// GetRoleWithPermissions loads a role and all its permissions
func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
if id <= 0 {
return nil, errors.New("id must be positive")
}
role := new(Role)
err := tx.NewSelect().
Model(role).
Where("id = ?", id).
Relation("Permissions").
Limit(1).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return role, nil
}
// ListAllRoles returns all roles
func ListAllRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
var roles []*Role
err := tx.NewSelect().
Model(&roles).
Order("name ASC").
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return roles, nil
}
// CreateRole creates a new role
func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
if role == nil {
return errors.New("role cannot be nil")
}
_, err := tx.NewInsert().
Model(role).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
}
return nil
}
// UpdateRole updates an existing role
func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error {
if role == nil {
return errors.New("role cannot be nil")
}
if role.ID <= 0 {
return errors.New("role id must be positive")
}
_, err := tx.NewUpdate().
Model(role).
WherePK().
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewUpdate")
}
return nil
}
// DeleteRole deletes a role (checks IsSystem protection)
func DeleteRole(ctx context.Context, tx bun.Tx, id int) error {
if id <= 0 {
return errors.New("id must be positive")
}
// Check if role is system role
role, err := GetRoleByID(ctx, tx, id)
if err != nil {
return errors.Wrap(err, "GetRoleByID")
}
if role == nil {
return errors.New("role not found")
}
if role.IsSystem {
return errors.New("cannot delete system role")
}
_, err = tx.NewDelete().
Model((*Role)(nil)).
Where("id = ?", id).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewDelete")
}
return nil
}
// AddPermissionToRole grants a permission to a role
func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int, createdAt int64) error {
if roleID <= 0 {
return errors.New("roleID must be positive")
}
if permissionID <= 0 {
return errors.New("permissionID must be positive")
}
// TODO: use proper m2m table
// also make createdAt automatic in table so not required as input here
_, err := tx.ExecContext(ctx, `
INSERT INTO role_permissions (role_id, permission_id, created_at)
VALUES ($1, $2, $3)
ON CONFLICT (role_id, permission_id) DO NOTHING
`, roleID, permissionID, createdAt)
if err != nil {
return errors.Wrap(err, "tx.ExecContext")
}
return nil
}
// RemovePermissionFromRole revokes a permission from a role
func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error {
if roleID <= 0 {
return errors.New("roleID must be positive")
}
if permissionID <= 0 {
return errors.New("permissionID must be positive")
}
// TODO: use proper m2m table
_, err := tx.ExecContext(ctx, `
DELETE FROM role_permissions
WHERE role_id = $1 AND permission_id = $2
`, roleID, permissionID)
if err != nil {
return errors.Wrap(err, "tx.ExecContext")
}
return nil
}

View File

@@ -78,7 +78,7 @@ func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonLis
total, err := tx.NewSelect().
Model(&seasons).
Count(ctx)
if err != nil && err != sql.ErrNoRows {
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
sl := &SeasonList{

View File

@@ -2,10 +2,13 @@ package db
import (
"context"
"database/sql"
"fmt"
"time"
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/bwmarrin/discordgo"
"github.com/pkg/errors"
"github.com/uptrace/bun"
@@ -22,8 +25,14 @@ type User struct {
DiscordID string `bun:"discord_id,unique"`
}
func (user *User) GetID() int {
return user.ID
type Users struct {
Users []*User
Total int
PageOpts PageOpts
}
func (u *User) GetID() int {
return u.ID
}
// CreateUser creates a new user with the given username and password
@@ -114,3 +123,110 @@ func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, er
}
return count == 0, nil
}
// GetRoles loads and returns all roles for this user
func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
if u == nil {
return nil, errors.New("user cannot be nil")
}
return GetUserRoles(ctx, tx, u.ID)
}
// GetPermissions loads and returns all permissions for this user (via roles)
func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
if u == nil {
return nil, errors.New("user cannot be nil")
}
// TODO: use proper m2m tables and relations instead of join
var permissions []*Permission
err := tx.NewSelect().
Model(&permissions).
Join("JOIN role_permissions AS rp ON rp.permission_id = p.id").
Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id").
Where("ur.user_id = ?", u.ID).
Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return permissions, nil
}
// HasPermission checks if user has a specific permission (including wildcard check)
func (u *User) HasPermission(ctx context.Context, tx bun.Tx, permissionName permissions.Permission) (bool, error) {
if u == nil {
return false, errors.New("user cannot be nil")
}
if permissionName == "" {
return false, errors.New("permissionName cannot be empty")
}
perms, err := u.GetPermissions(ctx, tx)
if err != nil {
return false, err
}
for _, p := range perms {
if p.Name == permissionName || p.Name == permissions.Wildcard {
return true, nil
}
}
return false, nil
}
// HasRole checks if user has a specific role
func (u *User) HasRole(ctx context.Context, tx bun.Tx, roleName roles.Role) (bool, error) {
if u == nil {
return false, errors.New("user cannot be nil")
}
return HasRole(ctx, tx, u.ID, roleName)
}
// IsAdmin is a convenience method to check if user has admin role
func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
if u == nil {
return false, errors.New("user cannot be nil")
}
return u.HasRole(ctx, tx, "admin")
}
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) {
if pageOpts == nil {
pageOpts = &PageOpts{}
}
if pageOpts.Page == 0 {
pageOpts.Page = 1
}
if pageOpts.PerPage == 0 {
pageOpts.PerPage = 50
}
if pageOpts.Order == "" {
pageOpts.Order = bun.OrderAsc
}
if pageOpts.OrderBy == "" {
pageOpts.OrderBy = "id"
}
users := []*User{}
err := tx.NewSelect().
Model(users).
OrderBy(pageOpts.OrderBy, pageOpts.Order).
Limit(pageOpts.PerPage).
Offset(pageOpts.PerPage * (pageOpts.Page - 1)).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
total, err := tx.NewSelect().
Model(users).
Count(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
list := &Users{
Users: users,
Total: total,
PageOpts: *pageOpts,
}
return list, nil
}

114
internal/db/userrole.go Normal file
View File

@@ -0,0 +1,114 @@
package db
import (
"context"
"time"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type UserRole struct {
bun.BaseModel `bun:"table:user_roles,alias:ur"`
ID int `bun:"id,pk,autoincrement"`
UserID int `bun:"user_id,notnull"`
RoleID int `bun:"role_id,notnull"`
GrantedBy *int `bun:"granted_by"`
GrantedAt int64 `bun:"granted_at,notnull"` // TODO: default now
ExpiresAt *int64 `bun:"expires_at"`
// Relations
User *User `bun:"rel:belongs-to,join:user_id=id"`
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
}
// GetUserRoles loads all roles for a given user
func GetUserRoles(ctx context.Context, tx bun.Tx, userID int) ([]*Role, error) {
if userID <= 0 {
return nil, errors.New("userID must be positive")
}
var roles []*Role
err := tx.NewSelect().
Model(&roles).
// TODO: why are we joining? can we do relation?
Join("JOIN user_roles AS ur ON ur.role_id = r.id").
Where("ur.user_id = ?", userID).
Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()).
Scan(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return roles, nil
}
// AssignRole grants a role to a user
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, grantedBy *int) error {
if userID <= 0 {
return errors.New("userID must be positive")
}
if roleID <= 0 {
return errors.New("roleID must be positive")
}
now := time.Now().Unix()
// TODO: use proper m2m table instead of raw SQL
_, err := tx.ExecContext(ctx, `
INSERT INTO user_roles (user_id, role_id, granted_by, granted_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, role_id) DO NOTHING
`, userID, roleID, grantedBy, now)
if err != nil {
return errors.Wrap(err, "tx.ExecContext")
}
return nil
}
// RevokeRole removes a role from a user
func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
if userID <= 0 {
return errors.New("userID must be positive")
}
if roleID <= 0 {
return errors.New("roleID must be positive")
}
// TODO: use proper m2m table instead of raw sql
_, err := tx.ExecContext(ctx, `
DELETE FROM user_roles
WHERE user_id = $1 AND role_id = $2
`, userID, roleID)
if err != nil {
return errors.Wrap(err, "tx.ExecContext")
}
return nil
}
// HasRole checks if a user has a specific role
func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (bool, error) {
if userID <= 0 {
return false, errors.New("userID must be positive")
}
if roleName == "" {
return false, errors.New("roleName cannot be empty")
}
// TODO: use proper m2m table instead of TableExpr and Join?
count, err := tx.NewSelect().
TableExpr("user_roles AS ur").
Join("JOIN roles AS r ON r.id = ur.role_id").
Where("ur.user_id = ?", userID).
Where("r.name = ?", roleName).
Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()).
Count(ctx)
if err != nil {
return false, errors.Wrap(err, "tx.NewSelect")
}
return count > 0, nil
}

View File

@@ -0,0 +1,36 @@
package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/view/page"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
return
}
defer func() { _ = tx.Rollback() }()
users, err := db.GetUsers(ctx, tx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetUsers"))
return
}
_ = tx.Commit()
renderSafely(page.AdminDashboard(users), s, r, w)
})
}

View File

@@ -0,0 +1,44 @@
package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/view/component/admin"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// AdminUsersList shows all users
func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "DB Transaction failed", errors.Wrap(err, "conn.BeginTx"))
return
}
defer func() { _ = tx.Rollback() }()
// Get all users
pageOpts, err := pageOptsFromForm(r)
if err != nil {
throwBadRequest(s, w, r, "invalid form data", err)
return
}
users, err := db.GetUsers(ctx, tx, pageOpts)
if err != nil {
throwInternalServiceError(s, w, r, "Failed to load users", errors.Wrap(err, "db.GetUsers"))
return
}
_ = tx.Commit()
renderSafely(admin.UserList(users), s, r, w)
})
}

View File

@@ -0,0 +1,56 @@
package handlers
import (
"context"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// shouldGrantAdmin checks if user's Discord ID is in admin list
func shouldGrantAdmin(user *db.User, cfg *rbac.Config) bool {
if cfg == nil || user == nil {
return false
}
if user.DiscordID == cfg.AdminDiscordID {
return true
}
return false
}
// ensureUserHasAdminRole grants admin role if not already granted
func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error {
if user == nil {
return errors.New("user cannot be nil")
}
// Check if user already has admin role
hasAdmin, err := user.HasRole(ctx, tx, roles.Admin)
if err != nil {
return errors.Wrap(err, "user.HasRole")
}
if hasAdmin {
return nil // Already admin
}
// Get admin role
adminRole, err := db.GetRoleByName(ctx, tx, roles.Admin)
if err != nil {
return errors.Wrap(err, "db.GetRoleByName")
}
if adminRole == nil {
return errors.New("admin role not found in database")
}
// Grant admin role (nil grantedBy = system granted)
err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil)
if err != nil {
return errors.Wrap(err, "db.AssignRole")
}
return nil
}

View File

@@ -193,6 +193,15 @@ func login(
if err != nil {
return nil, errors.Wrap(err, "user.UpdateDiscordToken")
}
// Check if user should be granted admin role (environment-based)
if shouldGrantAdmin(user, cfg.RBAC) {
err := ensureUserHasAdminRole(ctx, tx, user)
if err != nil {
return nil, errors.Wrap(err, "ensureUserHasAdminRole")
}
}
err := auth.Login(w, r, user, true)
if err != nil {
return nil, errors.Wrap(err, "auth.Login")

View File

@@ -20,16 +20,13 @@ func throwError(
err error,
level hws.ErrorLevel,
) {
err = s.ThrowError(w, r, hws.HWSError{
s.ThrowError(w, r, hws.HWSError{
StatusCode: statusCode,
Message: msg,
Error: err,
Level: level,
RenderErrorPage: true, // throw* family always renders error pages
})
if err != nil {
s.ThrowFatal(w, err)
}
}
// throwInternalServiceError handles 500 errors (server failures)

View File

@@ -0,0 +1,68 @@
package handlers
import (
"net/http"
"strconv"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func pageOptsFromForm(r *http.Request) (*db.PageOpts, error) {
var pageNum, perPage int
var order bun.Order
var orderBy string
var err error
if pageStr := r.FormValue("page"); pageStr != "" {
pageNum, err = strconv.Atoi(pageStr)
if err != nil {
return nil, errors.Wrap(err, "invalid page number")
}
}
if perPageStr := r.FormValue("per_page"); perPageStr != "" {
perPage, err = strconv.Atoi(perPageStr)
if err != nil {
return nil, errors.Wrap(err, "invalid per_page number")
}
}
order = bun.Order(r.FormValue("order"))
orderBy = r.FormValue("order_by")
pageOpts := &db.PageOpts{
Page: pageNum,
PerPage: perPage,
Order: order,
OrderBy: orderBy,
}
return pageOpts, nil
}
func pageOptsFromQuery(r *http.Request) (*db.PageOpts, error) {
var pageNum, perPage int
var order bun.Order
var orderBy string
var err error
if pageStr := r.URL.Query().Get("page"); pageStr != "" {
pageNum, err = strconv.Atoi(pageStr)
if err != nil {
return nil, errors.Wrap(err, "invalid page number")
}
}
if perPageStr := r.URL.Query().Get("per_page"); perPageStr != "" {
perPage, err = strconv.Atoi(perPageStr)
if err != nil {
return nil, errors.Wrap(err, "invalid per_page number")
}
}
order = bun.Order(r.URL.Query().Get("order"))
orderBy = r.URL.Query().Get("order_by")
pageOpts := &db.PageOpts{
Page: pageNum,
PerPage: perPage,
Order: order,
OrderBy: orderBy,
}
return pageOpts, nil
}

View File

@@ -2,9 +2,7 @@ package handlers
import (
"context"
"fmt"
"net/http"
"strconv"
"time"
"git.haelnorr.com/h/golib/hws"
@@ -27,30 +25,10 @@ func SeasonsPage(
return
}
defer tx.Rollback()
var pageNum, perPage int
var order bun.Order
var orderBy string
if pageStr := r.URL.Query().Get("page"); pageStr != "" {
pageNum, err = strconv.Atoi(pageStr)
if err != nil {
throwBadRequest(s, w, r, "Invalid page number", err)
return
}
}
if perPageStr := r.URL.Query().Get("per_page"); perPageStr != "" {
perPage, err = strconv.Atoi(perPageStr)
if err != nil {
throwBadRequest(s, w, r, "Invalid per_page number", err)
return
}
}
order = bun.Order(r.URL.Query().Get("order"))
orderBy = r.URL.Query().Get("order_by")
pageOpts := &db.PageOpts{
Page: pageNum,
PerPage: perPage,
Order: order,
OrderBy: orderBy,
pageOpts, err := pageOptsFromQuery(r)
if err != nil {
throwBadRequest(s, w, r, "invalid query", err)
return
}
seasons, err := db.ListSeasons(ctx, tx, pageOpts)
if err != nil {
@@ -76,36 +54,11 @@ func SeasonsList(
return
}
// Extract pagination/sort params from form
var pageNum, perPage int
var order bun.Order
var orderBy string
var err error
if pageStr := r.FormValue("page"); pageStr != "" {
pageNum, err = strconv.Atoi(pageStr)
if err != nil {
throwBadRequest(s, w, r, "Invalid page number", err)
return
}
pageOpts, err := pageOptsFromForm(r)
if err != nil {
throwBadRequest(s, w, r, "invalid form data", err)
return
}
if perPageStr := r.FormValue("per_page"); perPageStr != "" {
perPage, err = strconv.Atoi(perPageStr)
if err != nil {
throwBadRequest(s, w, r, "Invalid per_page number", err)
return
}
}
order = bun.Order(r.FormValue("order"))
orderBy = r.FormValue("order_by")
pageOpts := &db.PageOpts{
Page: pageNum,
PerPage: perPage,
Order: order,
OrderBy: orderBy,
}
fmt.Println(pageOpts)
// Database query
tx, err := conn.BeginTx(ctx, nil)

View File

@@ -0,0 +1,23 @@
// Package permissions provides constants for RBAC
package permissions
type Permission string
func (p Permission) String() string {
return string(p)
}
const (
// Wildcard - grants all permissions
Wildcard Permission = "*"
// Seasons permissions
SeasonsCreate Permission = "seasons.create"
SeasonsUpdate Permission = "seasons.update"
SeasonsDelete Permission = "seasons.delete"
// Users permissions
UsersUpdate Permission = "users.update"
UsersBan Permission = "users.ban"
UsersManageRoles Permission = "users.manage_roles"
)

View File

@@ -0,0 +1,95 @@
package rbac
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
"git.haelnorr.com/h/oslstats/pkg/contexts"
"github.com/pkg/errors"
)
// LoadPermissionsMiddleware loads user permissions into context after authentication
// MUST run AFTER auth.Authenticate() middleware
func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := db.CurrentUser(r.Context())
if user == nil {
// No authenticated user - continue without permissions
next.ServeHTTP(w, r)
return
}
// Start transaction for loading permissions
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := c.conn.BeginTx(ctx, nil)
if err != nil {
// Log but don't block - permission checks will fail gracefully
c.s.LogError(hws.HWSError{
Message: "Failed to start database transaction",
Error: errors.Wrap(err, "c.conn.BeginTx"),
Level: hws.ErrorERROR,
})
next.ServeHTTP(w, r)
return
}
defer func() { _ = tx.Rollback() }()
// Load user's roles_ and permissions
roles_, err := user.GetRoles(ctx, tx)
if err != nil {
c.s.LogError(hws.HWSError{
Message: "Failed to get user roles",
Error: errors.Wrap(err, "user.GetRoles"),
Level: hws.ErrorERROR,
})
next.ServeHTTP(w, r)
return
}
perms, err := user.GetPermissions(ctx, tx)
if err != nil {
c.s.LogError(hws.HWSError{
Message: "Failed to get user permissions",
Error: errors.Wrap(err, "user.GetPermissions"),
Level: hws.ErrorERROR,
})
next.ServeHTTP(w, r)
return
}
_ = tx.Commit() // read only transaction
// Build permission cache
cache := &contexts.PermissionCache{
Permissions: make(map[permissions.Permission]bool),
Roles: make(map[roles.Role]bool),
}
// Check for wildcard permission
hasWildcard := false
for _, perm := range perms {
cache.Permissions[perm.Name] = true
if perm.Name == permissions.Wildcard {
hasWildcard = true
}
}
cache.HasWildcard = hasWildcard
for _, role := range roles_ {
cache.Roles[role.Name] = true
}
// Add cache to context (type-safe)
ctx = context.WithValue(ctx, contexts.PermissionCacheKey, cache)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

108
internal/rbac/checker.go Normal file
View File

@@ -0,0 +1,108 @@
package rbac
import (
"context"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
"git.haelnorr.com/h/oslstats/pkg/contexts"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Checker struct {
conn *bun.DB
s *hws.Server
}
func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) {
if conn == nil || s == nil {
return nil, errors.New("arguments cannot be nil")
}
return &Checker{conn: conn, s: s}, nil
}
// UserHasPermission checks if user has a specific permission (uses cache)
func (c *Checker) UserHasPermission(ctx context.Context, user *db.User, permission permissions.Permission) (bool, error) {
if user == nil {
return false, nil
}
// Try cache first
cache := contexts.Permissions(ctx)
if cache != nil {
if cache.HasWildcard {
return true, nil
}
if has, exists := cache.Permissions[permission]; exists {
return has, nil
}
}
// Fallback to database
tx, err := c.conn.BeginTx(ctx, nil)
if err != nil {
return false, errors.Wrap(err, "conn.BeginTx")
}
defer func() { _ = tx.Rollback() }()
has, err := user.HasPermission(ctx, tx, permission)
if err != nil {
return false, err
}
return has, nil
}
// UserHasRole checks if user has a specific role (uses cache)
func (c *Checker) UserHasRole(ctx context.Context, user *db.User, role roles.Role) (bool, error) {
if user == nil {
return false, nil
}
cache := contexts.Permissions(ctx)
if cache != nil {
if has, exists := cache.Roles[role]; exists {
return has, nil
}
}
// Fallback to database
tx, err := c.conn.BeginTx(ctx, nil)
if err != nil {
return false, errors.Wrap(err, "conn.BeginTx")
}
defer func() { _ = tx.Rollback() }()
return user.HasRole(ctx, tx, role)
}
// UserHasAnyPermission checks if user has ANY of the given permissions
func (c *Checker) UserHasAnyPermission(ctx context.Context, user *db.User, permissions ...permissions.Permission) (bool, error) {
for _, perm := range permissions {
has, err := c.UserHasPermission(ctx, user, perm)
if err != nil {
return false, err
}
if has {
return true, nil
}
}
return false, nil
}
// UserHasAllPermissions checks if user has ALL of the given permissions
func (c *Checker) UserHasAllPermissions(ctx context.Context, user *db.User, permissions ...permissions.Permission) (bool, error) {
for _, perm := range permissions {
has, err := c.UserHasPermission(ctx, user, perm)
if err != nil {
return false, err
}
if !has {
return false, nil
}
}
return true, nil
}

22
internal/rbac/config.go Normal file
View File

@@ -0,0 +1,22 @@
// Package rbac provides Role-Based Access Control functionality
package rbac
import (
"errors"
"git.haelnorr.com/h/golib/env"
)
type Config struct {
AdminDiscordID string // ENV ADMIN_DISCORD_ID: Discord ID to grant admin role on first login (required)
}
func ConfigFromEnv() (any, error) {
cfg := &Config{
AdminDiscordID: env.String("ADMIN_DISCORD_ID", ""),
}
if cfg.AdminDiscordID == "" {
return nil, errors.New("env var not set: ADMIN_DISCORD_ID")
}
return cfg, nil
}

41
internal/rbac/ezconf.go Normal file
View File

@@ -0,0 +1,41 @@
package rbac
import (
"runtime"
"strings"
)
// EZConfIntegration provides integration with ezconf for automatic configuration
type EZConfIntegration struct {
configFunc func() (any, error)
name string
}
// PackagePath returns the path to the config package for source parsing
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
return func() (any, error) {
return e.configFunc()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return strings.ToLower(e.name)
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return e.name
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{name: "RBAC", configFunc: ConfigFromEnv}
}

View File

@@ -0,0 +1,101 @@
package rbac
import (
"net/http"
"git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
)
// RequirePermission creates middleware that requires a specific permission
func (c *Checker) RequirePermission(server *hws.Server, permission permissions.Permission) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := db.CurrentUser(r.Context())
if user == nil {
// Not logged in - redirect to login with page_from
cookies.SetPageFrom(w, r, r.URL.Path)
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
has, err := c.UserHasPermission(r.Context(), user, permission)
if err != nil {
// Log error and return 500
server.ThrowError(w, r, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "Permission check failed",
Error: err,
Level: hws.ErrorERROR,
RenderErrorPage: true,
})
return
}
if !has {
// User lacks permission - return 403
server.ThrowError(w, r, hws.HWSError{
StatusCode: http.StatusForbidden,
Message: "You don't have permission to access this resource",
Error: nil,
Level: hws.ErrorDEBUG,
RenderErrorPage: true,
})
return
}
next.ServeHTTP(w, r)
})
}
}
// RequireRole creates middleware that requires a specific role
func (c *Checker) RequireRole(server *hws.Server, role roles.Role) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := db.CurrentUser(r.Context())
if user == nil {
// Not logged in - redirect to login
cookies.SetPageFrom(w, r, r.URL.Path)
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
has, err := c.UserHasRole(r.Context(), user, role)
if err != nil {
// Log error and return 500
hwserr := hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "Role check failed",
Error: err,
Level: hws.ErrorERROR,
RenderErrorPage: true,
}
server.ThrowError(w, r, hwserr)
return
}
if !has {
// User lacks role - return 403
server.ThrowError(w, r, hws.HWSError{
StatusCode: http.StatusForbidden,
Message: "You don't have the required role to access this resource",
Error: nil,
Level: hws.ErrorDEBUG,
RenderErrorPage: true,
})
return
}
next.ServeHTTP(w, r)
})
}
}
// RequireAdmin is a convenience middleware for admin-only routes
func (c *Checker) RequireAdmin(server *hws.Server) func(http.Handler) http.Handler {
return c.RequireRole(server, roles.Admin)
}

View File

@@ -0,0 +1,13 @@
// Package roles provides constants for the RBAC
package roles
type Role string
func (r Role) String() string {
return string(r)
}
const (
Admin Role = "admin"
User Role = "user"
)

View File

@@ -0,0 +1,6 @@
package admin
import "git.haelnorr.com/h/oslstats/internal/db"
templ UserList(users *db.Users) {
}

View File

@@ -1,6 +1,10 @@
package nav
import "git.haelnorr.com/h/oslstats/internal/db"
import (
"context"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/pkg/contexts"
)
type ProfileItem struct {
name string // Label to display
@@ -8,8 +12,8 @@ type ProfileItem struct {
}
// Return the list of profile links
func getProfileItems() []ProfileItem {
return []ProfileItem{
func getProfileItems(ctx context.Context) []ProfileItem {
items := []ProfileItem{
{
name: "Profile",
href: "/profile",
@@ -19,12 +23,23 @@ func getProfileItems() []ProfileItem {
href: "/account",
},
}
// Add admin link if user has admin role
cache := contexts.Permissions(ctx)
if cache != nil && cache.Roles["admin"] {
items = append(items, ProfileItem{
name: "Admin Panel",
href: "/admin",
})
}
return items
}
// Returns the right portion of the navbar
templ navRight() {
{{ user := db.CurrentUser(ctx) }}
{{ items := getProfileItems() }}
{{ items := getProfileItems(ctx) }}
<div class="flex items-center gap-2">
<div class="sm:flex sm:gap-2">
if user != nil {

View File

@@ -0,0 +1,8 @@
package layout
templ AdminDashboard() {
@Global("Admin")
<div>
{ children... }
</div>
}

View File

@@ -0,0 +1,10 @@
package page
import "git.haelnorr.com/h/oslstats/internal/view/layout"
import "git.haelnorr.com/h/oslstats/internal/view/component/admin"
import "git.haelnorr.com/h/oslstats/internal/db"
templ AdminDashboard(users *db.Users) {
@layout.AdminDashboard()
@admin.UserList(users)
}

View File

@@ -1,6 +1,12 @@
// Package contexts provides utilities for loading and extracting structs from contexts
package contexts
import "context"
import (
"context"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
)
type Key string
@@ -8,7 +14,10 @@ func (c Key) String() string {
return "oslstats context key " + string(c)
}
var DevModeKey Key = Key("devmode")
var (
DevModeKey Key = Key("devmode")
PermissionCacheKey Key = Key("permissions")
)
func DevMode(ctx context.Context) DevInfo {
devmode, ok := ctx.Value(DevModeKey).(DevInfo)
@@ -22,3 +31,18 @@ type DevInfo struct {
WebsocketBase string
HTMXLog bool
}
// Permissions retrieves the permission cache from context (type-safe)
func Permissions(ctx context.Context) *PermissionCache {
cache, ok := ctx.Value(PermissionCacheKey).(*PermissionCache)
if !ok {
return nil
}
return cache
}
type PermissionCache struct {
Permissions map[permissions.Permission]bool
Roles map[roles.Role]bool
HasWildcard bool
}