fixed relationship issues

This commit is contained in:
2026-02-05 00:10:10 +11:00
parent 20308fe35c
commit 4c31c24069
22 changed files with 236 additions and 254 deletions

View File

@@ -82,7 +82,7 @@ migrate-create:
# Reset database (DESTRUCTIVE - dev only!) # Reset database (DESTRUCTIVE - dev only!)
reset-db: reset-db:
@echo "⚠️ WARNING: This will DELETE ALL DATA!" @echo "⚠️ WARNING - This will DELETE ALL DATA!"
make build make build
./bin/${BINARY_NAME}${SUFFIX} --reset-db ./bin/${BINARY_NAME}${SUFFIX} --reset-db

View File

@@ -1,20 +1,18 @@
package main package main
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"time" "time"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver" "github.com/uptrace/bun/driver/pgdriver"
) )
func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func() error, err error) { func setupBun(cfg *config.Config) (conn *bun.DB, close func() error) {
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s", dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL) cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL)
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
@@ -26,30 +24,19 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func
conn = bun.NewDB(sqldb, pgdialect.New()) conn = bun.NewDB(sqldb, pgdialect.New())
close = sqldb.Close close = sqldb.Close
return conn, close
err = loadModels(ctx, conn)
if err != nil {
return nil, nil, errors.Wrap(err, "loadModels")
}
return conn, close, nil
} }
func loadModels(ctx context.Context, conn *bun.DB) error { func registerDBModels(conn *bun.DB) {
models := []any{ models := []any{
(*db.RolePermission)(nil),
(*db.UserRole)(nil),
(*db.User)(nil), (*db.User)(nil),
(*db.DiscordToken)(nil), (*db.DiscordToken)(nil),
(*db.Season)(nil),
(*db.Role)(nil),
(*db.Permission)(nil),
(*db.AuditLog)(nil),
} }
conn.RegisterModel(models...)
for _, model := range models {
_, err := conn.NewCreateTable().
Model(model).
IfNotExists().
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.NewCreateTable")
}
}
return nil
} }

View File

@@ -61,7 +61,7 @@ func setupHTTPServer(
} }
// Initialize permissions checker // Initialize permissions checker
perms, err := rbac.NewChecker(bun, server) perms, err := rbac.NewChecker(bun, httpServer)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "rbac.NewChecker") return nil, errors.Wrap(err, "rbac.NewChecker")
} }

View File

@@ -10,6 +10,8 @@ import (
"text/tabwriter" "text/tabwriter"
"time" "time"
stderrors "errors"
"git.haelnorr.com/h/oslstats/cmd/oslstats/migrations" "git.haelnorr.com/h/oslstats/cmd/oslstats/migrations"
"git.haelnorr.com/h/oslstats/internal/backup" "git.haelnorr.com/h/oslstats/internal/backup"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
@@ -21,11 +23,8 @@ import (
// runMigrations executes database migrations // runMigrations executes database migrations
func runMigrations(ctx context.Context, cfg *config.Config, command string) error { func runMigrations(ctx context.Context, cfg *config.Config, command string) error {
conn, close, err := setupBun(ctx, cfg) conn, close := setupBun(cfg)
if err != nil { defer func() { _ = close() }()
return errors.Wrap(err, "setupBun")
}
defer close()
migrator := migrate.NewMigrator(conn, migrations.Migrations) migrator := migrate.NewMigrator(conn, migrations.Migrations)
@@ -36,7 +35,14 @@ func runMigrations(ctx context.Context, cfg *config.Config, command string) erro
switch command { switch command {
case "up": case "up":
return migrateUp(ctx, migrator, conn, cfg) err := migrateUp(ctx, migrator, conn, cfg)
if err != nil {
err2 := migrateRollback(ctx, migrator, conn, cfg)
if err2 != nil {
return stderrors.Join(errors.Wrap(err2, "error while rolling back after migration error"), err)
}
}
return err
case "rollback": case "rollback":
return migrateRollback(ctx, migrator, conn, cfg) return migrateRollback(ctx, migrator, conn, cfg)
case "status": case "status":
@@ -171,8 +177,8 @@ func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error {
fmt.Println("╚══════════════════════════════════════════════════════════╝") fmt.Println("╚══════════════════════════════════════════════════════════╝")
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tMIGRATED AT") _, _ = fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tMIGRATED AT")
fmt.Fprintln(w, "------\t---------\t-----\t-----------") _, _ = fmt.Fprintln(w, "------\t---------\t-----\t-----------")
appliedCount := 0 appliedCount := 0
for _, m := range ms { for _, m := range ms {
@@ -189,10 +195,10 @@ func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error {
} }
} }
fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, migratedAt) _, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, migratedAt)
} }
w.Flush() _ = w.Flush()
fmt.Printf("\n📊 Summary: %d applied, %d pending\n\n", fmt.Printf("\n📊 Summary: %d applied, %d pending\n\n",
appliedCount, len(ms)-appliedCount) appliedCount, len(ms)-appliedCount)
@@ -299,12 +305,12 @@ func init() {
Migrations.MustRegister( Migrations.MustRegister(
// UP migration // UP migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, dbConn *bun.DB) error {
// TODO: Add your migration code here // Add your migration code here
return nil return nil
}, },
// DOWN migration // DOWN migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, dbConn *bun.DB) error {
// TODO: Add your rollback code here // Add your rollback code here
return nil return nil
}, },
) )
@@ -326,7 +332,7 @@ func init() {
// resetDatabase drops and recreates all tables (destructive) // resetDatabase drops and recreates all tables (destructive)
func resetDatabase(ctx context.Context, cfg *config.Config) error { func resetDatabase(ctx context.Context, cfg *config.Config) error {
fmt.Println("⚠️ WARNING: This will DELETE ALL DATA in the database!") fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!")
fmt.Print("Type 'yes' to continue: ") fmt.Print("Type 'yes' to continue: ")
reader := bufio.NewReader(os.Stdin) reader := bufio.NewReader(os.Stdin)
@@ -340,11 +346,8 @@ func resetDatabase(ctx context.Context, cfg *config.Config) error {
fmt.Println("❌ Reset cancelled") fmt.Println("❌ Reset cancelled")
return nil return nil
} }
conn, close, err := setupBun(ctx, cfg) conn, close := setupBun(cfg)
if err != nil { defer func() { _ = close() }()
return errors.Wrap(err, "setupBun")
}
defer close()
models := []any{ models := []any{
(*db.User)(nil), (*db.User)(nil),

View File

@@ -12,20 +12,11 @@ func init() {
Migrations.MustRegister( Migrations.MustRegister(
// UP migration // UP migration
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, dbConn *bun.DB) error {
// Create roles table using raw SQL to avoid m2m relationship issues dbConn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil))
// Bun tries to resolve relationships when creating tables from models // Create permissions table
// TODO: use proper m2m table instead of raw sql _, err := dbConn.NewCreateTable().
_, err := dbConn.ExecContext(ctx, ` Model((*db.Role)(nil)).
CREATE TABLE roles ( Exec(ctx)
id SERIAL PRIMARY KEY,
name VARCHAR(50) UNIQUE NOT NULL,
display_name VARCHAR(100) NOT NULL,
description TEXT,
is_system BOOLEAN DEFAULT FALSE,
created_at BIGINT NOT NULL,
updated_at BIGINT NOT NULL
)
`)
if err != nil { if err != nil {
return err return err
} }
@@ -39,7 +30,6 @@ func init() {
} }
// Create indexes for permissions // Create indexes for permissions
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex(). _, err = dbConn.NewCreateIndex().
Model((*db.Permission)(nil)). Model((*db.Permission)(nil)).
Index("idx_permissions_resource"). Index("idx_permissions_resource").
@@ -49,7 +39,6 @@ func init() {
return err return err
} }
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex(). _, err = dbConn.NewCreateIndex().
Model((*db.Permission)(nil)). Model((*db.Permission)(nil)).
Index("idx_permissions_action"). Index("idx_permissions_action").
@@ -59,22 +48,13 @@ func init() {
return err return err
} }
// Create role_permissions join table (Bun doesn't auto-create m2m tables) _, err = dbConn.NewCreateTable().
// TODO: use proper m2m table instead of raw sql Model((*db.RolePermission)(nil)).
_, err = dbConn.ExecContext(ctx, ` Exec(ctx)
CREATE TABLE role_permissions (
id SERIAL PRIMARY KEY,
role_id INTEGER NOT NULL REFERENCES roles(id) ON DELETE CASCADE,
permission_id INTEGER NOT NULL REFERENCES permissions(id) ON DELETE CASCADE,
created_at BIGINT NOT NULL,
UNIQUE(role_id, permission_id)
)
`)
if err != nil { if err != nil {
return err return err
} }
// TODO: why do we need this?
_, err = dbConn.ExecContext(ctx, ` _, err = dbConn.ExecContext(ctx, `
CREATE INDEX idx_role_permissions_role ON role_permissions(role_id) CREATE INDEX idx_role_permissions_role ON role_permissions(role_id)
`) `)
@@ -82,7 +62,6 @@ func init() {
return err return err
} }
// TODO: why do we need this?
_, err = dbConn.ExecContext(ctx, ` _, err = dbConn.ExecContext(ctx, `
CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id) CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id)
`) `)
@@ -99,7 +78,6 @@ func init() {
} }
// Create indexes for user_roles // Create indexes for user_roles
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex(). _, err = dbConn.NewCreateIndex().
Model((*db.UserRole)(nil)). Model((*db.UserRole)(nil)).
Index("idx_user_roles_user"). Index("idx_user_roles_user").
@@ -109,7 +87,6 @@ func init() {
return err return err
} }
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex(). _, err = dbConn.NewCreateIndex().
Model((*db.UserRole)(nil)). Model((*db.UserRole)(nil)).
Index("idx_user_roles_role"). Index("idx_user_roles_role").
@@ -128,7 +105,6 @@ func init() {
} }
// Create indexes for audit_log // Create indexes for audit_log
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex(). _, err = dbConn.NewCreateIndex().
Model((*db.AuditLog)(nil)). Model((*db.AuditLog)(nil)).
Index("idx_audit_log_user"). Index("idx_audit_log_user").
@@ -138,7 +114,6 @@ func init() {
return err return err
} }
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex(). _, err = dbConn.NewCreateIndex().
Model((*db.AuditLog)(nil)). Model((*db.AuditLog)(nil)).
Index("idx_audit_log_action"). Index("idx_audit_log_action").
@@ -148,7 +123,6 @@ func init() {
return err return err
} }
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex(). _, err = dbConn.NewCreateIndex().
Model((*db.AuditLog)(nil)). Model((*db.AuditLog)(nil)).
Index("idx_audit_log_resource"). Index("idx_audit_log_resource").
@@ -158,7 +132,6 @@ func init() {
return err return err
} }
// TODO: why do we need this?
_, err = dbConn.NewCreateIndex(). _, err = dbConn.NewCreateIndex().
Model((*db.AuditLog)(nil)). Model((*db.AuditLog)(nil)).
Index("idx_audit_log_created"). Index("idx_audit_log_created").
@@ -176,12 +149,12 @@ func init() {
DisplayName: "Administrator", DisplayName: "Administrator",
Description: "Full system access with all permissions", Description: "Full system access with all permissions",
IsSystem: true, IsSystem: true,
CreatedAt: now, // TODO: this should be defaulted in table CreatedAt: now,
UpdatedAt: now, // TODO: this should be defaulted in table
} }
_, err = dbConn.NewInsert(). _, err = dbConn.NewInsert().
Model(adminRole). Model(adminRole).
Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return err return err
@@ -192,9 +165,7 @@ func init() {
DisplayName: "User", DisplayName: "User",
Description: "Standard user with basic permissions", Description: "Standard user with basic permissions",
IsSystem: true, IsSystem: true,
CreatedAt: now, // TODO: this should be defaulted in table CreatedAt: now,
UpdatedAt: now, // TODO: this should be defaulted in table
} }
_, err = dbConn.NewInsert(). _, err = dbConn.NewInsert().
@@ -205,7 +176,6 @@ func init() {
} }
// Seed system permissions // Seed system permissions
// TODO: timestamps for created should be defaulted in table
permissionsData := []*db.Permission{ permissionsData := []*db.Permission{
{Name: "*", DisplayName: "Wildcard (All Permissions)", Description: "Grants access to all permissions, past, present, and future", Resource: "*", Action: "*", IsSystem: true, CreatedAt: now}, {Name: "*", DisplayName: "Wildcard (All Permissions)", Description: "Grants access to all permissions, past, present, and future", Resource: "*", Action: "*", IsSystem: true, CreatedAt: now},
{Name: "seasons.create", DisplayName: "Create Seasons", Description: "Create new seasons", Resource: "seasons", Action: "create", IsSystem: true, CreatedAt: now}, {Name: "seasons.create", DisplayName: "Create Seasons", Description: "Create new seasons", Resource: "seasons", Action: "create", IsSystem: true, CreatedAt: now},
@@ -235,11 +205,14 @@ func init() {
} }
// Insert role_permission mapping // Insert role_permission mapping
// TODO: use proper m2m table, and default now in table settings adminRolePerms := &db.RolePermission{
_, err = dbConn.ExecContext(ctx, ` RoleID: adminRole.ID,
INSERT INTO role_permissions (role_id, permission_id, created_at) PermissionID: wildcardPerm.ID,
VALUES ($1, $2, $3) }
`, adminRole.ID, wildcardPerm.ID, now) _, err = dbConn.NewInsert().
Model(adminRolePerms).
On("CONFLICT (role_id, permission_id) DO NOTHING").
Exec(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -250,7 +223,6 @@ func init() {
func(ctx context.Context, dbConn *bun.DB) error { func(ctx context.Context, dbConn *bun.DB) error {
// Drop tables in reverse order // Drop tables in reverse order
// Use raw SQL to avoid relationship resolution issues // Use raw SQL to avoid relationship resolution issues
// TODO: surely we can use proper bun methods?
tables := []string{ tables := []string{
"audit_log", "audit_log",
"user_roles", "user_roles",

View File

@@ -30,6 +30,11 @@ func addRoutes(
) error { ) error {
// Create the routes // Create the routes
pageroutes := []hws.Route{ pageroutes := []hws.Route{
{
Path: "/permtest",
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: handlers.PermTester(s, conn),
},
{ {
Path: "/static/", Path: "/static/",
Method: hws.MethodGET, Method: hws.MethodGET,
@@ -63,8 +68,7 @@ func addRoutes(
{ {
Path: "/notification-tester", Path: "/notification-tester",
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: handlers.NotifyTester(s), Handler: perms.RequireAdmin(s)(handlers.NotifyTester(s)),
// TODO: add login protection
}, },
{ {
Path: "/seasons", Path: "/seasons",

View File

@@ -25,10 +25,8 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
// Setup the database connection // Setup the database connection
logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database") logger.Debug().Msg("Connecting to database")
bun, closedb, err := setupBun(ctx, cfg) bun, closedb := setupBun(cfg)
if err != nil { registerDBModels(bun)
return errors.Wrap(err, "setupDBConn")
}
// Setup embedded files // Setup embedded files
logger.Debug().Msg("Getting embedded files") logger.Debug().Msg("Getting embedded files")

View File

@@ -104,39 +104,37 @@ func (l *Logger) log(
} }
// GetRecentLogs retrieves recent audit logs with pagination // GetRecentLogs retrieves recent audit logs with pagination
// TODO: change this to user db.PageOpts func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
func (l *Logger) GetRecentLogs(ctx context.Context, limit, offset int) ([]*db.AuditLog, int, error) {
tx, err := l.conn.BeginTx(ctx, nil) tx, err := l.conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, 0, errors.Wrap(err, "conn.BeginTx") return nil, errors.Wrap(err, "conn.BeginTx")
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
logs, total, err := db.GetAuditLogs(ctx, tx, limit, offset, nil) logs, err := db.GetAuditLogs(ctx, tx, pageOpts, nil)
if err != nil { if err != nil {
return nil, 0, err return nil, err
} }
_ = tx.Commit() // read only transaction _ = tx.Commit() // read only transaction
return logs, total, nil return logs, nil
} }
// GetLogsByUser retrieves audit logs for a specific user // GetLogsByUser retrieves audit logs for a specific user
// TODO: change this to user db.PageOpts func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, limit, offset int) ([]*db.AuditLog, int, error) {
tx, err := l.conn.BeginTx(ctx, nil) tx, err := l.conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, 0, errors.Wrap(err, "conn.BeginTx") return nil, errors.Wrap(err, "conn.BeginTx")
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
logs, total, err := db.GetAuditLogsByUser(ctx, tx, userID, limit, offset) logs, err := db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
if err != nil { if err != nil {
return nil, 0, err return nil, err
} }
_ = tx.Commit() // read only transaction _ = tx.Commit() // read only transaction
return logs, total, nil return logs, nil
} }
// CleanupOldLogs deletes audit logs older than the specified number of days // CleanupOldLogs deletes audit logs older than the specified number of days

View File

@@ -28,7 +28,11 @@ type AuditLog struct {
User *User `bun:"rel:belongs-to,join:user_id=id"` User *User `bun:"rel:belongs-to,join:user_id=id"`
} }
// TODO: add AuditLogs to match list style with PageOpts type AuditLogs struct {
AuditLogs []*AuditLog
Total int
PageOpts PageOpts
}
// CreateAuditLog creates a new audit log entry // CreateAuditLog creates a new audit log entry
func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error { func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
@@ -54,12 +58,12 @@ type AuditLogFilters struct {
} }
// GetAuditLogs retrieves audit logs with optional filters and pagination // GetAuditLogs retrieves audit logs with optional filters and pagination
// TODO: change this to use db.PageOpts func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) {
func GetAuditLogs(ctx context.Context, tx bun.Tx, limit, offset int, filters *AuditLogFilters) ([]*AuditLog, int, error) { pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderDesc, "created_at")
query := tx.NewSelect(). query := tx.NewSelect().
Model((*AuditLog)(nil)). Model((*AuditLog)(nil)).
Relation("User"). Relation("User").
Order("created_at DESC") OrderBy(pageOpts.OrderBy, pageOpts.Order)
// Apply filters if provided // Apply filters if provided
if filters != nil { if filters != nil {
@@ -80,48 +84,52 @@ func GetAuditLogs(ctx context.Context, tx bun.Tx, limit, offset int, filters *Au
// Get total count // Get total count
total, err := query.Count(ctx) total, err := query.Count(ctx)
if err != nil { if err != nil {
return nil, 0, errors.Wrap(err, "query.Count") return nil, errors.Wrap(err, "query.Count")
} }
// Get paginated results // Get paginated results
var logs []*AuditLog logs := new([]*AuditLog)
err = query. err = query.
Limit(limit). Offset(pageOpts.PerPage*(pageOpts.Page-1)).
Offset(offset). Limit(pageOpts.PerPage).
Scan(ctx, &logs) Scan(ctx, &logs)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, 0, errors.Wrap(err, "query.Scan") return nil, errors.Wrap(err, "query.Scan")
} }
return logs, total, nil list := &AuditLogs{
AuditLogs: *logs,
Total: total,
PageOpts: *pageOpts,
}
return list, nil
} }
// GetAuditLogsByUser retrieves audit logs for a specific user // GetAuditLogsByUser retrieves audit logs for a specific user
// TODO: change this to use db.PageOpts func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*AuditLogs, error) {
func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, limit, offset int) ([]*AuditLog, int, error) {
if userID <= 0 { if userID <= 0 {
return nil, 0, errors.New("userID must be positive") return nil, errors.New("userID must be positive")
} }
filters := &AuditLogFilters{ filters := &AuditLogFilters{
UserID: &userID, UserID: &userID,
} }
return GetAuditLogs(ctx, tx, limit, offset, filters) return GetAuditLogs(ctx, tx, pageOpts, filters)
} }
// GetAuditLogsByAction retrieves audit logs for a specific action // GetAuditLogsByAction retrieves audit logs for a specific action
// TODO: change this to use db.PageOpts func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*AuditLogs, error) {
func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, limit, offset int) ([]*AuditLog, int, error) {
if action == "" { if action == "" {
return nil, 0, errors.New("action cannot be empty") return nil, errors.New("action cannot be empty")
} }
filters := &AuditLogFilters{ filters := &AuditLogFilters{
Action: &action, Action: &action,
} }
return GetAuditLogs(ctx, tx, limit, offset, filters) return GetAuditLogs(ctx, tx, pageOpts, filters)
} }
// CleanupOldAuditLogs deletes audit logs older than the specified timestamp // CleanupOldAuditLogs deletes audit logs older than the specified timestamp

View File

@@ -15,6 +15,25 @@ type OrderOpts struct {
Label string Label string
} }
func setDefaultPageOpts(p *PageOpts, page, perpage int, order bun.Order, orderby string) *PageOpts {
if p == nil {
p = new(PageOpts)
}
if p.Page == 0 {
p.Page = page
}
if p.PerPage == 0 {
p.PerPage = perpage
}
if p.Order == "" {
p.Order = order
}
if p.OrderBy == "" {
p.OrderBy = orderby
}
return p
}
// TotalPages calculates the total number of pages // TotalPages calculates the total number of pages
func (p *PageOpts) TotalPages(total int) int { func (p *PageOpts) TotalPages(total int) int {
if p.PerPage == 0 { if p.PerPage == 0 {

View File

@@ -20,6 +20,8 @@ type Permission struct {
Action string `bun:"action,notnull"` Action string `bun:"action,notnull"`
IsSystem bool `bun:"is_system,default:false"` IsSystem bool `bun:"is_system,default:false"`
CreatedAt int64 `bun:"created_at,notnull"` CreatedAt int64 `bun:"created_at,notnull"`
Roles []Role `bun:"m2m:role_permissions,join:Permission=Role"`
} }
// GetPermissionByName queries the database for a permission matching the given name // GetPermissionByName queries the database for a permission matching the given name

View File

@@ -3,6 +3,7 @@ package db
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"git.haelnorr.com/h/oslstats/internal/roles" "git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -18,10 +19,18 @@ type Role struct {
Description string `bun:"description"` Description string `bun:"description"`
IsSystem bool `bun:"is_system,default:false"` IsSystem bool `bun:"is_system,default:false"`
CreatedAt int64 `bun:"created_at,notnull"` CreatedAt int64 `bun:"created_at,notnull"`
UpdatedAt int64 `bun:"updated_at,notnull"` UpdatedAt *int64 `bun:"updated_at"`
// Relations (loaded on demand) // Relations (loaded on demand)
Permissions []*Permission `bun:"m2m:role_permissions,join:Role=Permission"` Users []User `bun:"m2m:user_roles,join:Role=User"`
Permissions []Permission `bun:"m2m:role_permissions,join:Role=Permission"`
}
type RolePermission struct {
RoleID int `bun:",pk"`
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
PermissionID int `bun:",pk"`
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
} }
// GetRoleByName queries the database for a role matching the given name // GetRoleByName queries the database for a role matching the given name
@@ -99,6 +108,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
if role == nil { if role == nil {
return errors.New("role cannot be nil") return errors.New("role cannot be nil")
} }
role.CreatedAt = time.Now().Unix()
_, err := tx.NewInsert(). _, err := tx.NewInsert().
Model(role). Model(role).
@@ -160,23 +170,23 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int) error {
} }
// AddPermissionToRole grants a permission to a role // AddPermissionToRole grants a permission to a role
func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int, createdAt int64) error { func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error {
if roleID <= 0 { if roleID <= 0 {
return errors.New("roleID must be positive") return errors.New("roleID must be positive")
} }
if permissionID <= 0 { if permissionID <= 0 {
return errors.New("permissionID must be positive") return errors.New("permissionID must be positive")
} }
rolePerm := &RolePermission{
// TODO: use proper m2m table RoleID: roleID,
// also make createdAt automatic in table so not required as input here PermissionID: permissionID,
_, err := tx.ExecContext(ctx, ` }
INSERT INTO role_permissions (role_id, permission_id, created_at) _, err := tx.NewInsert().
VALUES ($1, $2, $3) Model(rolePerm).
ON CONFLICT (role_id, permission_id) DO NOTHING On("CONFLICT (role_id, permission_id) DO NOTHING").
`, roleID, permissionID, createdAt) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.ExecContext") return errors.Wrap(err, "tx.NewInsert")
} }
return nil return nil
@@ -191,13 +201,13 @@ func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permission
return errors.New("permissionID must be positive") return errors.New("permissionID must be positive")
} }
// TODO: use proper m2m table _, err := tx.NewDelete().
_, err := tx.ExecContext(ctx, ` Model((*RolePermission)(nil)).
DELETE FROM role_permissions Where("role_id = ?", roleID).
WHERE role_id = $1 AND permission_id = $2 Where("permission_id = ?", permissionID).
`, roleID, permissionID) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.ExecContext") return errors.Wrap(err, "tx.NewDelete")
} }
return nil return nil

View File

@@ -23,7 +23,7 @@ type Season struct {
} }
type SeasonList struct { type SeasonList struct {
Seasons []Season Seasons []*Season
Total int Total int
PageOpts PageOpts PageOpts PageOpts
} }
@@ -50,24 +50,10 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim
} }
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) { func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) {
if pageOpts == nil { pageOpts = setDefaultPageOpts(pageOpts, 1, 10, bun.OrderDesc, "start_date")
pageOpts = &PageOpts{} seasons := new([]*Season)
}
if pageOpts.Page == 0 {
pageOpts.Page = 1
}
if pageOpts.PerPage == 0 {
pageOpts.PerPage = 10
}
if pageOpts.Order == "" {
pageOpts.Order = bun.OrderDesc
}
if pageOpts.OrderBy == "" {
pageOpts.OrderBy = "start_date"
}
seasons := []Season{}
err := tx.NewSelect(). err := tx.NewSelect().
Model(&seasons). Model(seasons).
OrderBy(pageOpts.OrderBy, pageOpts.Order). OrderBy(pageOpts.OrderBy, pageOpts.Order).
Offset(pageOpts.PerPage * (pageOpts.Page - 1)). Offset(pageOpts.PerPage * (pageOpts.Page - 1)).
Limit(pageOpts.PerPage). Limit(pageOpts.PerPage).
@@ -76,13 +62,13 @@ func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonLis
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
total, err := tx.NewSelect(). total, err := tx.NewSelect().
Model(&seasons). Model(seasons).
Count(ctx) Count(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
sl := &SeasonList{ sl := &SeasonList{
Seasons: seasons, Seasons: *seasons,
Total: total, Total: total,
PageOpts: *pageOpts, PageOpts: *pageOpts,
} }

View File

@@ -23,6 +23,8 @@ type User struct {
Username string `bun:"username,unique"` // Username (unique) Username string `bun:"username,unique"` // Username (unique)
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
DiscordID string `bun:"discord_id,unique"` DiscordID string `bun:"discord_id,unique"`
Roles []*Role `bun:"m2m:user_roles,join:User=Role"`
} }
type Users struct { type Users struct {
@@ -124,28 +126,35 @@ func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, er
return count == 0, nil return count == 0, nil
} }
// GetRoles loads and returns all roles for this user // GetRoles loads all the roles for this user
func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) { func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
if u == nil { if u == nil {
return nil, errors.New("user cannot be nil") return nil, errors.New("user cannot be nil")
} }
return GetUserRoles(ctx, tx, u.ID)
err := tx.NewSelect().
Model(u).
Relation("Roles").
Where("id = ?", u.ID).
Scan(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return u.Roles, nil
} }
// GetPermissions loads and returns all permissions for this user (via roles) // GetPermissions loads and returns all permissions for this user
func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) { func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
if u == nil { if u == nil {
return nil, errors.New("user cannot be nil") return nil, errors.New("user cannot be nil")
} }
// TODO: use proper m2m tables and relations instead of join
var permissions []*Permission var permissions []*Permission
err := tx.NewSelect(). err := tx.NewSelect().
Model(&permissions). Model(&permissions).
Join("JOIN role_permissions AS rp ON rp.permission_id = p.id"). Join("JOIN role_permissions AS rp on rp.permission_id = p.id").
Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id"). Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id").
Where("ur.user_id = ?", u.ID). Where("ur.user_id = ?", u.ID).
Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()).
Scan(ctx) Scan(ctx)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
@@ -192,22 +201,8 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
} }
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) { func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) {
if pageOpts == nil { pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderAsc, "id")
pageOpts = &PageOpts{} users := new([]*User)
}
if pageOpts.Page == 0 {
pageOpts.Page = 1
}
if pageOpts.PerPage == 0 {
pageOpts.PerPage = 50
}
if pageOpts.Order == "" {
pageOpts.Order = bun.OrderAsc
}
if pageOpts.OrderBy == "" {
pageOpts.OrderBy = "id"
}
users := []*User{}
err := tx.NewSelect(). err := tx.NewSelect().
Model(users). Model(users).
OrderBy(pageOpts.OrderBy, pageOpts.Order). OrderBy(pageOpts.OrderBy, pageOpts.Order).
@@ -224,7 +219,7 @@ func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
list := &Users{ list := &Users{
Users: users, Users: *users,
Total: total, Total: total,
PageOpts: *pageOpts, PageOpts: *pageOpts,
} }

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"context" "context"
"time"
"git.haelnorr.com/h/oslstats/internal/roles" "git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -10,42 +9,14 @@ import (
) )
type UserRole struct { type UserRole struct {
bun.BaseModel `bun:"table:user_roles,alias:ur"` UserID int `bun:",pk"`
User *User `bun:"rel:belongs-to,join:user_id=id"`
ID int `bun:"id,pk,autoincrement"` RoleID int `bun:",pk"`
UserID int `bun:"user_id,notnull"` Role *Role `bun:"rel:belongs-to,join:role_id=id"`
RoleID int `bun:"role_id,notnull"`
GrantedBy *int `bun:"granted_by"`
GrantedAt int64 `bun:"granted_at,notnull"` // TODO: default now
ExpiresAt *int64 `bun:"expires_at"`
// Relations
User *User `bun:"rel:belongs-to,join:user_id=id"`
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
}
// GetUserRoles loads all roles for a given user
func GetUserRoles(ctx context.Context, tx bun.Tx, userID int) ([]*Role, error) {
if userID <= 0 {
return nil, errors.New("userID must be positive")
}
var roles []*Role
err := tx.NewSelect().
Model(&roles).
// TODO: why are we joining? can we do relation?
Join("JOIN user_roles AS ur ON ur.role_id = r.id").
Where("ur.user_id = ?", userID).
Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()).
Scan(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return roles, nil
} }
// AssignRole grants a role to a user // AssignRole grants a role to a user
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, grantedBy *int) error { func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
if userID <= 0 { if userID <= 0 {
return errors.New("userID must be positive") return errors.New("userID must be positive")
} }
@@ -53,16 +24,16 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, grantedBy *i
return errors.New("roleID must be positive") return errors.New("roleID must be positive")
} }
now := time.Now().Unix() userRole := &UserRole{
UserID: userID,
// TODO: use proper m2m table instead of raw SQL RoleID: roleID,
_, err := tx.ExecContext(ctx, ` }
INSERT INTO user_roles (user_id, role_id, granted_by, granted_at) _, err := tx.NewInsert().
VALUES ($1, $2, $3, $4) Model(userRole).
ON CONFLICT (user_id, role_id) DO NOTHING On("CONFLICT (user_id, role_id) DO NOTHING").
`, userID, roleID, grantedBy, now) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.ExecContext") return errors.Wrap(err, "tx.NewInsert")
} }
return nil return nil
@@ -77,13 +48,13 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
return errors.New("roleID must be positive") return errors.New("roleID must be positive")
} }
// TODO: use proper m2m table instead of raw sql _, err := tx.NewDelete().
_, err := tx.ExecContext(ctx, ` Model((*UserRole)(nil)).
DELETE FROM user_roles Where("user_id = ?", userID).
WHERE user_id = $1 AND role_id = $2 Where("role_id = ?", roleID).
`, userID, roleID) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.ExecContext") return errors.Wrap(err, "tx.NewDelete")
} }
return nil return nil
@@ -97,18 +68,19 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b
if roleName == "" { if roleName == "" {
return false, errors.New("roleName cannot be empty") return false, errors.New("roleName cannot be empty")
} }
user := new(User)
// TODO: use proper m2m table instead of TableExpr and Join? err := tx.NewSelect().
count, err := tx.NewSelect(). Model(user).
TableExpr("user_roles AS ur"). Relation("Roles").
Join("JOIN roles AS r ON r.id = ur.role_id"). Where("u.id = ? ", userID).
Where("ur.user_id = ?", userID). Scan(ctx)
Where("r.name = ?", roleName).
Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()).
Count(ctx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "tx.NewSelect") return false, errors.Wrap(err, "tx.NewSelect")
} }
for _, role := range user.Roles {
return count > 0, nil if role.Name == roleName {
return true, nil
}
}
return false, nil
} }

View File

@@ -46,8 +46,8 @@ func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error
return errors.New("admin role not found in database") return errors.New("admin role not found in database")
} }
// Grant admin role (nil grantedBy = system granted) // Grant admin role
err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil) err = db.AssignRole(ctx, tx, user.ID, adminRole.ID)
if err != nil { if err != nil {
return errors.Wrap(err, "db.AssignRole") return errors.Wrap(err, "db.AssignRole")
} }

View File

@@ -0,0 +1,24 @@
package handlers
import (
"net/http"
"strconv"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/uptrace/bun"
)
func PermTester(s *hws.Server, conn *bun.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := db.CurrentUser(r.Context())
tx, _ := conn.BeginTx(r.Context(), nil)
isAdmin, err := user.HasRole(r.Context(), tx, roles.Admin)
tx.Rollback()
if err != nil {
throwInternalServiceError(s, w, r, "Error", err)
}
_, _ = w.Write([]byte(strconv.FormatBool(isAdmin)))
})
}

View File

@@ -18,8 +18,11 @@ type Checker struct {
} }
func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) { func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) {
if conn == nil || s == nil { if conn == nil {
return nil, errors.New("arguments cannot be nil") return nil, errors.New("conn cannot be nil")
}
if s == nil {
return nil, errors.New("server cannot be nil")
} }
return &Checker{conn: conn, s: s}, nil return &Checker{conn: conn, s: s}, nil
} }

View File

@@ -22,7 +22,7 @@ templ Global(title string) {
<meta charset="UTF-8"/> <meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/> <meta name="viewport" content="width=device-width, initial-scale=1"/>
<title>{ title }</title> <title>{ title }</title>
<link rel="icon" type="image/x-icon" href="/static/favicon.ico"/> <link rel="icon" type="image/x-icon" href="/static/assets/favicon.ico"/>
<link href="/static/css/output.css" rel="stylesheet"/> <link href="/static/css/output.css" rel="stylesheet"/>
<script src="/static/vendored/htmx@2.0.8.min.js"></script> <script src="/static/vendored/htmx@2.0.8.min.js"></script>
<script src="/static/vendored/htmx-ext-ws.min.js"></script> <script src="/static/vendored/htmx-ext-ws.min.js"></script>

View File

@@ -84,7 +84,7 @@ templ SeasonsList(seasons *db.SeasonList) {
<!-- Header: Name + Status Badge --> <!-- Header: Name + Status Badge -->
<div class="flex justify-between items-start mb-3"> <div class="flex justify-between items-start mb-3">
<h3 class="text-xl font-bold text-text">{ s.Name }</h3> <h3 class="text-xl font-bold text-text">{ s.Name }</h3>
@season.StatusBadge(&s, true, true) @season.StatusBadge(s, true, true)
</div> </div>
<!-- Info Row: Short Name + Start Date --> <!-- Info Row: Short Name + Start Date -->
<div class="flex items-center gap-3 text-sm"> <div class="flex items-center gap-3 text-sm">
@@ -107,5 +107,5 @@ templ SeasonsList(seasons *db.SeasonList) {
} }
func formatDate(t time.Time) string { func formatDate(t time.Time) string {
return t.Format("02/01/2006") // DD/MM/YYYY return t.Format("02/01/2006") // DD/MM/YYYY
} }

View File

@@ -1,3 +1,4 @@
// Package oauth provides OAuth utilities for generating and checking secure state tokens
package oauth package oauth
import ( import (