Merge pull request 'league' (#1) from league into master

Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
h
2026-02-15 19:59:31 +11:00
211 changed files with 17568 additions and 1578 deletions

View File

@@ -3,7 +3,7 @@ testdata_dir = "testdata"
tmp_dir = "tmp" tmp_dir = "tmp"
[build] [build]
args_bin = [] args_bin = ["--dev"]
bin = "./tmp/main" bin = "./tmp/main"
cmd = "go build -o ./tmp/main ./cmd/oslstats" cmd = "go build -o ./tmp/main ./cmd/oslstats"
delay = 1000 delay = 1000
@@ -14,7 +14,7 @@ tmp_dir = "tmp"
follow_symlink = false follow_symlink = false
full_bin = "" full_bin = ""
include_dir = [] include_dir = []
include_ext = ["go", "templ"] include_ext = ["go", "templ", "js"]
include_file = [] include_file = []
kill_delay = "0s" kill_delay = "0s"
log = "build-errors.log" log = "build-errors.log"

7
.gitignore vendored
View File

@@ -8,3 +8,10 @@ tmp/
static/css/output.css static/css/output.css
internal/view/**/*_templ.go internal/view/**/*_templ.go
internal/view/**/*_templ.txt internal/view/**/*_templ.txt
cmd/test/*
.opencode
# Database backups (compressed)
backups/*.sql.gz
backups/*.sql
!backups/.gitkeep

124
.test.env Normal file
View File

@@ -0,0 +1,124 @@
# Environment Configuration
# Generated by ezconf
#
# Variables marked as (required) must be set
# Variables with defaults can be left commented out to use the default value
# HLog Configuration
###################
# Log level for the logger - trace, debug, info, warn, error, fatal, panic (default: info)
LOG_LEVEL=trace
# Output destination for logs - console, file, or both (default: console)
# LOG_OUTPUT=console
# Directory path for log files (required)
LOG_DIR=
# Name of the log file (required)
LOG_FILE_NAME=
# Append to existing log file or overwrite (default: true)
# LOG_APPEND=true
# HWS Configuration
##################
# Host to listen on (default: 127.0.0.1)
# HWS_HOST=127.0.0.1
# Port to listen on (default: 3000)
HWS_PORT=3333
# Flag for GZIP compression on requests (default: false)
# HWS_GZIP=false
# Timeout for reading request headers in seconds (default: 2)
# HWS_READ_HEADER_TIMEOUT=2
# Timeout for writing requests in seconds (default: 10)
# HWS_WRITE_TIMEOUT=10
# Timeout for idle connections in seconds (default: 120)
# HWS_IDLE_TIMEOUT=120
# Delay in seconds before server shutsdown when Shutdown is called (default: 5)
# HWS_SHUTDOWN_DELAY=5
# HWSAuth Configuration
######################
# Enable SSL secure cookies (default: false)
# HWSAUTH_SSL=false
# Full server address for SSL (required)
HWSAUTH_TRUSTED_HOST=http://127.0.0.1:3000
# Secret key for signing JWT tokens (required)
HWSAUTH_SECRET_KEY=/2epovpAmHFwdmlCxHRnihT50ZQtrGF/wK7+wiJdFLI=
# Access token expiry in minutes (default: 5)
# HWSAUTH_ACCESS_TOKEN_EXPIRY=5
# Refresh token expiry in minutes (default: 1440)
# HWSAUTH_REFRESH_TOKEN_EXPIRY=1440
# Token fresh time in minutes (default: 5)
# HWSAUTH_TOKEN_FRESH_TIME=5
# Redirect destination for authenticated users (default: "/profile")
# HWSAUTH_LANDING_PAGE="/profile"
# Database type (postgres, mysql, sqlite, mariadb) (default: "postgres")
# HWSAUTH_DATABASE_TYPE="postgres"
# Database version string (default: "15")
HWSAUTH_DATABASE_VERSION=18
# Custom JWT blacklist table name (default: "jwtblacklist")
# HWSAUTH_JWT_TABLE_NAME="jwtblacklist"
# DB Configuration
#################
# Database user for authentication (required)
DB_USER=pgdev
# Database password for authentication (required)
DB_PASSWORD=pgdevuser
# Database host address (required)
DB_HOST=10.3.0.60
# Database port (default: 5432)
# DB_PORT=5432
# Database name to connect to (required)
DB_NAME=oslstats_test
# SSL mode for connection (default: disable)
# DB_SSL=disable
# Number of backups to keep (default: 10)
# DB_BACKUP_RETENTION=10
# Discord Configuration
######################
# Discord application client ID (required)
DISCORD_CLIENT_ID=1463459682235580499
# Discord application client secret (required)
DISCORD_CLIENT_SECRET=pinbGa9IkgYQfeBIfBuosor6ODK-JTON
# Path for the OAuth redirect handler (required)
DISCORD_REDIRECT_PATH=auth/callback
# Token for the discord bot (required)
DISCORD_BOT_TOKEN=MTQ2MzQ1OTY4MjIzNTU4MDQ5OQ.GK-9Q6.Z876_JG7oUIKFwKp5snxUjAzloxVjy7KP37TX4
# OAuth Configuration
####################
# Private key for signing OAuth state tokens (required)
OAUTH_PRIVATE_KEY=b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZWQyNTUxOQAAACDtDHHkeGp1POc0z6/vDj8SK48lVeuGswu/8UO4oBcYSAAAAJj7edqp+3naqQAAAAtzc2gtZWQyNTUxOQAAACDtDHHkeGp1POc0z6/vDj8SK48lVeuGswu/8UO4oBcYSAAAAEAuqALdQqnaDFb5PvuUN4ng1d191hsirOhnahsT0aJFV+0MceR4anU85zTPr+8OPxIrjyVV64azC7/xQ7igFxhIAAAAEWhhZWxub3JyQGZsYWdzaGlwAQIDBA==
# RBAC Configuration
###################
# Discord ID to grant admin role on first login (required)
ADMIN_DISCORD_ID=202990104170463241

342
AGENTS.md
View File

@@ -9,123 +9,231 @@ This document provides guidelines for AI coding agents and developers working on
**Architecture**: Web application with Discord OAuth, PostgreSQL database, templ templates **Architecture**: Web application with Discord OAuth, PostgreSQL database, templ templates
**Key Technologies**: Bun ORM, templ, TailwindCSS, custom golib libraries **Key Technologies**: Bun ORM, templ, TailwindCSS, custom golib libraries
## Build, Test, and Development Commands ## Build and Development Commands
### Building ### Building
NEVER BUILD MANUALLY
```bash ```bash
# Full production build (tailwind → templ → go generate → go build) # Full production build (tailwind → templ → go generate → go build)
make build just build
# Build and run # Build and run
make run just run
# Clean build artifacts
make clean
``` ```
### Development Mode ### Development Mode
```bash ```bash
# Watch mode with hot reload (templ, air, tailwindcss in parallel) # Watch mode with hot reload (templ, air, tailwindcss in parallel)
make dev just dev
# Development server runs on: # Development server runs on:
# - Proxy: http://localhost:3000 (use this) # - Proxy: http://localhost:3000 (use this)
# - App: http://localhost:3333 (internal) # - App: http://localhost:3333 (internal)
``` ```
### Testing ### Database Migrations
**oslstats uses Bun's migration framework for safe, incremental schema changes.**
#### Quick Reference
**New Migration System**: Migrations now accept a count parameter. Default is 1 migration at a time.
```bash ```bash
# Run all tests # Show migration status
go test ./... just migrate status
# Run tests for a specific package # Run 1 migration (default, with automatic backup)
go test ./pkg/oauth just migrate up 1
# OR just
just migrate up
# Run a single test function # Run 3 migrations
go test ./pkg/oauth -run TestGenerateState_Success just migrate up 3
# Run tests with verbose output # Run all pending migrations
go test -v ./pkg/oauth just migrate up all
# Run tests with coverage # Run with a specific environment file
go test -cover ./... just migrate up 3 .test.env
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out # Rollback works the same for all arguments
just migrate down 2 .test.env
# Create new migration
just migrate new add_email_to_users
# Dev: Reset database (DESTRUCTIVE - deletes all data)
just reset-db
``` ```
### Database #### Creating a New Migration
**Example: Adding an email field to users table**
1. **Generate migration file:**
```bash
just migrate new add_leagues_and_slap_version
```
Creates: `cmd/oslstats/migrations/20250124150030_add_leagues_and_slap_version.go`
2. **Edit the migration file:**
```go
package migrations
import (
"context"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP: Add email column
func(ctx context.Context, db *bun.DB) error {
_, err := dbConn.NewAddColumn().
Model((*db.Season)(nil)).
ColumnExpr("slap_version VARCHAR NOT NULL").
IfNotExists().
Exec(ctx)
if err != nil {
return err
}
// Create leagues table
_, err = dbConn.NewCreateTable().
Model((*db.League)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create season_leagues join table
_, err = dbConn.NewCreateTable().
Model((*db.SeasonLeague)(nil)).
Exec(ctx)
return err
},
// DOWN: Remove email column (for rollback)
func(ctx context.Context, db *bun.DB) error {
// Drop season_leagues join table first
_, err := dbConn.NewDropTable().
Model((*db.SeasonLeague)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
// Drop leagues table
_, err = dbConn.NewDropTable().
Model((*db.League)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
// Remove slap_version column from seasons table
_, err = dbConn.NewDropColumn().
Model((*db.Season)(nil)).
ColumnExpr("slap_version").
Exec(ctx)
return err
},
)
}
```
3. **Update the model** (`internal/db/user.go`):
```go
type Season struct {
bun.BaseModel `bun:"table:seasons,alias:s"`
ID int `bun:"id,pk,autoincrement"`
Name string `bun:"name,unique"`
SlapVersion string `bun:"slap_version"` // NEW FIELD
}
```
4. **Apply the migration:**
```bash
just migrate up 1
```
Output:
```
[INFO] Step 1/5: Validating migrations...
[INFO] Migration validation passed ✓
[INFO] Step 2/5: Checking for pending migrations...
[INFO] Running 1 migration(s):
📋 20250124150030_add_email_to_users
[INFO] Step 3/5: Creating backup...
[INFO] Backup created: backups/20250124_150145_pre_migration.sql.gz (2.3 MB)
[INFO] Step 4/5: Acquiring migration lock...
[INFO] Migration lock acquired
[INFO] Step 5/5: Applying migrations...
[INFO] Migrated to group 2
✅ 20250124150030_add_email_to_users
[INFO] Migration lock released
```
#### Environment Variables
```bash ```bash
# Run migrations # Backup directory (default: backups)
make migrate DB_BACKUP_DIR=backups
# OR
./bin/oslstats --migrate # Number of backups to keep (default: 10)
DB_BACKUP_RETENTION=10
``` ```
#### Troubleshooting
**"pg_dump not found"**
- Migrations will still run, but backups will be skipped
- Install PostgreSQL client tools for backups:
```bash
# Ubuntu/Debian
sudo apt-get install postgresql-client
# macOS
brew install postgresql
# Arch
sudo pacman -S postgresql-libs
```
**"migration already in progress"**
- Another instance is running migrations
- Wait for it to complete (max 5 minutes)
- If stuck, check for hung database connections
**"migration build failed"**
- Migration file has syntax errors
- Fix the errors and try again
- Use `go build ./cmd/oslstats/migrations` to debug
### Configuration Management ### Configuration Management
```bash ```bash
# Generate .env template file # Generate .env template file
make genenv just genenv
# OR with custom output: make genenv OUT=.env.example # OR with custom output: just genenv .env.example
# Show environment variable documentation # Show environment variable documentation
make envdoc just envdoc
# Show current environment values # Show current environment values
make showenv just showenv
``` ```
## Code Style Guidelines ## Code Style Guidelines
### Import Organization
Organize imports in **3 groups** separated by blank lines:
```go
import (
// 1. Standard library
"context"
"net/http"
"fmt"
// 2. External dependencies
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"github.com/uptrace/bun"
// 3. Internal packages
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/pkg/oauth"
)
```
### Naming Conventions
**Variables**:
- Local: `camelCase` (userAgentKey, httpServer, dbConn)
- Exported: `PascalCase` (Config, User, Token)
- Common abbreviations: `cfg`, `ctx`, `tx`, `db`, `err`, `w`, `r`
**Functions**:
- Exported: `PascalCase` (GetConfig, NewStore, GenerateState)
- Private: `camelCase` (throwError, shouldShowDetails, loadModels)
- HTTP handlers: Return `http.Handler`, use dependency injection pattern
- Database functions: Use `bun.Tx` as parameter for transactions
**Types**:
- Structs/Interfaces: `PascalCase` (Config, User, OAuthSession)
- Use `-er` suffix for interfaces (implied from usage)
**Files**:
- Prefer single word: `config.go`, `oauth.go`, `errors.go`
- Don't use snake_case except for tests: `state_test.go`
- Test files: `*_test.go` alongside source files
### Error Handling ### Error Handling
**Always wrap errors** with context using `github.com/pkg/errors`: **Always wrap errors** with context using `github.com/pkg/errors`:
```go ```go
if err != nil { if err != nil {
return errors.Wrap(err, "operation_name") return errors.Wrap(err, "package.FunctionName")
} }
``` ```
@@ -234,94 +342,14 @@ func ConfigFromEnv() (any, error) {
- Use inline comments for ENV var documentation in Config structs - Use inline comments for ENV var documentation in Config structs
- Explain security-critical code flows - Explain security-critical code flows
### Testing
**Test File Location**: Place `*_test.go` files alongside source files
**Test Naming**:
```go
func TestFunctionName_Scenario(t *testing.T)
func TestGenerateState_Success(t *testing.T)
func TestVerifyState_WrongUserAgentKey(t *testing.T)
```
**Test Structure**:
- Use subtests with `t.Run()` for related scenarios
- Use table-driven tests for multiple similar cases
- Create helper functions for common setup (e.g., `testConfig()`)
- Test happy paths, error cases, edge cases, and security properties
**Test Categories** (from pkg/oauth/state_test.go example):
1. Happy path tests
2. Error handling (nil params, empty fields, malformed input)
3. Security tests (MITM, CSRF, replay attacks, tampering)
4. Edge cases (concurrency, constant-time comparison)
5. Integration tests (round-trip verification)
### Security
**Critical Practices**:
- Use `crypto/subtle.ConstantTimeCompare` for cryptographic comparisons
- Implement CSRF protection via state tokens
- Store sensitive cookies as HttpOnly
- Use separate logging levels for security violations (WARN)
- Validate all inputs at function boundaries
- Use parameterized queries (Bun ORM handles this)
- Never commit secrets (.env, keys/ are gitignored)
## Project Structure
```
oslstats/
├── cmd/oslstats/ # Application entry point
│ ├── main.go # Entry point with flag parsing
│ ├── run.go # Server initialization & graceful shutdown
│ ├── httpserver.go # HTTP server setup
│ ├── routes.go # Route registration
│ ├── middleware.go # Middleware registration
│ ├── auth.go # Authentication setup
│ └── db.go # Database connection & migrations
├── internal/ # Private application code
│ ├── config/ # Configuration aggregation
│ ├── db/ # Database models & queries (Bun ORM)
│ ├── discord/ # Discord OAuth integration
│ ├── handlers/ # HTTP request handlers
│ ├── session/ # Session store (in-memory)
│ └── view/ # Templ templates
│ ├── component/ # Reusable UI components
│ ├── layout/ # Page layouts
│ └── page/ # Full pages
├── pkg/ # Reusable packages
│ ├── contexts/ # Context key definitions
│ ├── embedfs/ # Embedded static files
│ └── oauth/ # OAuth state management
├── bin/ # Compiled binaries (gitignored)
├── keys/ # Private keys (gitignored)
├── tmp/ # Air hot reload temp files (gitignored)
├── Makefile # Build automation
├── .air.toml # Hot reload configuration
└── go.mod # Go module definition
```
## Key Dependencies
- **git.haelnorr.com/h/golib/*** - Custom libraries (env, ezconf, hlog, hws, hwsauth, cookies, jwt)
- **github.com/a-h/templ** - Type-safe HTML templating
- **github.com/uptrace/bun** - PostgreSQL ORM
- **github.com/bwmarrin/discordgo** - Discord API client
- **github.com/pkg/errors** - Error wrapping (use this, not fmt.Errorf)
- **github.com/joho/godotenv** - .env file loading
## Notes for AI Agents ## Notes for AI Agents
1. **Never commit** .env files, keys/, or generated files (*_templ.go, output.css) 1. **Never commit** .env files, keys/, or generated files (*_templ.go, output.css)
2. **Database operations** should use `bun.Tx` for transaction safety 2. **Database operations** should use `bun.Tx` for transaction safety
3. **Templates** are written in templ, not Go html/template - run `templ generate` after changes 3. **Templates** are written in templ, not Go html/template - run `just templ` after changes
4. **Static files** are embedded via `//go:embed` - check pkg/embedfs/ 4. **Static files** are embedded via `//go:embed` - check internal/embedfs/
5. **Error messages** should be descriptive and use errors.Wrap for context 5. **Error messages** should be descriptive and use errors.Wrap for context
6. **Security is critical** - especially in OAuth flows (see pkg/oauth/state_test.go for examples) 6. **Security is critical** - especially in OAuth flows (see pkg/oauth/state_test.go for examples)
7. **Air proxy** runs on port 3000 during development; app runs on 3333 7. **Air proxy** runs on port 3000 during development; app runs on 3333
8. **Test coverage** is currently limited - prioritize testing security-critical code 8. **Configuration** uses ezconf pattern - see internal/*/ezconf.go files for examples
9. **Configuration** uses ezconf pattern - see internal/*/ezconf.go files for examples 9. When in plan mode, always use the interactive question tool if available
10. **Graceful shutdown** is implemented in cmd/oslstats/run.go - follow this pattern
11. When in plan mode, always use the interactive question tool if available

View File

@@ -1,39 +0,0 @@
# Makefile
.PHONY: build
BINARY_NAME=oslstats
build:
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css && \
go mod tidy && \
templ generate && \
go generate ./cmd/${BINARY_NAME} && \
go build -ldflags="-w -s" -o ./bin/${BINARY_NAME}${SUFFIX} ./cmd/${BINARY_NAME}
run:
make build
./bin/${BINARY_NAME}${SUFFIX}
dev:
templ generate --watch &\
air &\
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css --watch
clean:
go clean
genenv:
make build
./bin/${BINARY_NAME} --genenv ${OUT}
envdoc:
make build
./bin/${BINARY_NAME} --envdoc
showenv:
make build
./bin/${BINARY_NAME} --showenv
migrate:
make build
./bin/${BINARY_NAME}${SUFFIX} --migrate

View File

@@ -1,54 +0,0 @@
package main
import (
"context"
"database/sql"
"fmt"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)
func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func() error, err error) {
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL)
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
conn = bun.NewDB(sqldb, pgdialect.New())
close = sqldb.Close
err = loadModels(ctx, conn, cfg.Flags.MigrateDB)
if err != nil {
return nil, nil, errors.Wrap(err, "loadModels")
}
return conn, close, nil
}
func loadModels(ctx context.Context, conn *bun.DB, resetDB bool) error {
models := []any{
(*db.User)(nil),
(*db.DiscordToken)(nil),
}
for _, model := range models {
_, err := conn.NewCreateTable().
Model(model).
IfNotExists().
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.NewCreateTable")
}
if resetDB {
err = conn.ResetModel(ctx, model)
if err != nil {
return errors.Wrap(err, "db.ResetModel")
}
}
}
return nil
}

View File

@@ -5,12 +5,19 @@ import (
"fmt" "fmt"
"os" "os"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db/migrate"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func main() { func main() {
flags := config.SetupFlags() flags, err := config.SetupFlags()
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err)
os.Exit(1)
}
ctx := context.Background() ctx := context.Background()
cfg, loader, err := config.GetConfig(flags) cfg, loader, err := config.GetConfig(flags)
@@ -18,29 +25,66 @@ func main() {
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config")) fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config"))
os.Exit(1) os.Exit(1)
} }
// Handle utility flags
if flags.EnvDoc || flags.ShowEnv { if flags.EnvDoc || flags.ShowEnv {
loader.PrintEnvVarsStdout(flags.ShowEnv) if err = loader.PrintEnvVarsStdout(flags.ShowEnv); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to print env doc"))
}
return return
} }
if flags.GenEnv != "" { if flags.GenEnv != "" {
loader.GenerateEnvFile(flags.GenEnv, true) if err = loader.GenerateEnvFile(flags.GenEnv, true); err != nil {
return fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to generate env file"))
}
if flags.MigrateDB {
_, closedb, err := setupBun(ctx, cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
} }
closedb()
return return
} }
if err := run(ctx, os.Stdout, cfg); err != nil { // Setup the logger
fmt.Fprintf(os.Stderr, "%s\n", err) logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to init logger"))
os.Exit(1) os.Exit(1)
} }
// Handle migration file creation (doesn't need DB connection)
if flags.MigrateCreate != "" {
if err := migrate.CreateMigration(flags.MigrateCreate); err != nil {
logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "createMigration"))).Msg("Error creating migration")
}
return
}
// Handle commands that need database connection
if flags.MigrateUp != "" || flags.MigrateRollback != "" ||
flags.MigrateStatus || flags.MigrateDryRun ||
flags.ResetDB {
var command, countStr string
// Route to appropriate command
if flags.MigrateUp != "" {
command = "up"
countStr = flags.MigrateUp
} else if flags.MigrateRollback != "" {
command = "rollback"
countStr = flags.MigrateRollback
} else if flags.MigrateStatus {
command = "status"
}
if flags.ResetDB {
err = migrate.ResetDatabase(ctx, cfg)
} else {
err = migrate.RunMigrations(ctx, cfg, command, countStr)
}
if err != nil {
logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "dbFlags"))).Msg("Error migrating database")
}
return
}
// Normal server startup
if err := run(ctx, logger, cfg); err != nil {
logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "run"))).Msg("Error starting server")
}
} }

View File

@@ -1,24 +0,0 @@
package main
import (
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func addMiddleware(
server *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
) error {
err := server.AddMiddleware(
auth.Authenticate(),
)
if err != nil {
return errors.Wrap(err, "server.AddMiddleware")
}
return nil
}

View File

@@ -1,75 +0,0 @@
package main
import (
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/oslstats/internal/store"
)
func addRoutes(
server *hws.Server,
staticFS *http.FileSystem,
cfg *config.Config,
conn *bun.DB,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
store *store.Store,
discordAPI *discord.APIClient,
) error {
// Create the routes
routes := []hws.Route{
{
Path: "/static/",
Method: hws.MethodGET,
Handler: http.StripPrefix("/static/", handlers.StaticFS(staticFS, server)),
},
{
Path: "/",
Method: hws.MethodGET,
Handler: handlers.Index(server),
},
{
Path: "/login",
Method: hws.MethodGET,
Handler: auth.LogoutReq(handlers.Login(server, cfg, store, discordAPI)),
},
{
Path: "/auth/callback",
Method: hws.MethodGET,
Handler: auth.LogoutReq(handlers.Callback(server, auth, conn, cfg, store, discordAPI)),
},
{
Path: "/register",
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: auth.LogoutReq(handlers.Register(server, auth, conn, cfg, store)),
},
{
Path: "/logout",
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: auth.LoginReq(handlers.Logout(server, auth, conn, discordAPI)),
},
}
htmxRoutes := []hws.Route{
{
Path: "/htmx/isusernameunique",
Method: hws.MethodPOST,
Handler: handlers.IsUsernameUnique(server, conn, cfg, store),
},
}
// Register the routes with the server
err := server.AddRoutes(append(routes, htmxRoutes...)...)
if err != nil {
return errors.Wrap(err, "server.AddRoutes")
}
return nil
}

View File

@@ -2,7 +2,7 @@ package main
import ( import (
"context" "context"
"io" "fmt"
"os" "os"
"os/signal" "os/signal"
"sync" "sync"
@@ -12,30 +12,22 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/embedfs"
"git.haelnorr.com/h/oslstats/internal/server"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/pkg/embedfs"
) )
// Initializes and runs the server // Initializes and runs the server
func run(ctx context.Context, w io.Writer, cfg *config.Config) error { func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel() defer cancel()
// Setup the logger
logger, err := hlog.NewLogger(cfg.HLOG, w)
if err != nil {
return errors.Wrap(err, "hlog.NewLogger")
}
// Setup the database connection // Setup the database connection
logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database") logger.Debug().Msg("Connecting to database")
bun, closedb, err := setupBun(ctx, cfg) conn := db.NewDB(cfg.DB)
if err != nil {
return errors.Wrap(err, "setupDBConn")
}
defer closedb()
// Setup embedded files // Setup embedded files
logger.Debug().Msg("Getting embedded files") logger.Debug().Msg("Getting embedded files")
@@ -56,7 +48,7 @@ func run(ctx context.Context, w io.Writer, cfg *config.Config) error {
} }
logger.Debug().Msg("Setting up HTTP server") logger.Debug().Msg("Setting up HTTP server")
httpServer, err := setupHttpServer(&staticFS, cfg, logger, bun, store, discordAPI) httpServer, err := server.Setup(staticFS, cfg, logger, conn, store, discordAPI)
if err != nil { if err != nil {
return errors.Wrap(err, "setupHttpServer") return errors.Wrap(err, "setupHttpServer")
} }
@@ -73,11 +65,16 @@ func run(ctx context.Context, w io.Writer, cfg *config.Config) error {
wg.Go(func() { wg.Go(func() {
<-ctx.Done() <-ctx.Done()
shutdownCtx := context.Background() shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second) shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 60*time.Second)
defer cancel() defer cancel()
logger.Info().Msg("Shut down requested, waiting 60 seconds...")
err := httpServer.Shutdown(shutdownCtx) err := httpServer.Shutdown(shutdownCtx)
if err != nil { if err != nil {
logger.Error().Err(err).Msg("Graceful shutdown failed") logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "httpServer.Shutdown"))).Msg("Error during HTTP server shutdown")
}
err = conn.Close()
if err != nil {
logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "closedb"))).Msg("Error during database close")
} }
}) })
wg.Wait() wg.Wait()

18
go.mod
View File

@@ -1,14 +1,16 @@
module git.haelnorr.com/h/oslstats module git.haelnorr.com/h/oslstats
go 1.25.5 go 1.25.6
require ( require (
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/ezconf v0.1.1
git.haelnorr.com/h/golib/hlog v0.10.4 git.haelnorr.com/h/golib/hlog v0.10.4
git.haelnorr.com/h/golib/hws v0.3.1 git.haelnorr.com/h/golib/hws v0.5.0
git.haelnorr.com/h/golib/hwsauth v0.5.2 git.haelnorr.com/h/golib/hwsauth v0.6.1
git.haelnorr.com/h/golib/notify v0.1.0
github.com/a-h/templ v0.3.977 github.com/a-h/templ v0.3.977
github.com/coder/websocket v1.8.14
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/uptrace/bun v1.2.16 github.com/uptrace/bun v1.2.16
@@ -17,13 +19,15 @@ require (
) )
require ( require (
github.com/gobwas/glob v0.2.3 // indirect
github.com/gorilla/websocket v1.4.2 // indirect github.com/gorilla/websocket v1.4.2 // indirect
golang.org/x/crypto v0.45.0 // indirect golang.org/x/crypto v0.47.0 // indirect
) )
require ( require (
git.haelnorr.com/h/golib/cookies v0.9.0 git.haelnorr.com/h/golib/cookies v0.9.0
git.haelnorr.com/h/golib/jwt v0.10.1 // indirect git.haelnorr.com/h/golib/jwt v0.10.1 // indirect
git.haelnorr.com/h/timefmt v0.1.0
github.com/bwmarrin/discordgo v0.29.0 github.com/bwmarrin/discordgo v0.29.0
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
@@ -38,9 +42,9 @@ require (
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.opentelemetry.io/otel v1.38.0 // indirect go.opentelemetry.io/otel v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect
golang.org/x/sys v0.40.0 // indirect golang.org/x/sys v0.41.0 // indirect
k8s.io/apimachinery v0.35.0 // indirect k8s.io/apimachinery v0.35.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20260108192941-914a6e750570 // indirect k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
mellium.im/sasl v0.3.2 // indirect mellium.im/sasl v0.3.2 // indirect
) )

32
go.sum
View File

@@ -6,23 +6,31 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI
git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8= git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
git.haelnorr.com/h/golib/hws v0.3.1 h1:uFXAT8SuKs4VACBdrkmZ+dJjeBlSPgCKUPt8zGCcwrI= git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
git.haelnorr.com/h/golib/hws v0.3.1/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
git.haelnorr.com/h/golib/hwsauth v0.5.2 h1:K4McXMEHtI5o4fAL3AZrmaMkwORNqSTV3MM6BExNKag= git.haelnorr.com/h/golib/hwsauth v0.6.1 h1:3BiM6hwuYDjgfu02hshvUtr592DnWi9Epj//3N13ti0=
git.haelnorr.com/h/golib/hwsauth v0.5.2/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ= git.haelnorr.com/h/golib/hwsauth v0.6.1/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
git.haelnorr.com/h/timefmt v0.1.0 h1:ULDkWEtFIV+FkkoV0q9n62Spj+HDdtFL9QeAdGIEp+o=
git.haelnorr.com/h/timefmt v0.1.0/go.mod h1:12gXXYLP4w9Fa9ZkbZWdvKV6RyZEzwAm9mN+WB3oXpw=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg= github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=
github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo= github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
@@ -71,25 +79,25 @@ go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwEx
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8= k8s.io/apimachinery v0.35.1 h1:yxO6gV555P1YV0SANtnTjXYfiivaTPvCTKX6w6qdDsU=
k8s.io/apimachinery v0.35.0/go.mod h1:jQCgFZFR1F4Ik7hvr2g84RTJSZegBc8yHgFWKn//hns= k8s.io/apimachinery v0.35.1/go.mod h1:jQCgFZFR1F4Ik7hvr2g84RTJSZegBc8yHgFWKn//hns=
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
k8s.io/utils v0.0.0-20260108192941-914a6e750570 h1:JT4W8lsdrGENg9W+YwwdLJxklIuKWdRm+BC+xt33FOY= k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
k8s.io/utils v0.0.0-20260108192941-914a6e750570/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk= k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
mellium.im/sasl v0.3.2 h1:PT6Xp7ccn9XaXAnJ03FcEjmAn7kK1x7aoXV6F+Vmrl0= mellium.im/sasl v0.3.2 h1:PT6Xp7ccn9XaXAnJ03FcEjmAn7kK1x7aoXV6F+Vmrl0=
mellium.im/sasl v0.3.2/go.mod h1:NKXDi1zkr+BlMHLQjY3ofYuU4KSPFxknb8mfEu6SveY= mellium.im/sasl v0.3.2/go.mod h1:NKXDi1zkr+BlMHLQjY3ofYuU4KSPFxknb8mfEu6SveY=

View File

@@ -1,3 +1,4 @@
// Package config provides the environment based configuration for the program
package config package config
import ( import (
@@ -7,6 +8,7 @@ import (
"git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/pkg/oauth" "git.haelnorr.com/h/oslstats/pkg/oauth"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -19,10 +21,11 @@ type Config struct {
HLOG *hlog.Config HLOG *hlog.Config
Discord *discord.Config Discord *discord.Config
OAuth *oauth.Config OAuth *oauth.Config
RBAC *rbac.Config
Flags *Flags Flags *Flags
} }
// Load the application configuration and get a pointer to the Config object // GetConfig loads the application configuration and returns a pointer to the Config object
// If doconly is specified, only the loader will be returned // If doconly is specified, only the loader will be returned
func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
err := godotenv.Load(flags.EnvFile) err := godotenv.Load(flags.EnvFile)
@@ -31,14 +34,18 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
} }
loader := ezconf.New() loader := ezconf.New()
loader.RegisterIntegrations( err = loader.RegisterIntegrations(
hlog.NewEZConfIntegration(), hlog.NewEZConfIntegration(),
hws.NewEZConfIntegration(), hws.NewEZConfIntegration(),
hwsauth.NewEZConfIntegration(), hwsauth.NewEZConfIntegration(),
db.NewEZConfIntegration(), db.NewEZConfIntegration(),
discord.NewEZConfIntegration(), discord.NewEZConfIntegration(),
oauth.NewEZConfIntegration(), oauth.NewEZConfIntegration(),
rbac.NewEZConfIntegration(),
) )
if err != nil {
return nil, nil, errors.Wrap(err, "loader.RegisterIntegrations")
}
if err := loader.ParseEnvVars(); err != nil { if err := loader.ParseEnvVars(); err != nil {
return nil, nil, errors.Wrap(err, "loader.ParseEnvVars") return nil, nil, errors.Wrap(err, "loader.ParseEnvVars")
} }
@@ -81,6 +88,11 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
return nil, nil, errors.New("OAuth Config not loaded") return nil, nil, errors.New("OAuth Config not loaded")
} }
rbaccfg, ok := loader.GetConfig("rbac")
if !ok {
return nil, nil, errors.New("RBAC Config not loaded")
}
config := &Config{ config := &Config{
DB: dbcfg.(*db.Config), DB: dbcfg.(*db.Config),
HWS: hwscfg.(*hws.Config), HWS: hwscfg.(*hws.Config),
@@ -88,6 +100,7 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
HLOG: hlogcfg.(*hlog.Config), HLOG: hlogcfg.(*hlog.Config),
Discord: discordcfg.(*discord.Config), Discord: discordcfg.(*discord.Config),
OAuth: oauthcfg.(*oauth.Config), OAuth: oauthcfg.(*oauth.Config),
RBAC: rbaccfg.(*rbac.Config),
Flags: flags, Flags: flags,
} }

View File

@@ -2,31 +2,125 @@ package config
import ( import (
"flag" "flag"
"strconv"
"github.com/pkg/errors"
) )
type Flags struct { type Flags struct {
MigrateDB bool // Utility flags
EnvDoc bool EnvDoc bool
ShowEnv bool ShowEnv bool
GenEnv string GenEnv string
EnvFile string EnvFile string
DevMode bool
// Database reset (destructive)
ResetDB bool
// Migration commands
MigrateUp string
MigrateRollback string
MigrateStatus bool
MigrateCreate string
MigrateDryRun bool
// Backup control
MigrateNoBackup bool
} }
func SetupFlags() *Flags { func SetupFlags() (*Flags, error) {
// Parse commandline args // Utility flags
migrateDB := flag.Bool("migrate", false, "Reset all the database tables with the updated models")
envDoc := flag.Bool("envdoc", false, "Print all environment variables and their documentation") envDoc := flag.Bool("envdoc", false, "Print all environment variables and their documentation")
showEnv := flag.Bool("showenv", false, "Print all environment variable values and their documentation") showEnv := flag.Bool("showenv", false, "Print all environment variable values and their documentation")
genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)") genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)")
envfile := flag.String("envfile", ".env", "Specify a .env file to use for the configuration") envfile := flag.String("envfile", ".env", "Specify a .env file to use for the configuration")
devMode := flag.Bool("dev", false, "Run the server in dev mode")
// Database reset (destructive)
resetDB := flag.Bool("reset-db", false, "⚠️ DESTRUCTIVE: Drop and recreate all tables (dev only)")
// Migration commands
migrateUp := flag.String("migrate-up", "", "Run pending database migrations (usage: --migrate-up [count|all], default: 1)")
migrateRollback := flag.String("migrate-rollback", "", "Rollback migrations (usage: --migrate-rollback [count|all], default: 1)")
migrateStatus := flag.Bool("migrate-status", false, "Show database migration status")
migrateCreate := flag.String("migrate-create", "", "Create a new migration file with the given name")
migrateDryRun := flag.Bool("migrate-dry-run", false, "Preview pending migrations without applying them")
// Backup control
migrateNoBackup := flag.Bool("no-backup", false, "Skip automatic backups (dev only - faster but less safe)")
flag.Parse() flag.Parse()
flags := &Flags{ // Validate: can't use multiple migration commands at once
MigrateDB: *migrateDB, commands := 0
EnvDoc: *envDoc, if *migrateUp != "" {
ShowEnv: *showEnv, commands++
GenEnv: *genEnv,
EnvFile: *envfile,
} }
return flags if *migrateRollback != "" {
commands++
}
if *migrateStatus {
commands++
}
if *migrateDryRun {
commands++
}
if *resetDB {
commands++
}
if commands > 1 {
return nil, errors.New("cannot use multiple migration commands simultaneously")
}
// Validate migration count values
if *migrateUp != "" {
if err := validateMigrationCount(*migrateUp); err != nil {
return nil, errors.Wrap(err, "invalid --migrate-up value")
}
}
if *migrateRollback != "" {
if err := validateMigrationCount(*migrateRollback); err != nil {
return nil, errors.Wrap(err, "invalid --migrate-rollback value")
}
}
flags := &Flags{
EnvDoc: *envDoc,
ShowEnv: *showEnv,
GenEnv: *genEnv,
EnvFile: *envfile,
DevMode: *devMode,
ResetDB: *resetDB,
MigrateUp: *migrateUp,
MigrateRollback: *migrateRollback,
MigrateStatus: *migrateStatus,
MigrateCreate: *migrateCreate,
MigrateDryRun: *migrateDryRun,
MigrateNoBackup: *migrateNoBackup,
}
return flags, nil
}
// validateMigrationCount validates a migration count value
// Valid values: "all" or a positive integer (1, 2, 3, ...)
func validateMigrationCount(value string) error {
if value == "" {
return nil
}
if value == "all" {
return nil
}
// Try parsing as integer
count, err := strconv.Atoi(value)
if err != nil {
return errors.New("must be a positive integer or 'all'")
}
if count < 1 {
return errors.New("must be a positive integer (1 or greater)")
}
return nil
} }

View File

@@ -0,0 +1,16 @@
package contexts
import "context"
func DevMode(ctx context.Context) DevInfo {
devmode, ok := ctx.Value(DevModeKey).(DevInfo)
if !ok {
return DevInfo{}
}
return devmode
}
type DevInfo struct {
WebsocketBase string
HTMXLog bool
}

14
internal/contexts/keys.go Normal file
View File

@@ -0,0 +1,14 @@
// Package contexts provides utilities for loading and extracting structs from contexts
package contexts
type Key string
func (c Key) String() string {
return "oslstats context key " + string(c)
}
var (
DevModeKey Key = Key("devmode")
PermissionCacheKey Key = Key("permissions")
PreviewRoleKey Key = Key("preview-role")
)

View File

@@ -0,0 +1,64 @@
package contexts
import (
"context"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
)
// Permissions retrieves the permission cache from context (type-safe)
func Permissions(ctx context.Context) *PermissionCache {
cache, ok := ctx.Value(PermissionCacheKey).(*PermissionCache)
if !ok {
return nil
}
return cache
}
type PermissionCache struct {
Permissions map[permissions.Permission]bool
Roles map[roles.Role]bool
HasWildcard bool
}
// HasPermission returns true if the cache contains the provided permission
func (p *PermissionCache) HasPermission(perm permissions.Permission) bool {
if p.HasWildcard {
return true
}
_, exists := p.Permissions[perm]
return exists
}
// HasAnyPermission returns true if the cache contains any of the provided permissions
func (p *PermissionCache) HasAnyPermission(perms []permissions.Permission) bool {
if p.HasWildcard {
return true
}
for _, perm := range perms {
_, exists := p.Permissions[perm]
if exists {
return true
}
}
return false
}
// HasAllPermissions returns true only if more than one permission is provided and the cache
// contains all the provided permissions
func (p *PermissionCache) HasAllPermissions(perms []permissions.Permission) bool {
if p.HasWildcard {
return true
}
if len(perms) == 0 {
return false
}
for _, perm := range perms {
_, exists := p.Permissions[perm]
if !exists {
return false
}
}
return true
}

View File

@@ -0,0 +1,25 @@
package contexts
import (
"context"
"git.haelnorr.com/h/oslstats/internal/db"
)
// WithPreviewRole adds a preview role to the context
func WithPreviewRole(ctx context.Context, role *db.Role) context.Context {
return context.WithValue(ctx, PreviewRoleKey, role)
}
// GetPreviewRole retrieves the preview role from the context, or nil if not present
func GetPreviewRole(ctx context.Context) *db.Role {
if role, ok := ctx.Value(PreviewRoleKey).(*db.Role); ok {
return role
}
return nil
}
// IsPreviewMode returns true if the user is currently in preview mode
func IsPreviewMode(ctx context.Context) bool {
return GetPreviewRole(ctx) != nil
}

148
internal/db/audit.go Normal file
View File

@@ -0,0 +1,148 @@
package db
import (
"net/http"
"reflect"
"strings"
)
type AuditMeta struct {
r *http.Request
u *User
}
func NewAudit(r *http.Request, u *User) *AuditMeta {
if u == nil {
u = CurrentUser(r.Context())
}
return &AuditMeta{r, u}
}
// AuditInfo contains metadata for audit logging
type AuditInfo struct {
Action string // e.g., "seasons.create", "users.update"
ResourceType string // e.g., "season", "user"
ResourceID any // Primary key value (int, string, etc.)
Details any // Changed fields or additional metadata
}
// extractTableName gets the bun table name from a model type using reflection
// Example: Season with `bun:"table:seasons,alias:s"` returns "seasons"
func extractTableName[T any]() string {
var model T
t := reflect.TypeOf(model)
// Handle pointer types
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
// Look for bun.BaseModel field with table tag
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Type.Name() == "BaseModel" {
bunTag := field.Tag.Get("bun")
if bunTag != "" {
// Parse tag: "table:seasons,alias:s" -> "seasons"
for part := range strings.SplitSeq(bunTag, ",") {
part, _ := strings.CutPrefix(part, "table:")
return part
}
}
}
}
// Fallback: use struct name in lowercase + "s"
return strings.ToLower(t.Name()) + "s"
}
// extractResourceType converts a table name to singular resource type
// Example: "seasons" -> "season", "users" -> "user"
func extractResourceType(tableName string) string {
// Simple singularization: remove trailing 's'
if strings.HasSuffix(tableName, "s") && len(tableName) > 1 {
return tableName[:len(tableName)-1]
}
return tableName
}
// buildAction creates a permission-style action string
// Example: ("season", "create") -> "seasons.create"
func buildAction(resourceType, operation string) string {
// Pluralize resource type (simple: add 's')
plural := resourceType
if !strings.HasSuffix(plural, "s") {
plural = plural + "s"
}
return plural + "." + operation
}
// extractPrimaryKey uses reflection to find and return the primary key value from a model
// Returns nil if no primary key is found
func extractPrimaryKey[T any](model *T) any {
if model == nil {
return nil
}
v := reflect.ValueOf(model)
if v.Kind() == reflect.Pointer {
v = v.Elem()
}
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
bunTag := field.Tag.Get("bun")
if bunTag != "" && strings.Contains(bunTag, "pk") {
// Found primary key field
fieldValue := v.Field(i)
if fieldValue.IsValid() && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
}
return nil
}
// extractChangedFields builds a map of field names to their new values
// Only includes fields specified in the columns list
func extractChangedFields[T any](model *T, columns []string) map[string]any {
if model == nil || len(columns) == 0 {
return nil
}
result := make(map[string]any)
v := reflect.ValueOf(model)
if v.Kind() == reflect.Pointer {
v = v.Elem()
}
t := v.Type()
// Build map of bun column names to field names
columnToField := make(map[string]int)
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
bunTag := field.Tag.Get("bun")
if bunTag != "" {
// Parse bun tag to get column name (first part before comma)
parts := strings.Split(bunTag, ",")
if len(parts) > 0 && parts[0] != "" {
columnToField[parts[0]] = i
}
}
}
// Extract values for requested columns
for _, col := range columns {
if fieldIdx, ok := columnToField[col]; ok {
fieldValue := v.Field(fieldIdx)
if fieldValue.IsValid() && fieldValue.CanInterface() {
result[col] = fieldValue.Interface()
}
}
}
return result
}

201
internal/db/auditlog.go Normal file
View File

@@ -0,0 +1,201 @@
package db
import (
"context"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type AuditLog struct {
bun.BaseModel `bun:"table:audit_log,alias:al"`
ID int `bun:"id,pk,autoincrement"`
UserID int `bun:"user_id,notnull"`
Action string `bun:"action,notnull"`
ResourceType string `bun:"resource_type,notnull"`
ResourceID *string `bun:"resource_id"`
Details json.RawMessage `bun:"details,type:jsonb"`
IPAddress string `bun:"ip_address"`
UserAgent string `bun:"user_agent"`
Result string `bun:"result,notnull"` // success, denied, error
ErrorMessage *string `bun:"error_message"`
CreatedAt int64 `bun:"created_at,notnull"`
// Relations
User *User `bun:"rel:belongs-to,join:user_id=id"`
}
// CreateAuditLog creates a new audit log entry
func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
if log == nil {
return errors.New("log cannot be nil")
}
err := Insert(tx, log).Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
return nil
}
type AuditLogFilter struct {
*ListFilter
}
func NewAuditLogFilter() *AuditLogFilter {
return &AuditLogFilter{
ListFilter: NewListFilter(),
}
}
func (a *AuditLogFilter) UserID(id int) *AuditLogFilter {
a.Equals("al.user_id", id)
return a
}
func (a *AuditLogFilter) Action(action string) *AuditLogFilter {
a.Equals("al.action", action)
return a
}
func (a *AuditLogFilter) ResourceType(resourceType string) *AuditLogFilter {
a.Equals("al.resource_type", resourceType)
return a
}
func (a *AuditLogFilter) Result(result string) *AuditLogFilter {
a.Equals("al.result", result)
return a
}
func (a *AuditLogFilter) UserIDs(ids []int) *AuditLogFilter {
if len(ids) > 0 {
a.In("al.user_id", ids)
}
return a
}
func (a *AuditLogFilter) Actions(actions []string) *AuditLogFilter {
fmt.Println(actions)
if len(actions) > 0 {
a.In("al.action", actions)
}
return a
}
func (a *AuditLogFilter) ResourceTypes(resourceTypes []string) *AuditLogFilter {
if len(resourceTypes) > 0 {
a.In("al.resource_type", resourceTypes)
}
return a
}
func (a *AuditLogFilter) Results(results []string) *AuditLogFilter {
if len(results) > 0 {
a.In("al.result", results)
}
return a
}
func (a *AuditLogFilter) DateRange(start, end int64) *AuditLogFilter {
if start > 0 {
a.GreaterEqualThan("al.created_at", start)
}
if end > 0 {
a.LessEqualThan("al.created_at", end)
}
return a
}
// GetAuditLogs retrieves audit logs with optional filters and pagination
func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilter) (*List[AuditLog], error) {
defaultPageOpts := &PageOpts{
Page: 1,
PerPage: 10,
Order: bun.OrderDesc,
OrderBy: "created_at",
}
return GetList[AuditLog](tx).
Relation("User").
Filter(filters.filters...).
GetPaged(ctx, pageOpts, defaultPageOpts)
}
// GetAuditLogsByUser retrieves audit logs for a specific user
func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*List[AuditLog], error) {
if userID <= 0 {
return nil, errors.New("userID must be positive")
}
filters := NewAuditLogFilter().UserID(userID)
return GetAuditLogs(ctx, tx, pageOpts, filters)
}
// GetAuditLogsByAction retrieves audit logs for a specific action
func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*List[AuditLog], error) {
if action == "" {
return nil, errors.New("action cannot be empty")
}
filters := NewAuditLogFilter().Action(action)
return GetAuditLogs(ctx, tx, pageOpts, filters)
}
// GetAuditLogByID retrieves a single audit log by ID
func GetAuditLogByID(ctx context.Context, tx bun.Tx, id int) (*AuditLog, error) {
if id <= 0 {
return nil, errors.New("id must be positive")
}
return GetByField[AuditLog](tx, "al.id", id).Relation("User").Get(ctx)
}
// GetUniqueActions retrieves a list of all unique actions in the audit log
func GetUniqueActions(ctx context.Context, tx bun.Tx) ([]string, error) {
var actions []string
err := tx.NewSelect().
Model((*AuditLog)(nil)).
Column("action").
Distinct().
Order("action ASC").
Scan(ctx, &actions)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return actions, nil
}
// GetUniqueResourceTypes retrieves a list of all unique resource types in the audit log
func GetUniqueResourceTypes(ctx context.Context, tx bun.Tx) ([]string, error) {
var resourceTypes []string
err := tx.NewSelect().
Model((*AuditLog)(nil)).
Column("resource_type").
Distinct().
Order("resource_type ASC").
Scan(ctx, &resourceTypes)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return resourceTypes, nil
}
// CleanupOldAuditLogs deletes audit logs older than the specified timestamp
func CleanupOldAuditLogs(ctx context.Context, tx bun.Tx, olderThan int64) (int, error) {
result, err := tx.NewDelete().
Model((*AuditLog)(nil)).
Where("created_at < ?", olderThan).
Exec(ctx)
if err != nil {
return 0, errors.Wrap(err, "tx.NewDelete")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, errors.Wrap(err, "result.RowsAffected")
}
return int(rowsAffected), nil
}

View File

@@ -0,0 +1,91 @@
package db
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// LogSuccess logs a successful permission-protected action
func LogSuccess(
ctx context.Context,
tx bun.Tx,
meta *AuditMeta,
info *AuditInfo,
) error {
return log(ctx, tx, meta, info, "success", nil)
}
// LogError logs a failed action due to an error
func LogError(
ctx context.Context,
tx bun.Tx,
meta *AuditMeta,
info *AuditInfo,
err error,
) error {
errMsg := err.Error()
return log(ctx, tx, meta, info, "error", &errMsg)
}
func log(
ctx context.Context,
tx bun.Tx,
meta *AuditMeta,
info *AuditInfo,
result string,
errorMessage *string,
) error {
if meta == nil {
return errors.New("audit meta cannot be nil for audit logging")
}
if info == nil {
return errors.New("audit info cannot be nil for audit logging")
}
if meta.u == nil {
return errors.New("user cannot be nil for audit logging")
}
if meta.r == nil {
return errors.New("request cannot be nil for audit logging")
}
// Convert resourceID to string
var resourceIDStr *string
if info.ResourceID != nil {
idStr := fmt.Sprintf("%v", info.ResourceID)
resourceIDStr = &idStr
}
// Marshal details to JSON
var detailsJSON json.RawMessage
if info.Details != nil {
jsonBytes, err := json.Marshal(info.Details)
if err != nil {
return errors.Wrap(err, "json.Marshal details")
}
detailsJSON = jsonBytes
}
// Extract IP and User-Agent from request
ipAddress := meta.r.RemoteAddr
userAgent := meta.r.UserAgent()
log := &AuditLog{
UserID: meta.u.ID,
Action: info.Action,
ResourceType: info.ResourceType,
ResourceID: resourceIDStr,
Details: detailsJSON,
IPAddress: ipAddress,
UserAgent: userAgent,
Result: result,
ErrorMessage: errorMessage,
CreatedAt: time.Now().Unix(),
}
return CreateAuditLog(ctx, tx, log)
}

132
internal/db/backup.go Normal file
View File

@@ -0,0 +1,132 @@
package db
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"sort"
"time"
"github.com/pkg/errors"
)
// CreateBackup creates a compressed PostgreSQL dump before migrations
// Returns backup filename and error
// If pg_dump is not available, returns nil error with warning
func CreateBackup(ctx context.Context, cfg *Config, operation string) (string, error) {
// Check if pg_dump is available
if _, err := exec.LookPath("pg_dump"); err != nil {
fmt.Println("[WARN] pg_dump not found - skipping backup")
fmt.Println("[WARN] Install PostgreSQL client tools for automatic backups:")
fmt.Println("[WARN] Ubuntu/Debian: sudo apt-get install postgresql-client")
fmt.Println("[WARN] macOS: brew install postgresql")
fmt.Println("[WARN] Arch: sudo pacman -S postgresql-libs")
return "", nil // Don't fail, just warn
}
// Ensure backup directory exists
if err := os.MkdirAll(cfg.BackupDir, 0o755); err != nil {
return "", errors.Wrap(err, "failed to create backup directory")
}
// Generate filename: YYYYMMDD_HHmmss_pre_{operation}.sql.gz
timestamp := time.Now().Format("20060102_150405")
filename := filepath.Join(cfg.BackupDir,
fmt.Sprintf("%s_pre_%s.sql.gz", timestamp, operation))
// Check if gzip is available
useGzip := true
if _, err := exec.LookPath("gzip"); err != nil {
fmt.Println("[WARN] gzip not found - using uncompressed backup")
useGzip = false
filename = filepath.Join(cfg.BackupDir,
fmt.Sprintf("%s_pre_%s.sql", timestamp, operation))
}
// Build pg_dump command
var cmd *exec.Cmd
if useGzip {
// Use shell to pipe pg_dump through gzip
pgDumpCmd := fmt.Sprintf(
"pg_dump -h %s -p %d -U %s -d %s --no-owner --no-acl --clean --if-exists | gzip > %s",
cfg.Host,
cfg.Port,
cfg.User,
cfg.DB,
filename,
)
cmd = exec.CommandContext(ctx, "sh", "-c", pgDumpCmd)
} else {
cmd = exec.CommandContext(ctx, "pg_dump",
"-h", cfg.Host,
"-p", fmt.Sprint(cfg.Port),
"-U", cfg.User,
"-d", cfg.DB,
"-f", filename,
"--no-owner",
"--no-acl",
"--clean",
"--if-exists",
)
}
// Set password via environment variable
cmd.Env = append(os.Environ(),
fmt.Sprintf("PGPASSWORD=%s", cfg.Password))
// Run backup
if err := cmd.Run(); err != nil {
return "", errors.Wrap(err, "pg_dump failed")
}
// Get file size for logging
info, err := os.Stat(filename)
if err != nil {
return filename, errors.Wrap(err, "stat backup file")
}
sizeMB := float64(info.Size()) / 1024 / 1024
fmt.Printf("[INFO] Backup created: %s (%.2f MB)\n", filename, sizeMB)
return filename, nil
}
// CleanOldBackups keeps only the N most recent backups
func CleanOldBackups(cfg *Config, keepCount int) error {
// Get all backup files (both .sql and .sql.gz)
sqlFiles, err := filepath.Glob(filepath.Join(cfg.BackupDir, "*.sql"))
if err != nil {
return errors.Wrap(err, "failed to list .sql backups")
}
gzFiles, err := filepath.Glob(filepath.Join(cfg.BackupDir, "*.sql.gz"))
if err != nil {
return errors.Wrap(err, "failed to list .sql.gz backups")
}
files := append(sqlFiles, gzFiles...)
if len(files) <= keepCount {
return nil // Nothing to clean
}
// Sort files by modification time (newest first)
sort.Slice(files, func(i, j int) bool {
iInfo, _ := os.Stat(files[i])
jInfo, _ := os.Stat(files[j])
return iInfo.ModTime().After(jInfo.ModTime())
})
// Delete old backups
for i := keepCount; i < len(files); i++ {
if err := os.Remove(files[i]); err != nil {
fmt.Printf("[WARN] Failed to remove old backup %s: %v\n", files[i], err)
} else {
fmt.Printf("[INFO] Removed old backup: %s\n", filepath.Base(files[i]))
}
}
return nil
}

View File

@@ -12,16 +12,22 @@ type Config struct {
Port uint16 // ENV DB_PORT: Database port (default: 5432) Port uint16 // ENV DB_PORT: Database port (default: 5432)
DB string // ENV DB_NAME: Database name to connect to (required) DB string // ENV DB_NAME: Database name to connect to (required)
SSL string // ENV DB_SSL: SSL mode for connection (default: disable) SSL string // ENV DB_SSL: SSL mode for connection (default: disable)
// Backup configuration
BackupDir string // ENV DB_BACKUP_DIR: Directory for database backups (default: backups)
BackupRetention int // ENV DB_BACKUP_RETENTION: Number of backups to keep (default: 10)
} }
func ConfigFromEnv() (any, error) { func ConfigFromEnv() (any, error) {
cfg := &Config{ cfg := &Config{
User: env.String("DB_USER", ""), User: env.String("DB_USER", ""),
Password: env.String("DB_PASSWORD", ""), Password: env.String("DB_PASSWORD", ""),
Host: env.String("DB_HOST", ""), Host: env.String("DB_HOST", ""),
Port: env.UInt16("DB_PORT", 5432), Port: env.UInt16("DB_PORT", 5432),
DB: env.String("DB_NAME", ""), DB: env.String("DB_NAME", ""),
SSL: env.String("DB_SSL", "disable"), SSL: env.String("DB_SSL", "disable"),
BackupDir: env.String("DB_BACKUP_DIR", "backups"),
BackupRetention: env.Int("DB_BACKUP_RETENTION", 10),
} }
// Validate SSL mode // Validate SSL mode
@@ -50,6 +56,9 @@ func ConfigFromEnv() (any, error) {
if cfg.DB == "" { if cfg.DB == "" {
return nil, errors.New("Envar not set: DB_NAME") return nil, errors.New("Envar not set: DB_NAME")
} }
if cfg.BackupRetention < 1 {
return nil, errors.New("DB_BACKUP_RETENTION must be at least 1")
}
return cfg, nil return cfg, nil
} }

102
internal/db/delete.go Normal file
View File

@@ -0,0 +1,102 @@
package db
import (
"context"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type deleter[T any] struct {
tx bun.Tx
q *bun.DeleteQuery
resourceID any // Store ID before deletion for audit
audit *AuditMeta
auditInfo *AuditInfo
}
type systemType interface {
isSystem() bool
}
func DeleteItem[T any](tx bun.Tx) *deleter[T] {
return &deleter[T]{
tx: tx,
q: tx.NewDelete().
Model((*T)(nil)),
}
}
func (d *deleter[T]) Where(query string, args ...any) *deleter[T] {
d.q = d.q.Where(query, args...)
// Try to capture resource ID from WHERE clause if it's a simple "id = ?" pattern
if query == "id = ?" && len(args) > 0 {
d.resourceID = args[0]
}
return d
}
// WithAudit enables audit logging for this delete operation
// If the provided *AuditInfo is nil, will use reflection to automatically work out the details
func (d *deleter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *deleter[T] {
d.audit = meta
d.auditInfo = info
return d
}
func (d *deleter[T]) Delete(ctx context.Context) error {
result, err := d.q.Exec(ctx)
if err != nil {
return errors.Wrap(err, "bun.DeleteQuery.Exec")
}
rows, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "result.RowsAffected")
}
if rows == 0 {
resource := extractResourceType(extractTableName[T]())
return BadRequestNotFound(resource, "id", d.resourceID)
}
// Handle audit logging if enabled
if d.audit != nil {
if d.auditInfo == nil {
tableName := extractTableName[T]()
resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "delete")
d.auditInfo = &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: d.resourceID,
Details: nil, // Delete doesn't need details
}
}
err = LogSuccess(ctx, d.tx, d.audit, d.auditInfo)
if err != nil {
return errors.Wrap(err, "LogSuccess")
}
}
return nil
}
func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] {
return DeleteItem[T](tx).Where("id = ?", id)
}
func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error {
deleter := DeleteByID[T](tx, id)
item, err := GetByID[T](tx, id).Get(ctx)
if err != nil {
return errors.Wrap(err, "GetByID")
}
if (*item).isSystem() {
return errors.New("record is system protected")
}
if audit != nil {
deleter = deleter.WithAudit(audit, nil)
}
return deleter.Delete(ctx)
}

View File

@@ -22,14 +22,14 @@ type DiscordToken struct {
// UpdateDiscordToken adds the provided discord token to the database. // UpdateDiscordToken adds the provided discord token to the database.
// If the user already has a token stored, it will replace that token instead. // If the user already has a token stored, it will replace that token instead.
func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error { func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error {
if token == nil { if token == nil {
return errors.New("token cannot be nil") return errors.New("token cannot be nil")
} }
expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix()
discordToken := &DiscordToken{ discordToken := &DiscordToken{
DiscordID: user.DiscordID, DiscordID: u.DiscordID,
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
@@ -37,30 +37,28 @@ func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *disc
TokenType: token.TokenType, TokenType: token.TokenType,
} }
_, err := tx.NewInsert(). err := Insert(tx, discordToken).
Model(discordToken). ConflictUpdate([]string{"discord_id"}, "access_token", "refresh_token", "expires_at").
On("CONFLICT (discord_id) DO UPDATE").
Set("access_token = EXCLUDED.access_token").
Set("refresh_token = EXCLUDED.refresh_token").
Set("expires_at = EXCLUDED.expires_at").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewInsert") return errors.Wrap(err, "db.Insert")
} }
return nil return nil
} }
// DeleteDiscordTokens deletes a users discord OAuth tokens from the database. // DeleteDiscordTokens deletes a users discord OAuth tokens from the database.
// It returns the DiscordToken so that it can be revoked via the discord API // It returns the DiscordToken so that it can be revoked via the discord API
func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token, err := user.GetDiscordToken(ctx, tx) token, err := u.GetDiscordToken(ctx, tx)
if err != nil { if err != nil {
if IsBadRequest(err) {
return nil, nil // Token doesn't exist - not an error
}
return nil, errors.Wrap(err, "user.GetDiscordToken") return nil, errors.Wrap(err, "user.GetDiscordToken")
} }
_, err = tx.NewDelete(). _, err = tx.NewDelete().
Model((*DiscordToken)(nil)). Model((*DiscordToken)(nil)).
Where("discord_id = ?", user.DiscordID). Where("discord_id = ?", u.DiscordID).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.NewDelete") return nil, errors.Wrap(err, "tx.NewDelete")
@@ -69,25 +67,18 @@ func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordT
} }
// GetDiscordToken retrieves the users discord token from the database // GetDiscordToken retrieves the users discord token from the database
func (user *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token := new(DiscordToken) return GetByField[DiscordToken](tx, "discord_id", u.DiscordID).Get(ctx)
err := tx.NewSelect().
Model(token).
Where("discord_id = ?", user.DiscordID).
Limit(1).
Scan(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return token, nil
} }
// Convert reverts the token back into a *discord.Token // Convert reverts the token back into a *discord.Token
func (t *DiscordToken) Convert() *discord.Token { func (t *DiscordToken) Convert() *discord.Token {
expiresIn := t.ExpiresAt - time.Now().Unix()
expiresIn = max(expiresIn, 0)
token := &discord.Token{ token := &discord.Token{
AccessToken: t.AccessToken, AccessToken: t.AccessToken,
RefreshToken: t.RefreshToken, RefreshToken: t.RefreshToken,
ExpiresIn: int(t.ExpiresAt - time.Now().Unix()), ExpiresIn: int(expiresIn),
Scope: t.Scope, Scope: t.Scope,
TokenType: t.TokenType, TokenType: t.TokenType,
} }

2
internal/db/doc.go Normal file
View File

@@ -0,0 +1,2 @@
// Package db is an internal package for all the database models and related methods
package db

31
internal/db/errors.go Normal file
View File

@@ -0,0 +1,31 @@
package db
import (
"fmt"
"strings"
)
func IsBadRequest(err error) bool {
return strings.Contains(err.Error(), "bad request:")
}
func BadRequest(err string) error {
return fmt.Errorf("bad request: %s", err)
}
func BadRequestNotFound(resource, field string, value any) error {
errStr := fmt.Sprintf("%s with %s=%v not found", resource, field, value)
return BadRequest(errStr)
}
func BadRequestNotAssociated(parent, child, parentField, childField string, parentID, childID any) error {
errStr := fmt.Sprintf("%s with %s=%v not associated to %s with %s=%v",
child, childField, childID, parent, parentField, parentID)
return BadRequest(errStr)
}
func BadRequestAssociated(parent, child, parentField, childField string, parentID, childID any) error {
errStr := fmt.Sprintf("%s with %s=%v already associated to %s with %s=%v",
child, childField, childID, parent, parentField, parentID)
return BadRequest(errStr)
}

282
internal/db/fixture.go Normal file
View File

@@ -0,0 +1,282 @@
package db
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Fixture struct {
bun.BaseModel `bun:"table:fixtures,alias:f"`
ID int `bun:"id,pk,autoincrement"`
SeasonID int `bun:",notnull,unique:round"`
LeagueID int `bun:",notnull,unique:round"`
HomeTeamID int `bun:",notnull,unique:round"`
AwayTeamID int `bun:",notnull,unique:round"`
Round int `bun:"round,unique:round"`
GameWeek *int `bun:"game_week"`
CreatedAt int64 `bun:"created_at,notnull"`
UpdatedAt *int64 `bun:"updated_at"`
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
League *League `bun:"rel:belongs-to,join:league_id=id"`
HomeTeam *Team `bun:"rel:belongs-to,join:home_team_id=id"`
AwayTeam *Team `bun:"rel:belongs-to,join:away_team_id=id"`
}
func NewFixture(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string,
homeTeamID, awayTeamID, round int, audit *AuditMeta,
) (*Fixture, error) {
season, league, teams, err := GetSeasonLeague(ctx, tx, seasonShortName, leagueShortName)
if err != nil {
return nil, errors.Wrap(err, "GetSeasonLeague")
}
homeTeam, err := GetTeam(ctx, tx, homeTeamID)
if err != nil {
return nil, errors.Wrap(err, "GetTeam")
}
awayTeam, err := GetTeam(ctx, tx, awayTeamID)
if err != nil {
return nil, errors.Wrap(err, "GetTeam")
}
if err = checkTeamsAssociated(season, league, teams, []*Team{homeTeam, awayTeam}); err != nil {
return nil, errors.Wrap(err, "checkTeamsAssociated")
}
fixture := newFixture(season, league, homeTeam, awayTeam, round, time.Now())
err = Insert(tx, fixture).WithAudit(audit, nil).Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "Insert")
}
return fixture, nil
}
func NewRound(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string,
round int, audit *AuditMeta,
) ([]*Fixture, error) {
season, league, teams, err := GetSeasonLeague(ctx, tx, seasonShortName, leagueShortName)
if err != nil {
return nil, errors.Wrap(err, "GetSeasonLeague")
}
fixtures := generateRound(season, league, round, teams)
err = InsertMultiple(tx, fixtures).WithAudit(audit, nil).Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "InsertMultiple")
}
return fixtures, nil
}
func GetFixtures(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string) (*Season, *League, []*Fixture, error) {
season, league, _, err := GetSeasonLeague(ctx, tx, seasonShortName, leagueShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetSeasonLeague")
}
fixtures, err := GetList[Fixture](tx).
Where("season_id = ?", season.ID).
Where("league_id = ?", league.ID).
Order("game_week ASC NULLS FIRST", "round ASC", "id ASC").
Relation("HomeTeam").
Relation("AwayTeam").
GetAll(ctx)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetList")
}
return season, league, fixtures, nil
}
func GetFixture(ctx context.Context, tx bun.Tx, id int) (*Fixture, error) {
return GetByID[Fixture](tx, id).
Relation("Season").
Relation("League").
Relation("HomeTeam").
Relation("AwayTeam").
Get(ctx)
}
func GetFixturesByGameWeek(ctx context.Context, tx bun.Tx, seasonID, leagueID, gameweek int) ([]*Fixture, error) {
fixtures, err := GetList[Fixture](tx).
Where("season_id = ?", seasonID).
Where("league_id = ?", leagueID).
Where("game_week = ?", gameweek).
Order("round ASC", "id ASC").
Relation("HomeTeam").
Relation("AwayTeam").
GetAll(ctx)
if err != nil {
return nil, errors.Wrap(err, "GetList")
}
return fixtures, nil
}
func GetUnallocatedFixtures(ctx context.Context, tx bun.Tx, seasonID, leagueID int) ([]*Fixture, error) {
fixtures, err := GetList[Fixture](tx).
Where("season_id = ?", seasonID).
Where("league_id = ?", leagueID).
Where("game_week IS NULL").
Order("round ASC", "id ASC").
Relation("HomeTeam").
Relation("AwayTeam").
GetAll(ctx)
if err != nil {
return nil, errors.Wrap(err, "GetList")
}
return fixtures, nil
}
func CountUnallocatedFixtures(ctx context.Context, tx bun.Tx, seasonID, leagueID int) (int, error) {
count, err := GetList[Fixture](tx).
Where("season_id = ?", seasonID).
Where("league_id = ?", leagueID).
Where("game_week IS NULL").
Count(ctx)
if err != nil {
return 0, errors.Wrap(err, "GetList")
}
return count, nil
}
func GetMaxGameWeek(ctx context.Context, tx bun.Tx, seasonID, leagueID int) (int, error) {
var maxGameWeek int
err := tx.NewSelect().
Model((*Fixture)(nil)).
Column("game_week").
Where("season_id = ?", seasonID).
Where("league_id = ?", leagueID).
Order("game_week DESC NULLS LAST").
Limit(1).Scan(ctx, &maxGameWeek)
if err != nil {
return 0, errors.Wrap(err, "tx.NewSelect")
}
return maxGameWeek, nil
}
func UpdateFixtureGameWeeks(ctx context.Context, tx bun.Tx, fixtures []*Fixture, audit *AuditMeta) error {
details := []any{}
for _, fixture := range fixtures {
err := UpdateByID(tx, fixture.ID, fixture).
Column("game_week").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "UpdateByID")
}
details = append(details, map[string]any{"fixture_id": fixture.ID, "game_week": fixture.GameWeek})
}
info := &AuditInfo{
"fixtures.manage",
"fixture",
"multiple",
map[string]any{"updated": details},
}
err := LogSuccess(ctx, tx, audit, info)
if err != nil {
return errors.Wrap(err, "LogSuccess")
}
return nil
}
func DeleteAllFixtures(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string, audit *AuditMeta) error {
season, league, _, err := GetSeasonLeague(ctx, tx, seasonShortName, leagueShortName)
if err != nil {
return errors.Wrap(err, "GetSeasonLeague")
}
err = DeleteItem[Fixture](tx).
Where("season_id = ?", season.ID).
Where("league_id = ?", league.ID).
WithAudit(audit, nil).
Delete(ctx)
if err != nil {
return errors.Wrap(err, "DeleteItem")
}
return nil
}
func DeleteFixture(ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error {
err := DeleteByID[Fixture](tx, id).
WithAudit(audit, nil).
Delete(ctx)
if err != nil {
return errors.Wrap(err, "DeleteByID")
}
return nil
}
func newFixture(season *Season, league *League, homeTeam, awayTeam *Team, round int, created time.Time) *Fixture {
return &Fixture{
SeasonID: season.ID,
LeagueID: league.ID,
HomeTeamID: homeTeam.ID,
AwayTeamID: awayTeam.ID,
Round: round,
CreatedAt: created.Unix(),
}
}
func checkTeamsAssociated(season *Season, league *League, teamsIn []*Team, toCheck []*Team) error {
badIDs := []string{}
master := map[int]bool{}
for _, team := range teamsIn {
master[team.ID] = true
}
for _, team := range toCheck {
if !master[team.ID] {
badIDs = append(badIDs, strconv.Itoa(team.ID))
}
}
ids := strings.Join(badIDs, ",")
if len(ids) > 0 {
return BadRequestNotAssociated("season_league", "team",
"season_id,league_id", "ids",
fmt.Sprintf("%v,%v", season.ID, league.ID),
ids)
}
return nil
}
type versus struct {
homeTeam *Team
awayTeam *Team
}
func generateRound(season *Season, league *League, round int, teams []*Team) []*Fixture {
now := time.Now()
numTeams := len(teams)
numGames := numTeams * (numTeams - 1) / 2
fixtures := make([]*Fixture, numGames)
for i, matchup := range allTeamsPlay(teams, round) {
fixtures[i] = newFixture(season, league, matchup.homeTeam, matchup.awayTeam, round, now)
}
return fixtures
}
func allTeamsPlay(teams []*Team, round int) []*versus {
matchups := []*versus{}
if len(teams) < 2 {
return matchups
}
team1 := teams[0]
teams = teams[1:]
matchups = append(matchups, playOtherTeams(team1, teams, round)...)
matchups = append(matchups, allTeamsPlay(teams, round)...)
return matchups
}
func playOtherTeams(team *Team, teams []*Team, round int) []*versus {
matchups := make([]*versus, len(teams))
for i, opponent := range teams {
versus := &versus{}
if i%2+round%2 == 0 {
versus.homeTeam = team
versus.awayTeam = opponent
} else {
versus.homeTeam = opponent
versus.awayTeam = team
}
matchups[i] = versus
}
return matchups
}

70
internal/db/getbyfield.go Normal file
View File

@@ -0,0 +1,70 @@
package db
import (
"context"
"database/sql"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type fieldgetter[T any] struct {
q *bun.SelectQuery
field string
value any
model *T
}
func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
if g.field == "id" && (g.value).(int) < 1 {
return nil, errors.New("invalid id")
}
err := g.q.
Where("? = ?", bun.Ident(g.field), g.value).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
resource := extractResourceType(extractTableName[T]())
return nil, BadRequestNotFound(resource, g.field, g.value)
}
return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
}
return g.model, nil
}
func (g *fieldgetter[T]) Get(ctx context.Context) (*T, error) {
g.q = g.q.Limit(1)
return g.get(ctx)
}
func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] {
g.q = g.q.Relation(name, apply...)
return g
}
func (g *fieldgetter[T]) Join(join string, args ...any) *fieldgetter[T] {
g.q = g.q.Join(join, args...)
return g
}
// GetByField retrieves a single record by field name
func GetByField[T any](
tx bun.Tx,
field string,
value any,
) *fieldgetter[T] {
model := new(T)
return &fieldgetter[T]{
tx.NewSelect().Model(model),
field,
value,
model,
}
}
func GetByID[T any](
tx bun.Tx,
id int,
) *fieldgetter[T] {
return GetByField[T](tx, "id", id)
}

152
internal/db/getlist.go Normal file
View File

@@ -0,0 +1,152 @@
package db
import (
"context"
"database/sql"
"fmt"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type listgetter[T any] struct {
q *bun.SelectQuery
items *[]*T
}
type List[T any] struct {
Items []*T
Total int
PageOpts PageOpts
}
type Filter struct {
Field string
Value any
Comparator Comparator
}
type Comparator string
const (
Equal Comparator = "="
Less Comparator = "<"
LessEqual Comparator = "<="
Greater Comparator = ">"
GreaterEqual Comparator = ">="
In Comparator = "IN"
)
type ListFilter struct {
filters []Filter
}
func NewListFilter() *ListFilter {
return &ListFilter{[]Filter{}}
}
func (f *ListFilter) Equals(field string, value any) {
f.filters = append(f.filters, Filter{field, value, Equal})
}
func (f *ListFilter) LessThan(field string, value any) {
f.filters = append(f.filters, Filter{field, value, Less})
}
func (f *ListFilter) LessEqualThan(field string, value any) {
f.filters = append(f.filters, Filter{field, value, LessEqual})
}
func (f *ListFilter) GreaterThan(field string, value any) {
f.filters = append(f.filters, Filter{field, value, Greater})
}
func (f *ListFilter) GreaterEqualThan(field string, value any) {
f.filters = append(f.filters, Filter{field, value, GreaterEqual})
}
func (f *ListFilter) In(field string, values any) {
f.filters = append(f.filters, Filter{field, values, In})
}
func GetList[T any](tx bun.Tx) *listgetter[T] {
l := &listgetter[T]{
items: new([]*T),
}
l.q = tx.NewSelect().
Model(l.items)
return l
}
func (l *listgetter[T]) String() string {
return l.q.String()
}
func (l *listgetter[T]) Join(join string, args ...any) *listgetter[T] {
l.q = l.q.Join(join, args...)
return l
}
func (l *listgetter[T]) Where(query string, args ...any) *listgetter[T] {
l.q = l.q.Where(query, args...)
return l
}
func (l *listgetter[T]) Order(orders ...string) *listgetter[T] {
l.q = l.q.Order(orders...)
return l
}
func (l *listgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *listgetter[T] {
l.q = l.q.Relation(name, apply...)
return l
}
func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] {
for _, filter := range filters {
if filter.Comparator == In {
l.q = l.q.Where("? IN (?)", bun.Ident(filter.Field), bun.In(filter.Value))
} else {
l.q = l.q.Where("? ? ?", bun.Ident(filter.Field), bun.Safe(filter.Comparator), filter.Value)
}
}
fmt.Println(l.q.String())
return l
}
func (l *listgetter[T]) GetPaged(ctx context.Context, pageOpts, defaults *PageOpts) (*List[T], error) {
if defaults == nil {
return nil, errors.New("default pageopts is nil")
}
total, err := l.q.Count(ctx)
if err != nil {
return nil, errors.Wrap(err, "query.Count")
}
l.q, pageOpts = setPageOpts(l.q, pageOpts, defaults, total)
err = l.q.Scan(ctx)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, errors.Wrap(err, "query.Scan")
}
list := &List[T]{
Items: *l.items,
Total: total,
PageOpts: *pageOpts,
}
return list, nil
}
func (l *listgetter[T]) Count(ctx context.Context) (int, error) {
count, err := l.q.Count(ctx)
if err != nil {
return 0, errors.Wrap(err, "query.Count")
}
return count, nil
}
func (l *listgetter[T]) GetAll(ctx context.Context) ([]*T, error) {
err := l.q.Scan(ctx)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, errors.Wrap(err, "query.Scan")
}
return *l.items, nil
}

123
internal/db/insert.go Normal file
View File

@@ -0,0 +1,123 @@
package db
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type inserter[T any] struct {
tx bun.Tx
q *bun.InsertQuery
model *T
models []*T
isBulk bool
audit *AuditMeta
auditInfo *AuditInfo
}
// Insert creates an inserter for a single model
// The model will have all fields populated after Exec() via Returning("*")
func Insert[T any](tx bun.Tx, model *T) *inserter[T] {
if model == nil {
panic("model cannot be nil")
}
return &inserter[T]{
tx: tx,
q: tx.NewInsert().Model(model).Returning("*"),
model: model,
isBulk: false,
}
}
// InsertMultiple creates an inserter for bulk insert
// All models will have fields populated after Exec() via Returning("*")
func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] {
if len(models) == 0 {
panic("models cannot be nil or empty")
}
return &inserter[T]{
tx: tx,
q: tx.NewInsert().Model(&models).Returning("*"),
models: models,
isBulk: true,
}
}
func (i *inserter[T]) ConflictNothing(conflicts ...string) *inserter[T] {
fieldstr := strings.Join(conflicts, ", ")
i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO NOTHING", fieldstr))
return i
}
func (i *inserter[T]) ConflictUpdate(conflicts []string, columns ...string) *inserter[T] {
fieldstr := strings.Join(conflicts, ", ")
i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO UPDATE", fieldstr))
for _, column := range columns {
i.q = i.q.Set(fmt.Sprintf("%s = EXCLUDED.%s", column, column))
}
return i
}
// Returning overrides the default Returning("*") clause
// Example: .Returning("id", "created_at")
func (i *inserter[T]) Returning(columns ...string) *inserter[T] {
if len(columns) == 0 {
return i
}
// Build column list as single string
columnList := strings.Join(columns, ", ")
i.q = i.q.Returning(columnList)
return i
}
// WithAudit enables audit logging for this insert operation
// If the provided *AuditInfo is nil, will use reflection to automatically work out the details
func (i *inserter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *inserter[T] {
i.audit = meta
i.auditInfo = info
return i
}
// Exec executes the insert and optionally logs to audit
// Returns an error if insert fails or if audit callback fails (triggering rollback)
func (i *inserter[T]) Exec(ctx context.Context) error {
// Execute insert
_, err := i.q.Exec(ctx)
if err != nil {
return errors.Wrap(err, "bun.InsertQuery.Exec")
}
// Handle audit logging if enabled
if i.audit != nil {
if i.auditInfo == nil {
tableName := extractTableName[T]()
resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "create")
i.auditInfo = &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: nil,
Details: nil,
}
if i.isBulk {
i.auditInfo.Details = map[string]any{
"count": len(i.models),
}
} else {
i.auditInfo.ResourceID = extractPrimaryKey(i.model)
i.auditInfo.Details = i.model
}
}
err = LogSuccess(ctx, i.tx, i.audit, i.auditInfo)
if err != nil {
return errors.Wrap(err, "LogSuccess")
}
}
return nil
}

19
internal/db/isunique.go Normal file
View File

@@ -0,0 +1,19 @@
package db
import (
"context"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func IsUnique(ctx context.Context, tx bun.Tx, model any, field, value string) (bool, error) {
count, err := tx.NewSelect().
Model(model).
Where("? = ?", bun.Ident(field), value).
Count(ctx)
if err != nil {
return false, errors.Wrap(err, "tx.NewSelect")
}
return count == 0, nil
}

45
internal/db/league.go Normal file
View File

@@ -0,0 +1,45 @@
package db
import (
"context"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type League struct {
bun.BaseModel `bun:"table:leagues,alias:l"`
ID int `bun:"id,pk,autoincrement" json:"id"`
Name string `bun:"name,unique,notnull" json:"name"`
ShortName string `bun:"short_name,unique,notnull" json:"short_name"`
Description string `bun:"description" json:"description"`
Seasons []Season `bun:"m2m:season_leagues,join:League=Season" json:"-"`
Teams []Team `bun:"m2m:team_participations,join:League=Team" json:"-"`
}
func GetLeagues(ctx context.Context, tx bun.Tx) ([]*League, error) {
return GetList[League](tx).Relation("Seasons").GetAll(ctx)
}
func GetLeague(ctx context.Context, tx bun.Tx, shortname string) (*League, error) {
if shortname == "" {
return nil, errors.New("shortname cannot be empty")
}
return GetByField[League](tx, "short_name", shortname).Relation("Seasons").Get(ctx)
}
func NewLeague(ctx context.Context, tx bun.Tx, name, shortname, description string, audit *AuditMeta) (*League, error) {
league := &League{
Name: name,
ShortName: shortname,
Description: description,
}
err := Insert(tx, league).
WithAudit(audit, nil).Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "db.Insert")
}
return league, nil
}

View File

@@ -0,0 +1,511 @@
// Package migrate provides functions for managing database migrations
package migrate
import (
"bufio"
"context"
"fmt"
"os"
"os/exec"
"strconv"
"strings"
"text/tabwriter"
"time"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/db/migrations"
"github.com/pkg/errors"
"github.com/uptrace/bun/migrate"
)
// RunMigrations executes database migrations
func RunMigrations(ctx context.Context, cfg *config.Config, command string, countStr string) error {
conn := db.NewDB(cfg.DB)
defer func() { _ = conn.Close() }()
migrator := migrate.NewMigrator(conn.DB, migrations.Migrations)
// Initialize migration tables
if err := migrator.Init(ctx); err != nil {
return errors.Wrap(err, "migrator.Init")
}
switch command {
case "up":
err := migrateUp(ctx, migrator, conn, cfg, countStr)
if err != nil {
// On error, automatically rollback the migrations that were just applied
fmt.Println("[WARN] Migration failed, attempting automatic rollback...")
// We need to figure out how many migrations were applied in this batch
// For now, we'll skip automatic rollback since it's complex with the new count system
// The user can manually rollback if needed
return err
}
return err
case "rollback":
return migrateRollback(ctx, migrator, conn, cfg, countStr)
case "status":
return migrateStatus(ctx, migrator)
default:
return fmt.Errorf("unknown migration command: %s", command)
}
}
// migrateUp runs pending migrations
func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error {
// Parse count parameter
count, all, err := parseMigrationCount(countStr)
if err != nil {
return errors.Wrap(err, "parse migration count")
}
fmt.Println("[INFO] Step 1/5: Validating migrations...")
if err := validateMigrations(ctx); err != nil {
return err
}
fmt.Println("[INFO] Migration validation passed ✓")
fmt.Println("[INFO] Step 2/5: Checking for pending migrations...")
// Check for pending migrations using MigrationsWithStatus (read-only)
ms, err := migrator.MigrationsWithStatus(ctx)
if err != nil {
return errors.Wrap(err, "get migration status")
}
unapplied := ms.Unapplied()
if len(unapplied) == 0 {
fmt.Println("[INFO] No pending migrations")
return nil
}
// Select which migrations to apply
toApply := selectMigrationsToApply(unapplied, count, all)
if len(toApply) == 0 {
fmt.Println("[INFO] No migrations to run")
return nil
}
// Print what we're about to do
if all {
fmt.Printf("[INFO] Running all %d pending migration(s):\n", len(toApply))
} else {
fmt.Printf("[INFO] Running %d migration(s):\n", len(toApply))
}
for _, m := range toApply {
fmt.Printf(" 📋 %s\n", m.Name)
}
// Create backup unless --no-backup flag is set
if !cfg.Flags.MigrateNoBackup {
fmt.Println("[INFO] Step 3/5: Creating backup...")
_, err := db.CreateBackup(ctx, cfg.DB, "migration")
if err != nil {
return errors.Wrap(err, "create backup")
}
// Clean old backups
if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil {
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
}
} else {
fmt.Println("[INFO] Step 3/5: Skipping backup (--no-backup flag set)")
}
// Acquire migration lock
fmt.Println("[INFO] Step 4/5: Acquiring migration lock...")
if err := acquireMigrationLock(ctx, conn); err != nil {
return errors.Wrap(err, "acquire migration lock")
}
defer releaseMigrationLock(ctx, conn)
fmt.Println("[INFO] Migration lock acquired")
// Run migrations
fmt.Println("[INFO] Step 5/5: Applying migrations...")
group, err := executeUpMigrations(ctx, migrator, toApply)
if err != nil {
return errors.Wrap(err, "execute migrations")
}
if group.IsZero() {
fmt.Println("[INFO] No migrations to run")
return nil
}
fmt.Printf("[INFO] Migrated to group %d\n", group.ID)
for _, migration := range group.Migrations {
fmt.Printf(" ✅ %s\n", migration.Name)
}
return nil
}
// migrateRollback rolls back migrations
func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error {
// Parse count parameter
count, all, err := parseMigrationCount(countStr)
if err != nil {
return errors.Wrap(err, "parse migration count")
}
// Get all migrations with status
ms, err := migrator.MigrationsWithStatus(ctx)
if err != nil {
return errors.Wrap(err, "get migration status")
}
applied := ms.Applied()
if len(applied) == 0 {
fmt.Println("[INFO] No migrations to rollback")
return nil
}
// Select which migrations to rollback
toRollback := selectMigrationsToRollback(applied, count, all)
if len(toRollback) == 0 {
fmt.Println("[INFO] No migrations to rollback")
return nil
}
// Print what we're about to do
if all {
fmt.Printf("[INFO] Rolling back all %d migration(s):\n", len(toRollback))
} else {
fmt.Printf("[INFO] Rolling back %d migration(s):\n", len(toRollback))
}
for _, m := range toRollback {
fmt.Printf(" 📋 %s (group %d)\n", m.Name, m.GroupID)
}
// Create backup unless --no-backup flag is set
if !cfg.Flags.MigrateNoBackup {
fmt.Println("[INFO] Creating backup before rollback...")
_, err := db.CreateBackup(ctx, cfg.DB, "rollback")
if err != nil {
return errors.Wrap(err, "create backup")
}
// Clean old backups
if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil {
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
}
} else {
fmt.Println("[INFO] Skipping backup (--no-backup flag set)")
}
// Acquire migration lock
fmt.Println("[INFO] Acquiring migration lock...")
if err := acquireMigrationLock(ctx, conn); err != nil {
return errors.Wrap(err, "acquire migration lock")
}
defer releaseMigrationLock(ctx, conn)
fmt.Println("[INFO] Migration lock acquired")
// Rollback
fmt.Println("[INFO] Executing rollback...")
rolledBack, err := executeDownMigrations(ctx, migrator, toRollback)
if err != nil {
return errors.Wrap(err, "execute rollback")
}
fmt.Printf("[INFO] Successfully rolled back %d migration(s)\n", len(rolledBack))
for _, migration := range rolledBack {
fmt.Printf(" ↩️ %s\n", migration.Name)
}
return nil
}
// migrateStatus shows migration status
func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error {
ms, err := migrator.MigrationsWithStatus(ctx)
if err != nil {
return errors.Wrap(err, "get migration status")
}
fmt.Println("╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ DATABASE MIGRATION STATUS ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
w := tabwriter.NewWriter(os.Stdout, 0, 0, 1, ' ', 0)
_, _ = fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tCOMMENT")
_, _ = fmt.Fprintln(w, "----------\t---------------\t-----\t---------------------------")
appliedCount := 0
for _, m := range ms {
status := "⏳ Pending"
group := "-"
if m.GroupID > 0 {
status = "✅ Applied"
appliedCount++
group = fmt.Sprint(m.GroupID)
}
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, m.Comment)
}
_ = w.Flush()
fmt.Printf("\n📊 Summary: %d applied, %d pending\n\n",
appliedCount, len(ms)-appliedCount)
return nil
}
// validateMigrations ensures migrations compile before running
func validateMigrations(ctx context.Context) error {
cmd := exec.CommandContext(ctx, "go", "build",
"-o", "/dev/null", "./internal/db/migrations")
output, err := cmd.CombinedOutput()
if err != nil {
fmt.Println("[ERROR] Migration validation failed!")
fmt.Println(string(output))
return errors.Wrap(err, "migration build failed")
}
return nil
}
// acquireMigrationLock prevents concurrent migrations using PostgreSQL advisory lock
func acquireMigrationLock(ctx context.Context, conn *db.DB) error {
const lockID = 1234567890 // Arbitrary unique ID for migration lock
const timeoutSeconds = 300 // 5 minutes
// Set statement timeout for this session
_, err := conn.ExecContext(ctx,
fmt.Sprintf("SET statement_timeout = '%ds'", timeoutSeconds))
if err != nil {
return errors.Wrap(err, "set timeout")
}
var acquired bool
err = conn.NewRaw("SELECT pg_try_advisory_lock(?)", lockID).
Scan(ctx, &acquired)
if err != nil {
return errors.Wrap(err, "pg_try_advisory_lock")
}
if !acquired {
return errors.New("migration already in progress (could not acquire lock)")
}
return nil
}
// releaseMigrationLock releases the migration lock
func releaseMigrationLock(ctx context.Context, conn *db.DB) {
const lockID = 1234567890
_, err := conn.NewRaw("SELECT pg_advisory_unlock(?)", lockID).Exec(ctx)
if err != nil {
fmt.Printf("[WARN] Failed to release migration lock: %v\n", err)
} else {
fmt.Println("[INFO] Migration lock released")
}
}
// CreateMigration generates a new migration file
func CreateMigration(name string) error {
if name == "" {
return errors.New("migration name cannot be empty")
}
// Sanitize name (replace spaces with underscores, lowercase)
name = strings.ToLower(strings.ReplaceAll(name, " ", "_"))
// Generate timestamp
timestamp := time.Now().Format("20060102150405")
filename := fmt.Sprintf("internal/db/migrations/%s_%s.go", timestamp, name)
// Template
template := `package migrations
import (
"context"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP migration
func(ctx context.Context, conn *bun.DB) error {
// Add your migration code here
return nil
},
// DOWN migration
func(ctx context.Context, conn *bun.DB) error {
// Add your rollback code here
return nil
},
)
}
`
// Write file
if err := os.WriteFile(filename, []byte(template), 0o644); err != nil {
return errors.Wrap(err, "write migration file")
}
fmt.Printf("✅ Created migration: %s\n", filename)
fmt.Println("📝 Next steps:")
fmt.Println(" 1. Edit the file and implement the UP and DOWN functions")
fmt.Println(" 2. Run: just migrate up")
return nil
}
// parseMigrationCount parses a migration count string
// Returns: (count, all, error)
// - "" (empty) → (1, false, nil) - default to 1
// - "all" → (0, true, nil) - special case for all
// - "5" → (5, false, nil) - specific count
// - "invalid" → (0, false, error)
func parseMigrationCount(value string) (int, bool, error) {
// Default to 1 if empty
if value == "" {
return 1, false, nil
}
// Special case for "all"
if value == "all" {
return 0, true, nil
}
// Parse as integer
count, err := strconv.Atoi(value)
if err != nil {
return 0, false, errors.New("migration count must be a positive integer or 'all'")
}
if count < 1 {
return 0, false, errors.New("migration count must be a positive integer (1 or greater)")
}
return count, false, nil
}
// selectMigrationsToApply returns the subset of unapplied migrations to run
func selectMigrationsToApply(unapplied migrate.MigrationSlice, count int, all bool) migrate.MigrationSlice {
if all {
return unapplied
}
count = min(count, len(unapplied))
return unapplied[:count]
}
// selectMigrationsToRollback returns the subset of applied migrations to rollback
// Returns migrations in reverse chronological order (most recent first)
func selectMigrationsToRollback(applied migrate.MigrationSlice, count int, all bool) migrate.MigrationSlice {
if len(applied) == 0 || all {
return applied
}
count = min(count, len(applied))
return applied[:count]
}
// executeUpMigrations executes a subset of UP migrations
func executeUpMigrations(ctx context.Context, migrator *migrate.Migrator, migrations migrate.MigrationSlice) (*migrate.MigrationGroup, error) {
if len(migrations) == 0 {
return &migrate.MigrationGroup{}, nil
}
// Get the next group ID
ms, err := migrator.MigrationsWithStatus(ctx)
if err != nil {
return nil, errors.Wrap(err, "get migration status")
}
lastGroup := ms.LastGroup()
groupID := int64(1)
if lastGroup.ID > 0 {
groupID = lastGroup.ID + 1
}
// Create the migration group
group := &migrate.MigrationGroup{
ID: groupID,
Migrations: make(migrate.MigrationSlice, 0, len(migrations)),
}
// Execute each migration
for i := range migrations {
migration := &migrations[i]
migration.GroupID = groupID
// Mark as applied before execution (Bun's default behavior)
if err := migrator.MarkApplied(ctx, migration); err != nil {
return group, errors.Wrap(err, "mark applied")
}
// Add to group
group.Migrations = append(group.Migrations, *migration)
// Execute the UP function
if migration.Up != nil {
if err := migration.Up(ctx, migrator, migration); err != nil {
return group, errors.Wrap(err, fmt.Sprintf("migration %s failed", migration.Name))
}
}
}
return group, nil
}
// executeDownMigrations executes a subset of DOWN migrations
func executeDownMigrations(ctx context.Context, migrator *migrate.Migrator, migrations migrate.MigrationSlice) (migrate.MigrationSlice, error) {
rolledBack := make(migrate.MigrationSlice, 0, len(migrations))
// Execute each migration in order (already reversed)
for i := range migrations {
migration := &migrations[i]
// Execute the DOWN function
if migration.Down != nil {
if err := migration.Down(ctx, migrator, migration); err != nil {
return rolledBack, errors.Wrap(err, fmt.Sprintf("rollback %s failed", migration.Name))
}
}
// Mark as unapplied after execution
if err := migrator.MarkUnapplied(ctx, migration); err != nil {
return rolledBack, errors.Wrap(err, "mark unapplied")
}
rolledBack = append(rolledBack, *migration)
}
return rolledBack, nil
}
// ResetDatabase drops and recreates all tables (destructive)
func ResetDatabase(ctx context.Context, cfg *config.Config) error {
fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!")
fmt.Print("Type 'yes' to continue: ")
reader := bufio.NewReader(os.Stdin)
response, err := reader.ReadString('\n')
if err != nil {
return errors.Wrap(err, "read input")
}
response = strings.TrimSpace(response)
if response != "yes" {
fmt.Println("❌ Reset cancelled")
return nil
}
conn := db.NewDB(cfg.DB)
defer func() { _ = conn.Close() }()
models := conn.RegisterModels()
for _, model := range models {
if err := conn.ResetModel(ctx, model); err != nil {
return errors.Wrap(err, "reset model")
}
}
fmt.Println("✅ Database reset complete")
return nil
}

View File

@@ -0,0 +1,47 @@
package migrations
import (
"context"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP: Create initial tables (users, discord_tokens)
func(ctx context.Context, conn *bun.DB) error {
// Create users table
_, err := conn.NewCreateTable().
Model((*db.User)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create discord_tokens table
_, err = conn.NewCreateTable().
Model((*db.DiscordToken)(nil)).
Exec(ctx)
return err
},
// DOWN: Drop tables in reverse order
func(ctx context.Context, conn *bun.DB) error {
// Drop discord_tokens first (has foreign key to users)
_, err := conn.NewDropTable().
Model((*db.DiscordToken)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
// Drop users table
_, err = conn.NewDropTable().
Model((*db.User)(nil)).
IfExists().
Exec(ctx)
return err
},
)
}

View File

@@ -0,0 +1,34 @@
package migrations
import (
"context"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP migration
func(ctx context.Context, conn *bun.DB) error {
_, err := conn.NewCreateTable().
Model((*db.Season)(nil)).
Exec(ctx)
if err != nil {
return err
}
return nil
},
// DOWN migration
func(ctx context.Context, conn *bun.DB) error {
_, err := conn.NewDropTable().
Model((*db.Season)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
return nil
},
)
}

View File

@@ -0,0 +1,253 @@
package migrations
import (
"context"
"time"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP migration
func(ctx context.Context, conn *bun.DB) error {
conn.RegisterModel((*db.RolePermission)(nil), (*db.UserRole)(nil))
// Create permissions table
_, err := conn.NewCreateTable().
Model((*db.Role)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create permissions table
_, err = conn.NewCreateTable().
Model((*db.Permission)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create indexes for permissions
_, err = conn.NewCreateIndex().
Model((*db.Permission)(nil)).
Index("idx_permissions_resource").
Column("resource").
Exec(ctx)
if err != nil {
return err
}
_, err = conn.NewCreateIndex().
Model((*db.Permission)(nil)).
Index("idx_permissions_action").
Column("action").
Exec(ctx)
if err != nil {
return err
}
_, err = conn.NewCreateTable().
Model((*db.RolePermission)(nil)).
Exec(ctx)
if err != nil {
return err
}
_, err = conn.ExecContext(ctx, `
CREATE INDEX idx_role_permissions_role ON role_permissions(role_id)
`)
if err != nil {
return err
}
_, err = conn.ExecContext(ctx, `
CREATE INDEX idx_role_permissions_permission ON role_permissions(permission_id)
`)
if err != nil {
return err
}
// Create user_roles table
_, err = conn.NewCreateTable().
Model((*db.UserRole)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create indexes for user_roles
_, err = conn.NewCreateIndex().
Model((*db.UserRole)(nil)).
Index("idx_user_roles_user").
Column("user_id").
Exec(ctx)
if err != nil {
return err
}
_, err = conn.NewCreateIndex().
Model((*db.UserRole)(nil)).
Index("idx_user_roles_role").
Column("role_id").
Exec(ctx)
if err != nil {
return err
}
// Create audit_log table
_, err = conn.NewCreateTable().
Model((*db.AuditLog)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create indexes for audit_log
_, err = conn.NewCreateIndex().
Model((*db.AuditLog)(nil)).
Index("idx_audit_log_user").
Column("user_id").
Exec(ctx)
if err != nil {
return err
}
_, err = conn.NewCreateIndex().
Model((*db.AuditLog)(nil)).
Index("idx_audit_log_action").
Column("action").
Exec(ctx)
if err != nil {
return err
}
_, err = conn.NewCreateIndex().
Model((*db.AuditLog)(nil)).
Index("idx_audit_log_resource").
Column("resource_type", "resource_id").
Exec(ctx)
if err != nil {
return err
}
_, err = conn.NewCreateIndex().
Model((*db.AuditLog)(nil)).
Index("idx_audit_log_created").
Column("created_at").
Exec(ctx)
if err != nil {
return err
}
err = seedSystemRBAC(ctx, conn)
if err != nil {
return err
}
return nil
},
// DOWN migration
func(ctx context.Context, dbConn *bun.DB) error {
// Drop tables in reverse order
// Use raw SQL to avoid relationship resolution issues
tables := []string{
"audit_log",
"user_roles",
"role_permissions",
"permissions",
"roles",
}
for _, table := range tables {
_, err := dbConn.ExecContext(ctx, "DROP TABLE IF EXISTS "+table+" CASCADE")
if err != nil {
return err
}
}
return nil
},
)
}
func seedSystemRBAC(ctx context.Context, conn *bun.DB) error {
// Seed system roles
now := time.Now().Unix()
adminRole := &db.Role{
Name: "admin",
DisplayName: "Administrator",
Description: "Full system access with all permissions",
IsSystem: true,
CreatedAt: now,
}
_, err := conn.NewInsert().
Model(adminRole).
Returning("id").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "dbConn.NewInsert")
}
userRole := &db.Role{
Name: "user",
DisplayName: "User",
Description: "Standard user with basic permissions",
IsSystem: true,
CreatedAt: now,
}
_, err = conn.NewInsert().
Model(userRole).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "dbConn.NewInsert")
}
// Seed system permissions
permissionsData := []*db.Permission{
{Name: "*", DisplayName: "Wildcard (All Permissions)", Description: "Grants access to all permissions, past, present, and future", Resource: "*", Action: "*", IsSystem: true, CreatedAt: now},
{Name: "seasons.create", DisplayName: "Create Seasons", Description: "Create new seasons", Resource: "seasons", Action: "create", IsSystem: true, CreatedAt: now},
{Name: "seasons.update", DisplayName: "Update Seasons", Description: "Update existing seasons", Resource: "seasons", Action: "update", IsSystem: true, CreatedAt: now},
{Name: "seasons.delete", DisplayName: "Delete Seasons", Description: "Delete seasons", Resource: "seasons", Action: "delete", IsSystem: true, CreatedAt: now},
{Name: "users.update", DisplayName: "Update Users", Description: "Update user information", Resource: "users", Action: "update", IsSystem: true, CreatedAt: now},
{Name: "users.ban", DisplayName: "Ban Users", Description: "Ban users from the system", Resource: "users", Action: "ban", IsSystem: true, CreatedAt: now},
{Name: "users.manage_roles", DisplayName: "Manage User Roles", Description: "Assign and revoke user roles", Resource: "users", Action: "manage_roles", IsSystem: true, CreatedAt: now},
}
_, err = conn.NewInsert().
Model(&permissionsData).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "dbConn.NewInsert")
}
// Grant wildcard permission to admin role using Bun
// First, get the IDs
var wildcardPerm db.Permission
err = conn.NewSelect().
Model(&wildcardPerm).
Where("name = ?", "*").
Scan(ctx)
if err != nil {
return err
}
// Insert role_permission mapping
adminRolePerms := &db.RolePermission{
RoleID: adminRole.ID,
PermissionID: wildcardPerm.ID,
}
_, err = conn.NewInsert().
Model(adminRolePerms).
On("CONFLICT (role_id, permission_id) DO NOTHING").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "dbConn.NewInsert")
}
return nil
}

View File

@@ -0,0 +1,67 @@
package migrations
import (
"context"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/db"
)
func init() {
Migrations.MustRegister(
// UP migration
func(ctx context.Context, conn *bun.DB) error {
// Add slap_version column to seasons table
_, err := conn.NewAddColumn().
Model((*db.Season)(nil)).
ColumnExpr("slap_version VARCHAR NOT NULL DEFAULT 'rebound'").
IfNotExists().
Exec(ctx)
if err != nil {
return err
}
// Create leagues table
_, err = conn.NewCreateTable().
Model((*db.League)(nil)).
Exec(ctx)
if err != nil {
return err
}
// Create season_leagues join table
_, err = conn.NewCreateTable().
Model((*db.SeasonLeague)(nil)).
Exec(ctx)
return err
},
// DOWN migration
func(ctx context.Context, conn *bun.DB) error {
// Drop season_leagues join table first
_, err := conn.NewDropTable().
Model((*db.SeasonLeague)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
// Drop leagues table
_, err = conn.NewDropTable().
Model((*db.League)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
// Remove slap_version column from seasons table
_, err = conn.NewDropColumn().
Model((*db.Season)(nil)).
ColumnExpr("slap_version").
Exec(ctx)
return err
},
)
}

View File

@@ -0,0 +1,49 @@
package migrations
import (
"context"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP migration
func(ctx context.Context, conn *bun.DB) error {
// Add your migration code here
_, err := conn.NewCreateTable().
Model((*db.Team)(nil)).
Exec(ctx)
if err != nil {
return err
}
_, err = conn.NewCreateTable().
Model((*db.TeamParticipation)(nil)).
Exec(ctx)
if err != nil {
return err
}
return nil
},
// DOWN migration
func(ctx context.Context, conn *bun.DB) error {
// Add your rollback code here
_, err := conn.NewDropTable().
Model((*db.TeamParticipation)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
_, err = conn.NewDropTable().
Model((*db.Team)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
return nil
},
)
}

View File

@@ -0,0 +1,44 @@
package migrations
import (
"context"
"time"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP migration
func(ctx context.Context, conn *bun.DB) error {
// Add your migration code here
now := time.Now().Unix()
permissionsData := []*db.Permission{
{Name: "seasons.add_league", DisplayName: "Add Leagues to Season", Description: "Assign an existing league to Seasons", Resource: "seasons", Action: "add_league", IsSystem: true, CreatedAt: now},
{Name: "seasons.remove_league", DisplayName: "Remove Leagues from a Season", Description: "Remove an assigned league league from Seasons", Resource: "seasons", Action: "remove_league", IsSystem: true, CreatedAt: now},
{Name: "leagues.create", DisplayName: "Create Leagues", Description: "Create new leagues", Resource: "leagues", Action: "create", IsSystem: true, CreatedAt: now},
{Name: "leagues.update", DisplayName: "Update Leagues", Description: "Update existing leagues", Resource: "leagues", Action: "update", IsSystem: true, CreatedAt: now},
{Name: "leagues.delete", DisplayName: "Delete Leagues", Description: "Delete leagues", Resource: "leagues", Action: "delete", IsSystem: true, CreatedAt: now},
{Name: "teams.create", DisplayName: "Create Teams", Description: "Create new teams", Resource: "teams", Action: "create", IsSystem: true, CreatedAt: now},
{Name: "teams.update", DisplayName: "Update Teams", Description: "Update existing teams", Resource: "teams", Action: "update", IsSystem: true, CreatedAt: now},
{Name: "teams.delete", DisplayName: "Delete Teams", Description: "Delete teams", Resource: "teams", Action: "delete", IsSystem: true, CreatedAt: now},
{Name: "teams.add_to_league", DisplayName: "Add Teams to League", Description: "Add an existing team to a league/season", Resource: "teams", Action: "add_to_league", IsSystem: true, CreatedAt: now},
}
_, err := conn.NewInsert().
Model(&permissionsData).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "conn.NewInsert")
}
return nil
},
// DOWN migration
func(ctx context.Context, conn *bun.DB) error {
// Add your rollback code here
return nil
},
)
}

View File

@@ -0,0 +1,52 @@
package migrations
import (
"context"
"time"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP migration
func(ctx context.Context, conn *bun.DB) error {
// Add your migration code here
_, err := conn.NewCreateTable().
Model((*db.Fixture)(nil)).
IfNotExists().
Exec(ctx)
if err != nil {
return err
}
now := time.Now().Unix()
permissionsData := []*db.Permission{
{Name: "fixtures.create", DisplayName: "Create Fixtures", Description: "Create new fixtures", Resource: "fixtures", Action: "create", IsSystem: true, CreatedAt: now},
{Name: "fixtures.manage", DisplayName: "Manage Fixtures", Description: "Manage fixtures", Resource: "fixtures", Action: "manage", IsSystem: true, CreatedAt: now},
{Name: "fixtures.delete", DisplayName: "Delete Fixtures", Description: "Delete fixtures", Resource: "fixtures", Action: "delete", IsSystem: true, CreatedAt: now},
}
_, err = conn.NewInsert().
Model(&permissionsData).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "conn.NewInsert")
}
return nil
},
// DOWN migration
func(ctx context.Context, conn *bun.DB) error {
// Add your rollback code here
_, err := conn.NewDropTable().
Model((*db.Fixture)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
return nil
},
)
}

View File

@@ -0,0 +1,9 @@
// Package migrations defines the database migrations to apply when using the migrate tags
package migrations
import (
"github.com/uptrace/bun/migrate"
)
// Migrations is the collection of all database migrations
var Migrations = migrate.NewMigrations()

200
internal/db/paginate.go Normal file
View File

@@ -0,0 +1,200 @@
package db
import (
"net/http"
"strings"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/validation"
"github.com/uptrace/bun"
)
type PageOpts struct {
Page int
PerPage int
Order bun.Order
OrderBy string
}
type OrderOpts struct {
Order bun.Order
OrderBy string
Label string
}
func GetPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request) (*PageOpts, bool) {
var getter validation.Getter
switch r.Method {
case "GET":
getter = validation.NewQueryGetter(r)
case "POST":
var ok bool
getter, ok = validation.ParseFormOrError(s, w, r)
if !ok {
return nil, false
}
default:
return nil, false
}
return getPageOpts(s, w, r, getter), true
}
func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *PageOpts {
page := g.Int("page").Optional().Min(1).Value
perPage := g.Int("per_page").Optional().Min(1).Max(100).Value
order := g.String("order").TrimSpace().ToUpper().Optional().AllowedValues([]string{"ASC", "DESC"}).Value
orderBy := g.String("order_by").TrimSpace().Optional().ToLower().Value
valid := g.ValidateAndError(s, w, r)
if !valid {
return nil
}
pageOpts := &PageOpts{
Page: page,
PerPage: perPage,
Order: bun.Order(order),
OrderBy: orderBy,
}
return pageOpts
}
func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.SelectQuery, *PageOpts) {
if p == nil {
p = new(PageOpts)
}
if p.Page <= 0 {
p.Page = d.Page
}
if p.PerPage == 0 {
p.PerPage = d.PerPage
}
maxpage := p.TotalPages(totalitems)
if p.Page > maxpage && maxpage > 0 {
p.Page = maxpage
}
if p.Order == "" {
p.Order = d.Order
}
if p.OrderBy == "" {
p.OrderBy = d.OrderBy
}
p.OrderBy = sanitiseOrderBy(p.OrderBy)
q = q.OrderBy(p.OrderBy, p.Order).
Limit(p.PerPage).
Offset(p.PerPage * (p.Page - 1))
return q, p
}
func sanitiseOrderBy(orderby string) string {
result := strings.ToLower(orderby)
var builder strings.Builder
for _, r := range result {
if isValidChar(r) {
builder.WriteRune(r)
}
}
sanitized := builder.String()
if sanitized == "" {
return "_"
}
if !isValidFirstChar(rune(sanitized[0])) {
sanitized = "_" + sanitized
}
if len(sanitized) > 63 {
sanitized = sanitized[:63]
}
return sanitized
}
func isValidChar(r rune) bool {
return (r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_'
}
func isValidFirstChar(r rune) bool {
return (r >= 'a' && r <= 'z') || r == '_'
}
// TotalPages calculates the total number of pages
func (p *PageOpts) TotalPages(total int) int {
if p.PerPage == 0 {
return 0
}
pages := total / p.PerPage
if total%p.PerPage > 0 {
pages++
}
return pages
}
// HasPrevPage checks if there is a previous page
func (p *PageOpts) HasPrevPage() bool {
return p.Page > 1
}
// HasNextPage checks if there is a next page
func (p *PageOpts) HasNextPage(total int) bool {
return p.Page < p.TotalPages(total)
}
// GetPageRange returns an array of page numbers to display
// maxButtons controls how many page buttons to show
func (p *PageOpts) GetPageRange(total int, maxButtons int) []int {
totalPages := p.TotalPages(total)
if totalPages == 0 {
return []int{}
}
// If total pages is less than max buttons, show all pages
if totalPages <= maxButtons {
pages := make([]int, totalPages)
for i := range totalPages {
pages[i] = i + 1
}
return pages
}
// Calculate range around current page
halfButtons := maxButtons / 2
start := p.Page - halfButtons
end := p.Page + halfButtons
// Adjust if at beginning
if start < 1 {
start = 1
end = maxButtons
}
// Adjust if at end
if end > totalPages {
end = totalPages
start = totalPages - maxButtons + 1
}
pages := make([]int, 0, maxButtons)
for i := start; i <= end; i++ {
pages = append(pages, i)
}
return pages
}
// StartItem returns the number of the first item on the current page
func (p *PageOpts) StartItem() int {
if p.Page < 1 {
return 0
}
return (p.Page-1)*p.PerPage + 1
}
// EndItem returns the number of the last item on the current page
func (p *PageOpts) EndItem(total int) int {
end := p.Page * p.PerPage
if end > total {
return total
}
return end
}

96
internal/db/permission.go Normal file
View File

@@ -0,0 +1,96 @@
package db
import (
"context"
"git.haelnorr.com/h/oslstats/internal/permissions"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Permission struct {
bun.BaseModel `bun:"table:permissions,alias:p"`
ID int `bun:"id,pk,autoincrement"`
Name permissions.Permission `bun:"name,unique,notnull"`
DisplayName string `bun:"display_name,notnull"`
Description string `bun:"description"`
Resource string `bun:"resource,notnull"`
Action string `bun:"action,notnull"`
IsSystem bool `bun:"is_system,default:false"`
CreatedAt int64 `bun:"created_at,notnull"`
Roles []Role `bun:"m2m:role_permissions,join:Permission=Role"`
}
func (p Permission) isSystem() bool {
return p.IsSystem
}
// GetPermissionByName queries the database for a permission matching the given name
// Returns a BadRequestNotFound error if no permission is found
func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) {
if name == "" {
return nil, errors.New("name cannot be empty")
}
return GetByField[Permission](tx, "name", name).Get(ctx)
}
// GetPermissionByID queries the database for a permission matching the given ID
// Returns a BadRequestNotFound error if no permission is found
func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, error) {
if id <= 0 {
return nil, errors.New("id must be positive")
}
return GetByID[Permission](tx, id).Get(ctx)
}
// GetPermissionsByResource queries for all permissions for a given resource
func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) ([]*Permission, error) {
if resource == "" {
return nil, errors.New("resource cannot be empty")
}
return GetList[Permission](tx).
Where("resource = ?", resource).GetAll(ctx)
}
// ListAllPermissions returns all permissions
func ListAllPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
return GetList[Permission](tx).GetAll(ctx)
}
// CreatePermission creates a new permission
func CreatePermission(ctx context.Context, tx bun.Tx, perm *Permission) error {
if perm == nil {
return errors.New("permission cannot be nil")
}
if perm.Name == "" {
return errors.New("name cannot be empty")
}
if perm.DisplayName == "" {
return errors.New("display name cannot be empty")
}
if perm.Resource == "" {
return errors.New("resource cannot be empty")
}
if perm.Action == "" {
return errors.New("action cannot be empty")
}
err := Insert(tx, perm).
Returning("id").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
return nil
}
// DeletePermission deletes a permission (checks IsSystem protection)
func DeletePermission(ctx context.Context, tx bun.Tx, id int) error {
if id <= 0 {
return errors.New("id must be positive")
}
return DeleteWithProtection[Permission](ctx, tx, id, nil)
}

137
internal/db/role.go Normal file
View File

@@ -0,0 +1,137 @@
package db
import (
"context"
"time"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Role struct {
bun.BaseModel `bun:"table:roles,alias:r"`
ID int `bun:"id,pk,autoincrement"`
Name roles.Role `bun:"name,unique,notnull"`
DisplayName string `bun:"display_name,notnull"`
Description string `bun:"description"`
IsSystem bool `bun:"is_system,default:false"`
CreatedAt int64 `bun:"created_at,notnull"`
UpdatedAt *int64 `bun:"updated_at"`
// Relations (loaded on demand)
Users []User `bun:"m2m:user_roles,join:Role=User"`
Permissions []Permission `bun:"m2m:role_permissions,join:Role=Permission"`
}
func (r Role) isSystem() bool {
return r.IsSystem
}
// GetRoleByName queries the database for a role matching the given name
// Returns a BadRequestNotFound error if no role is found
func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) {
if name == "" {
return nil, errors.New("name cannot be empty")
}
return GetByField[Role](tx, "name", name).Relation("Permissions").Get(ctx)
}
// GetRoleByID queries the database for a role matching the given ID
// Returns a BadRequestNotFound error if no role is found
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
return GetByID[Role](tx, id).Relation("Permissions").Get(ctx)
}
// ListAllRoles returns all roles
func ListAllRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
return GetList[Role](tx).GetAll(ctx)
}
// GetRoles returns a paginated list of roles
func GetRoles(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Role], error) {
defaults := &PageOpts{
Page: 1,
PerPage: 25,
Order: bun.OrderAsc,
OrderBy: "display_name",
}
return GetList[Role](tx).GetPaged(ctx, pageOpts, defaults)
}
// CreateRole creates a new role
func CreateRole(ctx context.Context, tx bun.Tx, role *Role, audit *AuditMeta) error {
if role == nil {
return errors.New("role cannot be nil")
}
role.CreatedAt = time.Now().Unix()
err := Insert(tx, role).
Returning("id").
WithAudit(audit, nil).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
return nil
}
// UpdateRole updates an existing role
func UpdateRole(ctx context.Context, tx bun.Tx, role *Role, audit *AuditMeta) error {
if role == nil {
return errors.New("role cannot be nil")
}
if role.ID <= 0 {
return errors.New("role id must be positive")
}
err := Update(tx, role).
WherePK().
WithAudit(audit, nil).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Update")
}
return nil
}
// DeleteRole deletes a role (checks IsSystem protection)
// Also cleans up join table entries in role_permissions and user_roles
func DeleteRole(ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error {
if id <= 0 {
return errors.New("id must be positive")
}
// First check if role exists and is not system
role, err := GetRoleByID(ctx, tx, id)
if err != nil {
return errors.Wrap(err, "GetRoleByID")
}
if role.IsSystem {
return errors.New("cannot delete system roles")
}
// Delete role_permissions entries
_, err = tx.NewDelete().
Model((*RolePermission)(nil)).
Where("role_id = ?", id).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "delete role_permissions")
}
// Delete user_roles entries
_, err = tx.NewDelete().
Model((*UserRole)(nil)).
Where("role_id = ?", id).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "delete user_roles")
}
// Finally delete the role
return DeleteWithProtection[Role](ctx, tx, id, audit)
}

View File

@@ -0,0 +1,99 @@
package db
import (
"context"
"slices"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type RolePermission struct {
RoleID int `bun:",pk"`
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
PermissionID int `bun:",pk"`
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
}
func (r *Role) UpdatePermissions(ctx context.Context, tx bun.Tx, newPermissionsIDs []int, audit *AuditMeta) error {
addPerms, removePerms, err := detectChangedPermissions(ctx, tx, r, newPermissionsIDs)
if err != nil {
return errors.Wrap(err, "detectChangedPermissions")
}
addedPerms := []string{}
removedPerms := []string{}
for _, perm := range addPerms {
rolePerm := &RolePermission{
RoleID: r.ID,
PermissionID: perm.ID,
}
err := Insert(tx, rolePerm).
ConflictNothing("role_id", "permission_id").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
addedPerms = append(addedPerms, perm.Name.String())
}
for _, perm := range removePerms {
err := DeleteItem[RolePermission](tx).
Where("role_id = ?", r.ID).
Where("permission_id = ?", perm.ID).
Delete(ctx)
if err != nil {
return errors.Wrap(err, "DeleteItem")
}
removedPerms = append(removedPerms, perm.Name.String())
}
// Log the permission changes
if len(addedPerms) > 0 || len(removedPerms) > 0 {
details := map[string]any{
"role_name": string(r.Name),
}
if len(addedPerms) > 0 {
details["added_permissions"] = addedPerms
}
if len(removedPerms) > 0 {
details["removed_permissions"] = removedPerms
}
info := &AuditInfo{
"roles.update_permissions",
"role",
r.ID,
details,
}
err = LogSuccess(ctx, tx, audit, info)
if err != nil {
return errors.Wrap(err, "LogSuccess")
}
}
return nil
}
func detectChangedPermissions(ctx context.Context, tx bun.Tx, role *Role, permissionIDs []int) ([]*Permission, []*Permission, error) {
allPermissions, err := ListAllPermissions(ctx, tx)
if err != nil {
return nil, nil, errors.Wrap(err, "ListAllPermissions")
}
// Build map of current permissions
currentPermIDs := make(map[int]bool)
for _, perm := range role.Permissions {
currentPermIDs[perm.ID] = true
}
var addedPerms []*Permission
var removedPerms []*Permission
// Determine what to add and remove
for _, perm := range allPermissions {
hasNow := currentPermIDs[perm.ID]
shouldHave := slices.Contains(permissionIDs, perm.ID)
if shouldHave && !hasNow {
addedPerms = append(addedPerms, perm)
} else if !shouldHave && hasNow {
removedPerms = append(removedPerms, perm)
}
}
return addedPerms, removedPerms, nil
}

183
internal/db/season.go Normal file
View File

@@ -0,0 +1,183 @@
package db
import (
"context"
"strings"
"time"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// SeasonStatus represents the current status of a season
type SeasonStatus string
const (
// StatusUpcoming means the season has not started yet
StatusUpcoming SeasonStatus = "upcoming"
// StatusInProgress means the regular season is active
StatusInProgress SeasonStatus = "in_progress"
// StatusFinalsSoon means regular season ended, finals upcoming
StatusFinalsSoon SeasonStatus = "finals_soon"
// StatusFinals means finals are in progress
StatusFinals SeasonStatus = "finals"
// StatusCompleted means the season has finished
StatusCompleted SeasonStatus = "completed"
)
type Season struct {
bun.BaseModel `bun:"table:seasons,alias:s"`
ID int `bun:"id,pk,autoincrement" json:"id"`
Name string `bun:"name,unique,notnull" json:"name"`
ShortName string `bun:"short_name,unique,notnull" json:"short_name"`
StartDate time.Time `bun:"start_date,notnull" json:"start_date"`
EndDate bun.NullTime `bun:"end_date" json:"end_date"`
FinalsStartDate bun.NullTime `bun:"finals_start_date" json:"finals_start_date"`
FinalsEndDate bun.NullTime `bun:"finals_end_date" json:"finals_end_date"`
SlapVersion string `bun:"slap_version,notnull,default:'rebound'" json:"slap_version"`
Leagues []League `bun:"m2m:season_leagues,join:Season=League" json:"-"`
Teams []Team `bun:"m2m:team_participations,join:Season=Team" json:"-"`
}
// NewSeason creats a new season
func NewSeason(ctx context.Context, tx bun.Tx, name, version, shortname string,
start time.Time, audit *AuditMeta,
) (*Season, error) {
season := &Season{
Name: name,
ShortName: strings.ToUpper(shortname),
StartDate: start.Truncate(time.Hour * 24),
SlapVersion: version,
}
err := Insert(tx, season).
WithAudit(audit, nil).Exec(ctx)
if err != nil {
return nil, errors.WithMessage(err, "db.Insert")
}
return season, nil
}
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) {
defaults := &PageOpts{
1,
10,
bun.OrderDesc,
"start_date",
}
return GetList[Season](tx).Relation("Leagues").GetPaged(ctx, pageOpts, defaults)
}
func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error) {
if shortname == "" {
return nil, errors.New("short_name not provided")
}
return GetByField[Season](tx, "short_name", shortname).Relation("Leagues").Relation("Teams").Get(ctx)
}
// Update updates the season struct. It does not insert to the database
func (s *Season) Update(ctx context.Context, tx bun.Tx, version string,
start, end, finalsStart, finalsEnd time.Time, audit *AuditMeta,
) error {
s.SlapVersion = version
s.StartDate = start.Truncate(time.Hour * 24)
if !end.IsZero() {
s.EndDate.Time = end.Truncate(time.Hour * 24)
}
if !finalsStart.IsZero() {
s.FinalsStartDate.Time = finalsStart.Truncate(time.Hour * 24)
}
if !finalsEnd.IsZero() {
s.FinalsEndDate.Time = finalsEnd.Truncate(time.Hour * 24)
}
return Update(tx, s).WherePK().
Column("slap_version", "start_date", "end_date", "finals_start_date", "finals_end_date").
WithAudit(audit, nil).Exec(ctx)
}
func (s *Season) MapTeamsToLeagues(ctx context.Context, tx bun.Tx) ([]LeagueWithTeams, error) {
// For each league, get the teams
leaguesWithTeams := make([]LeagueWithTeams, len(s.Leagues))
for i, league := range s.Leagues {
var teams []*Team
err := tx.NewSelect().
Model(&teams).
Join("INNER JOIN team_participations AS tp ON tp.team_id = t.id").
Where("tp.season_id = ? AND tp.league_id = ?", s.ID, league.ID).
Order("t.name ASC").
Scan(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
leaguesWithTeams[i] = LeagueWithTeams{
League: &league,
Teams: teams,
}
}
return leaguesWithTeams, nil
}
type LeagueWithTeams struct {
League *League
Teams []*Team
}
// GetStatus returns the current status of the season based on dates
func (s *Season) GetStatus() SeasonStatus {
now := time.Now()
if now.Before(s.StartDate) {
return StatusUpcoming
}
if !s.FinalsStartDate.IsZero() {
if !s.FinalsEndDate.IsZero() && now.After(s.FinalsEndDate.Time) {
return StatusCompleted
}
if now.After(s.FinalsStartDate.Time) {
return StatusFinals
}
if !s.EndDate.IsZero() && now.After(s.EndDate.Time) {
return StatusFinalsSoon
}
return StatusInProgress
}
if !s.EndDate.IsZero() && now.After(s.EndDate.Time) {
return StatusCompleted
}
return StatusInProgress
}
// GetDefaultTab returns the default tab to show based on the season status
func (s *Season) GetDefaultTab() string {
switch s.GetStatus() {
case StatusInProgress:
return "table"
case StatusUpcoming:
return "teams"
default:
return "finals"
}
}
func (s *Season) HasLeague(league *League) bool {
for _, league_ := range s.Leagues {
if league_.ID == league.ID {
return true
}
}
return false
}
func (s *Season) GetLeague(leagueShortName string) (*League, error) {
for _, league := range s.Leagues {
if league.ShortName == leagueShortName {
return &league, nil
}
}
return nil, BadRequestNotAssociated("season", "league",
"id", "short_name", s.ID, leagueShortName)
}

111
internal/db/seasonleague.go Normal file
View File

@@ -0,0 +1,111 @@
package db
import (
"context"
"git.haelnorr.com/h/oslstats/internal/permissions"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type SeasonLeague struct {
SeasonID int `bun:",pk"`
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
LeagueID int `bun:",pk"`
League *League `bun:"rel:belongs-to,join:league_id=id"`
}
// GetSeasonLeague retrieves a specific season-league combination with teams
func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string) (*Season, *League, []*Team, error) {
if seasonShortName == "" {
return nil, nil, nil, errors.New("season short_name cannot be empty")
}
if leagueShortName == "" {
return nil, nil, nil, errors.New("league short_name cannot be empty")
}
season, err := GetSeason(ctx, tx, seasonShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetSeason")
}
league, err := season.GetLeague(leagueShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "season.GetLeague")
}
// Get all teams participating in this season+league
var teams []*Team
err = tx.NewSelect().
Model(&teams).
Join("INNER JOIN team_participations AS tp ON tp.team_id = t.id").
Where("tp.season_id = ? AND tp.league_id = ?", season.ID, league.ID).
Order("t.name ASC").
Scan(ctx)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "tx.Select teams")
}
return season, league, teams, nil
}
func NewSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShortName string, audit *AuditMeta) error {
season, err := GetSeason(ctx, tx, seasonShortName)
if err != nil {
return errors.Wrap(err, "GetSeason")
}
league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil {
return errors.Wrap(err, "GetLeague")
}
if season.HasLeague(league) {
return BadRequestAssociated("season", "league",
"id", "id", season.ID, league.ID)
}
seasonLeague := &SeasonLeague{
SeasonID: season.ID,
LeagueID: league.ID,
}
info := &AuditInfo{
string(permissions.SeasonsAddLeague),
"season",
season.ID,
map[string]any{"league_id": league.ID},
}
err = Insert(tx, seasonLeague).WithAudit(audit, info).Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
return nil
}
func (s *Season) RemoveLeague(ctx context.Context, tx bun.Tx, leagueShortName string, audit *AuditMeta) error {
league, err := s.GetLeague(leagueShortName)
if err != nil {
return errors.Wrap(err, "s.GetLeague")
}
info := &AuditInfo{
string(permissions.SeasonsRemoveLeague),
"season",
s.ID,
map[string]any{"league_id": league.ID},
}
err = DeleteItem[SeasonLeague](tx).
Where("season_id = ?", s.ID).
Where("league_id = ?", league.ID).
WithAudit(audit, info).
Delete(ctx)
if err != nil {
return errors.Wrap(err, "db.DeleteItem")
}
return nil
}
func (t *Team) InTeams(teams []*Team) bool {
for _, team := range teams {
if t.ID == team.ID {
return true
}
}
return false
}

56
internal/db/setup.go Normal file
View File

@@ -0,0 +1,56 @@
package db
import (
"database/sql"
"fmt"
"time"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)
type DB struct {
*bun.DB
}
func (db *DB) Close() error {
return db.DB.Close()
}
func (db *DB) RegisterModels() []any {
models := []any{
(*RolePermission)(nil),
(*UserRole)(nil),
(*SeasonLeague)(nil),
(*TeamParticipation)(nil),
(*User)(nil),
(*DiscordToken)(nil),
(*Season)(nil),
(*League)(nil),
(*Team)(nil),
(*Role)(nil),
(*Permission)(nil),
(*AuditLog)(nil),
(*Fixture)(nil),
}
db.RegisterModel(models...)
return models
}
func NewDB(cfg *Config) *DB {
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DB, cfg.SSL)
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
sqldb.SetMaxOpenConns(25)
sqldb.SetMaxIdleConns(10)
sqldb.SetConnMaxLifetime(5 * time.Minute)
sqldb.SetConnMaxIdleTime(5 * time.Minute)
db := &DB{
bun.NewDB(sqldb, pgdialect.New()),
}
db.RegisterModels()
return db
}

73
internal/db/team.go Normal file
View File

@@ -0,0 +1,73 @@
package db
import (
"context"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Team struct {
bun.BaseModel `bun:"table:teams,alias:t"`
ID int `bun:"id,pk,autoincrement" json:"id"`
Name string `bun:"name,unique,notnull" json:"name"`
ShortName string `bun:"short_name,notnull,unique:short_names" json:"short_name"`
AltShortName string `bun:"alt_short_name,notnull,unique:short_names" json:"alt_short_name"`
Color string `bun:"color" json:"color,omitempty"`
Seasons []Season `bun:"m2m:team_participations,join:Team=Season" json:"-"`
Leagues []League `bun:"m2m:team_participations,join:Team=League" json:"-"`
}
func NewTeam(ctx context.Context, tx bun.Tx, name, shortName, altShortName, color string, audit *AuditMeta) (*Team, error) {
team := &Team{
Name: name,
ShortName: shortName,
AltShortName: altShortName,
Color: color,
}
err := Insert(tx, team).
WithAudit(audit, nil).Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "db.Insert")
}
return team, nil
}
func ListTeams(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Team], error) {
defaults := &PageOpts{
1,
10,
bun.OrderAsc,
"name",
}
return GetList[Team](tx).GetPaged(ctx, pageOpts, defaults)
}
func GetTeam(ctx context.Context, tx bun.Tx, id int) (*Team, error) {
if id == 0 {
return nil, errors.New("id not provided")
}
return GetByID[Team](tx, id).Relation("Seasons").Relation("Leagues").Get(ctx)
}
func TeamShortNamesUnique(ctx context.Context, tx bun.Tx, shortName, altShortName string) (bool, error) {
// Check if this combination of short_name and alt_short_name exists
count, err := tx.NewSelect().
Model((*Team)(nil)).
Where("short_name = ? AND alt_short_name = ?", shortName, altShortName).
Count(ctx)
if err != nil {
return false, errors.Wrap(err, "tx.Select")
}
return count == 0, nil
}
func (t *Team) InSeason(seasonID int) bool {
for _, season := range t.Seasons {
if season.ID == seasonID {
return true
}
}
return false
}

View File

@@ -0,0 +1,56 @@
package db
import (
"context"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type TeamParticipation struct {
SeasonID int `bun:",pk,unique:season_team"`
Season *Season `bun:"rel:belongs-to,join:season_id=id"`
LeagueID int `bun:",pk"`
League *League `bun:"rel:belongs-to,join:league_id=id"`
TeamID int `bun:",pk,unique:season_team"`
Team *Team `bun:"rel:belongs-to,join:team_id=id"`
}
func NewTeamParticipation(ctx context.Context, tx bun.Tx,
seasonShortName, leagueShortName string, teamID int, audit *AuditMeta,
) (*Team, *Season, *League, error) {
season, err := GetSeason(ctx, tx, seasonShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetSeason")
}
league, err := season.GetLeague(leagueShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "season.GetLeague")
}
team, err := GetTeam(ctx, tx, teamID)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetTeam")
}
if team.InSeason(season.ID) {
return nil, nil, nil, BadRequestAssociated("season", "team",
"id", "id", season.ID, team.ID)
}
participation := &TeamParticipation{
SeasonID: season.ID,
LeagueID: league.ID,
TeamID: team.ID,
}
info := &AuditInfo{
"teams.join_season",
"team",
teamID,
map[string]any{"season_id": season.ID, "league_id": league.ID},
}
err = Insert(tx, participation).
WithAudit(audit, info).Exec(ctx)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "db.Insert")
}
return team, season, league, nil
}

113
internal/db/txhelpers.go Normal file
View File

@@ -0,0 +1,113 @@
package db
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/throw"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// TxFunc is a function that runs within a database transaction
type (
TxFunc func(ctx context.Context, tx bun.Tx) (bool, error)
TxFuncSilent func(ctx context.Context, tx bun.Tx) error
)
var timeout = 15 * time.Second
// WithReadTx executes a read-only transaction with automatic rollback
// Returns true if successful, false if error was thrown to client
func (db *DB) WithReadTx(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
fn TxFunc,
) bool {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
ok, err := db.withTx(ctx, fn, false)
if err != nil {
throw.InternalServiceError(s, w, r, "Database error", err)
}
return ok
}
// WithTxFailSilently executes a transaction with automatic rollback
// Returns true if successful, false if error occured.
// Does not throw any errors to the client.
func (db *DB) WithTxFailSilently(
ctx context.Context,
fn TxFuncSilent,
) error {
fnc := func(ctx context.Context, tx bun.Tx) (bool, error) {
err := fn(ctx, tx)
return err == nil, err
}
_, err := db.withTx(ctx, fnc, true)
return err
}
// WithWriteTx executes a write transaction with automatic rollback on error
// Commits only if fn returns nil. Returns true if successful.
func (db *DB) WithWriteTx(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
fn TxFunc,
) bool {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
ok, err := db.withTx(ctx, fn, true)
if err != nil {
throw.InternalServiceError(s, w, r, "Database error", err)
}
return ok
}
// WithNotifyTx executes a transaction with notification-based error handling
// Uses notifyInternalServiceError instead of throwInternalServiceError
func (db *DB) WithNotifyTx(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
fn TxFunc,
) bool {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
ok, err := db.withTx(ctx, fn, true)
if err != nil {
notify.InternalServiceError(s, w, r, "Database error", err)
}
return ok
}
// withTx executes a transaction with automatic rollback on error
func (db *DB) withTx(
ctx context.Context,
fn TxFunc,
write bool,
) (bool, error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return false, errors.Wrap(err, "conn.BeginTx")
}
defer func() { _ = tx.Rollback() }()
ok, err := fn(ctx, tx)
if err != nil || !ok {
return false, err
}
if write {
err = tx.Commit()
if err != nil {
return false, errors.Wrap(err, "tx.Commit")
}
} else {
_ = tx.Commit()
}
return true, nil
}

122
internal/db/update.go Normal file
View File

@@ -0,0 +1,122 @@
package db
import (
"context"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type updater[T any] struct {
tx bun.Tx
q *bun.UpdateQuery
model *T
columns []string
audit *AuditMeta
auditInfo *AuditInfo
}
// Update creates an updater for a model
// You must specify which columns to update via .Column() or use .WherePK()
func Update[T any](tx bun.Tx, model *T) *updater[T] {
if model == nil {
panic("model cannot be nil")
}
return &updater[T]{
tx: tx,
q: tx.NewUpdate().Model(model),
model: model,
}
}
// UpdateByID creates an updater with an ID where clause
// You must still specify which columns to update via .Column()
func UpdateByID[T any](tx bun.Tx, id int, model *T) *updater[T] {
if id <= 0 {
panic("id must be positive")
}
return Update(tx, model).Where("id = ?", id)
}
// Column specifies which columns to update
// Example: .Column("start_date", "end_date")
func (u *updater[T]) Column(columns ...string) *updater[T] {
u.columns = append(u.columns, columns...)
u.q = u.q.Column(columns...)
return u
}
// Where adds a WHERE clause
// Example: .Where("id = ?", 123)
func (u *updater[T]) Where(query string, args ...any) *updater[T] {
u.q = u.q.Where(query, args...)
return u
}
// WherePK adds a WHERE clause on the primary key
// The model must have its primary key field populated
func (u *updater[T]) WherePK() *updater[T] {
u.q = u.q.WherePK()
return u
}
// Set adds a raw SET clause for complex updates
// Example: .Set("updated_at = NOW()")
func (u *updater[T]) Set(query string, args ...any) *updater[T] {
u.q = u.q.Set(query, args...)
return u
}
// WithAudit enables audit logging for this update operation
// If the provided *AuditInfo is nil, will use reflection to automatically work out the details
func (u *updater[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *updater[T] {
u.audit = meta
u.auditInfo = info
return u
}
// Exec executes the update and optionally logs to audit
// Returns an error if update fails or if audit callback fails (triggering rollback)
func (u *updater[T]) Exec(ctx context.Context) error {
// Build audit details BEFORE update (captures changed fields)
var details map[string]any
if u.audit != nil && len(u.columns) > 0 {
details = extractChangedFields(u.model, u.columns)
}
// Execute update
result, err := u.q.Exec(ctx)
if err != nil {
return errors.Wrap(err, "bun.UpdateQuery.Exec")
}
rows, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "result.RowsAffected")
}
if rows == 0 {
resource := extractResourceType(extractTableName[T]())
return BadRequestNotFound(resource, "id", extractPrimaryKey(u.model))
}
// Handle audit logging if enabled
if u.audit != nil {
if u.auditInfo == nil {
tableName := extractTableName[T]()
resourceType := extractResourceType(tableName)
action := buildAction(resourceType, "update")
u.auditInfo = &AuditInfo{
Action: action,
ResourceType: resourceType,
ResourceID: extractPrimaryKey(u.model),
Details: details, // Changed fields only
}
}
err = LogSuccess(ctx, u.tx, u.audit, u.auditInfo)
if err != nil {
return errors.Wrap(err, "LogSuccess")
}
}
return nil
}

View File

@@ -2,46 +2,35 @@ package db
import ( import (
"context" "context"
"fmt"
"time" "time"
"git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
var CurrentUser hwsauth.ContextLoader[*User]
type User struct { type User struct {
bun.BaseModel `bun:"table:users,alias:u"` bun.BaseModel `bun:"table:users,alias:u"`
ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key) ID int `bun:"id,pk,autoincrement" json:"id"`
Username string `bun:"username,unique"` // Username (unique) Username string `bun:"username,unique" json:"username"`
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database CreatedAt int64 `bun:"created_at" json:"created_at"`
DiscordID string `bun:"discord_id,unique"` DiscordID string `bun:"discord_id,unique" json:"discord_id"`
Roles []*Role `bun:"m2m:user_roles,join:User=Role" json:"-"`
} }
func (user *User) GetID() int { func (u *User) GetID() int {
return user.ID return u.ID
} }
// Change the user's username var CurrentUser hwsauth.ContextLoader[*User]
func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername string) error {
_, err := tx.NewUpdate().
Model(user).
Set("username = ?", newUsername).
Where("id = ?", user.ID).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.Update")
}
user.Username = newUsername
return nil
}
// CreateUser creates a new user with the given username and password // CreateUser creates a new user with the given username and password
func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) { func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User, audit *AuditMeta) (*User, error) {
if discorduser == nil { if discorduser == nil {
return nil, errors.New("user cannot be nil") return nil, errors.New("user cannot be nil")
} }
@@ -50,81 +39,113 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
DiscordID: discorduser.ID, DiscordID: discorduser.ID,
} }
audit.u = user
_, err := tx.NewInsert(). err := Insert(tx, user).
Model(user). WithAudit(audit, nil).
Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.Insert") return nil, errors.Wrap(err, "db.Insert")
} }
return user, nil return user, nil
} }
// GetUserByID queries the database for a user matching the given ID // GetUserByID queries the database for a user matching the given ID
// Returns nil, nil if no user is found // Returns a BadRequestNotFound error if no user is found
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
fmt.Printf("user id requested: %v", id) return GetByID[User](tx, id).Get(ctx)
user := new(User)
err := tx.NewSelect().
Model(user).
Where("id = ?", id).
Limit(1).
Scan(ctx)
if err != nil {
if err.Error() == "sql: no rows in result set" {
return nil, nil
}
return nil, errors.Wrap(err, "tx.Select")
}
return user, nil
} }
// GetUserByUsername queries the database for a user matching the given username // GetUserByUsername queries the database for a user matching the given username
// Returns nil, nil if no user is found // Returns a BadRequestNotFound error if no user is found
func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) { func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) {
user := new(User) if username == "" {
err := tx.NewSelect(). return nil, errors.New("username not provided")
Model(user).
Where("username = ?", username).
Limit(1).
Scan(ctx)
if err != nil {
if err.Error() == "sql: no rows in result set" {
return nil, nil
}
return nil, errors.Wrap(err, "tx.Select")
} }
return user, nil return GetByField[User](tx, "username", username).Get(ctx)
} }
// GetUserByDiscordID queries the database for a user matching the given discord id // GetUserByDiscordID queries the database for a user matching the given discord id
// Returns nil, nil if no user is found // Returns a BadRequestNotFound error if no user is found
func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) { func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) {
user := new(User) if discordID == "" {
err := tx.NewSelect(). return nil, errors.New("discord_id not provided")
Model(user).
Where("discord_id = ?", discordID).
Limit(1).
Scan(ctx)
if err != nil {
if err.Error() == "sql: no rows in result set" {
return nil, nil
}
return nil, errors.Wrap(err, "tx.Select")
} }
return user, nil return GetByField[User](tx, "discord_id", discordID).Get(ctx)
} }
// IsUsernameUnique checks if the given username is unique (not already taken) // GetRoles loads all the roles for this user
// Returns true if the username is available, false if it's taken func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) { if u == nil {
count, err := tx.NewSelect(). return nil, errors.New("user cannot be nil")
Model((*User)(nil)).
Where("username = ?", username).
Count(ctx)
if err != nil {
return false, errors.Wrap(err, "tx.Count")
} }
return count == 0, nil u, err := GetByField[User](tx, "id", u.ID).
Relation("Roles").Get(ctx)
if err != nil {
return nil, errors.Wrap(err, "GetByField")
}
return u.Roles, nil
}
// GetPermissions loads and returns all permissions for this user
func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
if u == nil {
return nil, errors.New("user cannot be nil")
}
return GetList[Permission](tx).
Join("JOIN role_permissions AS rp on rp.permission_id = p.id").
Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id").
Where("ur.user_id = ?", u.ID).
GetAll(ctx)
}
// HasPermission checks if user has a specific permission (including wildcard check)
func (u *User) HasPermission(ctx context.Context, tx bun.Tx, permissionName permissions.Permission) (bool, error) {
if u == nil {
return false, errors.New("user cannot be nil")
}
if permissionName == "" {
return false, errors.New("permissionName cannot be empty")
}
perms, err := u.GetPermissions(ctx, tx)
if err != nil {
return false, err
}
for _, p := range perms {
if p.Name == permissionName || p.Name == permissions.Wildcard {
return true, nil
}
}
return false, nil
}
// HasRole checks if user has a specific role
func (u *User) HasRole(ctx context.Context, tx bun.Tx, roleName roles.Role) (bool, error) {
if u == nil {
return false, errors.New("user cannot be nil")
}
return HasRole(ctx, tx, u.ID, roleName)
}
// IsAdmin is a convenience method to check if user has admin role
func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
if u == nil {
return false, errors.New("user cannot be nil")
}
return u.HasRole(ctx, tx, "admin")
}
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) {
defaults := &PageOpts{1, 50, bun.OrderAsc, "id"}
return GetList[User](tx).GetPaged(ctx, pageOpts, defaults)
}
// GetUsersWithRoles queries the database for users with their roles preloaded
func GetUsersWithRoles(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) {
defaults := &PageOpts{1, 25, bun.OrderAsc, "id"}
return GetList[User](tx).Relation("Roles").GetPaged(ctx, pageOpts, defaults)
} }

103
internal/db/userrole.go Normal file
View File

@@ -0,0 +1,103 @@
package db
import (
"context"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type UserRole struct {
UserID int `bun:",pk"`
User *User `bun:"rel:belongs-to,join:user_id=id"`
RoleID int `bun:",pk"`
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
}
// AssignRole grants a role to a user
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int, audit *AuditMeta) error {
if userID <= 0 {
return errors.New("userID must be positive")
}
if roleID <= 0 {
return errors.New("roleID must be positive")
}
userRole := &UserRole{
UserID: userID,
RoleID: roleID,
}
details := map[string]any{
"action": "grant",
"role_id": roleID,
}
info := &AuditInfo{
string(permissions.UsersManageRoles),
"user",
userID,
details,
}
err := Insert(tx, userRole).
ConflictNothing("user_id", "role_id").
WithAudit(audit, info).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
return nil
}
// RevokeRole removes a role from a user
func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int, audit *AuditMeta) error {
if userID <= 0 {
return errors.New("userID must be positive")
}
if roleID <= 0 {
return errors.New("roleID must be positive")
}
details := map[string]any{
"action": "revoke",
"role_id": roleID,
}
info := &AuditInfo{
string(permissions.UsersManageRoles),
"user",
userID,
details,
}
err := DeleteItem[UserRole](tx).
Where("user_id = ?", userID).
Where("role_id = ?", roleID).
WithAudit(audit, info).
Delete(ctx)
if err != nil {
return errors.Wrap(err, "DeleteItem")
}
return nil
}
// HasRole checks if a user has a specific role
func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (bool, error) {
if userID <= 0 {
return false, errors.New("userID must be positive")
}
if roleName == "" {
return false, errors.New("roleName cannot be empty")
}
user, err := GetByID[User](tx, userID).
Relation("Roles").Get(ctx)
if err != nil {
return false, errors.Wrap(err, "GetByID")
}
for _, role := range user.Roles {
if role.Name == roleName {
return true, nil
}
}
return false, nil
}

View File

@@ -10,26 +10,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type OAuthSession struct {
*discordgo.Session
}
func NewOAuthSession(token *Token) (*OAuthSession, error) {
session, err := discordgo.New("Bearer " + token.AccessToken)
if err != nil {
return nil, errors.Wrap(err, "discordgo.New")
}
return &OAuthSession{Session: session}, nil
}
func (s *OAuthSession) GetUser() (*discordgo.User, error) {
user, err := s.User("@me")
if err != nil {
return nil, errors.Wrap(err, "s.User")
}
return user, nil
}
// APIClient is an HTTP client wrapper that handles Discord API rate limits // APIClient is an HTTP client wrapper that handles Discord API rate limits
type APIClient struct { type APIClient struct {
cfg *Config cfg *Config
@@ -38,6 +18,7 @@ type APIClient struct {
mu sync.RWMutex mu sync.RWMutex
buckets map[string]*RateLimitState buckets map[string]*RateLimitState
trustedHost string trustedHost string
bot *BotSession
} }
// NewAPIClient creates a new Discord API client with rate limit handling // NewAPIClient creates a new Discord API client with rate limit handling
@@ -51,11 +32,20 @@ func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APICli
if trustedhost == "" { if trustedhost == "" {
return nil, errors.New("trustedhost cannot be empty") return nil, errors.New("trustedhost cannot be empty")
} }
bot, err := newBotSession(cfg)
if err != nil {
return nil, errors.Wrap(err, "newBotSession")
}
return &APIClient{ return &APIClient{
client: &http.Client{Timeout: 30 * time.Second}, client: &http.Client{Timeout: 30 * time.Second},
logger: logger, logger: logger,
buckets: make(map[string]*RateLimitState), buckets: make(map[string]*RateLimitState),
cfg: cfg, cfg: cfg,
trustedHost: trustedhost, trustedHost: trustedhost,
bot: bot,
}, nil }, nil
} }
func (api *APIClient) Ping() (*discordgo.Application, error) {
return api.bot.Application("@me")
}

22
internal/discord/bot.go Normal file
View File

@@ -0,0 +1,22 @@
package discord
import (
"github.com/bwmarrin/discordgo"
"github.com/pkg/errors"
)
type BotSession struct {
*discordgo.Session
}
func newBotSession(cfg *Config) (*BotSession, error) {
session, err := discordgo.New("Bot " + cfg.BotToken)
if err != nil {
return nil, errors.Wrap(err, "discordgo.New")
}
return &BotSession{Session: session}, nil
}
func (api *APIClient) Bot() *BotSession {
return api.bot
}

View File

@@ -12,6 +12,7 @@ type Config struct {
ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required) ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required)
OAuthScopes string // Authorisation scopes for OAuth OAuthScopes string // Authorisation scopes for OAuth
RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required) RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required)
BotToken string // ENV DISCORD_BOT_TOKEN: Token for the discord bot (required)
} }
func ConfigFromEnv() (any, error) { func ConfigFromEnv() (any, error) {
@@ -20,6 +21,7 @@ func ConfigFromEnv() (any, error) {
ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""), ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""),
OAuthScopes: getOAuthScopes(), OAuthScopes: getOAuthScopes(),
RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""), RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""),
BotToken: env.String("DISCORD_BOT_TOKEN", ""),
} }
// Check required fields // Check required fields
@@ -32,6 +34,9 @@ func ConfigFromEnv() (any, error) {
if cfg.RedirectPath == "" { if cfg.RedirectPath == "" {
return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH") return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH")
} }
if cfg.BotToken == "" {
return nil, errors.New("Envar not set: DISCORD_BOT_TOKEN")
}
return cfg, nil return cfg, nil
} }

View File

@@ -8,9 +8,30 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/bwmarrin/discordgo"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type OAuthSession struct {
*discordgo.Session
}
func NewOAuthSession(token *Token) (*OAuthSession, error) {
session, err := discordgo.New("Bearer " + token.AccessToken)
if err != nil {
return nil, errors.Wrap(err, "discordgo.New")
}
return &OAuthSession{Session: session}, nil
}
func (s *OAuthSession) GetUser() (*discordgo.User, error) {
user, err := s.User("@me")
if err != nil {
return nil, errors.Wrap(err, "s.User")
}
return user, nil
}
// Token represents a response from the Discord OAuth API after a successful authorization request // Token represents a response from the Discord OAuth API after a successful authorization request
type Token struct { type Token struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`

View File

@@ -0,0 +1,21 @@
// Package embedfs creates an embedded filesystem with the static web assets
package embedfs
import (
"embed"
"io/fs"
"github.com/pkg/errors"
)
//go:embed web/*
var embeddedFiles embed.FS
// GetEmbeddedFS gets the embedded files
func GetEmbeddedFS() (*fs.FS, error) {
subFS, err := fs.Sub(embeddedFiles, "web")
if err != nil {
return nil, errors.Wrap(err, "fs.Sub")
}
return &subFS, nil
}

View File

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

View File

@@ -0,0 +1,151 @@
/* Flatpickr Catppuccin Mocha Theme */
/* Override flatpickr colors to match our custom theme */
.flatpickr-calendar {
background: #1e1e2e; /* mantle */
border: 1px solid #45475a; /* surface1 */
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.3);
}
.flatpickr-months {
background: #181825; /* base */
border-bottom: 1px solid #45475a; /* surface1 */
}
.flatpickr-month {
color: #cdd6f4; /* text */
}
.flatpickr-current-month .flatpickr-monthDropdown-months {
background: #1e1e2e; /* mantle */
color: #cdd6f4; /* text */
border: 1px solid #45475a; /* surface1 */
}
.flatpickr-current-month .flatpickr-monthDropdown-months:hover {
background: #313244; /* surface0 */
}
.flatpickr-current-month input.cur-year {
color: #cdd6f4; /* text */
background: #1e1e2e; /* mantle */
}
.flatpickr-current-month input.cur-year:hover {
background: #313244; /* surface0 */
}
.flatpickr-prev-month,
.flatpickr-next-month {
color: #cdd6f4; /* text */
}
.flatpickr-prev-month:hover,
.flatpickr-next-month:hover {
color: #89b4fa; /* blue */
}
.flatpickr-weekdays {
background: #181825; /* base */
border-bottom: 1px solid #45475a; /* surface1 */
}
span.flatpickr-weekday {
color: #bac2de; /* subtext0 */
font-weight: 600;
}
.flatpickr-days {
background: #1e1e2e; /* mantle */
}
.flatpickr-day {
color: #cdd6f4; /* text */
border: 1px solid transparent;
}
.flatpickr-day.today {
border-color: #89b4fa; /* blue */
background: #89b4fa20; /* blue with transparency */
color: #89b4fa; /* blue */
}
.flatpickr-day.today:hover {
background: #89b4fa40; /* blue with more transparency */
border-color: #89b4fa; /* blue */
color: #89b4fa; /* blue */
}
.flatpickr-day.selected,
.flatpickr-day.startRange,
.flatpickr-day.endRange {
background: #89b4fa; /* blue */
border-color: #89b4fa; /* blue */
color: #181825; /* base */
}
.flatpickr-day.selected:hover,
.flatpickr-day.startRange:hover,
.flatpickr-day.endRange:hover {
background: #74a7f9; /* slightly lighter blue */
border-color: #74a7f9;
}
.flatpickr-day:hover {
background: #313244; /* surface0 */
border-color: #45475a; /* surface1 */
}
.flatpickr-day.prevMonthDay,
.flatpickr-day.nextMonthDay {
color: #585b70; /* surface2 */
}
.flatpickr-day.flatpickr-disabled,
.flatpickr-day.flatpickr-disabled:hover {
color: #585b70; /* surface2 */
cursor: not-allowed;
}
.flatpickr-day.inRange {
background: #89b4fa30; /* blue with light transparency */
border-color: transparent;
box-shadow: -5px 0 0 #89b4fa30, 5px 0 0 #89b4fa30;
}
.flatpickr-time {
background: #181825; /* base */
border-top: 1px solid #45475a; /* surface1 */
}
.flatpickr-time input {
color: #cdd6f4; /* text */
background: #1e1e2e; /* mantle */
}
.flatpickr-time input:hover,
.flatpickr-time input:focus {
background: #313244; /* surface0 */
}
.flatpickr-time .flatpickr-time-separator,
.flatpickr-time .flatpickr-am-pm {
color: #cdd6f4; /* text */
}
.flatpickr-time .flatpickr-am-pm:hover,
.flatpickr-time .flatpickr-am-pm:focus {
background: #313244; /* surface0 */
}
.flatpickr-time .numInputWrapper span.arrowUp:after {
border-bottom-color: #cdd6f4; /* text */
}
.flatpickr-time .numInputWrapper span.arrowDown:after {
border-top-color: #cdd6f4; /* text */
}
.flatpickr-time .numInputWrapper span:hover {
background: #313244; /* surface0 */
}

View File

@@ -15,11 +15,14 @@
--color-maroon: var(--maroon); --color-maroon: var(--maroon);
--color-peach: var(--peach); --color-peach: var(--peach);
--color-yellow: var(--yellow); --color-yellow: var(--yellow);
--color-dark-yellow: var(--dark-yellow);
--color-green: var(--green); --color-green: var(--green);
--color-dark-green: var(--dark-green);
--color-teal: var(--teal); --color-teal: var(--teal);
--color-sky: var(--sky); --color-sky: var(--sky);
--color-sapphire: var(--sapphire); --color-sapphire: var(--sapphire);
--color-blue: var(--blue); --color-blue: var(--blue);
--color-dark-blue: var(--dark-blue);
--color-lavender: var(--lavender); --color-lavender: var(--lavender);
--color-text: var(--text); --color-text: var(--text);
--color-subtext1: var(--subtext1); --color-subtext1: var(--subtext1);
@@ -45,11 +48,14 @@
--maroon: hsl(355, 76%, 59%); --maroon: hsl(355, 76%, 59%);
--peach: hsl(22, 99%, 52%); --peach: hsl(22, 99%, 52%);
--yellow: hsl(35, 77%, 49%); --yellow: hsl(35, 77%, 49%);
--dark-yellow: hsl(35, 50%, 85%);
--green: hsl(109, 58%, 40%); --green: hsl(109, 58%, 40%);
--dark-green: hsl(109, 35%, 85%);
--teal: hsl(183, 74%, 35%); --teal: hsl(183, 74%, 35%);
--sky: hsl(197, 97%, 46%); --sky: hsl(197, 97%, 46%);
--sapphire: hsl(189, 70%, 42%); --sapphire: hsl(189, 70%, 42%);
--blue: hsl(220, 91%, 54%); --blue: hsl(220, 91%, 54%);
--dark-blue: hsl(220, 50%, 85%);
--lavender: hsl(231, 97%, 72%); --lavender: hsl(231, 97%, 72%);
--text: hsl(234, 16%, 35%); --text: hsl(234, 16%, 35%);
--subtext1: hsl(233, 13%, 41%); --subtext1: hsl(233, 13%, 41%);
@@ -75,11 +81,14 @@
--maroon: hsl(350, 65%, 77%); --maroon: hsl(350, 65%, 77%);
--peach: hsl(23, 92%, 75%); --peach: hsl(23, 92%, 75%);
--yellow: hsl(41, 86%, 83%); --yellow: hsl(41, 86%, 83%);
--dark-yellow: hsl(41, 30%, 25%);
--green: hsl(115, 54%, 76%); --green: hsl(115, 54%, 76%);
--dark-green: hsl(115, 25%, 22%);
--teal: hsl(170, 57%, 73%); --teal: hsl(170, 57%, 73%);
--sky: hsl(189, 71%, 73%); --sky: hsl(189, 71%, 73%);
--sapphire: hsl(199, 76%, 69%); --sapphire: hsl(199, 76%, 69%);
--blue: hsl(217, 92%, 76%); --blue: hsl(217, 92%, 76%);
--dark-blue: hsl(217, 30%, 25%);
--lavender: hsl(232, 97%, 85%); --lavender: hsl(232, 97%, 85%);
--text: hsl(226, 64%, 88%); --text: hsl(226, 64%, 88%);
--subtext1: hsl(227, 35%, 80%); --subtext1: hsl(227, 35%, 80%);
@@ -118,3 +127,74 @@
font-weight: 700; font-weight: 700;
font-style: italic; font-style: italic;
} }
/* Custom Scrollbar Styles - Catppuccin Theme */
/* Firefox */
* {
scrollbar-width: thin;
scrollbar-color: var(--surface1) var(--mantle);
}
/* Webkit browsers (Chrome, Safari, Edge) */
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: var(--mantle);
border-radius: 4px;
}
::-webkit-scrollbar-thumb {
background: var(--surface1);
border-radius: 4px;
border: 2px solid var(--mantle);
}
::-webkit-scrollbar-thumb:hover {
background: var(--surface2);
}
::-webkit-scrollbar-thumb:active {
background: var(--overlay0);
}
/* Specific styling for multi-select dropdowns */
.multi-select-dropdown::-webkit-scrollbar {
width: 6px;
}
.multi-select-dropdown::-webkit-scrollbar-track {
background: var(--base);
border-radius: 3px;
}
.multi-select-dropdown::-webkit-scrollbar-thumb {
background: var(--surface2);
border-radius: 3px;
border: 1px solid var(--base);
}
.multi-select-dropdown::-webkit-scrollbar-thumb:hover {
background: var(--overlay0);
}
/* Specific styling for modal content */
.modal-scrollable::-webkit-scrollbar {
width: 8px;
}
.modal-scrollable::-webkit-scrollbar-track {
background: var(--base);
}
.modal-scrollable::-webkit-scrollbar-thumb {
background: var(--surface1);
border-radius: 4px;
}
.modal-scrollable::-webkit-scrollbar-thumb:hover {
background: var(--surface2);
}

View File

@@ -0,0 +1,263 @@
// Admin dashboard utilities
// Format JSON for display in modals
function formatJSON(json) {
try {
const parsed = typeof json === "string" ? JSON.parse(json) : json;
return JSON.stringify(parsed, null, 2);
} catch (e) {
return json;
}
}
// Initialize flatpickr for all date inputs
function initFlatpickr() {
document.querySelectorAll(".flatpickr-date").forEach(function (input) {
if (!input._flatpickr) {
flatpickr(input, {
dateFormat: "d/m/Y",
allowInput: true,
});
}
});
}
// Submit the audit filter form with specific page/perPage/order params
function submitAuditFilter(page, perPage, order, orderBy) {
const form = document.getElementById("audit-filters-form");
if (!form) return;
// Create hidden inputs for pagination/sorting if they don't exist
let pageInput = form.querySelector('input[name="page"]');
if (!pageInput) {
pageInput = document.createElement("input");
pageInput.type = "hidden";
pageInput.name = "page";
form.appendChild(pageInput);
}
pageInput.value = page;
let perPageInput = form.querySelector('input[name="per_page"]');
if (!perPageInput) {
perPageInput = document.createElement("input");
perPageInput.type = "hidden";
perPageInput.name = "per_page";
form.appendChild(perPageInput);
}
perPageInput.value = perPage;
let orderInput = form.querySelector('input[name="order"]');
if (!orderInput) {
orderInput = document.createElement("input");
orderInput.type = "hidden";
orderInput.name = "order";
form.appendChild(orderInput);
}
orderInput.value = order;
let orderByInput = form.querySelector('input[name="order_by"]');
if (!orderByInput) {
orderByInput = document.createElement("input");
orderByInput.type = "hidden";
orderByInput.name = "order_by";
form.appendChild(orderByInput);
}
orderByInput.value = orderBy;
htmx.trigger(form, "submit");
}
// Sort by column - toggle direction if same column
function sortAuditColumn(field, currentOrder, currentOrderBy) {
const page = 1; // Reset to first page when sorting
const perPageSelect = document.getElementById("per-page-select");
const perPage = perPageSelect ? parseInt(perPageSelect.value) || 25 : 25;
let newOrder, newOrderBy;
if (currentOrderBy === field) {
// Toggle order
newOrder = currentOrder === "ASC" ? "DESC" : "ASC";
newOrderBy = field;
} else {
// New column, default to DESC
newOrder = "DESC";
newOrderBy = field;
}
submitAuditFilter(page, perPage, newOrder, newOrderBy);
}
// Clear all audit filters
function clearAuditFilters() {
const form = document.getElementById("audit-filters-form");
if (!form) return;
form.reset();
// Clear flatpickr instances
document.querySelectorAll(".flatpickr-date").forEach(function (input) {
var fp = input._flatpickr;
if (fp) fp.clear();
});
// Clear multi-select dropdowns
document.querySelectorAll(".multi-select-container").forEach(function (container) {
var hiddenInput = container.querySelector('input[type="hidden"]');
if (hiddenInput) hiddenInput.value = "";
var selectedDisplay = container.querySelector(".multi-select-selected");
if (selectedDisplay)
selectedDisplay.innerHTML = '<span class="text-subtext1">Select...</span>';
container.querySelectorAll(".multi-select-option").forEach(function (opt) {
opt.classList.remove("bg-blue", "text-mantle");
opt.classList.add("hover:bg-surface1");
});
});
// Trigger form submission with reset pagination
submitAuditFilter(1, 25, "DESC", "created_at");
}
// Toggle multi-select dropdown visibility
function toggleMultiSelect(containerId) {
var dropdown = document.getElementById(containerId + "-dropdown");
if (dropdown) {
dropdown.classList.toggle("hidden");
}
}
// Toggle multi-select option selection
function toggleMultiSelectOption(containerId, value, label) {
var container = document.getElementById(containerId);
var hiddenInput = container.querySelector('input[type="hidden"]');
var selectedDisplay = container.querySelector(".multi-select-selected");
var values = hiddenInput.value ? hiddenInput.value.split(",") : [];
var index = values.indexOf(value);
if (index > -1) {
values.splice(index, 1);
} else {
values.push(value);
}
hiddenInput.value = values.join(",");
var option = container.querySelector('[data-value="' + value + '"]');
if (option) {
if (index > -1) {
option.classList.remove("bg-blue", "text-mantle");
option.classList.add("hover:bg-surface1");
} else {
option.classList.add("bg-blue", "text-mantle");
option.classList.remove("hover:bg-surface1");
}
}
if (values.length === 0) {
selectedDisplay.innerHTML = '<span class="text-subtext1">Select...</span>';
} else if (values.length === 1) {
selectedDisplay.innerHTML = "<span>" + label + "</span>";
} else {
selectedDisplay.innerHTML = "<span>" + values.length + " selected</span>";
}
// Trigger form submission
document.getElementById("audit-filters-form").requestSubmit();
}
// Submit the users page with specific page/perPage/order params
function submitUsersPage(page, perPage, order, orderBy) {
const formData = new FormData();
formData.append("page", page);
formData.append("per_page", perPage);
formData.append("order", order);
formData.append("order_by", orderBy);
htmx.ajax("POST", "/admin/users", {
target: "#users-list-container",
swap: "outerHTML",
values: Object.fromEntries(formData),
});
}
// Sort users column - toggle direction if same column
function sortUsersColumn(field, currentOrder, currentOrderBy) {
const page = 1; // Reset to first page when sorting
const perPageSelect = document.getElementById("users-per-page-select");
const perPage = perPageSelect ? parseInt(perPageSelect.value) || 25 : 25;
let newOrder, newOrderBy;
if (currentOrderBy === field) {
// Toggle order
newOrder = currentOrder === "ASC" ? "DESC" : "ASC";
newOrderBy = field;
} else {
// New column, default to ASC
newOrder = "ASC";
newOrderBy = field;
}
submitUsersPage(page, perPage, newOrder, newOrderBy);
}
// Submit the roles page with specific page/perPage/order params
function submitRolesPage(page, perPage, order, orderBy) {
const formData = new FormData();
formData.append("page", page);
formData.append("per_page", perPage);
formData.append("order", order);
formData.append("order_by", orderBy);
htmx.ajax("POST", "/admin/roles", {
target: "#roles-list-container",
swap: "outerHTML",
values: Object.fromEntries(formData),
});
}
// Sort roles column - toggle direction if same column
function sortRolesColumn(field, currentOrder, currentOrderBy) {
const page = 1; // Reset to first page when sorting
const perPageSelect = document.getElementById("roles-per-page-select");
const perPage = perPageSelect ? parseInt(perPageSelect.value) || 25 : 25;
let newOrder, newOrderBy;
if (currentOrderBy === field) {
// Toggle order
newOrder = currentOrder === "ASC" ? "DESC" : "ASC";
newOrderBy = field;
} else {
// New column, default to ASC
newOrder = "ASC";
newOrderBy = field;
}
submitRolesPage(page, perPage, newOrder, newOrderBy);
}
// Handle HTMX navigation and initialization
// Tab navigation active state is handled by tabs.js (generic).
// This file only handles admin-specific concerns (flatpickr, multi-select).
document.addEventListener("DOMContentLoaded", function () {
// Initialize flatpickr on page load
initFlatpickr();
document.body.addEventListener("htmx:afterSwap", function (event) {
// Re-initialize flatpickr after admin content swap
if (
event.detail.target.id === "admin-content" ||
event.detail.target.id === "audit-results-container"
) {
initFlatpickr();
}
});
// Close multi-select dropdowns when clicking outside
document.addEventListener("click", function (evt) {
if (!evt.target.closest(".multi-select-container")) {
document.querySelectorAll(".multi-select-dropdown").forEach(function (d) {
d.classList.add("hidden");
});
}
});
});

View File

@@ -0,0 +1,17 @@
function copyToClipboard(elementId, buttonId) {
const element = document.getElementById(elementId);
const button = document.getElementById(buttonId);
navigator.clipboard.writeText(element.innerText)
.then(() => {
const originalText = button.innerText;
button.innerText = 'Copied!';
setTimeout(() => {
button.innerText = originalText;
}, 2000);
})
.catch(err => {
console.error('Failed to copy:', err);
button.innerText = 'Failed';
});
}

View File

@@ -0,0 +1,58 @@
function paginateData(
formID,
rootPath,
initPage,
initPerPage,
initOrder,
initOrderBy,
) {
return {
page: initPage,
perPage: initPerPage,
order: initOrder || "ASC",
orderBy: initOrderBy || "name",
goToPage(n) {
this.page = n;
this.submit();
},
handleSortChange(value) {
const [field, direction] = value.split("|");
this.orderBy = field;
this.order = direction;
this.page = 1; // Reset to first page when sorting
this.submit();
},
sortByColumn(field) {
if (this.orderBy === field) {
// Toggle order if same column
this.order = this.order === "ASC" ? "DESC" : "ASC";
} else {
// New column, default to DESC
this.orderBy = field;
this.order = "DESC";
}
this.page = 1; // Reset to first page when sorting
this.submit();
},
setPerPage(n) {
this.perPage = n;
this.page = 1; // Reset to first page when changing per page
this.submit();
},
submit() {
var url = `${rootPath}?page=${this.page}&per_page=${this.perPage}&order=${this.order}&order_by=${this.orderBy}`;
htmx.find("#pagination-page").value = this.page;
htmx.find("#pagination-per-page").value = this.perPage;
htmx.find("#sort-order").value = this.order;
htmx.find("#sort-order-by").value = this.orderBy;
htmx.find(`#${formID}`).setAttribute("hx-post", url);
htmx.process(`#${formID}`);
htmx.trigger(`#${formID}`, "submit");
},
};
}

View File

@@ -0,0 +1,50 @@
// Generic tab navigation handler
// Manages active tab styling after HTMX content swaps.
//
// Usage: Add data-tab-nav="<content-target-id>" to your <nav> element.
// Tab links inside the nav should have href attributes ending with the section name.
//
// Example:
// <nav data-tab-nav="admin-content">
// <a href="/admin/users" hx-post="/admin/users" hx-target="#admin-content">Users</a>
// </nav>
// <main id="admin-content">...</main>
(function () {
var activeClasses = ["border-blue", "text-blue", "font-semibold"];
var inactiveClasses = [
"border-transparent",
"text-subtext0",
"hover:text-text",
"hover:border-surface2",
];
function updateActiveTab(targetId) {
var nav = document.querySelector('[data-tab-nav="' + targetId + '"]');
if (!nav) return;
var path = window.location.pathname;
var section = path.split("/").pop() || "";
nav.querySelectorAll("a").forEach(function (link) {
var href = link.getAttribute("href");
var isActive = href && href.endsWith("/" + section);
activeClasses.forEach(function (cls) {
isActive ? link.classList.add(cls) : link.classList.remove(cls);
});
inactiveClasses.forEach(function (cls) {
isActive ? link.classList.remove(cls) : link.classList.add(cls);
});
});
}
document.addEventListener("DOMContentLoaded", function () {
document.body.addEventListener("htmx:afterSwap", function (event) {
var targetId = event.detail.target.id;
if (targetId) {
updateActiveTab(targetId);
}
});
});
})();

View File

@@ -0,0 +1,54 @@
function progressBar(toastId, duration) {
const progressBar = {
progress: 0,
paused: false,
animateToastProgress() {
const toast = document.getElementById(toastId);
if (!toast) return;
const bar = document.getElementById([toastId, "progress"].join("-"));
const startTime = performance.now();
let totalPausedTime = 0;
let pauseStartTime = null;
const animate = (currentTime) => {
const toast = document.getElementById(toastId);
if (!toast) return; // Toast was manually removed
if (this.paused) {
// Track when pause started
if (pauseStartTime === null) {
pauseStartTime = currentTime;
}
// Keep animating while paused
requestAnimationFrame(animate);
return;
} else {
// If we were paused, accumulate the paused time
if (pauseStartTime !== null) {
totalPausedTime += currentTime - pauseStartTime;
pauseStartTime = null;
}
}
// Calculate actual elapsed time (excluding paused time)
const elapsed = currentTime - startTime - totalPausedTime;
this.progress = Math.min((elapsed / duration) * 100, 100) + "%";
bar.style["width"] = this.progress;
if (elapsed >= duration) {
toast.remove()
} else {
requestAnimationFrame(animate);
}
};
requestAnimationFrame(animate);
}
};
progressBar.animateToastProgress();
return progressBar;
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,795 @@
.flatpickr-calendar {
background: transparent;
opacity: 0;
display: none;
text-align: center;
visibility: hidden;
padding: 0;
-webkit-animation: none;
animation: none;
direction: ltr;
border: 0;
font-size: 14px;
line-height: 24px;
border-radius: 5px;
position: absolute;
width: 307.875px;
-webkit-box-sizing: border-box;
box-sizing: border-box;
-ms-touch-action: manipulation;
touch-action: manipulation;
background: #3f4458;
-webkit-box-shadow: 1px 0 0 #20222c, -1px 0 0 #20222c, 0 1px 0 #20222c, 0 -1px 0 #20222c, 0 3px 13px rgba(0,0,0,0.08);
box-shadow: 1px 0 0 #20222c, -1px 0 0 #20222c, 0 1px 0 #20222c, 0 -1px 0 #20222c, 0 3px 13px rgba(0,0,0,0.08);
}
.flatpickr-calendar.open,
.flatpickr-calendar.inline {
opacity: 1;
max-height: 640px;
visibility: visible;
}
.flatpickr-calendar.open {
display: inline-block;
z-index: 99999;
}
.flatpickr-calendar.animate.open {
-webkit-animation: fpFadeInDown 300ms cubic-bezier(0.23, 1, 0.32, 1);
animation: fpFadeInDown 300ms cubic-bezier(0.23, 1, 0.32, 1);
}
.flatpickr-calendar.inline {
display: block;
position: relative;
top: 2px;
}
.flatpickr-calendar.static {
position: absolute;
top: calc(100% + 2px);
}
.flatpickr-calendar.static.open {
z-index: 999;
display: block;
}
.flatpickr-calendar.multiMonth .flatpickr-days .dayContainer:nth-child(n+1) .flatpickr-day.inRange:nth-child(7n+7) {
-webkit-box-shadow: none !important;
box-shadow: none !important;
}
.flatpickr-calendar.multiMonth .flatpickr-days .dayContainer:nth-child(n+2) .flatpickr-day.inRange:nth-child(7n+1) {
-webkit-box-shadow: -2px 0 0 #e6e6e6, 5px 0 0 #e6e6e6;
box-shadow: -2px 0 0 #e6e6e6, 5px 0 0 #e6e6e6;
}
.flatpickr-calendar .hasWeeks .dayContainer,
.flatpickr-calendar .hasTime .dayContainer {
border-bottom: 0;
border-bottom-right-radius: 0;
border-bottom-left-radius: 0;
}
.flatpickr-calendar .hasWeeks .dayContainer {
border-left: 0;
}
.flatpickr-calendar.hasTime .flatpickr-time {
height: 40px;
border-top: 1px solid #20222c;
}
.flatpickr-calendar.noCalendar.hasTime .flatpickr-time {
height: auto;
}
.flatpickr-calendar:before,
.flatpickr-calendar:after {
position: absolute;
display: block;
pointer-events: none;
border: solid transparent;
content: '';
height: 0;
width: 0;
left: 22px;
}
.flatpickr-calendar.rightMost:before,
.flatpickr-calendar.arrowRight:before,
.flatpickr-calendar.rightMost:after,
.flatpickr-calendar.arrowRight:after {
left: auto;
right: 22px;
}
.flatpickr-calendar.arrowCenter:before,
.flatpickr-calendar.arrowCenter:after {
left: 50%;
right: 50%;
}
.flatpickr-calendar:before {
border-width: 5px;
margin: 0 -5px;
}
.flatpickr-calendar:after {
border-width: 4px;
margin: 0 -4px;
}
.flatpickr-calendar.arrowTop:before,
.flatpickr-calendar.arrowTop:after {
bottom: 100%;
}
.flatpickr-calendar.arrowTop:before {
border-bottom-color: #20222c;
}
.flatpickr-calendar.arrowTop:after {
border-bottom-color: #3f4458;
}
.flatpickr-calendar.arrowBottom:before,
.flatpickr-calendar.arrowBottom:after {
top: 100%;
}
.flatpickr-calendar.arrowBottom:before {
border-top-color: #20222c;
}
.flatpickr-calendar.arrowBottom:after {
border-top-color: #3f4458;
}
.flatpickr-calendar:focus {
outline: 0;
}
.flatpickr-wrapper {
position: relative;
display: inline-block;
}
.flatpickr-months {
display: -webkit-box;
display: -webkit-flex;
display: -ms-flexbox;
display: flex;
}
.flatpickr-months .flatpickr-month {
background: #3f4458;
color: #fff;
fill: #fff;
height: 34px;
line-height: 1;
text-align: center;
position: relative;
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
overflow: hidden;
-webkit-box-flex: 1;
-webkit-flex: 1;
-ms-flex: 1;
flex: 1;
}
.flatpickr-months .flatpickr-prev-month,
.flatpickr-months .flatpickr-next-month {
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
text-decoration: none;
cursor: pointer;
position: absolute;
top: 0;
height: 34px;
padding: 10px;
z-index: 3;
color: #fff;
fill: #fff;
}
.flatpickr-months .flatpickr-prev-month.flatpickr-disabled,
.flatpickr-months .flatpickr-next-month.flatpickr-disabled {
display: none;
}
.flatpickr-months .flatpickr-prev-month i,
.flatpickr-months .flatpickr-next-month i {
position: relative;
}
.flatpickr-months .flatpickr-prev-month.flatpickr-prev-month,
.flatpickr-months .flatpickr-next-month.flatpickr-prev-month {
/*
/*rtl:begin:ignore*/
/*
*/
left: 0;
/*
/*rtl:end:ignore*/
/*
*/
}
/*
/*rtl:begin:ignore*/
/*
/*rtl:end:ignore*/
.flatpickr-months .flatpickr-prev-month.flatpickr-next-month,
.flatpickr-months .flatpickr-next-month.flatpickr-next-month {
/*
/*rtl:begin:ignore*/
/*
*/
right: 0;
/*
/*rtl:end:ignore*/
/*
*/
}
/*
/*rtl:begin:ignore*/
/*
/*rtl:end:ignore*/
.flatpickr-months .flatpickr-prev-month:hover,
.flatpickr-months .flatpickr-next-month:hover {
color: #eee;
}
.flatpickr-months .flatpickr-prev-month:hover svg,
.flatpickr-months .flatpickr-next-month:hover svg {
fill: #f64747;
}
.flatpickr-months .flatpickr-prev-month svg,
.flatpickr-months .flatpickr-next-month svg {
width: 14px;
height: 14px;
}
.flatpickr-months .flatpickr-prev-month svg path,
.flatpickr-months .flatpickr-next-month svg path {
-webkit-transition: fill 0.1s;
transition: fill 0.1s;
fill: inherit;
}
.numInputWrapper {
position: relative;
height: auto;
}
.numInputWrapper input,
.numInputWrapper span {
display: inline-block;
}
.numInputWrapper input {
width: 100%;
}
.numInputWrapper input::-ms-clear {
display: none;
}
.numInputWrapper input::-webkit-outer-spin-button,
.numInputWrapper input::-webkit-inner-spin-button {
margin: 0;
-webkit-appearance: none;
}
.numInputWrapper span {
position: absolute;
right: 0;
width: 14px;
padding: 0 4px 0 2px;
height: 50%;
line-height: 50%;
opacity: 0;
cursor: pointer;
border: 1px solid rgba(255,255,255,0.15);
-webkit-box-sizing: border-box;
box-sizing: border-box;
}
.numInputWrapper span:hover {
background: rgba(192,187,167,0.1);
}
.numInputWrapper span:active {
background: rgba(192,187,167,0.2);
}
.numInputWrapper span:after {
display: block;
content: "";
position: absolute;
}
.numInputWrapper span.arrowUp {
top: 0;
border-bottom: 0;
}
.numInputWrapper span.arrowUp:after {
border-left: 4px solid transparent;
border-right: 4px solid transparent;
border-bottom: 4px solid rgba(255,255,255,0.6);
top: 26%;
}
.numInputWrapper span.arrowDown {
top: 50%;
}
.numInputWrapper span.arrowDown:after {
border-left: 4px solid transparent;
border-right: 4px solid transparent;
border-top: 4px solid rgba(255,255,255,0.6);
top: 40%;
}
.numInputWrapper span svg {
width: inherit;
height: auto;
}
.numInputWrapper span svg path {
fill: rgba(255,255,255,0.5);
}
.numInputWrapper:hover {
background: rgba(192,187,167,0.05);
}
.numInputWrapper:hover span {
opacity: 1;
}
.flatpickr-current-month {
font-size: 135%;
line-height: inherit;
font-weight: 300;
color: inherit;
position: absolute;
width: 75%;
left: 12.5%;
padding: 7.48px 0 0 0;
line-height: 1;
height: 34px;
display: inline-block;
text-align: center;
-webkit-transform: translate3d(0px, 0px, 0px);
transform: translate3d(0px, 0px, 0px);
}
.flatpickr-current-month span.cur-month {
font-family: inherit;
font-weight: 700;
color: inherit;
display: inline-block;
margin-left: 0.5ch;
padding: 0;
}
.flatpickr-current-month span.cur-month:hover {
background: rgba(192,187,167,0.05);
}
.flatpickr-current-month .numInputWrapper {
width: 6ch;
width: 7ch\0;
display: inline-block;
}
.flatpickr-current-month .numInputWrapper span.arrowUp:after {
border-bottom-color: #fff;
}
.flatpickr-current-month .numInputWrapper span.arrowDown:after {
border-top-color: #fff;
}
.flatpickr-current-month input.cur-year {
background: transparent;
-webkit-box-sizing: border-box;
box-sizing: border-box;
color: inherit;
cursor: text;
padding: 0 0 0 0.5ch;
margin: 0;
display: inline-block;
font-size: inherit;
font-family: inherit;
font-weight: 300;
line-height: inherit;
height: auto;
border: 0;
border-radius: 0;
vertical-align: initial;
-webkit-appearance: textfield;
-moz-appearance: textfield;
appearance: textfield;
}
.flatpickr-current-month input.cur-year:focus {
outline: 0;
}
.flatpickr-current-month input.cur-year[disabled],
.flatpickr-current-month input.cur-year[disabled]:hover {
font-size: 100%;
color: rgba(255,255,255,0.5);
background: transparent;
pointer-events: none;
}
.flatpickr-current-month .flatpickr-monthDropdown-months {
appearance: menulist;
background: #3f4458;
border: none;
border-radius: 0;
box-sizing: border-box;
color: inherit;
cursor: pointer;
font-size: inherit;
font-family: inherit;
font-weight: 300;
height: auto;
line-height: inherit;
margin: -1px 0 0 0;
outline: none;
padding: 0 0 0 0.5ch;
position: relative;
vertical-align: initial;
-webkit-box-sizing: border-box;
-webkit-appearance: menulist;
-moz-appearance: menulist;
width: auto;
}
.flatpickr-current-month .flatpickr-monthDropdown-months:focus,
.flatpickr-current-month .flatpickr-monthDropdown-months:active {
outline: none;
}
.flatpickr-current-month .flatpickr-monthDropdown-months:hover {
background: rgba(192,187,167,0.05);
}
.flatpickr-current-month .flatpickr-monthDropdown-months .flatpickr-monthDropdown-month {
background-color: #3f4458;
outline: none;
padding: 0;
}
.flatpickr-weekdays {
background: transparent;
text-align: center;
overflow: hidden;
width: 100%;
display: -webkit-box;
display: -webkit-flex;
display: -ms-flexbox;
display: flex;
-webkit-box-align: center;
-webkit-align-items: center;
-ms-flex-align: center;
align-items: center;
height: 28px;
}
.flatpickr-weekdays .flatpickr-weekdaycontainer {
display: -webkit-box;
display: -webkit-flex;
display: -ms-flexbox;
display: flex;
-webkit-box-flex: 1;
-webkit-flex: 1;
-ms-flex: 1;
flex: 1;
}
span.flatpickr-weekday {
cursor: default;
font-size: 90%;
background: #3f4458;
color: #fff;
line-height: 1;
margin: 0;
text-align: center;
display: block;
-webkit-box-flex: 1;
-webkit-flex: 1;
-ms-flex: 1;
flex: 1;
font-weight: bolder;
}
.dayContainer,
.flatpickr-weeks {
padding: 1px 0 0 0;
}
.flatpickr-days {
position: relative;
overflow: hidden;
display: -webkit-box;
display: -webkit-flex;
display: -ms-flexbox;
display: flex;
-webkit-box-align: start;
-webkit-align-items: flex-start;
-ms-flex-align: start;
align-items: flex-start;
width: 307.875px;
}
.flatpickr-days:focus {
outline: 0;
}
.dayContainer {
padding: 0;
outline: 0;
text-align: left;
width: 307.875px;
min-width: 307.875px;
max-width: 307.875px;
-webkit-box-sizing: border-box;
box-sizing: border-box;
display: inline-block;
display: -ms-flexbox;
display: -webkit-box;
display: -webkit-flex;
display: flex;
-webkit-flex-wrap: wrap;
flex-wrap: wrap;
-ms-flex-wrap: wrap;
-ms-flex-pack: justify;
-webkit-justify-content: space-around;
justify-content: space-around;
-webkit-transform: translate3d(0px, 0px, 0px);
transform: translate3d(0px, 0px, 0px);
opacity: 1;
}
.dayContainer + .dayContainer {
-webkit-box-shadow: -1px 0 0 #20222c;
box-shadow: -1px 0 0 #20222c;
}
.flatpickr-day {
background: none;
border: 1px solid transparent;
border-radius: 150px;
-webkit-box-sizing: border-box;
box-sizing: border-box;
color: rgba(255,255,255,0.95);
cursor: pointer;
font-weight: 400;
width: 14.2857143%;
-webkit-flex-basis: 14.2857143%;
-ms-flex-preferred-size: 14.2857143%;
flex-basis: 14.2857143%;
max-width: 39px;
height: 39px;
line-height: 39px;
margin: 0;
display: inline-block;
position: relative;
-webkit-box-pack: center;
-webkit-justify-content: center;
-ms-flex-pack: center;
justify-content: center;
text-align: center;
}
.flatpickr-day.inRange,
.flatpickr-day.prevMonthDay.inRange,
.flatpickr-day.nextMonthDay.inRange,
.flatpickr-day.today.inRange,
.flatpickr-day.prevMonthDay.today.inRange,
.flatpickr-day.nextMonthDay.today.inRange,
.flatpickr-day:hover,
.flatpickr-day.prevMonthDay:hover,
.flatpickr-day.nextMonthDay:hover,
.flatpickr-day:focus,
.flatpickr-day.prevMonthDay:focus,
.flatpickr-day.nextMonthDay:focus {
cursor: pointer;
outline: 0;
background: #646c8c;
border-color: #646c8c;
}
.flatpickr-day.today {
border-color: #eee;
}
.flatpickr-day.today:hover,
.flatpickr-day.today:focus {
border-color: #eee;
background: #eee;
color: #3f4458;
}
.flatpickr-day.selected,
.flatpickr-day.startRange,
.flatpickr-day.endRange,
.flatpickr-day.selected.inRange,
.flatpickr-day.startRange.inRange,
.flatpickr-day.endRange.inRange,
.flatpickr-day.selected:focus,
.flatpickr-day.startRange:focus,
.flatpickr-day.endRange:focus,
.flatpickr-day.selected:hover,
.flatpickr-day.startRange:hover,
.flatpickr-day.endRange:hover,
.flatpickr-day.selected.prevMonthDay,
.flatpickr-day.startRange.prevMonthDay,
.flatpickr-day.endRange.prevMonthDay,
.flatpickr-day.selected.nextMonthDay,
.flatpickr-day.startRange.nextMonthDay,
.flatpickr-day.endRange.nextMonthDay {
background: #80cbc4;
-webkit-box-shadow: none;
box-shadow: none;
color: #fff;
border-color: #80cbc4;
}
.flatpickr-day.selected.startRange,
.flatpickr-day.startRange.startRange,
.flatpickr-day.endRange.startRange {
border-radius: 50px 0 0 50px;
}
.flatpickr-day.selected.endRange,
.flatpickr-day.startRange.endRange,
.flatpickr-day.endRange.endRange {
border-radius: 0 50px 50px 0;
}
.flatpickr-day.selected.startRange + .endRange:not(:nth-child(7n+1)),
.flatpickr-day.startRange.startRange + .endRange:not(:nth-child(7n+1)),
.flatpickr-day.endRange.startRange + .endRange:not(:nth-child(7n+1)) {
-webkit-box-shadow: -10px 0 0 #80cbc4;
box-shadow: -10px 0 0 #80cbc4;
}
.flatpickr-day.selected.startRange.endRange,
.flatpickr-day.startRange.startRange.endRange,
.flatpickr-day.endRange.startRange.endRange {
border-radius: 50px;
}
.flatpickr-day.inRange {
border-radius: 0;
-webkit-box-shadow: -5px 0 0 #646c8c, 5px 0 0 #646c8c;
box-shadow: -5px 0 0 #646c8c, 5px 0 0 #646c8c;
}
.flatpickr-day.flatpickr-disabled,
.flatpickr-day.flatpickr-disabled:hover,
.flatpickr-day.prevMonthDay,
.flatpickr-day.nextMonthDay,
.flatpickr-day.notAllowed,
.flatpickr-day.notAllowed.prevMonthDay,
.flatpickr-day.notAllowed.nextMonthDay {
color: rgba(255,255,255,0.3);
background: transparent;
border-color: transparent;
cursor: default;
}
.flatpickr-day.flatpickr-disabled,
.flatpickr-day.flatpickr-disabled:hover {
cursor: not-allowed;
color: rgba(255,255,255,0.1);
}
.flatpickr-day.week.selected {
border-radius: 0;
-webkit-box-shadow: -5px 0 0 #80cbc4, 5px 0 0 #80cbc4;
box-shadow: -5px 0 0 #80cbc4, 5px 0 0 #80cbc4;
}
.flatpickr-day.hidden {
visibility: hidden;
}
.rangeMode .flatpickr-day {
margin-top: 1px;
}
.flatpickr-weekwrapper {
float: left;
}
.flatpickr-weekwrapper .flatpickr-weeks {
padding: 0 12px;
-webkit-box-shadow: 1px 0 0 #20222c;
box-shadow: 1px 0 0 #20222c;
}
.flatpickr-weekwrapper .flatpickr-weekday {
float: none;
width: 100%;
line-height: 28px;
}
.flatpickr-weekwrapper span.flatpickr-day,
.flatpickr-weekwrapper span.flatpickr-day:hover {
display: block;
width: 100%;
max-width: none;
color: rgba(255,255,255,0.3);
background: transparent;
cursor: default;
border: none;
}
.flatpickr-innerContainer {
display: block;
display: -webkit-box;
display: -webkit-flex;
display: -ms-flexbox;
display: flex;
-webkit-box-sizing: border-box;
box-sizing: border-box;
overflow: hidden;
}
.flatpickr-rContainer {
display: inline-block;
padding: 0;
-webkit-box-sizing: border-box;
box-sizing: border-box;
}
.flatpickr-time {
text-align: center;
outline: 0;
display: block;
height: 0;
line-height: 40px;
max-height: 40px;
-webkit-box-sizing: border-box;
box-sizing: border-box;
overflow: hidden;
display: -webkit-box;
display: -webkit-flex;
display: -ms-flexbox;
display: flex;
}
.flatpickr-time:after {
content: "";
display: table;
clear: both;
}
.flatpickr-time .numInputWrapper {
-webkit-box-flex: 1;
-webkit-flex: 1;
-ms-flex: 1;
flex: 1;
width: 40%;
height: 40px;
float: left;
}
.flatpickr-time .numInputWrapper span.arrowUp:after {
border-bottom-color: rgba(255,255,255,0.95);
}
.flatpickr-time .numInputWrapper span.arrowDown:after {
border-top-color: rgba(255,255,255,0.95);
}
.flatpickr-time.hasSeconds .numInputWrapper {
width: 26%;
}
.flatpickr-time.time24hr .numInputWrapper {
width: 49%;
}
.flatpickr-time input {
background: transparent;
-webkit-box-shadow: none;
box-shadow: none;
border: 0;
border-radius: 0;
text-align: center;
margin: 0;
padding: 0;
height: inherit;
line-height: inherit;
color: rgba(255,255,255,0.95);
font-size: 14px;
position: relative;
-webkit-box-sizing: border-box;
box-sizing: border-box;
-webkit-appearance: textfield;
-moz-appearance: textfield;
appearance: textfield;
}
.flatpickr-time input.flatpickr-hour {
font-weight: bold;
}
.flatpickr-time input.flatpickr-minute,
.flatpickr-time input.flatpickr-second {
font-weight: 400;
}
.flatpickr-time input:focus {
outline: 0;
border: 0;
}
.flatpickr-time .flatpickr-time-separator,
.flatpickr-time .flatpickr-am-pm {
height: inherit;
float: left;
line-height: inherit;
color: rgba(255,255,255,0.95);
font-weight: bold;
width: 2%;
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
-webkit-align-self: center;
-ms-flex-item-align: center;
align-self: center;
}
.flatpickr-time .flatpickr-am-pm {
outline: 0;
width: 18%;
cursor: pointer;
text-align: center;
font-weight: 400;
}
.flatpickr-time input:hover,
.flatpickr-time .flatpickr-am-pm:hover,
.flatpickr-time input:focus,
.flatpickr-time .flatpickr-am-pm:focus {
background: #6a7395;
}
.flatpickr-input[readonly] {
cursor: pointer;
}
@-webkit-keyframes fpFadeInDown {
from {
opacity: 0;
-webkit-transform: translate3d(0, -20px, 0);
transform: translate3d(0, -20px, 0);
}
to {
opacity: 1;
-webkit-transform: translate3d(0, 0, 0);
transform: translate3d(0, 0, 0);
}
}
@keyframes fpFadeInDown {
from {
opacity: 0;
-webkit-transform: translate3d(0, -20px, 0);
transform: translate3d(0, -20px, 0);
}
to {
opacity: 1;
-webkit-transform: translate3d(0, 0, 0);
transform: translate3d(0, 0, 0);
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,258 @@
package handlers
import (
"context"
"net/http"
"strconv"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/validation"
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
"git.haelnorr.com/h/timefmt"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// AdminAuditLogsPage renders the full admin dashboard page with audit logs section (GET request)
func AdminAuditLogsPage(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts, ok := db.GetPageOpts(s, w, r)
if !ok {
return
}
var logs *db.List[db.AuditLog]
var users []*db.User
var actions []string
var resourceTypes []string
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
// Get filters from query
filters, ok := getAuditFiltersFromQuery(s, w, r)
if !ok {
return false, nil
}
// Get audit logs
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, filters)
if err != nil {
return false, errors.Wrap(err, "db.GetAuditLogs")
}
// Get all users for filter dropdown
usersList, err := db.GetUsers(ctx, tx, nil)
if err != nil {
return false, errors.Wrap(err, "db.GetUsers")
}
users = usersList.Items
// Get unique actions
actions, err = db.GetUniqueActions(ctx, tx)
if err != nil {
return false, errors.Wrap(err, "db.GetUniqueActions")
}
// Get unique resource types
resourceTypes, err = db.GetUniqueResourceTypes(ctx, tx)
if err != nil {
return false, errors.Wrap(err, "db.GetUniqueResourceTypes")
}
return true, nil
}); !ok {
return
}
renderSafely(adminview.AuditLogsPage(logs, users, actions, resourceTypes), s, r, w)
})
}
// AdminAuditLogsList shows the full audit logs list with filters (POST request for HTMX)
func AdminAuditLogsList(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts, ok := db.GetPageOpts(s, w, r)
if !ok {
return
}
var logs *db.List[db.AuditLog]
var users []*db.User
var actions []string
var resourceTypes []string
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
// Get filters from form
filters, ok := getAuditFiltersFromForm(s, w, r)
if !ok {
return false, nil
}
// Get audit logs
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, filters)
if err != nil {
return false, errors.Wrap(err, "db.GetAuditLogs")
}
// Get all users for filter dropdown
usersList, err := db.GetUsers(ctx, tx, nil)
if err != nil {
return false, errors.Wrap(err, "db.GetUsers")
}
users = usersList.Items
// Get unique actions
actions, err = db.GetUniqueActions(ctx, tx)
if err != nil {
return false, errors.Wrap(err, "db.GetUniqueActions")
}
// Get unique resource types
resourceTypes, err = db.GetUniqueResourceTypes(ctx, tx)
if err != nil {
return false, errors.Wrap(err, "db.GetUniqueResourceTypes")
}
return true, nil
}); !ok {
return
}
renderSafely(adminview.AuditLogsList(logs, users, actions, resourceTypes), s, r, w)
})
}
// AdminAuditLogsFilter returns only the results container (table + pagination) for HTMX updates
func AdminAuditLogsFilter(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts, ok := db.GetPageOpts(s, w, r)
if !ok {
return
}
var logs *db.List[db.AuditLog]
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
// Get filters from form
filters, ok := getAuditFiltersFromForm(s, w, r)
if !ok {
return false, nil
}
// Get audit logs
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, filters)
if err != nil {
return false, errors.Wrap(err, "db.GetAuditLogs")
}
return true, nil
}); !ok {
return
}
// Return only the results container, not the full page with filters
renderSafely(adminview.AuditLogsResults(logs), s, r, w)
})
}
// AdminAuditLogDetail shows details for a single audit log entry
func AdminAuditLogDetail(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get ID from path
idStr := r.PathValue("id")
if idStr == "" {
throw.BadRequest(s, w, r, "Missing audit log ID", nil)
return
}
id, err := strconv.Atoi(idStr)
if err != nil {
throw.BadRequest(s, w, r, "Invalid audit log ID", err)
return
}
var log *db.AuditLog
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
log, err = db.GetAuditLogByID(ctx, tx, id)
if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetAuditLogByID")
}
return true, nil
}); !ok {
return
}
renderSafely(adminview.AuditLogDetail(log), s, r, w)
})
}
// getAuditFiltersFromQuery extracts audit log filters from query string
func getAuditFiltersFromQuery(s *hws.Server, w http.ResponseWriter, r *http.Request) (*db.AuditLogFilter, bool) {
g := validation.NewQueryGetter(r)
filters, ok := buildAuditFilters(g, s, w, r)
return filters, ok
}
// getAuditFiltersFromForm extracts audit log filters from form data
func getAuditFiltersFromForm(s *hws.Server, w http.ResponseWriter, r *http.Request) (*db.AuditLogFilter, bool) {
g, ok := validation.ParseFormOrError(s, w, r)
if !ok {
return nil, false
}
return buildAuditFilters(g, s, w, r)
}
// buildAuditFilters builds audit log filters from a validation.Getter
func buildAuditFilters(g validation.Getter, s *hws.Server, w http.ResponseWriter, r *http.Request) (*db.AuditLogFilter, bool) {
filters := db.NewAuditLogFilter()
userIDs := g.IntList("user_id").Values()
actions := g.StringList("action").Values()
resourceTypes := g.StringList("resource_type").Values()
results := g.StringList("result").Values()
format := timefmt.NewBuilder().DayNumeric2().Slash().
MonthNumeric2().Slash().Year4().Build()
startDate := g.Time("start_date", format).Optional().Value
endDate := g.Time("end_date", format).Optional().Value
if !g.ValidateAndError(s, w, r) {
return nil, false
}
if len(userIDs) > 0 {
filters.UserIDs(userIDs)
}
if len(actions) > 0 {
filters.Actions(actions)
}
if len(resourceTypes) > 0 {
filters.ResourceTypes(resourceTypes)
}
if len(results) > 0 {
filters.Results(results)
}
if !startDate.IsZero() {
filters.DateRange(startDate.Unix(), 0)
}
if !endDate.IsZero() {
endOfDay := endDate.Add(23*time.Hour + 59*time.Minute + 59*time.Second)
filters.DateRange(0, endOfDay.Unix())
}
return filters, true
}

View File

@@ -0,0 +1,30 @@
package handlers
import (
"context"
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// AdminDashboard renders the full admin dashboard page (defaults to users section)
func AdminDashboard(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var users *db.List[db.User]
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
users, err = db.GetUsersWithRoles(ctx, tx, nil)
if err != nil {
return false, errors.Wrap(err, "db.GetUsersWithRoles")
}
return true, nil
}); !ok {
return
}
renderSafely(adminview.DashboardPage(users), s, r, w)
})
}

View File

@@ -0,0 +1,79 @@
package handlers
import (
"context"
"net/http"
"strconv"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/roles"
"git.haelnorr.com/h/oslstats/internal/throw"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// AdminPreviewRoleStart starts preview mode for a specific role
func AdminPreviewRoleStart(s *hws.Server, conn *db.DB, ssl bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get role ID from URL
roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr)
if err != nil {
throw.BadRequest(s, w, r, "Invalid role ID", err)
return
}
// Verify role exists and is not admin
var role *db.Role
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, "Role not found")
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID")
}
// Cannot preview admin role
if role.Name == roles.Admin {
throw.BadRequest(s, w, r, "Cannot preview admin role", nil)
return false, nil
}
return true, nil
}); !ok {
return
}
// Set preview role cookie
rbac.SetPreviewRoleCookie(w, roleID, ssl)
// Redirect to home page
http.Redirect(w, r, "/", http.StatusSeeOther)
})
}
// AdminPreviewRoleStop stops preview mode and returns to normal view
func AdminPreviewRoleStop(s *hws.Server) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Clear preview role cookie
rbac.ClearPreviewRoleCookie(w)
// Check if we should stay on current page or redirect to admin
stay := r.URL.Query().Get("stay")
if stay == "true" {
// Get referer to redirect back to current page
referer := r.Header.Get("Referer")
if referer == "" {
referer = "/"
}
http.Redirect(w, r, referer, http.StatusSeeOther)
} else {
// Redirect to admin roles page
http.Redirect(w, r, "/admin/roles", http.StatusSeeOther)
}
})
}

View File

@@ -0,0 +1,341 @@
package handlers
import (
"context"
"net/http"
"sort"
"strconv"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/roles"
"git.haelnorr.com/h/oslstats/internal/validation"
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// AdminRoles renders the full admin dashboard page with roles section
func AdminRoles(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts, ok := db.GetPageOpts(s, w, r)
if !ok {
return
}
var rolesList *db.List[db.Role]
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
rolesList, err = db.GetRoles(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.GetRoles")
}
return true, nil
}); !ok {
return
}
if r.Method == "GET" {
renderSafely(adminview.RolesPage(rolesList), s, r, w)
} else {
renderSafely(adminview.RolesList(rolesList), s, r, w)
}
})
}
// AdminRoleCreateForm shows the create role form modal
func AdminRoleCreateForm(s *hws.Server) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
renderSafely(adminview.RoleCreateForm(), s, r, w)
})
}
// AdminRoleCreate creates a new role
func AdminRoleCreate(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, ok := validation.ParseFormOrNotify(s, w, r)
if !ok {
return
}
name := getter.String("name").Required().Value
displayName := getter.String("display_name").Required().Value
description := getter.String("description").Value
if !getter.ValidateAndNotify(s, w, r) {
return
}
pageOpts, ok := db.GetPageOpts(s, w, r)
if !ok {
return
}
var rolesList *db.List[db.Role]
var newRole *db.Role
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
newRole = &db.Role{
Name: roles.Role(name),
DisplayName: displayName,
Description: description,
IsSystem: false,
CreatedAt: time.Now().Unix(),
}
err := db.CreateRole(ctx, tx, newRole, db.NewAudit(r, nil))
if err != nil {
return false, errors.Wrap(err, "db.CreateRole")
}
rolesList, err = db.GetRoles(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.GetRoles")
}
return true, nil
}); !ok {
return
}
renderSafely(adminview.RolesList(rolesList), s, r, w)
})
}
// AdminRoleManage shows the role management modal with details and actions
func AdminRoleManage(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr)
if err != nil {
respond.BadRequest(w, err)
return
}
var role *db.Role
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID")
}
return true, nil
}); !ok {
return
}
renderSafely(adminview.RoleManageModal(role), s, r, w)
})
}
// AdminRoleDeleteConfirm shows the delete confirmation dialog
func AdminRoleDeleteConfirm(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
var role *db.Role
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID")
}
return true, nil
}); !ok {
return
}
renderSafely(adminview.ConfirmDeleteRole(roleID, role.DisplayName), s, r, w)
})
}
// AdminRoleDelete deletes a role
func AdminRoleDelete(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr)
if err != nil {
respond.BadRequest(w, err)
return
}
pageOpts, ok := db.GetPageOpts(s, w, r)
if !ok {
return
}
var rolesList *db.List[db.Role]
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
// First check if role exists and get its details
role, err := db.GetRoleByID(ctx, tx, roleID)
if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID")
}
// Check if it's a system role
if role.IsSystem {
return false, errors.New("cannot delete system roles")
}
// Delete the role with audit logging
err = db.DeleteRole(ctx, tx, roleID, db.NewAudit(r, nil))
if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.DeleteRole")
}
// Reload roles
rolesList, err = db.GetRoles(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.GetRoles")
}
return true, nil
}); !ok {
return
}
renderSafely(adminview.RolesList(rolesList), s, r, w)
})
}
// AdminRolePermissionsModal shows the permissions management modal for a role
func AdminRolePermissionsModal(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr)
if err != nil {
respond.BadRequest(w, err)
return
}
var role *db.Role
var allPermissions []*db.Permission
var groupedPerms []adminview.PermissionsByResource
var rolePermIDs map[int]bool
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
// Load role with permissions
var err error
role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID")
}
// Load all permissions
allPermissions, err = db.ListAllPermissions(ctx, tx)
if err != nil {
return false, errors.Wrap(err, "db.ListAllPermissions")
}
return true, nil
}); !ok {
return
}
// Group permissions by resource
permsByResource := make(map[string][]*db.Permission)
for _, perm := range allPermissions {
permsByResource[perm.Resource] = append(permsByResource[perm.Resource], perm)
}
// Convert to sorted slice
for resource, perms := range permsByResource {
groupedPerms = append(groupedPerms, adminview.PermissionsByResource{
Resource: resource,
Permissions: perms,
})
}
sort.Slice(groupedPerms, func(i, j int) bool {
return groupedPerms[i].Resource < groupedPerms[j].Resource
})
// Create map of current role permissions for checkbox state
rolePermIDs = make(map[int]bool)
for _, perm := range role.Permissions {
rolePermIDs[perm.ID] = true
}
renderSafely(adminview.RolePermissionsModal(role, groupedPerms, rolePermIDs), s, r, w)
})
}
// AdminRolePermissionsUpdate updates the permissions for a role
func AdminRolePermissionsUpdate(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr)
if err != nil {
respond.BadRequest(w, err)
return
}
getter, ok := validation.ParseFormOrNotify(s, w, r)
if !ok {
return
}
// Get selected permission IDs from form
permissionIDs := getter.IntList("permission_ids").Values()
if !getter.ValidateAndNotify(s, w, r) {
return
}
pageOpts, ok := db.GetPageOpts(s, w, r)
if !ok {
return
}
var rolesList *db.List[db.Role]
if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
role, err := db.GetRoleByID(ctx, tx, roleID)
if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID")
}
err = role.UpdatePermissions(ctx, tx, permissionIDs, db.NewAudit(r, nil))
if err != nil {
return false, errors.Wrap(err, "role.UpdatePermissions")
}
// Reload roles
rolesList, err = db.GetRoles(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.GetRoles")
}
return true, nil
}); !ok {
return
}
renderSafely(adminview.RolesList(rolesList), s, r, w)
})
}

View File

@@ -0,0 +1,39 @@
package handlers
import (
"context"
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// AdminUsersPage renders the full admin dashboard page with users section
func AdminUsersPage(s *hws.Server, conn *db.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts, ok := db.GetPageOpts(s, w, r)
if !ok {
return
}
var users *db.List[db.User]
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
users, err = db.GetUsersWithRoles(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.GetUsersWithRoles")
}
return true, nil
}); !ok {
return
}
if r.Method == "GET" {
renderSafely(adminview.DashboardPage(users), s, r, w)
} else {
renderSafely(adminview.UserList(users), s, r, w)
}
})
}

View File

@@ -0,0 +1,53 @@
package handlers
import (
"context"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// shouldGrantAdmin checks if user's Discord ID is in admin list
func shouldGrantAdmin(user *db.User, cfg *rbac.Config) bool {
if cfg == nil || user == nil {
return false
}
if user.DiscordID == cfg.AdminDiscordID {
return true
}
return false
}
// ensureUserHasAdminRole grants admin role if not already granted
func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error {
if user == nil {
return errors.New("user cannot be nil")
}
// Check if user already has admin role
hasAdmin, err := user.HasRole(ctx, tx, roles.Admin)
if err != nil {
return errors.Wrap(err, "user.HasRole")
}
if hasAdmin {
return nil // Already admin
}
// Get admin role
adminRole, err := db.GetRoleByName(ctx, tx, roles.Admin)
if err != nil {
return errors.Wrap(err, "db.GetRoleByName")
}
// Grant admin role
err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil)
if err != nil {
return errors.Wrap(err, "db.AssignRole")
}
return nil
}

View File

@@ -3,7 +3,6 @@ package handlers
import ( import (
"context" "context"
"net/http" "net/http"
"time"
"git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
@@ -15,13 +14,15 @@ import (
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/validation"
"git.haelnorr.com/h/oslstats/pkg/oauth" "git.haelnorr.com/h/oslstats/pkg/oauth"
) )
func Callback( func Callback(
server *hws.Server, s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB, conn *db.DB,
cfg *config.Config, cfg *config.Config,
store *store.Store, store *store.Store,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
@@ -31,45 +32,36 @@ func Callback(
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5) attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
if exceeded { if exceeded {
err := errors.Errorf( err := track.Error(attempts)
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
)
store.ClearRedirectTrack(r, "/callback") store.ClearRedirectTrack(r, "/callback")
throw.BadRequest(s, w, r, "Too many redirects. Please try logging in again.", err)
throwError(
server,
w,
r,
http.StatusBadRequest,
"OAuth callback failed: Too many redirect attempts. Please try logging in again.",
err,
"warn",
)
return return
} }
getter := validation.NewQueryGetter(r)
state := r.URL.Query().Get("state") state := getter.String("state").Required().Value
code := r.URL.Query().Get("code") code := getter.String("code").Required().Value
if state == "" && code == "" { if !getter.Validate() {
http.Redirect(w, r, "/", http.StatusBadRequest) store.ClearRedirectTrack(r, "/callback")
apiErr := getter.String("error").Value
errDesc := getter.String("error_description").Value
if apiErr == "access_denied" {
throw.Unauthorized(s, w, r, "OAuth login failed or cancelled", errors.New(errDesc))
return
}
throw.BadRequest(s, w, r, "OAuth login failed", errors.New("state or code parameters missing"))
return return
} }
data, err := verifyState(cfg.OAuth, w, r, state) data, err := verifyState(cfg.OAuth, w, r, state)
if err != nil { if err != nil {
store.ClearRedirectTrack(r, "/callback")
if vsErr, ok := err.(*verifyStateError); ok { if vsErr, ok := err.(*verifyStateError); ok {
if vsErr.IsCookieError() { if vsErr.IsCookieError() {
throwUnauthorized(server, w, r, "OAuth session not found or expired", err) throw.Unauthorized(s, w, r, "OAuth session not found or expired", err)
} else { } else {
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
} }
} else { } else {
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
} }
return return
} }
@@ -77,20 +69,17 @@ func Callback(
switch data { switch data {
case "login": case "login":
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) var redirect func()
defer cancel() if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
tx, err := conn.BeginTx(ctx, nil) redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
if err != nil { if err != nil {
throwInternalServiceError(server, w, r, "DB Transaction failed to start", err) throw.InternalServiceError(s, w, r, "OAuth login failed", err)
return false, nil
}
return true, nil
}); !ok {
return return
} }
defer tx.Rollback()
redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
if err != nil {
throwInternalServiceError(server, w, r, "OAuth login failed", err)
return
}
tx.Commit()
redirect() redirect()
return return
} }
@@ -169,7 +158,7 @@ func login(
} }
user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID) user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID)
if err != nil { if err != nil && !db.IsBadRequest(err) {
return nil, errors.Wrap(err, "db.GetUserByDiscordID") return nil, errors.Wrap(err, "db.GetUserByDiscordID")
} }
var redirect string var redirect string
@@ -193,6 +182,15 @@ func login(
if err != nil { if err != nil {
return nil, errors.Wrap(err, "user.UpdateDiscordToken") return nil, errors.Wrap(err, "user.UpdateDiscordToken")
} }
// Check if user should be granted admin role (environment-based)
if shouldGrantAdmin(user, cfg.RBAC) {
err := ensureUserHasAdminRole(ctx, tx, user)
if err != nil {
return nil, errors.Wrap(err, "ensureUserHasAdminRole")
}
}
err := auth.Login(w, r, user, true) err := auth.Login(w, r, user, true)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "auth.Login") return nil, errors.Wrap(err, "auth.Login")

2
internal/handlers/doc.go Normal file
View File

@@ -0,0 +1,2 @@
// Package handlers contains all the functions for handling http requests and serving content
package handlers

View File

@@ -4,28 +4,10 @@ import (
"net/http" "net/http"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/view/page" "git.haelnorr.com/h/oslstats/internal/notify"
baseview "git.haelnorr.com/h/oslstats/internal/view/baseview"
) )
// func ErrorPage(
// error hws.HWSError,
// ) (hws.ErrorPage, error) {
// messages := map[int]string{
// 400: "The request you made was malformed or unexpected.",
// 401: "You need to login to view this page.",
// 403: "You do not have permission to view this page.",
// 404: "The page or resource you have requested does not exist.",
// 500: `An error occured on the server. Please try again, and if this
// continues to happen contact an administrator.`,
// 503: "The server is currently down for maintenance and should be back soon. =)",
// }
// msg, exists := messages[error.StatusCode]
// if !exists {
// return nil, errors.New("No valid message for the given code")
// }
// return page.Error(error.StatusCode, http.StatusText(error.StatusCode), msg), nil
// }
func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) { func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
// Determine if this status code should show technical details // Determine if this status code should show technical details
showDetails := shouldShowDetails(hwsError.StatusCode) showDetails := shouldShowDetails(hwsError.StatusCode)
@@ -40,12 +22,12 @@ func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
// Get technical details if applicable // Get technical details if applicable
var details string var details string
if showDetails && hwsError.Error != nil { if showDetails && hwsError.Error != nil {
details = hwsError.Error.Error() details = notify.FormatErrorDetails(hwsError.Error)
} }
// Render appropriate template // Render appropriate template
if details != "" { if details != "" {
return page.ErrorWithDetails( return baseview.ErrorPageWithDetails(
hwsError.StatusCode, hwsError.StatusCode,
http.StatusText(hwsError.StatusCode), http.StatusText(hwsError.StatusCode),
message, message,
@@ -53,7 +35,7 @@ func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
), nil ), nil
} }
return page.Error( return baseview.ErrorPage(
hwsError.StatusCode, hwsError.StatusCode,
http.StatusText(hwsError.StatusCode), http.StatusText(hwsError.StatusCode),
message, message,
@@ -63,7 +45,7 @@ func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
// shouldShowDetails determines if a status code should display technical details // shouldShowDetails determines if a status code should display technical details
func shouldShowDetails(statusCode int) bool { func shouldShowDetails(statusCode int) bool {
switch statusCode { switch statusCode {
case 400, 500, 503: // Bad Request, Internal Server Error, Service Unavailable case 400, 418, 500, 503: // Bad Request, Internal Server Error, Service Unavailable
return true return true
case 401, 403, 404: // Unauthorized, Forbidden, Not Found case 401, 403, 404: // Unauthorized, Forbidden, Not Found
return false return false
@@ -80,6 +62,7 @@ func getDefaultMessage(statusCode int) string {
401: "You need to login to view this page.", 401: "You need to login to view this page.",
403: "You do not have permission to view this page.", 403: "You do not have permission to view this page.",
404: "The page or resource you have requested does not exist.", 404: "The page or resource you have requested does not exist.",
418: "I'm a teapot!",
500: `An error occurred on the server. Please try again, and if this 500: `An error occurred on the server. Please try again, and if this
continues to happen contact an administrator.`, continues to happen contact an administrator.`,
503: "The server is currently down for maintenance and should be back soon. =)", 503: "The server is currently down for maintenance and should be back soon. =)",

View File

@@ -1,109 +1,45 @@
package handlers package handlers
import ( import (
"fmt" "encoding/json"
"net/http" "net/http"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/throw"
"github.com/a-h/templ"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// throwError is a generic helper that all throw* functions use internally // parseErrorDetails extracts code and stacktrace from JSON Details field
func throwError( // Returns (code, stacktrace). If parsing fails, returns (500, original details string)
s *hws.Server, func parseErrorDetails(details string) (int, string) {
w http.ResponseWriter, if details == "" {
r *http.Request, return 500, ""
statusCode int, }
msg string,
err error, var errDetails notify.ErrorDetails
level string, err := json.Unmarshal([]byte(details), &errDetails)
) {
err = s.ThrowError(w, r, hws.HWSError{
StatusCode: statusCode,
Message: msg,
Error: err,
Level: hws.ErrorLevel(level),
RenderErrorPage: true, // throw* family always renders error pages
})
if err != nil { if err != nil {
s.ThrowFatal(w, err) // Not JSON or malformed - treat as plain stacktrace with default code
return 500, details
}
return errDetails.Code, errDetails.Stacktrace
}
func renderSafely(page templ.Component, s *hws.Server, r *http.Request, w http.ResponseWriter) {
err := page.Render(r.Context(), w)
if err != nil {
throw.InternalServiceError(s, w, r, "Failed to render page", errors.Wrap(err, "page."))
} }
} }
// throwInternalServiceError handles 500 errors (server failures) func logError(s *hws.Server, msg string, err error) {
func throwInternalServiceError( s.LogError(hws.HWSError{
s *hws.Server, Message: msg,
w http.ResponseWriter, Error: err,
r *http.Request, Level: hws.ErrorERROR,
msg string, StatusCode: http.StatusInternalServerError,
err error, })
) {
throwError(s, w, r, http.StatusInternalServerError, msg, err, "error")
}
// throwBadRequest handles 400 errors (malformed requests)
func throwBadRequest(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusBadRequest, msg, err, "debug")
}
// throwForbidden handles 403 errors (normal permission denials)
func throwForbidden(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusForbidden, msg, err, "debug")
}
// throwForbiddenSecurity handles 403 errors for security events (uses WARN level)
func throwForbiddenSecurity(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusForbidden, msg, err, "warn")
}
// throwUnauthorized handles 401 errors (not authenticated)
func throwUnauthorized(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusUnauthorized, msg, err, "debug")
}
// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level)
func throwUnauthorizedSecurity(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusUnauthorized, msg, err, "warn")
}
// throwNotFound handles 404 errors
func throwNotFound(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
path string,
) {
msg := fmt.Sprintf("The requested resource was not found: %s", path)
err := errors.New("Resource not found")
throwError(s, w, r, http.StatusNotFound, msg, err, "debug")
} }

View File

@@ -0,0 +1,173 @@
package handlers
import (
"context"
"net/http"
"strconv"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/validation"
"git.haelnorr.com/h/oslstats/internal/view/seasonsview"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func GenerateFixtures(
s *hws.Server,
conn *db.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, ok := validation.ParseFormOrNotify(s, w, r)
if !ok {
return
}
seasonShortName := getter.String("season_short_name").TrimSpace().Required().Value
leagueShortName := getter.String("league_short_name").TrimSpace().Required().Value
round := getter.Int("round").Required().Value
if !getter.ValidateAndNotify(s, w, r) {
return
}
var season *db.Season
var league *db.League
var fixtures []*db.Fixture
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
_, err := db.NewRound(ctx, tx, seasonShortName, leagueShortName, round, db.NewAudit(r, nil))
if err != nil {
if db.IsBadRequest(err) {
respond.BadRequest(w, errors.Wrap(err, "db.NewRound"))
return false, nil
}
return false, errors.Wrap(err, "db.NewRound")
}
season, league, fixtures, err = db.GetFixtures(ctx, tx, seasonShortName, leagueShortName)
if err != nil {
return false, errors.Wrap(err, "db.GetFixtures")
}
return true, nil
}); !ok {
return
}
renderSafely(seasonsview.SeasonLeagueManageFixtures(season, league, fixtures), s, r, w)
})
}
func UpdateFixtures(
s *hws.Server,
conn *db.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, ok := validation.ParseFormOrNotify(s, w, r)
if !ok {
return
}
seasonShortName := getter.String("season_short_name").TrimSpace().Required().Value
leagueShortName := getter.String("league_short_name").TrimSpace().Required().Value
allocations := getter.GetMaps("allocations")
if !getter.ValidateAndNotify(s, w, r) {
return
}
updates, err := mapUpdates(allocations)
if err != nil {
respond.BadRequest(w, errors.Wrap(err, "strconv.Atoi"))
return
}
var fixtures []*db.Fixture
if !conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
_, _, fixtures, err = db.GetFixtures(ctx, tx, seasonShortName, leagueShortName)
if err != nil {
if db.IsBadRequest(err) {
respond.BadRequest(w, errors.Wrap(err, "db.NewRound"))
return false, nil
}
return false, errors.Wrap(err, "db.GetFixtures")
}
var valid bool
fixtures, valid = updateFixtures(fixtures, updates)
if !valid {
notify.Warn(s, w, r, "Invalid game weeks", "A game week is missing or has no games", nil)
return false, nil
}
err = db.UpdateFixtureGameWeeks(ctx, tx, fixtures, db.NewAudit(r, nil))
if err != nil {
if db.IsBadRequest(err) {
respond.BadRequest(w, errors.Wrap(err, "db.UpdateFixtureGameWeeks"))
}
return false, errors.Wrap(err, "db.UpdateFixtureGameWeeks")
}
return true, nil
}) {
return
}
notify.Success(s, w, r, "Fixtures Updated", "Fixtures successfully updated", nil)
})
}
func DeleteFixture(
s *hws.Server,
conn *db.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fixtureIDstr := r.PathValue("fixture_id")
fixtureID, err := strconv.Atoi(fixtureIDstr)
if err != nil {
respond.BadRequest(w, errors.Wrap(err, "strconv.Atoi"))
return
}
if !conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
err := db.DeleteFixture(ctx, tx, fixtureID, db.NewAudit(r, nil))
if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, errors.Wrap(err, "db.DeleteFixture"))
return false, nil
}
return false, errors.Wrap(err, "db.DeleteFixture")
}
return true, nil
}) {
return
}
})
}
func mapUpdates(allocations []map[string]string) (map[int]int, error) {
updates := map[int]int{}
for _, v := range allocations {
id, err := strconv.Atoi(v["id"])
if err != nil {
return nil, errors.Wrap(err, "strconv.Atoi")
}
gameWeek, err := strconv.Atoi(v["game_week"])
if err != nil {
return nil, errors.Wrap(err, "strconv.Atoi")
}
updates[id] = gameWeek
}
return updates, nil
}
func updateFixtures(fixtures []*db.Fixture, updates map[int]int) ([]*db.Fixture, bool) {
updated := []*db.Fixture{}
gameWeeks := map[int]int{}
for _, fixture := range fixtures {
if gameWeek, exists := updates[fixture.ID]; exists {
fixture.GameWeek = &gameWeek
updated = append(updated, fixture)
}
gameWeeks[*fixture.GameWeek]++
}
for i := range len(gameWeeks) {
count, exists := gameWeeks[i+1]
if !exists || count < 1 {
return nil, false
}
}
return updated, true
}

View File

@@ -3,20 +3,20 @@ package handlers
import ( import (
"net/http" "net/http"
"git.haelnorr.com/h/oslstats/internal/view/page"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/throw"
homeview "git.haelnorr.com/h/oslstats/internal/view/homeview"
) )
// Handles responses to the / path. Also serves a 404 Page for paths that // Index handles responses to the / path. Also serves a 404 Page for paths that
// don't have explicit handlers // don't have explicit handlers
func Index(server *hws.Server) http.Handler { func Index(s *hws.Server) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" { if r.URL.Path != "/" {
throwNotFound(server, w, r, r.URL.Path) throw.NotFound(s, w, r, r.URL.Path)
} }
page.Index().Render(r.Context(), w) renderSafely(homeview.IndexPage(), s, r, w)
}, },
) )
} }

View File

@@ -0,0 +1,52 @@
package handlers
import (
"context"
"fmt"
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/validation"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// IsUnique creates a handler that checks field uniqueness
// Returns 200 OK if unique, 409 Conflict if not unique
func IsUnique(
s *hws.Server,
conn *db.DB,
model any,
field string,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, err := validation.ParseForm(r)
if err != nil {
respond.BadRequest(w, err)
return
}
value := getter.String(field).TrimSpace().Required().Value
if !getter.Validate() {
respond.BadRequest(w, err)
return
}
unique := false
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
unique, err = db.IsUnique(ctx, tx, model, field, value)
if err != nil {
return false, errors.Wrap(err, "db.IsUnique")
}
return true, nil
}); !ok {
return
}
if unique {
respond.OK(w)
} else {
err := fmt.Errorf("'%s' is not unique for field '%s'", value, field)
respond.Conflict(w, err)
}
})
}

View File

@@ -1,45 +0,0 @@
package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/store"
"github.com/uptrace/bun"
)
func IsUsernameUnique(
server *hws.Server,
conn *bun.DB,
cfg *config.Config,
store *store.Store,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("username")
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(server, w, r, "Database transaction failed", err)
return
}
defer tx.Rollback()
unique, err := db.IsUsernameUnique(ctx, tx, username)
if err != nil {
throwInternalServiceError(server, w, r, "Database query failed", err)
return
}
tx.Commit()
if !unique {
w.WriteHeader(http.StatusConflict)
} else {
w.WriteHeader(http.StatusOK)
}
},
)
}

View File

@@ -0,0 +1,34 @@
package handlers
import (
"context"
"net/http"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/view/leaguesview"
)
func LeaguesList(
s *hws.Server,
conn *db.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var leagues []*db.League
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
leagues, err = db.GetLeagues(ctx, tx)
if err != nil {
return false, errors.Wrap(err, "db.GetLeagues")
}
return true, nil
}); !ok {
return
}
renderSafely(leaguesview.ListPage(leagues), s, r, w)
})
}

View File

@@ -0,0 +1,85 @@
package handlers
import (
"context"
"fmt"
"net/http"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/validation"
leaguesview "git.haelnorr.com/h/oslstats/internal/view/leaguesview"
)
func NewLeague(
s *hws.Server,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
renderSafely(leaguesview.NewPage(), s, r, w)
})
}
func NewLeagueSubmit(
s *hws.Server,
conn *db.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, ok := validation.ParseFormOrNotify(s, w, r)
if !ok {
return
}
name := getter.String("name").
TrimSpace().Required().
MaxLength(50).MinLength(3).Value
shortname := getter.String("short_name").
TrimSpace().Required().
MaxLength(10).MinLength(2).Value
description := getter.String("description").
TrimSpace().MaxLength(500).Value
if !getter.ValidateAndNotify(s, w, r) {
return
}
nameUnique := false
shortNameUnique := false
var league *db.League
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
nameUnique, err = db.IsUnique(ctx, tx, (*db.League)(nil), "name", name)
if err != nil {
return false, errors.Wrap(err, "db.IsLeagueNameUnique")
}
shortNameUnique, err = db.IsUnique(ctx, tx, (*db.League)(nil), "short_name", shortname)
if err != nil {
return false, errors.Wrap(err, "db.IsLeagueShortNameUnique")
}
if !nameUnique || !shortNameUnique {
return true, nil
}
league, err = db.NewLeague(ctx, tx, name, shortname, description, db.NewAudit(r, nil))
if err != nil {
return false, errors.Wrap(err, "db.NewLeague")
}
return true, nil
}); !ok {
return
}
if !nameUnique {
notify.Warn(s, w, r, "Duplicate Name", "This league name is already taken.", nil)
return
}
if !shortNameUnique {
notify.Warn(s, w, r, "Duplicate Short Name", "This short name is already taken.", nil)
return
}
respond.HXRedirect(w, "/leagues/%s", league.ShortName)
notify.SuccessWithDelay(s, w, r, "League Created", fmt.Sprintf("Successfully created league: %s", name), nil)
})
}

View File

@@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
stderrors "errors"
"net/http" "net/http"
"git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/cookies"
@@ -8,51 +9,63 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/pkg/oauth" "git.haelnorr.com/h/oslstats/pkg/oauth"
) )
func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler { func Login(
s *hws.Server,
conn *db.DB,
cfg *config.Config,
st *store.Store,
discordAPI *discord.APIClient,
) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
errDB := conn.Ping()
_, errDisc := discordAPI.Ping()
err := stderrors.Join(errors.Wrap(errDB, "conn.Ping"), errors.Wrap(errDisc, "discordAPI.Ping"))
err = errors.Wrap(err, "login error")
if r.Method == "POST" {
if err != nil {
notify.ServiceUnavailable(s, w, r, "Login currently unavailable", err)
respond.OK(w)
return
}
respond.HXRedirect(w, "/login")
return
}
if err != nil {
throw.ServiceUnavailable(s, w, r, "Login currently unavailable", err)
return
}
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost) cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5) attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
if exceeded { if exceeded {
err := errors.Errorf( err = track.Error(attempts)
"login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
)
st.ClearRedirectTrack(r, "/login") st.ClearRedirectTrack(r, "/login")
throw.BadRequest(s, w, r, "Too many redirects. Please clear your browser cookies and try again", err)
throwError(
server,
w,
r,
http.StatusBadRequest,
"Login failed: Too many redirect attempts. Please clear your browser cookies and try again.",
err,
"warn",
)
return return
} }
state, uak, err := oauth.GenerateState(cfg.OAuth, "login") state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
if err != nil { if err != nil {
throwInternalServiceError(server, w, r, "Failed to generate state token", err) throw.InternalServiceError(s, w, r, "Failed to generate state token", err)
return return
} }
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL) oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
link, err := discordAPI.GetOAuthLink(state) link, err := discordAPI.GetOAuthLink(state)
if err != nil { if err != nil {
throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err) throw.InternalServiceError(s, w, r, "An error occurred trying to generate the login link", err)
return return
} }
st.ClearRedirectTrack(r, "/login") st.ClearRedirectTrack(r, "/login")

View File

@@ -3,57 +3,48 @@ package handlers
import ( import (
"context" "context"
"net/http" "net/http"
"time"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/throw"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
func Logout( func Logout(
server *hws.Server, s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB, conn *db.DB,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
return
}
defer tx.Rollback()
user := db.CurrentUser(r.Context()) user := db.CurrentUser(r.Context())
if user == nil { if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
// JIC - should be impossible to get here if route is protected by LoginReq token, err := user.DeleteDiscordTokens(ctx, tx)
w.Header().Set("HX-Redirect", "/") if err != nil {
return false, errors.Wrap(err, "user.DeleteDiscordTokens")
}
if token != nil {
err = discordAPI.RevokeToken(token.Convert())
if err != nil {
throw.InternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
return false, nil
}
}
err = auth.Logout(tx, w, r)
if err != nil {
throw.InternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout"))
return false, nil
}
return true, nil
}); !ok {
return return
} }
token, err := user.DeleteDiscordTokens(ctx, tx) respond.HXRedirect(w, "/")
if err != nil {
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
return
}
err = discordAPI.RevokeToken(token.Convert())
if err != nil {
throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
return
}
err = auth.Logout(tx, w, r)
if err != nil {
throwInternalServiceError(server, w, r, "Logout failed", err)
return
}
tx.Commit()
w.Header().Set("HX-Redirect", "/")
}, },
) )
} }

View File

@@ -0,0 +1,112 @@
package handlers
import (
"context"
"net/http"
"strconv"
"git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/notify"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/view/popup"
"github.com/coder/websocket"
"github.com/pkg/errors"
)
func NotificationWS(
s *hws.Server,
cfg *config.Config,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") != "websocket" {
throw.NotFound(s, w, r, r.URL.Path)
return
}
nc, err := setupClient(s, w, r)
if err != nil {
logError(s, "Failed to get notification client", errors.Wrap(err, "setupClient"))
return
}
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: []string{cfg.HWSAuth.TrustedHost},
})
if err != nil {
logError(s, "Failed to open websocket", errors.Wrap(err, "websocket.Accept"))
return
}
ctx := ws.CloseRead(r.Context())
err = notifyLoop(ctx, nc, ws)
if err != nil {
logError(s, "Notification error", errors.Wrap(err, "notifyLoop"))
}
_ = ws.CloseNow()
},
)
}
func setupClient(s *hws.Server, w http.ResponseWriter, r *http.Request) (*hws.Client, error) {
user := db.CurrentUser(r.Context())
altID := ""
if user != nil {
altID = strconv.Itoa(user.ID)
}
subCookie, err := r.Cookie("ws_sub_id")
subID := ""
if err == nil {
subID = subCookie.Value
}
nc, err := s.GetClient(subID, altID)
if err != nil {
return nil, errors.Wrap(err, "s.GetClient")
}
cookies.SetCookie(w, "ws_sub_id", "/", nc.ID(), 0)
return nc, nil
}
func notifyLoop(ctx context.Context, c *hws.Client, ws *websocket.Conn) error {
notifs, stop := c.Listen()
defer close(stop)
count := 0
for {
select {
case <-ctx.Done():
return nil
case nt, ok := <-notifs:
count++
if !ok {
return nil
}
w, err := ws.Writer(ctx, websocket.MessageText)
if err != nil {
return errors.Wrap(err, "ws.Writer")
}
switch nt.Level {
case hws.LevelShutdown:
err = popup.Toast(nt, count, 30000).Render(ctx, w)
case notify.LevelWarn:
err = popup.Toast(nt, count, 10000).Render(ctx, w)
case notify.LevelError:
// Parse error code and stacktrace from Details field
code, stacktrace := parseErrorDetails(nt.Details)
err = popup.ErrorModalWS(code, stacktrace, nt, count).Render(ctx, w)
case notify.LevelInfo:
err = popup.Toast(nt, count, 6000).Render(ctx, w)
case notify.LevelSuccess:
err = popup.Toast(nt, count, 3000).Render(ctx, w)
default:
err = popup.Toast(nt, count, 6000).Render(ctx, w)
}
if err != nil {
return errors.Wrap(err, "popup.Toast")
}
err = w.Close()
if err != nil {
return errors.Wrap(err, "w.Close")
}
}
}
}

View File

@@ -0,0 +1,41 @@
package handlers
import (
"net/http"
"github.com/pkg/errors"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify"
testview "git.haelnorr.com/h/oslstats/internal/view/testview"
)
// NotifyTester handles responses to the / path. Also serves a 404 Page for paths that
// don't have explicit handlers
func NotifyTester(s *hws.Server) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
testErr := errors.New("This is a stack trace. No really i swear. Just pretend ok? Thanks")
if r.Method == "GET" {
renderSafely(testview.NotificationTestPage(), s, r, w)
} else {
_ = r.ParseForm()
// target := r.Form.Get("target")
title := r.Form.Get("title")
level := r.Form.Get("type")
message := r.Form.Get("message")
switch level {
case "success":
notify.Success(s, w, r, title, message, nil)
case "info":
notify.Info(s, w, r, title, message, nil)
case "warn":
notify.Warn(s, w, r, title, message, nil)
case "error":
notify.InternalServiceError(s, w, r, message, testErr)
}
}
},
)
}

View File

@@ -3,7 +3,6 @@ package handlers
import ( import (
"context" "context"
"net/http" "net/http"
"time"
"git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
@@ -13,14 +12,16 @@ import (
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/view/page" "git.haelnorr.com/h/oslstats/internal/throw"
authview "git.haelnorr.com/h/oslstats/internal/view/authview"
) )
func Register( func Register(
server *hws.Server, s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB, conn *db.DB,
cfg *config.Config, cfg *config.Config,
store *store.Store, store *store.Store,
) http.Handler { ) http.Handler {
@@ -29,27 +30,9 @@ func Register(
attempts, exceeded, track := store.TrackRedirect(r, "/register", 3) attempts, exceeded, track := store.TrackRedirect(r, "/register", 3)
if exceeded { if exceeded {
err := errors.Errorf( err := track.Error(attempts)
"registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
cfg.HWSAuth.SSL,
)
store.ClearRedirectTrack(r, "/register") store.ClearRedirectTrack(r, "/register")
throw.BadRequest(s, w, r, "Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again", err)
throwError(
server,
w,
r,
http.StatusBadRequest,
"Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.",
err,
"warn",
)
return return
} }
@@ -65,65 +48,51 @@ func Register(
} }
store.ClearRedirectTrack(r, "/register") store.ClearRedirectTrack(r, "/register")
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel() if r.Method == "GET" {
tx, err := conn.BeginTx(ctx, nil) renderSafely(authview.RegisterPage(details.DiscordUser.Username), s, r, w)
if err != nil {
throwInternalServiceError(server, w, r, "Database transaction failed", err)
return return
} }
defer tx.Rollback() username := r.FormValue("username")
method := r.Method unique := false
if method == "GET" { var user *db.User
tx.Commit() if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
page.Register(details.DiscordUser.Username).Render(r.Context(), w) unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username)
return
}
if method == "POST" {
username := r.FormValue("username")
user, err := registerUser(ctx, tx, username, details)
if err != nil { if err != nil {
throwInternalServiceError(server, w, r, "Registration failed", err) return false, errors.Wrap(err, "db.IsUsernameUnique")
}
if !unique {
return true, nil
}
user, err = db.CreateUser(ctx, tx, username, details.DiscordUser, db.NewAudit(r, nil))
if err != nil {
return false, errors.Wrap(err, "db.CreateUser")
}
err = user.UpdateDiscordToken(ctx, tx, details.Token)
if err != nil {
return false, errors.Wrap(err, "db.UpdateDiscordToken")
}
if shouldGrantAdmin(user, cfg.RBAC) {
err := ensureUserHasAdminRole(ctx, tx, user)
if err != nil {
return false, errors.Wrap(err, "ensureUserHasAdminRole")
}
}
return true, nil
}); !ok {
return
}
if !unique {
respond.Conflict(w, errors.New("username is taken"))
} else {
err = auth.Login(w, r, user, true)
if err != nil {
throw.InternalServiceError(s, w, r, "Login failed", err)
return return
} }
tx.Commit() pageFrom := cookies.CheckPageFrom(w, r)
if user == nil { respond.HXRedirect(w, "%s", pageFrom)
w.WriteHeader(http.StatusConflict)
} else {
err = auth.Login(w, r, user, true)
if err != nil {
throwInternalServiceError(server, w, r, "Login failed", err)
return
}
pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom)
}
return
} }
}, },
) )
} }
func registerUser(
ctx context.Context,
tx bun.Tx,
username string,
details *store.RegistrationSession,
) (*db.User, error) {
unique, err := db.IsUsernameUnique(ctx, tx, username)
if err != nil {
return nil, errors.Wrap(err, "db.IsUsernameUnique")
}
if !unique {
return nil, nil
}
user, err := db.CreateUser(ctx, tx, username, details.DiscordUser)
if err != nil {
return nil, errors.Wrap(err, "db.CreateUser")
}
err = user.UpdateDiscordToken(ctx, tx, details.Token)
if err != nil {
return nil, errors.Wrap(err, "db.UpdateDiscordToken")
}
return user, nil
}

Some files were not shown because too many files have changed in this diff Show More