fixed relationship issues
This commit is contained in:
2
Makefile
2
Makefile
@@ -82,7 +82,7 @@ migrate-create:
|
|||||||
|
|
||||||
# Reset database (DESTRUCTIVE - dev only!)
|
# Reset database (DESTRUCTIVE - dev only!)
|
||||||
reset-db:
|
reset-db:
|
||||||
@echo "⚠️ WARNING: This will DELETE ALL DATA!"
|
@echo "⚠️ WARNING - This will DELETE ALL DATA!"
|
||||||
make build
|
make build
|
||||||
./bin/${BINARY_NAME}${SUFFIX} --reset-db
|
./bin/${BINARY_NAME}${SUFFIX} --reset-db
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,18 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
"github.com/uptrace/bun/dialect/pgdialect"
|
"github.com/uptrace/bun/dialect/pgdialect"
|
||||||
"github.com/uptrace/bun/driver/pgdriver"
|
"github.com/uptrace/bun/driver/pgdriver"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func() error, err error) {
|
func setupBun(cfg *config.Config) (conn *bun.DB, close func() error) {
|
||||||
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
|
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
|
||||||
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL)
|
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL)
|
||||||
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
|
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
|
||||||
@@ -26,30 +24,19 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func
|
|||||||
|
|
||||||
conn = bun.NewDB(sqldb, pgdialect.New())
|
conn = bun.NewDB(sqldb, pgdialect.New())
|
||||||
close = sqldb.Close
|
close = sqldb.Close
|
||||||
|
return conn, close
|
||||||
err = loadModels(ctx, conn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrap(err, "loadModels")
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, close, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadModels(ctx context.Context, conn *bun.DB) error {
|
func registerDBModels(conn *bun.DB) {
|
||||||
models := []any{
|
models := []any{
|
||||||
|
(*db.RolePermission)(nil),
|
||||||
|
(*db.UserRole)(nil),
|
||||||
(*db.User)(nil),
|
(*db.User)(nil),
|
||||||
(*db.DiscordToken)(nil),
|
(*db.DiscordToken)(nil),
|
||||||
|
(*db.Season)(nil),
|
||||||
|
(*db.Role)(nil),
|
||||||
|
(*db.Permission)(nil),
|
||||||
|
(*db.AuditLog)(nil),
|
||||||
}
|
}
|
||||||
|
conn.RegisterModel(models...)
|
||||||
for _, model := range models {
|
|
||||||
_, err := conn.NewCreateTable().
|
|
||||||
Model(model).
|
|
||||||
IfNotExists().
|
|
||||||
Exec(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "db.NewCreateTable")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func setupHTTPServer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize permissions checker
|
// Initialize permissions checker
|
||||||
perms, err := rbac.NewChecker(bun, server)
|
perms, err := rbac.NewChecker(bun, httpServer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "rbac.NewChecker")
|
return nil, errors.Wrap(err, "rbac.NewChecker")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"text/tabwriter"
|
"text/tabwriter"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
stderrors "errors"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/cmd/oslstats/migrations"
|
"git.haelnorr.com/h/oslstats/cmd/oslstats/migrations"
|
||||||
"git.haelnorr.com/h/oslstats/internal/backup"
|
"git.haelnorr.com/h/oslstats/internal/backup"
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
@@ -21,11 +23,8 @@ import (
|
|||||||
|
|
||||||
// runMigrations executes database migrations
|
// runMigrations executes database migrations
|
||||||
func runMigrations(ctx context.Context, cfg *config.Config, command string) error {
|
func runMigrations(ctx context.Context, cfg *config.Config, command string) error {
|
||||||
conn, close, err := setupBun(ctx, cfg)
|
conn, close := setupBun(cfg)
|
||||||
if err != nil {
|
defer func() { _ = close() }()
|
||||||
return errors.Wrap(err, "setupBun")
|
|
||||||
}
|
|
||||||
defer close()
|
|
||||||
|
|
||||||
migrator := migrate.NewMigrator(conn, migrations.Migrations)
|
migrator := migrate.NewMigrator(conn, migrations.Migrations)
|
||||||
|
|
||||||
@@ -36,7 +35,14 @@ func runMigrations(ctx context.Context, cfg *config.Config, command string) erro
|
|||||||
|
|
||||||
switch command {
|
switch command {
|
||||||
case "up":
|
case "up":
|
||||||
return migrateUp(ctx, migrator, conn, cfg)
|
err := migrateUp(ctx, migrator, conn, cfg)
|
||||||
|
if err != nil {
|
||||||
|
err2 := migrateRollback(ctx, migrator, conn, cfg)
|
||||||
|
if err2 != nil {
|
||||||
|
return stderrors.Join(errors.Wrap(err2, "error while rolling back after migration error"), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
case "rollback":
|
case "rollback":
|
||||||
return migrateRollback(ctx, migrator, conn, cfg)
|
return migrateRollback(ctx, migrator, conn, cfg)
|
||||||
case "status":
|
case "status":
|
||||||
@@ -171,8 +177,8 @@ func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error {
|
|||||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||||
fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tMIGRATED AT")
|
_, _ = fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tMIGRATED AT")
|
||||||
fmt.Fprintln(w, "------\t---------\t-----\t-----------")
|
_, _ = fmt.Fprintln(w, "------\t---------\t-----\t-----------")
|
||||||
|
|
||||||
appliedCount := 0
|
appliedCount := 0
|
||||||
for _, m := range ms {
|
for _, m := range ms {
|
||||||
@@ -189,10 +195,10 @@ func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, migratedAt)
|
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, migratedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Flush()
|
_ = w.Flush()
|
||||||
|
|
||||||
fmt.Printf("\n📊 Summary: %d applied, %d pending\n\n",
|
fmt.Printf("\n📊 Summary: %d applied, %d pending\n\n",
|
||||||
appliedCount, len(ms)-appliedCount)
|
appliedCount, len(ms)-appliedCount)
|
||||||
@@ -299,12 +305,12 @@ func init() {
|
|||||||
Migrations.MustRegister(
|
Migrations.MustRegister(
|
||||||
// UP migration
|
// UP migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, dbConn *bun.DB) error {
|
||||||
// TODO: Add your migration code here
|
// Add your migration code here
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
// DOWN migration
|
// DOWN migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, dbConn *bun.DB) error {
|
||||||
// TODO: Add your rollback code here
|
// Add your rollback code here
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -326,7 +332,7 @@ func init() {
|
|||||||
|
|
||||||
// resetDatabase drops and recreates all tables (destructive)
|
// resetDatabase drops and recreates all tables (destructive)
|
||||||
func resetDatabase(ctx context.Context, cfg *config.Config) error {
|
func resetDatabase(ctx context.Context, cfg *config.Config) error {
|
||||||
fmt.Println("⚠️ WARNING: This will DELETE ALL DATA in the database!")
|
fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!")
|
||||||
fmt.Print("Type 'yes' to continue: ")
|
fmt.Print("Type 'yes' to continue: ")
|
||||||
|
|
||||||
reader := bufio.NewReader(os.Stdin)
|
reader := bufio.NewReader(os.Stdin)
|
||||||
@@ -340,11 +346,8 @@ func resetDatabase(ctx context.Context, cfg *config.Config) error {
|
|||||||
fmt.Println("❌ Reset cancelled")
|
fmt.Println("❌ Reset cancelled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
conn, close, err := setupBun(ctx, cfg)
|
conn, close := setupBun(cfg)
|
||||||
if err != nil {
|
defer func() { _ = close() }()
|
||||||
return errors.Wrap(err, "setupBun")
|
|
||||||
}
|
|
||||||
defer close()
|
|
||||||
|
|
||||||
models := []any{
|
models := []any{
|
||||||
(*db.User)(nil),
|
(*db.User)(nil),
|
||||||
|
|||||||
@@ -12,20 +12,11 @@ func init() {
|
|||||||
Migrations.MustRegister(
|
Migrations.MustRegister(
|
||||||
// UP migration
|
// UP migration
|
||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, dbConn *bun.DB) error {
|
||||||
// Create roles table using raw SQL to avoid m2m relationship issues
|
dbConn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil))
|
||||||
// Bun tries to resolve relationships when creating tables from models
|
// Create permissions table
|
||||||
// TODO: use proper m2m table instead of raw sql
|
_, err := dbConn.NewCreateTable().
|
||||||
_, err := dbConn.ExecContext(ctx, `
|
Model((*db.Role)(nil)).
|
||||||
CREATE TABLE roles (
|
Exec(ctx)
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
name VARCHAR(50) UNIQUE NOT NULL,
|
|
||||||
display_name VARCHAR(100) NOT NULL,
|
|
||||||
description TEXT,
|
|
||||||
is_system BOOLEAN DEFAULT FALSE,
|
|
||||||
created_at BIGINT NOT NULL,
|
|
||||||
updated_at BIGINT NOT NULL
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -39,7 +30,6 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create indexes for permissions
|
// Create indexes for permissions
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = dbConn.NewCreateIndex().
|
||||||
Model((*db.Permission)(nil)).
|
Model((*db.Permission)(nil)).
|
||||||
Index("idx_permissions_resource").
|
Index("idx_permissions_resource").
|
||||||
@@ -49,7 +39,6 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = dbConn.NewCreateIndex().
|
||||||
Model((*db.Permission)(nil)).
|
Model((*db.Permission)(nil)).
|
||||||
Index("idx_permissions_action").
|
Index("idx_permissions_action").
|
||||||
@@ -59,22 +48,13 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create role_permissions join table (Bun doesn't auto-create m2m tables)
|
_, err = dbConn.NewCreateTable().
|
||||||
// TODO: use proper m2m table instead of raw sql
|
Model((*db.RolePermission)(nil)).
|
||||||
_, err = dbConn.ExecContext(ctx, `
|
Exec(ctx)
|
||||||
CREATE TABLE role_permissions (
|
|
||||||
id SERIAL PRIMARY KEY,
|
|
||||||
role_id INTEGER NOT NULL REFERENCES roles(id) ON DELETE CASCADE,
|
|
||||||
permission_id INTEGER NOT NULL REFERENCES permissions(id) ON DELETE CASCADE,
|
|
||||||
created_at BIGINT NOT NULL,
|
|
||||||
UNIQUE(role_id, permission_id)
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.ExecContext(ctx, `
|
_, err = dbConn.ExecContext(ctx, `
|
||||||
CREATE INDEX idx_role_permissions_role ON role_permissions(role_id)
|
CREATE INDEX idx_role_permissions_role ON role_permissions(role_id)
|
||||||
`)
|
`)
|
||||||
@@ -82,7 +62,6 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.ExecContext(ctx, `
|
_, err = dbConn.ExecContext(ctx, `
|
||||||
CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id)
|
CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id)
|
||||||
`)
|
`)
|
||||||
@@ -99,7 +78,6 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create indexes for user_roles
|
// Create indexes for user_roles
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = dbConn.NewCreateIndex().
|
||||||
Model((*db.UserRole)(nil)).
|
Model((*db.UserRole)(nil)).
|
||||||
Index("idx_user_roles_user").
|
Index("idx_user_roles_user").
|
||||||
@@ -109,7 +87,6 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = dbConn.NewCreateIndex().
|
||||||
Model((*db.UserRole)(nil)).
|
Model((*db.UserRole)(nil)).
|
||||||
Index("idx_user_roles_role").
|
Index("idx_user_roles_role").
|
||||||
@@ -128,7 +105,6 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create indexes for audit_log
|
// Create indexes for audit_log
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = dbConn.NewCreateIndex().
|
||||||
Model((*db.AuditLog)(nil)).
|
Model((*db.AuditLog)(nil)).
|
||||||
Index("idx_audit_log_user").
|
Index("idx_audit_log_user").
|
||||||
@@ -138,7 +114,6 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = dbConn.NewCreateIndex().
|
||||||
Model((*db.AuditLog)(nil)).
|
Model((*db.AuditLog)(nil)).
|
||||||
Index("idx_audit_log_action").
|
Index("idx_audit_log_action").
|
||||||
@@ -148,7 +123,6 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = dbConn.NewCreateIndex().
|
||||||
Model((*db.AuditLog)(nil)).
|
Model((*db.AuditLog)(nil)).
|
||||||
Index("idx_audit_log_resource").
|
Index("idx_audit_log_resource").
|
||||||
@@ -158,7 +132,6 @@ func init() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: why do we need this?
|
|
||||||
_, err = dbConn.NewCreateIndex().
|
_, err = dbConn.NewCreateIndex().
|
||||||
Model((*db.AuditLog)(nil)).
|
Model((*db.AuditLog)(nil)).
|
||||||
Index("idx_audit_log_created").
|
Index("idx_audit_log_created").
|
||||||
@@ -176,12 +149,12 @@ func init() {
|
|||||||
DisplayName: "Administrator",
|
DisplayName: "Administrator",
|
||||||
Description: "Full system access with all permissions",
|
Description: "Full system access with all permissions",
|
||||||
IsSystem: true,
|
IsSystem: true,
|
||||||
CreatedAt: now, // TODO: this should be defaulted in table
|
CreatedAt: now,
|
||||||
UpdatedAt: now, // TODO: this should be defaulted in table
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewInsert().
|
_, err = dbConn.NewInsert().
|
||||||
Model(adminRole).
|
Model(adminRole).
|
||||||
|
Returning("id").
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -192,9 +165,7 @@ func init() {
|
|||||||
DisplayName: "User",
|
DisplayName: "User",
|
||||||
Description: "Standard user with basic permissions",
|
Description: "Standard user with basic permissions",
|
||||||
IsSystem: true,
|
IsSystem: true,
|
||||||
CreatedAt: now, // TODO: this should be defaulted in table
|
CreatedAt: now,
|
||||||
UpdatedAt: now, // TODO: this should be defaulted in table
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbConn.NewInsert().
|
_, err = dbConn.NewInsert().
|
||||||
@@ -205,7 +176,6 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Seed system permissions
|
// Seed system permissions
|
||||||
// TODO: timestamps for created should be defaulted in table
|
|
||||||
permissionsData := []*db.Permission{
|
permissionsData := []*db.Permission{
|
||||||
{Name: "*", DisplayName: "Wildcard (All Permissions)", Description: "Grants access to all permissions, past, present, and future", Resource: "*", Action: "*", IsSystem: true, CreatedAt: now},
|
{Name: "*", DisplayName: "Wildcard (All Permissions)", Description: "Grants access to all permissions, past, present, and future", Resource: "*", Action: "*", IsSystem: true, CreatedAt: now},
|
||||||
{Name: "seasons.create", DisplayName: "Create Seasons", Description: "Create new seasons", Resource: "seasons", Action: "create", IsSystem: true, CreatedAt: now},
|
{Name: "seasons.create", DisplayName: "Create Seasons", Description: "Create new seasons", Resource: "seasons", Action: "create", IsSystem: true, CreatedAt: now},
|
||||||
@@ -235,11 +205,14 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Insert role_permission mapping
|
// Insert role_permission mapping
|
||||||
// TODO: use proper m2m table, and default now in table settings
|
adminRolePerms := &db.RolePermission{
|
||||||
_, err = dbConn.ExecContext(ctx, `
|
RoleID: adminRole.ID,
|
||||||
INSERT INTO role_permissions (role_id, permission_id, created_at)
|
PermissionID: wildcardPerm.ID,
|
||||||
VALUES ($1, $2, $3)
|
}
|
||||||
`, adminRole.ID, wildcardPerm.ID, now)
|
_, err = dbConn.NewInsert().
|
||||||
|
Model(adminRolePerms).
|
||||||
|
On("CONFLICT (role_id, permission_id) DO NOTHING").
|
||||||
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -250,7 +223,6 @@ func init() {
|
|||||||
func(ctx context.Context, dbConn *bun.DB) error {
|
func(ctx context.Context, dbConn *bun.DB) error {
|
||||||
// Drop tables in reverse order
|
// Drop tables in reverse order
|
||||||
// Use raw SQL to avoid relationship resolution issues
|
// Use raw SQL to avoid relationship resolution issues
|
||||||
// TODO: surely we can use proper bun methods?
|
|
||||||
tables := []string{
|
tables := []string{
|
||||||
"audit_log",
|
"audit_log",
|
||||||
"user_roles",
|
"user_roles",
|
||||||
|
|||||||
@@ -30,6 +30,11 @@ func addRoutes(
|
|||||||
) error {
|
) error {
|
||||||
// Create the routes
|
// Create the routes
|
||||||
pageroutes := []hws.Route{
|
pageroutes := []hws.Route{
|
||||||
|
{
|
||||||
|
Path: "/permtest",
|
||||||
|
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
|
||||||
|
Handler: handlers.PermTester(s, conn),
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Path: "/static/",
|
Path: "/static/",
|
||||||
Method: hws.MethodGET,
|
Method: hws.MethodGET,
|
||||||
@@ -63,8 +68,7 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/notification-tester",
|
Path: "/notification-tester",
|
||||||
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
|
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
|
||||||
Handler: handlers.NotifyTester(s),
|
Handler: perms.RequireAdmin(s)(handlers.NotifyTester(s)),
|
||||||
// TODO: add login protection
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/seasons",
|
Path: "/seasons",
|
||||||
|
|||||||
@@ -25,10 +25,8 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
|
|||||||
// Setup the database connection
|
// Setup the database connection
|
||||||
logger.Debug().Msg("Config loaded and logger started")
|
logger.Debug().Msg("Config loaded and logger started")
|
||||||
logger.Debug().Msg("Connecting to database")
|
logger.Debug().Msg("Connecting to database")
|
||||||
bun, closedb, err := setupBun(ctx, cfg)
|
bun, closedb := setupBun(cfg)
|
||||||
if err != nil {
|
registerDBModels(bun)
|
||||||
return errors.Wrap(err, "setupDBConn")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup embedded files
|
// Setup embedded files
|
||||||
logger.Debug().Msg("Getting embedded files")
|
logger.Debug().Msg("Getting embedded files")
|
||||||
|
|||||||
@@ -104,39 +104,37 @@ func (l *Logger) log(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRecentLogs retrieves recent audit logs with pagination
|
// GetRecentLogs retrieves recent audit logs with pagination
|
||||||
// TODO: change this to user db.PageOpts
|
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
||||||
func (l *Logger) GetRecentLogs(ctx context.Context, limit, offset int) ([]*db.AuditLog, int, error) {
|
|
||||||
tx, err := l.conn.BeginTx(ctx, nil)
|
tx, err := l.conn.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, errors.Wrap(err, "conn.BeginTx")
|
return nil, errors.Wrap(err, "conn.BeginTx")
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
logs, total, err := db.GetAuditLogs(ctx, tx, limit, offset, nil)
|
logs, err := db.GetAuditLogs(ctx, tx, pageOpts, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = tx.Commit() // read only transaction
|
_ = tx.Commit() // read only transaction
|
||||||
return logs, total, nil
|
return logs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLogsByUser retrieves audit logs for a specific user
|
// GetLogsByUser retrieves audit logs for a specific user
|
||||||
// TODO: change this to user db.PageOpts
|
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
||||||
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, limit, offset int) ([]*db.AuditLog, int, error) {
|
|
||||||
tx, err := l.conn.BeginTx(ctx, nil)
|
tx, err := l.conn.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, errors.Wrap(err, "conn.BeginTx")
|
return nil, errors.Wrap(err, "conn.BeginTx")
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
logs, total, err := db.GetAuditLogsByUser(ctx, tx, userID, limit, offset)
|
logs, err := db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = tx.Commit() // read only transaction
|
_ = tx.Commit() // read only transaction
|
||||||
return logs, total, nil
|
return logs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupOldLogs deletes audit logs older than the specified number of days
|
// CleanupOldLogs deletes audit logs older than the specified number of days
|
||||||
|
|||||||
@@ -28,7 +28,11 @@ type AuditLog struct {
|
|||||||
User *User `bun:"rel:belongs-to,join:user_id=id"`
|
User *User `bun:"rel:belongs-to,join:user_id=id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add AuditLogs to match list style with PageOpts
|
type AuditLogs struct {
|
||||||
|
AuditLogs []*AuditLog
|
||||||
|
Total int
|
||||||
|
PageOpts PageOpts
|
||||||
|
}
|
||||||
|
|
||||||
// CreateAuditLog creates a new audit log entry
|
// CreateAuditLog creates a new audit log entry
|
||||||
func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
|
func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
|
||||||
@@ -54,12 +58,12 @@ type AuditLogFilters struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAuditLogs retrieves audit logs with optional filters and pagination
|
// GetAuditLogs retrieves audit logs with optional filters and pagination
|
||||||
// TODO: change this to use db.PageOpts
|
func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) {
|
||||||
func GetAuditLogs(ctx context.Context, tx bun.Tx, limit, offset int, filters *AuditLogFilters) ([]*AuditLog, int, error) {
|
pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderDesc, "created_at")
|
||||||
query := tx.NewSelect().
|
query := tx.NewSelect().
|
||||||
Model((*AuditLog)(nil)).
|
Model((*AuditLog)(nil)).
|
||||||
Relation("User").
|
Relation("User").
|
||||||
Order("created_at DESC")
|
OrderBy(pageOpts.OrderBy, pageOpts.Order)
|
||||||
|
|
||||||
// Apply filters if provided
|
// Apply filters if provided
|
||||||
if filters != nil {
|
if filters != nil {
|
||||||
@@ -80,48 +84,52 @@ func GetAuditLogs(ctx context.Context, tx bun.Tx, limit, offset int, filters *Au
|
|||||||
// Get total count
|
// Get total count
|
||||||
total, err := query.Count(ctx)
|
total, err := query.Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, errors.Wrap(err, "query.Count")
|
return nil, errors.Wrap(err, "query.Count")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get paginated results
|
// Get paginated results
|
||||||
var logs []*AuditLog
|
logs := new([]*AuditLog)
|
||||||
err = query.
|
err = query.
|
||||||
Limit(limit).
|
Offset(pageOpts.PerPage*(pageOpts.Page-1)).
|
||||||
Offset(offset).
|
Limit(pageOpts.PerPage).
|
||||||
Scan(ctx, &logs)
|
Scan(ctx, &logs)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return nil, 0, errors.Wrap(err, "query.Scan")
|
return nil, errors.Wrap(err, "query.Scan")
|
||||||
}
|
}
|
||||||
|
|
||||||
return logs, total, nil
|
list := &AuditLogs{
|
||||||
|
AuditLogs: *logs,
|
||||||
|
Total: total,
|
||||||
|
PageOpts: *pageOpts,
|
||||||
|
}
|
||||||
|
|
||||||
|
return list, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuditLogsByUser retrieves audit logs for a specific user
|
// GetAuditLogsByUser retrieves audit logs for a specific user
|
||||||
// TODO: change this to use db.PageOpts
|
func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*AuditLogs, error) {
|
||||||
func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, limit, offset int) ([]*AuditLog, int, error) {
|
|
||||||
if userID <= 0 {
|
if userID <= 0 {
|
||||||
return nil, 0, errors.New("userID must be positive")
|
return nil, errors.New("userID must be positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
filters := &AuditLogFilters{
|
filters := &AuditLogFilters{
|
||||||
UserID: &userID,
|
UserID: &userID,
|
||||||
}
|
}
|
||||||
|
|
||||||
return GetAuditLogs(ctx, tx, limit, offset, filters)
|
return GetAuditLogs(ctx, tx, pageOpts, filters)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuditLogsByAction retrieves audit logs for a specific action
|
// GetAuditLogsByAction retrieves audit logs for a specific action
|
||||||
// TODO: change this to use db.PageOpts
|
func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*AuditLogs, error) {
|
||||||
func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, limit, offset int) ([]*AuditLog, int, error) {
|
|
||||||
if action == "" {
|
if action == "" {
|
||||||
return nil, 0, errors.New("action cannot be empty")
|
return nil, errors.New("action cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
filters := &AuditLogFilters{
|
filters := &AuditLogFilters{
|
||||||
Action: &action,
|
Action: &action,
|
||||||
}
|
}
|
||||||
|
|
||||||
return GetAuditLogs(ctx, tx, limit, offset, filters)
|
return GetAuditLogs(ctx, tx, pageOpts, filters)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupOldAuditLogs deletes audit logs older than the specified timestamp
|
// CleanupOldAuditLogs deletes audit logs older than the specified timestamp
|
||||||
|
|||||||
@@ -15,6 +15,25 @@ type OrderOpts struct {
|
|||||||
Label string
|
Label string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setDefaultPageOpts(p *PageOpts, page, perpage int, order bun.Order, orderby string) *PageOpts {
|
||||||
|
if p == nil {
|
||||||
|
p = new(PageOpts)
|
||||||
|
}
|
||||||
|
if p.Page == 0 {
|
||||||
|
p.Page = page
|
||||||
|
}
|
||||||
|
if p.PerPage == 0 {
|
||||||
|
p.PerPage = perpage
|
||||||
|
}
|
||||||
|
if p.Order == "" {
|
||||||
|
p.Order = order
|
||||||
|
}
|
||||||
|
if p.OrderBy == "" {
|
||||||
|
p.OrderBy = orderby
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
// TotalPages calculates the total number of pages
|
// TotalPages calculates the total number of pages
|
||||||
func (p *PageOpts) TotalPages(total int) int {
|
func (p *PageOpts) TotalPages(total int) int {
|
||||||
if p.PerPage == 0 {
|
if p.PerPage == 0 {
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ type Permission struct {
|
|||||||
Action string `bun:"action,notnull"`
|
Action string `bun:"action,notnull"`
|
||||||
IsSystem bool `bun:"is_system,default:false"`
|
IsSystem bool `bun:"is_system,default:false"`
|
||||||
CreatedAt int64 `bun:"created_at,notnull"`
|
CreatedAt int64 `bun:"created_at,notnull"`
|
||||||
|
|
||||||
|
Roles []Role `bun:"m2m:role_permissions,join:Permission=Role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPermissionByName queries the database for a permission matching the given name
|
// GetPermissionByName queries the database for a permission matching the given name
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package db
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -18,10 +19,18 @@ type Role struct {
|
|||||||
Description string `bun:"description"`
|
Description string `bun:"description"`
|
||||||
IsSystem bool `bun:"is_system,default:false"`
|
IsSystem bool `bun:"is_system,default:false"`
|
||||||
CreatedAt int64 `bun:"created_at,notnull"`
|
CreatedAt int64 `bun:"created_at,notnull"`
|
||||||
UpdatedAt int64 `bun:"updated_at,notnull"`
|
UpdatedAt *int64 `bun:"updated_at"`
|
||||||
|
|
||||||
// Relations (loaded on demand)
|
// Relations (loaded on demand)
|
||||||
Permissions []*Permission `bun:"m2m:role_permissions,join:Role=Permission"`
|
Users []User `bun:"m2m:user_roles,join:Role=User"`
|
||||||
|
Permissions []Permission `bun:"m2m:role_permissions,join:Role=Permission"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RolePermission struct {
|
||||||
|
RoleID int `bun:",pk"`
|
||||||
|
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
|
||||||
|
PermissionID int `bun:",pk"`
|
||||||
|
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoleByName queries the database for a role matching the given name
|
// GetRoleByName queries the database for a role matching the given name
|
||||||
@@ -99,6 +108,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
|
|||||||
if role == nil {
|
if role == nil {
|
||||||
return errors.New("role cannot be nil")
|
return errors.New("role cannot be nil")
|
||||||
}
|
}
|
||||||
|
role.CreatedAt = time.Now().Unix()
|
||||||
|
|
||||||
_, err := tx.NewInsert().
|
_, err := tx.NewInsert().
|
||||||
Model(role).
|
Model(role).
|
||||||
@@ -160,23 +170,23 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddPermissionToRole grants a permission to a role
|
// AddPermissionToRole grants a permission to a role
|
||||||
func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int, createdAt int64) error {
|
func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID int) error {
|
||||||
if roleID <= 0 {
|
if roleID <= 0 {
|
||||||
return errors.New("roleID must be positive")
|
return errors.New("roleID must be positive")
|
||||||
}
|
}
|
||||||
if permissionID <= 0 {
|
if permissionID <= 0 {
|
||||||
return errors.New("permissionID must be positive")
|
return errors.New("permissionID must be positive")
|
||||||
}
|
}
|
||||||
|
rolePerm := &RolePermission{
|
||||||
// TODO: use proper m2m table
|
RoleID: roleID,
|
||||||
// also make createdAt automatic in table so not required as input here
|
PermissionID: permissionID,
|
||||||
_, err := tx.ExecContext(ctx, `
|
}
|
||||||
INSERT INTO role_permissions (role_id, permission_id, created_at)
|
_, err := tx.NewInsert().
|
||||||
VALUES ($1, $2, $3)
|
Model(rolePerm).
|
||||||
ON CONFLICT (role_id, permission_id) DO NOTHING
|
On("CONFLICT (role_id, permission_id) DO NOTHING").
|
||||||
`, roleID, permissionID, createdAt)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "tx.ExecContext")
|
return errors.Wrap(err, "tx.NewInsert")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -191,13 +201,13 @@ func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permission
|
|||||||
return errors.New("permissionID must be positive")
|
return errors.New("permissionID must be positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: use proper m2m table
|
_, err := tx.NewDelete().
|
||||||
_, err := tx.ExecContext(ctx, `
|
Model((*RolePermission)(nil)).
|
||||||
DELETE FROM role_permissions
|
Where("role_id = ?", roleID).
|
||||||
WHERE role_id = $1 AND permission_id = $2
|
Where("permission_id = ?", permissionID).
|
||||||
`, roleID, permissionID)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "tx.ExecContext")
|
return errors.Wrap(err, "tx.NewDelete")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ type Season struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SeasonList struct {
|
type SeasonList struct {
|
||||||
Seasons []Season
|
Seasons []*Season
|
||||||
Total int
|
Total int
|
||||||
PageOpts PageOpts
|
PageOpts PageOpts
|
||||||
}
|
}
|
||||||
@@ -50,24 +50,10 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) {
|
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) {
|
||||||
if pageOpts == nil {
|
pageOpts = setDefaultPageOpts(pageOpts, 1, 10, bun.OrderDesc, "start_date")
|
||||||
pageOpts = &PageOpts{}
|
seasons := new([]*Season)
|
||||||
}
|
|
||||||
if pageOpts.Page == 0 {
|
|
||||||
pageOpts.Page = 1
|
|
||||||
}
|
|
||||||
if pageOpts.PerPage == 0 {
|
|
||||||
pageOpts.PerPage = 10
|
|
||||||
}
|
|
||||||
if pageOpts.Order == "" {
|
|
||||||
pageOpts.Order = bun.OrderDesc
|
|
||||||
}
|
|
||||||
if pageOpts.OrderBy == "" {
|
|
||||||
pageOpts.OrderBy = "start_date"
|
|
||||||
}
|
|
||||||
seasons := []Season{}
|
|
||||||
err := tx.NewSelect().
|
err := tx.NewSelect().
|
||||||
Model(&seasons).
|
Model(seasons).
|
||||||
OrderBy(pageOpts.OrderBy, pageOpts.Order).
|
OrderBy(pageOpts.OrderBy, pageOpts.Order).
|
||||||
Offset(pageOpts.PerPage * (pageOpts.Page - 1)).
|
Offset(pageOpts.PerPage * (pageOpts.Page - 1)).
|
||||||
Limit(pageOpts.PerPage).
|
Limit(pageOpts.PerPage).
|
||||||
@@ -76,13 +62,13 @@ func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonLis
|
|||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
total, err := tx.NewSelect().
|
total, err := tx.NewSelect().
|
||||||
Model(&seasons).
|
Model(seasons).
|
||||||
Count(ctx)
|
Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
sl := &SeasonList{
|
sl := &SeasonList{
|
||||||
Seasons: seasons,
|
Seasons: *seasons,
|
||||||
Total: total,
|
Total: total,
|
||||||
PageOpts: *pageOpts,
|
PageOpts: *pageOpts,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ type User struct {
|
|||||||
Username string `bun:"username,unique"` // Username (unique)
|
Username string `bun:"username,unique"` // Username (unique)
|
||||||
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
|
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
|
||||||
DiscordID string `bun:"discord_id,unique"`
|
DiscordID string `bun:"discord_id,unique"`
|
||||||
|
|
||||||
|
Roles []*Role `bun:"m2m:user_roles,join:User=Role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Users struct {
|
type Users struct {
|
||||||
@@ -124,28 +126,35 @@ func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, er
|
|||||||
return count == 0, nil
|
return count == 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoles loads and returns all roles for this user
|
// GetRoles loads all the roles for this user
|
||||||
func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
|
func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
|
||||||
if u == nil {
|
if u == nil {
|
||||||
return nil, errors.New("user cannot be nil")
|
return nil, errors.New("user cannot be nil")
|
||||||
}
|
}
|
||||||
return GetUserRoles(ctx, tx, u.ID)
|
|
||||||
|
err := tx.NewSelect().
|
||||||
|
Model(u).
|
||||||
|
Relation("Roles").
|
||||||
|
Where("id = ?", u.ID).
|
||||||
|
Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
|
}
|
||||||
|
return u.Roles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPermissions loads and returns all permissions for this user (via roles)
|
// GetPermissions loads and returns all permissions for this user
|
||||||
func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
|
func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
|
||||||
if u == nil {
|
if u == nil {
|
||||||
return nil, errors.New("user cannot be nil")
|
return nil, errors.New("user cannot be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: use proper m2m tables and relations instead of join
|
|
||||||
var permissions []*Permission
|
var permissions []*Permission
|
||||||
err := tx.NewSelect().
|
err := tx.NewSelect().
|
||||||
Model(&permissions).
|
Model(&permissions).
|
||||||
Join("JOIN role_permissions AS rp ON rp.permission_id = p.id").
|
Join("JOIN role_permissions AS rp on rp.permission_id = p.id").
|
||||||
Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id").
|
Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id").
|
||||||
Where("ur.user_id = ?", u.ID).
|
Where("ur.user_id = ?", u.ID).
|
||||||
Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()).
|
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
@@ -192,22 +201,8 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) {
|
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) {
|
||||||
if pageOpts == nil {
|
pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderAsc, "id")
|
||||||
pageOpts = &PageOpts{}
|
users := new([]*User)
|
||||||
}
|
|
||||||
if pageOpts.Page == 0 {
|
|
||||||
pageOpts.Page = 1
|
|
||||||
}
|
|
||||||
if pageOpts.PerPage == 0 {
|
|
||||||
pageOpts.PerPage = 50
|
|
||||||
}
|
|
||||||
if pageOpts.Order == "" {
|
|
||||||
pageOpts.Order = bun.OrderAsc
|
|
||||||
}
|
|
||||||
if pageOpts.OrderBy == "" {
|
|
||||||
pageOpts.OrderBy = "id"
|
|
||||||
}
|
|
||||||
users := []*User{}
|
|
||||||
err := tx.NewSelect().
|
err := tx.NewSelect().
|
||||||
Model(users).
|
Model(users).
|
||||||
OrderBy(pageOpts.OrderBy, pageOpts.Order).
|
OrderBy(pageOpts.OrderBy, pageOpts.Order).
|
||||||
@@ -224,7 +219,7 @@ func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error
|
|||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
list := &Users{
|
list := &Users{
|
||||||
Users: users,
|
Users: *users,
|
||||||
Total: total,
|
Total: total,
|
||||||
PageOpts: *pageOpts,
|
PageOpts: *pageOpts,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -10,42 +9,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type UserRole struct {
|
type UserRole struct {
|
||||||
bun.BaseModel `bun:"table:user_roles,alias:ur"`
|
UserID int `bun:",pk"`
|
||||||
|
User *User `bun:"rel:belongs-to,join:user_id=id"`
|
||||||
ID int `bun:"id,pk,autoincrement"`
|
RoleID int `bun:",pk"`
|
||||||
UserID int `bun:"user_id,notnull"`
|
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
|
||||||
RoleID int `bun:"role_id,notnull"`
|
|
||||||
GrantedBy *int `bun:"granted_by"`
|
|
||||||
GrantedAt int64 `bun:"granted_at,notnull"` // TODO: default now
|
|
||||||
ExpiresAt *int64 `bun:"expires_at"`
|
|
||||||
|
|
||||||
// Relations
|
|
||||||
User *User `bun:"rel:belongs-to,join:user_id=id"`
|
|
||||||
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserRoles loads all roles for a given user
|
|
||||||
func GetUserRoles(ctx context.Context, tx bun.Tx, userID int) ([]*Role, error) {
|
|
||||||
if userID <= 0 {
|
|
||||||
return nil, errors.New("userID must be positive")
|
|
||||||
}
|
|
||||||
|
|
||||||
var roles []*Role
|
|
||||||
err := tx.NewSelect().
|
|
||||||
Model(&roles).
|
|
||||||
// TODO: why are we joining? can we do relation?
|
|
||||||
Join("JOIN user_roles AS ur ON ur.role_id = r.id").
|
|
||||||
Where("ur.user_id = ?", userID).
|
|
||||||
Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()).
|
|
||||||
Scan(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
|
||||||
}
|
|
||||||
return roles, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AssignRole grants a role to a user
|
// AssignRole grants a role to a user
|
||||||
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, grantedBy *int) error {
|
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
|
||||||
if userID <= 0 {
|
if userID <= 0 {
|
||||||
return errors.New("userID must be positive")
|
return errors.New("userID must be positive")
|
||||||
}
|
}
|
||||||
@@ -53,16 +24,16 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, grantedBy *i
|
|||||||
return errors.New("roleID must be positive")
|
return errors.New("roleID must be positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().Unix()
|
userRole := &UserRole{
|
||||||
|
UserID: userID,
|
||||||
// TODO: use proper m2m table instead of raw SQL
|
RoleID: roleID,
|
||||||
_, err := tx.ExecContext(ctx, `
|
}
|
||||||
INSERT INTO user_roles (user_id, role_id, granted_by, granted_at)
|
_, err := tx.NewInsert().
|
||||||
VALUES ($1, $2, $3, $4)
|
Model(userRole).
|
||||||
ON CONFLICT (user_id, role_id) DO NOTHING
|
On("CONFLICT (user_id, role_id) DO NOTHING").
|
||||||
`, userID, roleID, grantedBy, now)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "tx.ExecContext")
|
return errors.Wrap(err, "tx.NewInsert")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -77,13 +48,13 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
|
|||||||
return errors.New("roleID must be positive")
|
return errors.New("roleID must be positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: use proper m2m table instead of raw sql
|
_, err := tx.NewDelete().
|
||||||
_, err := tx.ExecContext(ctx, `
|
Model((*UserRole)(nil)).
|
||||||
DELETE FROM user_roles
|
Where("user_id = ?", userID).
|
||||||
WHERE user_id = $1 AND role_id = $2
|
Where("role_id = ?", roleID).
|
||||||
`, userID, roleID)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "tx.ExecContext")
|
return errors.Wrap(err, "tx.NewDelete")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -97,18 +68,19 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b
|
|||||||
if roleName == "" {
|
if roleName == "" {
|
||||||
return false, errors.New("roleName cannot be empty")
|
return false, errors.New("roleName cannot be empty")
|
||||||
}
|
}
|
||||||
|
user := new(User)
|
||||||
// TODO: use proper m2m table instead of TableExpr and Join?
|
err := tx.NewSelect().
|
||||||
count, err := tx.NewSelect().
|
Model(user).
|
||||||
TableExpr("user_roles AS ur").
|
Relation("Roles").
|
||||||
Join("JOIN roles AS r ON r.id = ur.role_id").
|
Where("u.id = ? ", userID).
|
||||||
Where("ur.user_id = ?", userID).
|
Scan(ctx)
|
||||||
Where("r.name = ?", roleName).
|
|
||||||
Where("ur.expires_at IS NULL OR ur.expires_at > ?", time.Now().Unix()).
|
|
||||||
Count(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "tx.NewSelect")
|
return false, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
|
for _, role := range user.Roles {
|
||||||
return count > 0, nil
|
if role.Name == roleName {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,8 +46,8 @@ func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error
|
|||||||
return errors.New("admin role not found in database")
|
return errors.New("admin role not found in database")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grant admin role (nil grantedBy = system granted)
|
// Grant admin role
|
||||||
err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil)
|
err = db.AssignRole(ctx, tx, user.ID, adminRole.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "db.AssignRole")
|
return errors.Wrap(err, "db.AssignRole")
|
||||||
}
|
}
|
||||||
|
|||||||
24
internal/handlers/permtest.go
Normal file
24
internal/handlers/permtest.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func PermTester(s *hws.Server, conn *bun.DB) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user := db.CurrentUser(r.Context())
|
||||||
|
tx, _ := conn.BeginTx(r.Context(), nil)
|
||||||
|
isAdmin, err := user.HasRole(r.Context(), tx, roles.Admin)
|
||||||
|
tx.Rollback()
|
||||||
|
if err != nil {
|
||||||
|
throwInternalServiceError(s, w, r, "Error", err)
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(strconv.FormatBool(isAdmin)))
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -18,8 +18,11 @@ type Checker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) {
|
func NewChecker(conn *bun.DB, s *hws.Server) (*Checker, error) {
|
||||||
if conn == nil || s == nil {
|
if conn == nil {
|
||||||
return nil, errors.New("arguments cannot be nil")
|
return nil, errors.New("conn cannot be nil")
|
||||||
|
}
|
||||||
|
if s == nil {
|
||||||
|
return nil, errors.New("server cannot be nil")
|
||||||
}
|
}
|
||||||
return &Checker{conn: conn, s: s}, nil
|
return &Checker{conn: conn, s: s}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ templ Global(title string) {
|
|||||||
<meta charset="UTF-8"/>
|
<meta charset="UTF-8"/>
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||||
<title>{ title }</title>
|
<title>{ title }</title>
|
||||||
<link rel="icon" type="image/x-icon" href="/static/favicon.ico"/>
|
<link rel="icon" type="image/x-icon" href="/static/assets/favicon.ico"/>
|
||||||
<link href="/static/css/output.css" rel="stylesheet"/>
|
<link href="/static/css/output.css" rel="stylesheet"/>
|
||||||
<script src="/static/vendored/htmx@2.0.8.min.js"></script>
|
<script src="/static/vendored/htmx@2.0.8.min.js"></script>
|
||||||
<script src="/static/vendored/htmx-ext-ws.min.js"></script>
|
<script src="/static/vendored/htmx-ext-ws.min.js"></script>
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ templ SeasonsList(seasons *db.SeasonList) {
|
|||||||
<!-- Header: Name + Status Badge -->
|
<!-- Header: Name + Status Badge -->
|
||||||
<div class="flex justify-between items-start mb-3">
|
<div class="flex justify-between items-start mb-3">
|
||||||
<h3 class="text-xl font-bold text-text">{ s.Name }</h3>
|
<h3 class="text-xl font-bold text-text">{ s.Name }</h3>
|
||||||
@season.StatusBadge(&s, true, true)
|
@season.StatusBadge(s, true, true)
|
||||||
</div>
|
</div>
|
||||||
<!-- Info Row: Short Name + Start Date -->
|
<!-- Info Row: Short Name + Start Date -->
|
||||||
<div class="flex items-center gap-3 text-sm">
|
<div class="flex items-center gap-3 text-sm">
|
||||||
@@ -107,5 +107,5 @@ templ SeasonsList(seasons *db.SeasonList) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func formatDate(t time.Time) string {
|
func formatDate(t time.Time) string {
|
||||||
return t.Format("02/01/2006") // DD/MM/YYYY
|
return t.Format("02/01/2006") // DD/MM/YYYY
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
// Package oauth provides OAuth utilities for generating and checking secure state tokens
|
||||||
package oauth
|
package oauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
Reference in New Issue
Block a user