diff --git a/Makefile b/Makefile index aeeab39..91e912e 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ migrate-create: # Reset database (DESTRUCTIVE - dev only!) reset-db: - @echo "⚠️ WARNING: This will DELETE ALL DATA!" + @echo "⚠️ WARNING - This will DELETE ALL DATA!" make build ./bin/${BINARY_NAME}${SUFFIX} --reset-db diff --git a/cmd/oslstats/db.go b/cmd/oslstats/db.go index 5ea9d10..bf5e796 100644 --- a/cmd/oslstats/db.go +++ b/cmd/oslstats/db.go @@ -1,20 +1,18 @@ package main import ( - "context" "database/sql" "fmt" "time" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" - "github.com/pkg/errors" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/driver/pgdriver" ) -func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func() error, err error) { +func setupBun(cfg *config.Config) (conn *bun.DB, close func() error) { dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s", cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL) sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) @@ -26,30 +24,19 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func conn = bun.NewDB(sqldb, pgdialect.New()) close = sqldb.Close - - err = loadModels(ctx, conn) - if err != nil { - return nil, nil, errors.Wrap(err, "loadModels") - } - - return conn, close, nil + return conn, close } -func loadModels(ctx context.Context, conn *bun.DB) error { +func registerDBModels(conn *bun.DB) { models := []any{ + (*db.RolePermission)(nil), + (*db.UserRole)(nil), (*db.User)(nil), (*db.DiscordToken)(nil), + (*db.Season)(nil), + (*db.Role)(nil), + (*db.Permission)(nil), + (*db.AuditLog)(nil), } - - for _, model := range models { - _, err := conn.NewCreateTable(). - Model(model). - IfNotExists(). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "db.NewCreateTable") - } - } - - return nil + conn.RegisterModel(models...) } diff --git a/cmd/oslstats/httpserver.go b/cmd/oslstats/httpserver.go index 7bfea84..b0528b4 100644 --- a/cmd/oslstats/httpserver.go +++ b/cmd/oslstats/httpserver.go @@ -61,7 +61,7 @@ func setupHTTPServer( } // Initialize permissions checker - perms, err := rbac.NewChecker(bun, server) + perms, err := rbac.NewChecker(bun, httpServer) if err != nil { return nil, errors.Wrap(err, "rbac.NewChecker") } diff --git a/cmd/oslstats/migrate.go b/cmd/oslstats/migrate.go index 81c8a3d..1564d63 100644 --- a/cmd/oslstats/migrate.go +++ b/cmd/oslstats/migrate.go @@ -10,6 +10,8 @@ import ( "text/tabwriter" "time" + stderrors "errors" + "git.haelnorr.com/h/oslstats/cmd/oslstats/migrations" "git.haelnorr.com/h/oslstats/internal/backup" "git.haelnorr.com/h/oslstats/internal/config" @@ -21,11 +23,8 @@ import ( // runMigrations executes database migrations func runMigrations(ctx context.Context, cfg *config.Config, command string) error { - conn, close, err := setupBun(ctx, cfg) - if err != nil { - return errors.Wrap(err, "setupBun") - } - defer close() + conn, close := setupBun(cfg) + defer func() { _ = close() }() migrator := migrate.NewMigrator(conn, migrations.Migrations) @@ -36,7 +35,14 @@ func runMigrations(ctx context.Context, cfg *config.Config, command string) erro switch command { case "up": - return migrateUp(ctx, migrator, conn, cfg) + err := migrateUp(ctx, migrator, conn, cfg) + if err != nil { + err2 := migrateRollback(ctx, migrator, conn, cfg) + if err2 != nil { + return stderrors.Join(errors.Wrap(err2, "error while rolling back after migration error"), err) + } + } + return err case "rollback": return migrateRollback(ctx, migrator, conn, cfg) case "status": @@ -171,8 +177,8 @@ func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error { fmt.Println("╚══════════════════════════════════════════════════════════╝") w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) - fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tMIGRATED AT") - fmt.Fprintln(w, "------\t---------\t-----\t-----------") + _, _ = fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tMIGRATED AT") + _, _ = fmt.Fprintln(w, "------\t---------\t-----\t-----------") appliedCount := 0 for _, m := range ms { @@ -189,10 +195,10 @@ func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error { } } - fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, migratedAt) + _, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, migratedAt) } - w.Flush() + _ = w.Flush() fmt.Printf("\n📊 Summary: %d applied, %d pending\n\n", appliedCount, len(ms)-appliedCount) @@ -299,12 +305,12 @@ func init() { Migrations.MustRegister( // UP migration func(ctx context.Context, dbConn *bun.DB) error { - // TODO: Add your migration code here + // Add your migration code here return nil }, // DOWN migration func(ctx context.Context, dbConn *bun.DB) error { - // TODO: Add your rollback code here + // Add your rollback code here return nil }, ) @@ -326,7 +332,7 @@ func init() { // resetDatabase drops and recreates all tables (destructive) func resetDatabase(ctx context.Context, cfg *config.Config) error { - fmt.Println("⚠️ WARNING: This will DELETE ALL DATA in the database!") + fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!") fmt.Print("Type 'yes' to continue: ") reader := bufio.NewReader(os.Stdin) @@ -340,11 +346,8 @@ func resetDatabase(ctx context.Context, cfg *config.Config) error { fmt.Println("❌ Reset cancelled") return nil } - conn, close, err := setupBun(ctx, cfg) - if err != nil { - return errors.Wrap(err, "setupBun") - } - defer close() + conn, close := setupBun(cfg) + defer func() { _ = close() }() models := []any{ (*db.User)(nil), diff --git a/cmd/oslstats/migrations/20260202231414_add_rbac_system.go b/cmd/oslstats/migrations/20260202231414_add_rbac_system.go index 72c2bf5..73684a9 100644 --- a/cmd/oslstats/migrations/20260202231414_add_rbac_system.go +++ b/cmd/oslstats/migrations/20260202231414_add_rbac_system.go @@ -12,20 +12,11 @@ func init() { Migrations.MustRegister( // UP migration func(ctx context.Context, dbConn *bun.DB) error { - // Create roles table using raw SQL to avoid m2m relationship issues - // Bun tries to resolve relationships when creating tables from models - // TODO: use proper m2m table instead of raw sql - _, err := dbConn.ExecContext(ctx, ` - CREATE TABLE roles ( - id SERIAL PRIMARY KEY, - name VARCHAR(50) UNIQUE NOT NULL, - display_name VARCHAR(100) NOT NULL, - description TEXT, - is_system BOOLEAN DEFAULT FALSE, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL - ) - `) + dbConn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil)) + // Create permissions table + _, err := dbConn.NewCreateTable(). + Model((*db.Role)(nil)). + Exec(ctx) if err != nil { return err } @@ -39,7 +30,6 @@ func init() { } // Create indexes for permissions - // TODO: why do we need this? _, err = dbConn.NewCreateIndex(). Model((*db.Permission)(nil)). Index("idx_permissions_resource"). @@ -49,7 +39,6 @@ func init() { return err } - // TODO: why do we need this? _, err = dbConn.NewCreateIndex(). Model((*db.Permission)(nil)). Index("idx_permissions_action"). @@ -59,22 +48,13 @@ func init() { return err } - // Create role_permissions join table (Bun doesn't auto-create m2m tables) - // TODO: use proper m2m table instead of raw sql - _, err = dbConn.ExecContext(ctx, ` - CREATE TABLE role_permissions ( - id SERIAL PRIMARY KEY, - role_id INTEGER NOT NULL REFERENCES roles(id) ON DELETE CASCADE, - permission_id INTEGER NOT NULL REFERENCES permissions(id) ON DELETE CASCADE, - created_at BIGINT NOT NULL, - UNIQUE(role_id, permission_id) - ) - `) + _, err = dbConn.NewCreateTable(). + Model((*db.RolePermission)(nil)). + Exec(ctx) if err != nil { return err } - // TODO: why do we need this? _, err = dbConn.ExecContext(ctx, ` CREATE INDEX idx_role_permissions_role ON role_permissions(role_id) `) @@ -82,7 +62,6 @@ func init() { return err } - // TODO: why do we need this? _, err = dbConn.ExecContext(ctx, ` CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id) `) @@ -99,7 +78,6 @@ func init() { } // Create indexes for user_roles - // TODO: why do we need this? _, err = dbConn.NewCreateIndex(). Model((*db.UserRole)(nil)). Index("idx_user_roles_user"). @@ -109,7 +87,6 @@ func init() { return err } - // TODO: why do we need this? _, err = dbConn.NewCreateIndex(). Model((*db.UserRole)(nil)). Index("idx_user_roles_role"). @@ -128,7 +105,6 @@ func init() { } // Create indexes for audit_log - // TODO: why do we need this? _, err = dbConn.NewCreateIndex(). Model((*db.AuditLog)(nil)). Index("idx_audit_log_user"). @@ -138,7 +114,6 @@ func init() { return err } - // TODO: why do we need this? _, err = dbConn.NewCreateIndex(). Model((*db.AuditLog)(nil)). Index("idx_audit_log_action"). @@ -148,7 +123,6 @@ func init() { return err } - // TODO: why do we need this? _, err = dbConn.NewCreateIndex(). Model((*db.AuditLog)(nil)). Index("idx_audit_log_resource"). @@ -158,7 +132,6 @@ func init() { return err } - // TODO: why do we need this? _, err = dbConn.NewCreateIndex(). Model((*db.AuditLog)(nil)). Index("idx_audit_log_created"). @@ -176,12 +149,12 @@ func init() { DisplayName: "Administrator", Description: "Full system access with all permissions", IsSystem: true, - CreatedAt: now, // TODO: this should be defaulted in table - UpdatedAt: now, // TODO: this should be defaulted in table + CreatedAt: now, } _, err = dbConn.NewInsert(). Model(adminRole). + Returning("id"). Exec(ctx) if err != nil { return err @@ -192,9 +165,7 @@ func init() { DisplayName: "User", Description: "Standard user with basic permissions", IsSystem: true, - CreatedAt: now, // TODO: this should be defaulted in table - UpdatedAt: now, // TODO: this should be defaulted in table - + CreatedAt: now, } _, err = dbConn.NewInsert(). @@ -205,7 +176,6 @@ func init() { } // Seed system permissions - // TODO: timestamps for created should be defaulted in table permissionsData := []*db.Permission{ {Name: "*", DisplayName: "Wildcard (All Permissions)", Description: "Grants access to all permissions, past, present, and future", Resource: "*", Action: "*", IsSystem: true, CreatedAt: now}, {Name: "seasons.create", DisplayName: "Create Seasons", Description: "Create new seasons", Resource: "seasons", Action: "create", IsSystem: true, CreatedAt: now}, @@ -235,11 +205,14 @@ func init() { } // Insert role_permission mapping - // TODO: use proper m2m table, and default now in table settings - _, err = dbConn.ExecContext(ctx, ` - INSERT INTO role_permissions (role_id, permission_id, created_at) - VALUES ($1, $2, $3) - `, adminRole.ID, wildcardPerm.ID, now) + adminRolePerms := &db.RolePermission{ + RoleID: adminRole.ID, + PermissionID: wildcardPerm.ID, + } + _, err = dbConn.NewInsert(). + Model(adminRolePerms). + On("CONFLICT (role_id, permission_id) DO NOTHING"). + Exec(ctx) if err != nil { return err } @@ -250,7 +223,6 @@ func init() { func(ctx context.Context, dbConn *bun.DB) error { // Drop tables in reverse order // Use raw SQL to avoid relationship resolution issues - // TODO: surely we can use proper bun methods? tables := []string{ "audit_log", "user_roles", diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index b663d46..334b520 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -30,6 +30,11 @@ func addRoutes( ) error { // Create the routes pageroutes := []hws.Route{ + { + Path: "/permtest", + Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, + Handler: handlers.PermTester(s, conn), + }, { Path: "/static/", Method: hws.MethodGET, @@ -63,8 +68,7 @@ func addRoutes( { Path: "/notification-tester", Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, - Handler: handlers.NotifyTester(s), - // TODO: add login protection + Handler: perms.RequireAdmin(s)(handlers.NotifyTester(s)), }, { Path: "/seasons", diff --git a/cmd/oslstats/run.go b/cmd/oslstats/run.go index 4c0eef0..ee831df 100644 --- a/cmd/oslstats/run.go +++ b/cmd/oslstats/run.go @@ -25,10 +25,8 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error { // Setup the database connection logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Connecting to database") - bun, closedb, err := setupBun(ctx, cfg) - if err != nil { - return errors.Wrap(err, "setupDBConn") - } + bun, closedb := setupBun(cfg) + registerDBModels(bun) // Setup embedded files logger.Debug().Msg("Getting embedded files") diff --git a/internal/auditlog/logger.go b/internal/auditlog/logger.go index e3e806f..9a06e15 100644 --- a/internal/auditlog/logger.go +++ b/internal/auditlog/logger.go @@ -104,39 +104,37 @@ func (l *Logger) log( } // GetRecentLogs retrieves recent audit logs with pagination -// TODO: change this to user db.PageOpts -func (l *Logger) GetRecentLogs(ctx context.Context, limit, offset int) ([]*db.AuditLog, int, error) { +func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) { tx, err := l.conn.BeginTx(ctx, nil) if err != nil { - return nil, 0, errors.Wrap(err, "conn.BeginTx") + return nil, errors.Wrap(err, "conn.BeginTx") } defer func() { _ = tx.Rollback() }() - logs, total, err := db.GetAuditLogs(ctx, tx, limit, offset, nil) + logs, err := db.GetAuditLogs(ctx, tx, pageOpts, nil) if err != nil { - return nil, 0, err + return nil, err } _ = tx.Commit() // read only transaction - return logs, total, nil + return logs, nil } // GetLogsByUser retrieves audit logs for a specific user -// TODO: change this to user db.PageOpts -func (l *Logger) GetLogsByUser(ctx context.Context, userID int, limit, offset int) ([]*db.AuditLog, int, error) { +func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) { tx, err := l.conn.BeginTx(ctx, nil) if err != nil { - return nil, 0, errors.Wrap(err, "conn.BeginTx") + return nil, errors.Wrap(err, "conn.BeginTx") } defer func() { _ = tx.Rollback() }() - logs, total, err := db.GetAuditLogsByUser(ctx, tx, userID, limit, offset) + logs, err := db.GetAuditLogsByUser(ctx, tx, userID, pageOpts) if err != nil { - return nil, 0, err + return nil, err } _ = tx.Commit() // read only transaction - return logs, total, nil + return logs, nil } // CleanupOldLogs deletes audit logs older than the specified number of days diff --git a/internal/db/auditlog.go b/internal/db/auditlog.go index f5c0d83..2fe7a93 100644 --- a/internal/db/auditlog.go +++ b/internal/db/auditlog.go @@ -28,7 +28,11 @@ type AuditLog struct { User *User `bun:"rel:belongs-to,join:user_id=id"` } -// TODO: add AuditLogs to match list style with PageOpts +type AuditLogs struct { + AuditLogs []*AuditLog + Total int + PageOpts PageOpts +} // CreateAuditLog creates a new audit log entry func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error { @@ -54,12 +58,12 @@ type AuditLogFilters struct { } // GetAuditLogs retrieves audit logs with optional filters and pagination -// TODO: change this to use db.PageOpts -func GetAuditLogs(ctx context.Context, tx bun.Tx, limit, offset int, filters *AuditLogFilters) ([]*AuditLog, int, error) { +func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) { + pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderDesc, "created_at") query := tx.NewSelect(). Model((*AuditLog)(nil)). Relation("User"). - Order("created_at DESC") + OrderBy(pageOpts.OrderBy, pageOpts.Order) // Apply filters if provided if filters != nil { @@ -80,48 +84,52 @@ func GetAuditLogs(ctx context.Context, tx bun.Tx, limit, offset int, filters *Au // Get total count total, err := query.Count(ctx) if err != nil { - return nil, 0, errors.Wrap(err, "query.Count") + return nil, errors.Wrap(err, "query.Count") } // Get paginated results - var logs []*AuditLog + logs := new([]*AuditLog) err = query. - Limit(limit). - Offset(offset). + Offset(pageOpts.PerPage*(pageOpts.Page-1)). + Limit(pageOpts.PerPage). Scan(ctx, &logs) if err != nil && err != sql.ErrNoRows { - return nil, 0, errors.Wrap(err, "query.Scan") + return nil, errors.Wrap(err, "query.Scan") } - return logs, total, nil + list := &AuditLogs{ + AuditLogs: *logs, + Total: total, + PageOpts: *pageOpts, + } + + return list, nil } // GetAuditLogsByUser retrieves audit logs for a specific user -// TODO: change this to use db.PageOpts -func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, limit, offset int) ([]*AuditLog, int, error) { +func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*AuditLogs, error) { if userID <= 0 { - return nil, 0, errors.New("userID must be positive") + return nil, errors.New("userID must be positive") } filters := &AuditLogFilters{ UserID: &userID, } - return GetAuditLogs(ctx, tx, limit, offset, filters) + return GetAuditLogs(ctx, tx, pageOpts, filters) } // GetAuditLogsByAction retrieves audit logs for a specific action -// TODO: change this to use db.PageOpts -func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, limit, offset int) ([]*AuditLog, int, error) { +func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*AuditLogs, error) { if action == "" { - return nil, 0, errors.New("action cannot be empty") + return nil, errors.New("action cannot be empty") } filters := &AuditLogFilters{ Action: &action, } - return GetAuditLogs(ctx, tx, limit, offset, filters) + return GetAuditLogs(ctx, tx, pageOpts, filters) } // CleanupOldAuditLogs deletes audit logs older than the specified timestamp diff --git a/internal/db/paginate.go b/internal/db/paginate.go index 366cb33..27cc3a8 100644 --- a/internal/db/paginate.go +++ b/internal/db/paginate.go @@ -15,6 +15,25 @@ type OrderOpts struct { Label string } +func setDefaultPageOpts(p *PageOpts, page, perpage int, order bun.Order, orderby string) *PageOpts { + if p == nil { + p = new(PageOpts) + } + if p.Page == 0 { + p.Page = page + } + if p.PerPage == 0 { + p.PerPage = perpage + } + if p.Order == "" { + p.Order = order + } + if p.OrderBy == "" { + p.OrderBy = orderby + } + return p +} + // TotalPages calculates the total number of pages func (p *PageOpts) TotalPages(total int) int { if p.PerPage == 0 { diff --git a/internal/db/permission.go b/internal/db/permission.go index 64aacef..8443127 100644 --- a/internal/db/permission.go +++ b/internal/db/permission.go @@ -20,6 +20,8 @@ type Permission struct { Action string `bun:"action,notnull"` IsSystem bool `bun:"is_system,default:false"` CreatedAt int64 `bun:"created_at,notnull"` + + Roles []Role `bun:"m2m:role_permissions,join:Permission=Role"` } // GetPermissionByName queries the database for a permission matching the given name diff --git a/internal/db/role.go b/internal/db/role.go index dbf87bd..ed60194 100644 --- a/internal/db/role.go +++ b/internal/db/role.go @@ -3,6 +3,7 @@ package db import ( "context" "database/sql" + "time" "git.haelnorr.com/h/oslstats/internal/roles" "github.com/pkg/errors" @@ -18,10 +19,18 @@ type Role struct { Description string `bun:"description"` IsSystem bool `bun:"is_system,default:false"` CreatedAt int64 `bun:"created_at,notnull"` - UpdatedAt int64 `bun:"updated_at,notnull"` + UpdatedAt *int64 `bun:"updated_at"` // Relations (loaded on demand) - Permissions []*Permission `bun:"m2m:role_permissions,join:Role=Permission"` + Users []User `bun:"m2m:user_roles,join:Role=User"` + Permissions []Permission `bun:"m2m:role_permissions,join:Role=Permission"` +} + +type RolePermission struct { + RoleID int `bun:",pk"` + Role *Role `bun:"rel:belongs-to,join:role_id=id"` + PermissionID int `bun:",pk"` + Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"` } // GetRoleByName queries the database for a role matching the given name @@ -99,6 +108,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error { if role == nil { return errors.New("role cannot be nil") } + role.CreatedAt = time.Now().Unix() _, err := tx.NewInsert(). Model(role). @@ -160,23 +170,23 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int) error { } // AddPermissionToRole grants a permission to a role -func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int, createdAt int64) error { +func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error { if roleID <= 0 { return errors.New("roleID must be positive") } if permissionID <= 0 { return errors.New("permissionID must be positive") } - - // TODO: use proper m2m table - // also make createdAt automatic in table so not required as input here - _, err := tx.ExecContext(ctx, ` - INSERT INTO role_permissions (role_id, permission_id, created_at) - VALUES ($1, $2, $3) - ON CONFLICT (role_id, permission_id) DO NOTHING - `, roleID, permissionID, createdAt) + rolePerm := &RolePermission{ + RoleID: roleID, + PermissionID: permissionID, + } + _, err := tx.NewInsert(). + Model(rolePerm). + On("CONFLICT (role_id, permission_id) DO NOTHING"). + Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.ExecContext") + return errors.Wrap(err, "tx.NewInsert") } return nil @@ -191,13 +201,13 @@ func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permission return errors.New("permissionID must be positive") } - // TODO: use proper m2m table - _, err := tx.ExecContext(ctx, ` - DELETE FROM role_permissions - WHERE role_id = $1 AND permission_id = $2 - `, roleID, permissionID) + _, err := tx.NewDelete(). + Model((*RolePermission)(nil)). + Where("role_id = ?", roleID). + Where("permission_id = ?", permissionID). + Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.ExecContext") + return errors.Wrap(err, "tx.NewDelete") } return nil diff --git a/internal/db/season.go b/internal/db/season.go index 49021a4..69dd267 100644 --- a/internal/db/season.go +++ b/internal/db/season.go @@ -23,7 +23,7 @@ type Season struct { } type SeasonList struct { - Seasons []Season + Seasons []*Season Total int PageOpts PageOpts } @@ -50,24 +50,10 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim } func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) { - if pageOpts == nil { - pageOpts = &PageOpts{} - } - if pageOpts.Page == 0 { - pageOpts.Page = 1 - } - if pageOpts.PerPage == 0 { - pageOpts.PerPage = 10 - } - if pageOpts.Order == "" { - pageOpts.Order = bun.OrderDesc - } - if pageOpts.OrderBy == "" { - pageOpts.OrderBy = "start_date" - } - seasons := []Season{} + pageOpts = setDefaultPageOpts(pageOpts, 1, 10, bun.OrderDesc, "start_date") + seasons := new([]*Season) err := tx.NewSelect(). - Model(&seasons). + Model(seasons). OrderBy(pageOpts.OrderBy, pageOpts.Order). Offset(pageOpts.PerPage * (pageOpts.Page - 1)). Limit(pageOpts.PerPage). @@ -76,13 +62,13 @@ func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonLis return nil, errors.Wrap(err, "tx.NewSelect") } total, err := tx.NewSelect(). - Model(&seasons). + Model(seasons). Count(ctx) if err != nil { return nil, errors.Wrap(err, "tx.NewSelect") } sl := &SeasonList{ - Seasons: seasons, + Seasons: *seasons, Total: total, PageOpts: *pageOpts, } diff --git a/internal/db/user.go b/internal/db/user.go index bdcdeeb..093af91 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -23,6 +23,8 @@ type User struct { Username string `bun:"username,unique"` // Username (unique) CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database DiscordID string `bun:"discord_id,unique"` + + Roles []*Role `bun:"m2m:user_roles,join:User=Role"` } type Users struct { @@ -124,28 +126,35 @@ func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, er return count == 0, nil } -// GetRoles loads and returns all roles for this user +// GetRoles loads all the roles for this user func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) { if u == nil { return nil, errors.New("user cannot be nil") } - return GetUserRoles(ctx, tx, u.ID) + + err := tx.NewSelect(). + Model(u). + Relation("Roles"). + Where("id = ?", u.ID). + Scan(ctx) + if err != nil { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return u.Roles, nil } -// GetPermissions loads and returns all permissions for this user (via roles) +// GetPermissions loads and returns all permissions for this user func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) { if u == nil { return nil, errors.New("user cannot be nil") } - // TODO: use proper m2m tables and relations instead of join var permissions []*Permission err := tx.NewSelect(). Model(&permissions). - Join("JOIN role_permissions AS rp ON rp.permission_id = p.id"). + Join("JOIN role_permissions AS rp on rp.permission_id = p.id"). Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id"). Where("ur.user_id = ?", u.ID). - Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()). Scan(ctx) if err != nil && err != sql.ErrNoRows { return nil, errors.Wrap(err, "tx.NewSelect") @@ -192,22 +201,8 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) { } func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) { - if pageOpts == nil { - pageOpts = &PageOpts{} - } - if pageOpts.Page == 0 { - pageOpts.Page = 1 - } - if pageOpts.PerPage == 0 { - pageOpts.PerPage = 50 - } - if pageOpts.Order == "" { - pageOpts.Order = bun.OrderAsc - } - if pageOpts.OrderBy == "" { - pageOpts.OrderBy = "id" - } - users := []*User{} + pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderAsc, "id") + users := new([]*User) err := tx.NewSelect(). Model(users). OrderBy(pageOpts.OrderBy, pageOpts.Order). @@ -224,7 +219,7 @@ func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error return nil, errors.Wrap(err, "tx.NewSelect") } list := &Users{ - Users: users, + Users: *users, Total: total, PageOpts: *pageOpts, } diff --git a/internal/db/userrole.go b/internal/db/userrole.go index ba2708f..da538c8 100644 --- a/internal/db/userrole.go +++ b/internal/db/userrole.go @@ -2,7 +2,6 @@ package db import ( "context" - "time" "git.haelnorr.com/h/oslstats/internal/roles" "github.com/pkg/errors" @@ -10,42 +9,14 @@ import ( ) type UserRole struct { - bun.BaseModel `bun:"table:user_roles,alias:ur"` - - ID int `bun:"id,pk,autoincrement"` - UserID int `bun:"user_id,notnull"` - RoleID int `bun:"role_id,notnull"` - GrantedBy *int `bun:"granted_by"` - GrantedAt int64 `bun:"granted_at,notnull"` // TODO: default now - ExpiresAt *int64 `bun:"expires_at"` - - // Relations - User *User `bun:"rel:belongs-to,join:user_id=id"` - Role *Role `bun:"rel:belongs-to,join:role_id=id"` -} - -// GetUserRoles loads all roles for a given user -func GetUserRoles(ctx context.Context, tx bun.Tx, userID int) ([]*Role, error) { - if userID <= 0 { - return nil, errors.New("userID must be positive") - } - - var roles []*Role - err := tx.NewSelect(). - Model(&roles). - // TODO: why are we joining? can we do relation? - Join("JOIN user_roles AS ur ON ur.role_id = r.id"). - Where("ur.user_id = ?", userID). - Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()). - Scan(ctx) - if err != nil { - return nil, errors.Wrap(err, "tx.NewSelect") - } - return roles, nil + UserID int `bun:",pk"` + User *User `bun:"rel:belongs-to,join:user_id=id"` + RoleID int `bun:",pk"` + Role *Role `bun:"rel:belongs-to,join:role_id=id"` } // AssignRole grants a role to a user -func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, grantedBy *int) error { +func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { if userID <= 0 { return errors.New("userID must be positive") } @@ -53,16 +24,16 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, grantedBy *i return errors.New("roleID must be positive") } - now := time.Now().Unix() - - // TODO: use proper m2m table instead of raw SQL - _, err := tx.ExecContext(ctx, ` - INSERT INTO user_roles (user_id, role_id, granted_by, granted_at) - VALUES ($1, $2, $3, $4) - ON CONFLICT (user_id, role_id) DO NOTHING - `, userID, roleID, grantedBy, now) + userRole := &UserRole{ + UserID: userID, + RoleID: roleID, + } + _, err := tx.NewInsert(). + Model(userRole). + On("CONFLICT (user_id, role_id) DO NOTHING"). + Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.ExecContext") + return errors.Wrap(err, "tx.NewInsert") } return nil @@ -77,13 +48,13 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { return errors.New("roleID must be positive") } - // TODO: use proper m2m table instead of raw sql - _, err := tx.ExecContext(ctx, ` - DELETE FROM user_roles - WHERE user_id = $1 AND role_id = $2 - `, userID, roleID) + _, err := tx.NewDelete(). + Model((*UserRole)(nil)). + Where("user_id = ?", userID). + Where("role_id = ?", roleID). + Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.ExecContext") + return errors.Wrap(err, "tx.NewDelete") } return nil @@ -97,18 +68,19 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b if roleName == "" { return false, errors.New("roleName cannot be empty") } - - // TODO: use proper m2m table instead of TableExpr and Join? - count, err := tx.NewSelect(). - TableExpr("user_roles AS ur"). - Join("JOIN roles AS r ON r.id = ur.role_id"). - Where("ur.user_id = ?", userID). - Where("r.name = ?", roleName). - Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()). - Count(ctx) + user := new(User) + err := tx.NewSelect(). + Model(user). + Relation("Roles"). + Where("u.id = ? ", userID). + Scan(ctx) if err != nil { return false, errors.Wrap(err, "tx.NewSelect") } - - return count > 0, nil + for _, role := range user.Roles { + if role.Name == roleName { + return true, nil + } + } + return false, nil } diff --git a/internal/handlers/auth_helpers.go b/internal/handlers/auth_helpers.go index 6a29242..7511bc1 100644 --- a/internal/handlers/auth_helpers.go +++ b/internal/handlers/auth_helpers.go @@ -46,8 +46,8 @@ func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error return errors.New("admin role not found in database") } - // Grant admin role (nil grantedBy = system granted) - err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil) + // Grant admin role + err = db.AssignRole(ctx, tx, user.ID, adminRole.ID) if err != nil { return errors.Wrap(err, "db.AssignRole") } diff --git a/internal/handlers/test.go b/internal/handlers/notifytest.go similarity index 100% rename from internal/handlers/test.go rename to internal/handlers/notifytest.go diff --git a/internal/handlers/permtest.go b/internal/handlers/permtest.go new file mode 100644 index 0000000..7f2d145 --- /dev/null +++ b/internal/handlers/permtest.go @@ -0,0 +1,24 @@ +package handlers + +import ( + "net/http" + "strconv" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/roles" + "github.com/uptrace/bun" +) + +func PermTester(s *hws.Server, conn *bun.DB) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := db.CurrentUser(r.Context()) + tx, _ := conn.BeginTx(r.Context(), nil) + isAdmin, err := user.HasRole(r.Context(), tx, roles.Admin) + tx.Rollback() + if err != nil { + throwInternalServiceError(s, w, r, "Error", err) + } + _, _ = w.Write([]byte(strconv.FormatBool(isAdmin))) + }) +} diff --git a/internal/rbac/checker.go b/internal/rbac/checker.go index 5b88594..886cf89 100644 --- a/internal/rbac/checker.go +++ b/internal/rbac/checker.go @@ -18,8 +18,11 @@ type Checker struct { } func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) { - if conn == nil || s == nil { - return nil, errors.New("arguments cannot be nil") + if conn == nil { + return nil, errors.New("conn cannot be nil") + } + if s == nil { + return nil, errors.New("server cannot be nil") } return &Checker{conn: conn, s: s}, nil } diff --git a/internal/view/layout/global.templ b/internal/view/layout/global.templ index 5e08d78..8e829be 100644 --- a/internal/view/layout/global.templ +++ b/internal/view/layout/global.templ @@ -22,7 +22,7 @@ templ Global(title string) { { title } - + diff --git a/internal/view/page/seasons_list.templ b/internal/view/page/seasons_list.templ index b7b53c6..f300edf 100644 --- a/internal/view/page/seasons_list.templ +++ b/internal/view/page/seasons_list.templ @@ -84,7 +84,7 @@ templ SeasonsList(seasons *db.SeasonList) {

{ s.Name }

- @season.StatusBadge(&s, true, true) + @season.StatusBadge(s, true, true)
@@ -107,5 +107,5 @@ templ SeasonsList(seasons *db.SeasonList) { } func formatDate(t time.Time) string { - return t.Format("02/01/2006") // DD/MM/YYYY + return t.Format("02/01/2006") // DD/MM/YYYY } diff --git a/pkg/oauth/config.go b/pkg/oauth/config.go index c37ef21..09a21e5 100644 --- a/pkg/oauth/config.go +++ b/pkg/oauth/config.go @@ -1,3 +1,4 @@ +// Package oauth provides OAuth utilities for generating and checking secure state tokens package oauth import (