diff --git a/.gitignore b/.gitignore index 33692cc..9d37f15 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.db* .logs/ server.log +keys/ bin/ tmp/ static/css/output.css diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..a888b37 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,327 @@ +# AGENTS.md - Developer Guide for oslstats + +This document provides guidelines for AI coding agents and developers working on the oslstats codebase. + +## Project Overview + +**Module**: `git.haelnorr.com/h/oslstats` +**Language**: Go 1.25.5 +**Architecture**: Web application with Discord OAuth, PostgreSQL database, templ templates +**Key Technologies**: Bun ORM, templ, TailwindCSS, custom golib libraries + +## Build, Test, and Development Commands + +### Building +```bash +# Full production build (tailwind → templ → go generate → go build) +make build + +# Build and run +make run + +# Clean build artifacts +make clean +``` + +### Development Mode +```bash +# Watch mode with hot reload (templ, air, tailwindcss in parallel) +make dev + +# Development server runs on: +# - Proxy: http://localhost:3000 (use this) +# - App: http://localhost:3333 (internal) +``` + +### Testing +```bash +# Run all tests +go test ./... + +# Run tests for a specific package +go test ./pkg/oauth + +# Run a single test function +go test ./pkg/oauth -run TestGenerateState_Success + +# Run tests with verbose output +go test -v ./pkg/oauth + +# Run tests with coverage +go test -cover ./... +go test -coverprofile=coverage.out ./... +go tool cover -html=coverage.out +``` + +### Database +```bash +# Run migrations +make migrate +# OR +./bin/oslstats --migrate +``` + +### Configuration Management +```bash +# Generate .env template file +make genenv +# OR with custom output: make genenv OUT=.env.example + +# Show environment variable documentation +make envdoc + +# Show current environment values +make showenv +``` + +## Code Style Guidelines + +### Import Organization +Organize imports in **3 groups** separated by blank lines: + +```go +import ( + // 1. Standard library + "context" + "net/http" + "fmt" + + // 2. External dependencies + "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" + "github.com/uptrace/bun" + + // 3. Internal packages + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/pkg/oauth" +) +``` + +### Naming Conventions + +**Variables**: +- Local: `camelCase` (userAgentKey, httpServer, dbConn) +- Exported: `PascalCase` (Config, User, Token) +- Common abbreviations: `cfg`, `ctx`, `tx`, `db`, `err`, `w`, `r` + +**Functions**: +- Exported: `PascalCase` (GetConfig, NewStore, GenerateState) +- Private: `camelCase` (throwError, shouldShowDetails, loadModels) +- HTTP handlers: Return `http.Handler`, use dependency injection pattern +- Database functions: Use `bun.Tx` as parameter for transactions + +**Types**: +- Structs/Interfaces: `PascalCase` (Config, User, OAuthSession) +- Use `-er` suffix for interfaces (implied from usage) + +**Files**: +- Prefer single word: `config.go`, `oauth.go`, `errors.go` +- Use snake_case if needed: `discord_tokens.go`, `state_test.go` +- Test files: `*_test.go` alongside source files + +### Error Handling + +**Always wrap errors** with context using `github.com/pkg/errors`: + +```go +if err != nil { + return errors.Wrap(err, "operation_name") +} +``` + +**Validate inputs at function start**: +```go +func DoSomething(cfg *Config, data string) error { + if cfg == nil { + return errors.New("cfg cannot be nil") + } + if data == "" { + return errors.New("data cannot be empty") + } + // ... rest of function +} +``` + +**HTTP error helpers** (in handlers package): +- `throwInternalServiceError(s, w, r, msg, err)` - 500 errors +- `throwBadRequest(s, w, r, msg, err)` - 400 errors +- `throwForbidden(s, w, r, msg, err)` - 403 errors (normal) +- `throwForbiddenSecurity(s, w, r, msg, err)` - 403 security violations (WARN level) +- `throwUnauthorized(s, w, r, msg, err)` - 401 errors (normal) +- `throwUnauthorizedSecurity(s, w, r, msg, err)` - 401 security violations (WARN level) +- `throwNotFound(s, w, r, path)` - 404 errors + +### Common Patterns + +**HTTP Handler Pattern**: +```go +func HandlerName(server *hws.Server, deps ...) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // Handler logic here + }, + ) +} +``` + +**Database Operation Pattern**: +```go +func GetSomething(ctx context.Context, tx bun.Tx, id int) (*Result, error) { + result := new(Result) + err := tx.NewSelect(). + Model(result). + Where("id = ?", id). + Scan(ctx) + if err != nil { + if err.Error() == "sql: no rows in result set" { + return nil, nil // Return nil, nil for not found + } + return nil, errors.Wrap(err, "tx.Select") + } + return result, nil +} +``` + +**Setup Function Pattern** (returns instance, cleanup func, error): +```go +func setupSomething(ctx context.Context, cfg *Config) (*Type, func() error, error) { + instance := newInstance() + + err := configure(instance) + if err != nil { + return nil, nil, errors.Wrap(err, "configure") + } + + return instance, instance.Close, nil +} +``` + +**Configuration Pattern** (using ezconf): +```go +type Config struct { + Field string // ENV FIELD_NAME: Description (required/default: value) +} + +func ConfigFromEnv() (any, error) { + cfg := &Config{ + Field: env.String("FIELD_NAME", "default"), + } + // Validation here + return cfg, nil +} +``` + +### Formatting & Types + +**Formatting**: +- Use `gofmt` (standard Go formatting) +- No tabs vs spaces debate - Go uses tabs + +**Types**: +- Prefer explicit types over inference when it improves clarity +- Use struct tags for ORM and JSON marshaling: + ```go + type User struct { + bun.BaseModel `bun:"table:users,alias:u"` + ID int `bun:"id,pk,autoincrement"` + Username string `bun:"username,unique"` + AccessToken string `json:"access_token"` + } + ``` + +**Comments**: +- Document exported functions and types +- Use inline comments for ENV var documentation in Config structs +- Explain security-critical code flows + +### Testing + +**Test File Location**: Place `*_test.go` files alongside source files + +**Test Naming**: +```go +func TestFunctionName_Scenario(t *testing.T) +func TestGenerateState_Success(t *testing.T) +func TestVerifyState_WrongUserAgentKey(t *testing.T) +``` + +**Test Structure**: +- Use subtests with `t.Run()` for related scenarios +- Use table-driven tests for multiple similar cases +- Create helper functions for common setup (e.g., `testConfig()`) +- Test happy paths, error cases, edge cases, and security properties + +**Test Categories** (from pkg/oauth/state_test.go example): +1. Happy path tests +2. Error handling (nil params, empty fields, malformed input) +3. Security tests (MITM, CSRF, replay attacks, tampering) +4. Edge cases (concurrency, constant-time comparison) +5. Integration tests (round-trip verification) + +### Security + +**Critical Practices**: +- Use `crypto/subtle.ConstantTimeCompare` for cryptographic comparisons +- Implement CSRF protection via state tokens +- Store sensitive cookies as HttpOnly +- Use separate logging levels for security violations (WARN) +- Validate all inputs at function boundaries +- Use parameterized queries (Bun ORM handles this) +- Never commit secrets (.env, keys/ are gitignored) + +## Project Structure + +``` +oslstats/ +├── cmd/oslstats/ # Application entry point +│ ├── main.go # Entry point with flag parsing +│ ├── run.go # Server initialization & graceful shutdown +│ ├── httpserver.go # HTTP server setup +│ ├── routes.go # Route registration +│ ├── middleware.go # Middleware registration +│ ├── auth.go # Authentication setup +│ └── db.go # Database connection & migrations +├── internal/ # Private application code +│ ├── config/ # Configuration aggregation +│ ├── db/ # Database models & queries (Bun ORM) +│ ├── discord/ # Discord OAuth integration +│ ├── handlers/ # HTTP request handlers +│ ├── session/ # Session store (in-memory) +│ └── view/ # Templ templates +│ ├── component/ # Reusable UI components +│ ├── layout/ # Page layouts +│ └── page/ # Full pages +├── pkg/ # Reusable packages +│ ├── contexts/ # Context key definitions +│ ├── embedfs/ # Embedded static files +│ └── oauth/ # OAuth state management +├── bin/ # Compiled binaries (gitignored) +├── keys/ # Private keys (gitignored) +├── tmp/ # Air hot reload temp files (gitignored) +├── Makefile # Build automation +├── .air.toml # Hot reload configuration +└── go.mod # Go module definition +``` + +## Key Dependencies + +- **git.haelnorr.com/h/golib/*** - Custom libraries (env, ezconf, hlog, hws, hwsauth, cookies, jwt) +- **github.com/a-h/templ** - Type-safe HTML templating +- **github.com/uptrace/bun** - PostgreSQL ORM +- **github.com/bwmarrin/discordgo** - Discord API client +- **github.com/pkg/errors** - Error wrapping (use this, not fmt.Errorf) +- **github.com/joho/godotenv** - .env file loading + +## Notes for AI Agents + +1. **Never commit** .env files, keys/, or generated files (*_templ.go, output.css) +2. **Database operations** should use `bun.Tx` for transaction safety +3. **Templates** are written in templ, not Go html/template - run `templ generate` after changes +4. **Static files** are embedded via `//go:embed` - check pkg/embedfs/ +5. **Error messages** should be descriptive and use errors.Wrap for context +6. **Security is critical** - especially in OAuth flows (see pkg/oauth/state_test.go for examples) +7. **Air proxy** runs on port 3000 during development; app runs on 3333 +8. **Test coverage** is currently limited - prioritize testing security-critical code +9. **Configuration** uses ezconf pattern - see internal/*/ezconf.go files for examples +10. **Graceful shutdown** is implemented in cmd/oslstats/run.go - follow this pattern +11. When in plan mode, always use the interactive question tool if available diff --git a/Makefile b/Makefile index f3f2a8b..8f59266 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,6 @@ BINARY_NAME=oslstats build: - ./scripts/generate-css-sources.sh && \ tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css && \ go mod tidy && \ templ generate && \ @@ -16,7 +15,6 @@ run: ./bin/${BINARY_NAME}${SUFFIX} dev: - ./scripts/generate-css-sources.sh && \ templ generate --watch &\ air &\ tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css --watch @@ -36,3 +34,6 @@ showenv: make build ./bin/${BINARY_NAME} --showenv +migrate: + make build + ./bin/${BINARY_NAME}${SUFFIX} --migrate diff --git a/cmd/oslstats/auth.go b/cmd/oslstats/auth.go index 4f57359..d2b6ff5 100644 --- a/cmd/oslstats/auth.go +++ b/cmd/oslstats/auth.go @@ -8,7 +8,6 @@ import ( "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/handlers" - "git.haelnorr.com/h/oslstats/pkg/contexts" "github.com/pkg/errors" "github.com/uptrace/bun" ) @@ -31,6 +30,7 @@ func setupAuth( beginTx, logger, handlers.ErrorPage, + conn.DB, ) if err != nil { return nil, errors.Wrap(err, "hwsauth.NewAuthenticator") @@ -38,7 +38,9 @@ func setupAuth( auth.IgnorePaths(ignoredPaths...) - contexts.CurrentUser = auth.CurrentModel + db.CurrentUser = auth.CurrentModel return auth, nil } + +// TODO: make a new getuser function that wraps db.GetUserByID and does OAuth refresh diff --git a/cmd/oslstats/db.go b/cmd/oslstats/db.go index 1df8f57..922b928 100644 --- a/cmd/oslstats/db.go +++ b/cmd/oslstats/db.go @@ -20,7 +20,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func conn = bun.NewDB(sqldb, pgdialect.New()) close = sqldb.Close - err = loadModels(ctx, conn, cfg.Flags.ResetDB) + err = loadModels(ctx, conn, cfg.Flags.MigrateDB) if err != nil { return nil, nil, errors.Wrap(err, "loadModels") } @@ -31,6 +31,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func func loadModels(ctx context.Context, conn *bun.DB, resetDB bool) error { models := []any{ (*db.User)(nil), + (*db.DiscordToken)(nil), } for _, model := range models { diff --git a/cmd/oslstats/httpserver.go b/cmd/oslstats/httpserver.go index fa0461d..bd2763f 100644 --- a/cmd/oslstats/httpserver.go +++ b/cmd/oslstats/httpserver.go @@ -4,13 +4,15 @@ import ( "io/fs" "net/http" - "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/internal/handlers" - "git.haelnorr.com/h/golib/hlog" + "git.haelnorr.com/h/golib/hws" "github.com/pkg/errors" "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/store" ) func setupHttpServer( @@ -18,6 +20,8 @@ func setupHttpServer( config *config.Config, logger *hlog.Logger, bun *bun.DB, + store *store.Store, + discordAPI *discord.APIClient, ) (server *hws.Server, err error) { if staticFS == nil { return nil, errors.New("No filesystem provided") @@ -53,7 +57,7 @@ func setupHttpServer( return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths") } - err = addRoutes(httpServer, &fs, config, logger, bun, auth) + err = addRoutes(httpServer, &fs, config, bun, auth, store, discordAPI) if err != nil { return nil, errors.Wrap(err, "addRoutes") } diff --git a/cmd/oslstats/main.go b/cmd/oslstats/main.go index e92d126..38e0eea 100644 --- a/cmd/oslstats/main.go +++ b/cmd/oslstats/main.go @@ -29,6 +29,16 @@ func main() { return } + if flags.MigrateDB { + _, closedb, err := setupBun(ctx, cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err) + os.Exit(1) + } + closedb() + return + } + if err := run(ctx, os.Stdout, cfg); err != nil { fmt.Fprintf(os.Stderr, "%s\n", err) os.Exit(1) diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index efea343..c19f6d0 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -5,22 +5,24 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hwsauth" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/internal/db" - "git.haelnorr.com/h/oslstats/internal/handlers" - - "git.haelnorr.com/h/golib/hlog" "github.com/pkg/errors" "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/store" ) func addRoutes( server *hws.Server, staticFS *http.FileSystem, - config *config.Config, - logger *hlog.Logger, + cfg *config.Config, conn *bun.DB, auth *hwsauth.Authenticator[*db.User, bun.Tx], + store *store.Store, + discordAPI *discord.APIClient, ) error { // Create the routes routes := []hws.Route{ @@ -34,10 +36,38 @@ func addRoutes( Method: hws.MethodGET, Handler: handlers.Index(server), }, + { + Path: "/login", + Method: hws.MethodGET, + Handler: auth.LogoutReq(handlers.Login(server, cfg, store, discordAPI)), + }, + { + Path: "/auth/callback", + Method: hws.MethodGET, + Handler: auth.LogoutReq(handlers.Callback(server, auth, conn, cfg, store, discordAPI)), + }, + { + Path: "/register", + Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, + Handler: auth.LogoutReq(handlers.Register(server, auth, conn, cfg, store)), + }, + { + Path: "/logout", + Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, + Handler: auth.LoginReq(handlers.Logout(server, auth, conn, discordAPI)), + }, + } + + htmxRoutes := []hws.Route{ + { + Path: "/htmx/isusernameunique", + Method: hws.MethodPOST, + Handler: handlers.IsUsernameUnique(server, conn, cfg, store), + }, } // Register the routes with the server - err := server.AddRoutes(routes...) + err := server.AddRoutes(append(routes, htmxRoutes...)...) if err != nil { return errors.Wrap(err, "server.AddRoutes") } diff --git a/cmd/oslstats/run.go b/cmd/oslstats/run.go index 7906c26..2d6f34f 100644 --- a/cmd/oslstats/run.go +++ b/cmd/oslstats/run.go @@ -9,18 +9,21 @@ import ( "time" "git.haelnorr.com/h/golib/hlog" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/pkg/embedfs" "github.com/pkg/errors" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/pkg/embedfs" ) // Initializes and runs the server -func run(ctx context.Context, w io.Writer, config *config.Config) error { +func run(ctx context.Context, w io.Writer, cfg *config.Config) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() // Setup the logger - logger, err := hlog.NewLogger(config.HLOG, w) + logger, err := hlog.NewLogger(cfg.HLOG, w) if err != nil { return errors.Wrap(err, "hlog.NewLogger") } @@ -28,7 +31,7 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error { // Setup the database connection logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Connecting to database") - bun, closedb, err := setupBun(ctx, config) + bun, closedb, err := setupBun(ctx, cfg) if err != nil { return errors.Wrap(err, "setupDBConn") } @@ -41,8 +44,19 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error { return errors.Wrap(err, "getStaticFiles") } + // Setup session store + logger.Debug().Msg("Setting up session store") + store := store.NewStore() + + // Setup Discord API client + logger.Debug().Msg("Setting up Discord API client") + discordAPI, err := discord.NewAPIClient(cfg.Discord, logger, cfg.HWSAuth.TrustedHost) + if err != nil { + return errors.Wrap(err, "discord.NewAPIClient") + } + logger.Debug().Msg("Setting up HTTP server") - httpServer, err := setupHttpServer(&staticFS, config, logger, bun) + httpServer, err := setupHttpServer(&staticFS, cfg, logger, bun, store, discordAPI) if err != nil { return errors.Wrap(err, "setupHttpServer") } diff --git a/go.mod b/go.mod index 26e14e9..e909fb8 100644 --- a/go.mod +++ b/go.mod @@ -6,20 +6,25 @@ require ( git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/hlog v0.10.4 - git.haelnorr.com/h/golib/hws v0.2.3 - git.haelnorr.com/h/golib/hwsauth v0.3.4 + git.haelnorr.com/h/golib/hws v0.3.1 + git.haelnorr.com/h/golib/hwsauth v0.5.2 github.com/a-h/templ v0.3.977 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 github.com/uptrace/bun v1.2.16 github.com/uptrace/bun/dialect/pgdialect v1.2.16 github.com/uptrace/bun/driver/pgdriver v1.2.16 - golang.org/x/crypto v0.45.0 ) require ( - git.haelnorr.com/h/golib/cookies v0.9.0 // indirect - git.haelnorr.com/h/golib/jwt v0.10.0 // indirect + github.com/gorilla/websocket v1.4.2 // indirect + golang.org/x/crypto v0.45.0 // indirect +) + +require ( + git.haelnorr.com/h/golib/cookies v0.9.0 + git.haelnorr.com/h/golib/jwt v0.10.1 // indirect + github.com/bwmarrin/discordgo v0.29.0 github.com/go-logr/logr v1.4.3 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/google/uuid v1.6.0 // indirect diff --git a/go.sum b/go.sum index 78c8ff7..afdf0fa 100644 --- a/go.sum +++ b/go.sum @@ -6,16 +6,18 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= -git.haelnorr.com/h/golib/hws v0.2.3 h1:gZQkBciXKh3jYw05vZncSR2lvIqi0H2MVfIWySySsmw= -git.haelnorr.com/h/golib/hws v0.2.3/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= -git.haelnorr.com/h/golib/hwsauth v0.3.4 h1:wwYBb6cQQ+x9hxmYuZBF4mVmCv/n4PjJV//e1+SgPOo= -git.haelnorr.com/h/golib/hwsauth v0.3.4/go.mod h1:LI7Qz68GPNIW8732Zwptb//ybjiFJOoXf4tgUuUEqHI= -git.haelnorr.com/h/golib/jwt v0.10.0 h1:8cI8mSnb8X+EmJtrBO/5UZwuBMtib0IE9dv85gkm94E= -git.haelnorr.com/h/golib/jwt v0.10.0/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= +git.haelnorr.com/h/golib/hws v0.3.1 h1:uFXAT8SuKs4VACBdrkmZ+dJjeBlSPgCKUPt8zGCcwrI= +git.haelnorr.com/h/golib/hws v0.3.1/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= +git.haelnorr.com/h/golib/hwsauth v0.5.2 h1:K4McXMEHtI5o4fAL3AZrmaMkwORNqSTV3MM6BExNKag= +git.haelnorr.com/h/golib/hwsauth v0.5.2/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ= +git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= +git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg= github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo= +github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= +github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -28,6 +30,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -66,13 +70,19 @@ go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8= diff --git a/internal/config/config.go b/internal/config/config.go index 87b9f15..112752a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,8 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/pkg/oauth" "github.com/joho/godotenv" "github.com/pkg/errors" ) @@ -15,6 +17,8 @@ type Config struct { HWS *hws.Config HWSAuth *hwsauth.Config HLOG *hlog.Config + Discord *discord.Config + OAuth *oauth.Config Flags *Flags } @@ -32,6 +36,8 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { hws.NewEZConfIntegration(), hwsauth.NewEZConfIntegration(), db.NewEZConfIntegration(), + discord.NewEZConfIntegration(), + oauth.NewEZConfIntegration(), ) if err := loader.ParseEnvVars(); err != nil { return nil, nil, errors.Wrap(err, "loader.ParseEnvVars") @@ -65,11 +71,23 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { return nil, nil, errors.New("DB Config not loaded") } + discordcfg, ok := loader.GetConfig("discord") + if !ok { + return nil, nil, errors.New("Dicord Config not loaded") + } + + oauthcfg, ok := loader.GetConfig("oauth") + if !ok { + return nil, nil, errors.New("OAuth Config not loaded") + } + config := &Config{ DB: dbcfg.(*db.Config), HWS: hwscfg.(*hws.Config), HWSAuth: hwsauthcfg.(*hwsauth.Config), HLOG: hlogcfg.(*hlog.Config), + Discord: discordcfg.(*discord.Config), + OAuth: oauthcfg.(*oauth.Config), Flags: flags, } diff --git a/internal/config/flags.go b/internal/config/flags.go index 7d89108..ba87131 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -5,16 +5,16 @@ import ( ) type Flags struct { - ResetDB bool - EnvDoc bool - ShowEnv bool - GenEnv string - EnvFile string + MigrateDB bool + EnvDoc bool + ShowEnv bool + GenEnv string + EnvFile string } func SetupFlags() *Flags { // Parse commandline args - resetDB := flag.Bool("resetdb", false, "Reset all the database tables with the updated models") + migrateDB := flag.Bool("migrate", false, "Reset all the database tables with the updated models") envDoc := flag.Bool("envdoc", false, "Print all environment variables and their documentation") showEnv := flag.Bool("showenv", false, "Print all environment variable values and their documentation") genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)") @@ -22,11 +22,11 @@ func SetupFlags() *Flags { flag.Parse() flags := &Flags{ - ResetDB: *resetDB, - EnvDoc: *envDoc, - ShowEnv: *showEnv, - GenEnv: *genEnv, - EnvFile: *envfile, + MigrateDB: *migrateDB, + EnvDoc: *envDoc, + ShowEnv: *showEnv, + GenEnv: *genEnv, + EnvFile: *envfile, } return flags } diff --git a/internal/db/discord_tokens.go b/internal/db/discord_tokens.go new file mode 100644 index 0000000..7cc09d9 --- /dev/null +++ b/internal/db/discord_tokens.go @@ -0,0 +1,95 @@ +package db + +import ( + "context" + "time" + + "git.haelnorr.com/h/oslstats/internal/discord" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type DiscordToken struct { + bun.BaseModel `bun:"table:discord_tokens,alias:dt"` + + DiscordID string `bun:"discord_id,pk,notnull"` + AccessToken string `bun:"access_token,notnull"` + RefreshToken string `bun:"refresh_token,notnull"` + ExpiresAt int64 `bun:"expires_at,notnull"` + Scope string `bun:"scope,notnull"` + TokenType string `bun:"token_type,notnull"` +} + +// UpdateDiscordToken adds the provided discord token to the database. +// If the user already has a token stored, it will replace that token instead. +func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error { + if token == nil { + return errors.New("token cannot be nil") + } + expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix() + + discordToken := &DiscordToken{ + DiscordID: user.DiscordID, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + ExpiresAt: expiresAt, + Scope: token.Scope, + TokenType: token.TokenType, + } + + _, err := tx.NewInsert(). + Model(discordToken). + On("CONFLICT (discord_id) DO UPDATE"). + Set("access_token = EXCLUDED.access_token"). + Set("refresh_token = EXCLUDED.refresh_token"). + Set("expires_at = EXCLUDED.expires_at"). + Exec(ctx) + + if err != nil { + return errors.Wrap(err, "tx.NewInsert") + } + return nil +} + +// DeleteDiscordTokens deletes a users discord OAuth tokens from the database. +// It returns the DiscordToken so that it can be revoked via the discord API +func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { + token, err := user.GetDiscordToken(ctx, tx) + if err != nil { + return nil, errors.Wrap(err, "user.GetDiscordToken") + } + _, err = tx.NewDelete(). + Model((*DiscordToken)(nil)). + Where("discord_id = ?", user.DiscordID). + Exec(ctx) + if err != nil { + return nil, errors.Wrap(err, "tx.NewDelete") + } + return token, nil +} + +// GetDiscordToken retrieves the users discord token from the database +func (user *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { + token := new(DiscordToken) + err := tx.NewSelect(). + Model(token). + Where("discord_id = ?", user.DiscordID). + Limit(1). + Scan(ctx) + if err != nil { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return token, nil +} + +// Convert reverts the token back into a *discord.Token +func (t *DiscordToken) Convert() *discord.Token { + token := &discord.Token{ + AccessToken: t.AccessToken, + RefreshToken: t.RefreshToken, + ExpiresIn: int(t.ExpiresAt - time.Now().Unix()), + Scope: t.Scope, + TokenType: t.TokenType, + } + return token +} diff --git a/internal/db/ezconf.go b/internal/db/ezconf.go index db4eac9..1cb08c5 100644 --- a/internal/db/ezconf.go +++ b/internal/db/ezconf.go @@ -37,5 +37,5 @@ func (e EZConfIntegration) GroupName() string { // NewEZConfIntegration creates a new EZConf integration helper func NewEZConfIntegration() EZConfIntegration { - return EZConfIntegration{name: "db", configFunc: ConfigFromEnv} + return EZConfIntegration{name: "DB", configFunc: ConfigFromEnv} } diff --git a/internal/db/user.go b/internal/db/user.go index 01764eb..40c5ca8 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -2,65 +2,30 @@ package db import ( "context" + "fmt" + "time" + "git.haelnorr.com/h/golib/hwsauth" + "github.com/bwmarrin/discordgo" "github.com/pkg/errors" "github.com/uptrace/bun" - "golang.org/x/crypto/bcrypt" ) +var CurrentUser hwsauth.ContextLoader[*User] + type User struct { bun.BaseModel `bun:"table:users,alias:u"` - ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key) - Username string `bun:"username,unique"` // Username (unique) - PasswordHash string `bun:"password_hash,nullzero"` // Bcrypt hashed password (not exported in JSON) - CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database - Bio string `bun:"bio"` // Short byline set by the user + ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key) + Username string `bun:"username,unique"` // Username (unique) + CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database + DiscordID string `bun:"discord_id,unique"` } func (user *User) GetID() int { return user.ID } -// Uses bcrypt to set the users password_hash from the given password -func (user *User) SetPassword(ctx context.Context, tx bun.Tx, password string) error { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "bcrypt.GenerateFromPassword") - } - newPassword := string(hashedPassword) - - _, err = tx.NewUpdate(). - Model(user). - Set("password_hash = ?", newPassword). - Where("id = ?", user.ID). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "tx.Update") - } - return nil -} - -// Uses bcrypt to check if the given password matches the users password_hash -func (user *User) CheckPassword(ctx context.Context, tx bun.Tx, password string) error { - var hashedPassword string - err := tx.NewSelect(). - Table("users"). - Column("password_hash"). - Where("id = ?", user.ID). - Limit(1). - Scan(ctx, &hashedPassword) - if err != nil { - return errors.Wrap(err, "tx.Select") - } - - err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) - if err != nil { - return errors.Wrap(err, "Username or password incorrect") - } - return nil -} - // Change the user's username func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername string) error { _, err := tx.NewUpdate(). @@ -75,35 +40,18 @@ func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername str return nil } -// Change the user's bio -func (user *User) ChangeBio(ctx context.Context, tx bun.Tx, newBio string) error { - _, err := tx.NewUpdate(). - Model(user). - Set("bio = ?", newBio). - Where("id = ?", user.ID). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "tx.Update") - } - user.Bio = newBio - return nil -} - // CreateUser creates a new user with the given username and password -func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*User, error) { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return nil, errors.Wrap(err, "bcrypt.GenerateFromPassword") +func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) { + if discorduser == nil { + return nil, errors.New("user cannot be nil") } - user := &User{ - Username: username, - PasswordHash: string(hashedPassword), - CreatedAt: 0, // You may want to set this to time.Now().Unix() - Bio: "", + Username: username, + CreatedAt: time.Now().Unix(), + DiscordID: discorduser.ID, } - _, err = tx.NewInsert(). + _, err := tx.NewInsert(). Model(user). Exec(ctx) if err != nil { @@ -116,6 +64,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*Use // GetUserByID queries the database for a user matching the given ID // Returns nil, nil if no user is found func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { + fmt.Printf("user id requested: %v", id) user := new(User) err := tx.NewSelect(). Model(user). @@ -149,6 +98,24 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, return user, nil } +// GetUserByDiscordID queries the database for a user matching the given discord id +// Returns nil, nil if no user is found +func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) { + user := new(User) + err := tx.NewSelect(). + Model(user). + Where("discord_id = ?", discordID). + Limit(1). + Scan(ctx) + if err != nil { + if err.Error() == "sql: no rows in result set" { + return nil, nil + } + return nil, errors.Wrap(err, "tx.Select") + } + return user, nil +} + // IsUsernameUnique checks if the given username is unique (not already taken) // Returns true if the username is available, false if it's taken func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) { diff --git a/internal/discord/api.go b/internal/discord/api.go new file mode 100644 index 0000000..aa490b5 --- /dev/null +++ b/internal/discord/api.go @@ -0,0 +1,61 @@ +package discord + +import ( + "net/http" + "sync" + "time" + + "git.haelnorr.com/h/golib/hlog" + "github.com/bwmarrin/discordgo" + "github.com/pkg/errors" +) + +type OAuthSession struct { + *discordgo.Session +} + +func NewOAuthSession(token *Token) (*OAuthSession, error) { + session, err := discordgo.New("Bearer " + token.AccessToken) + if err != nil { + return nil, errors.Wrap(err, "discordgo.New") + } + return &OAuthSession{Session: session}, nil +} + +func (s *OAuthSession) GetUser() (*discordgo.User, error) { + user, err := s.User("@me") + if err != nil { + return nil, errors.Wrap(err, "s.User") + } + return user, nil +} + +// APIClient is an HTTP client wrapper that handles Discord API rate limits +type APIClient struct { + cfg *Config + client *http.Client + logger *hlog.Logger + mu sync.RWMutex + buckets map[string]*RateLimitState + trustedHost string +} + +// NewAPIClient creates a new Discord API client with rate limit handling +func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APIClient, error) { + if cfg == nil { + return nil, errors.New("config cannot be nil") + } + if logger == nil { + return nil, errors.New("logger cannot be nil") + } + if trustedhost == "" { + return nil, errors.New("trustedhost cannot be empty") + } + return &APIClient{ + client: &http.Client{Timeout: 30 * time.Second}, + logger: logger, + buckets: make(map[string]*RateLimitState), + cfg: cfg, + trustedHost: trustedhost, + }, nil +} diff --git a/internal/discord/config.go b/internal/discord/config.go new file mode 100644 index 0000000..cc086ff --- /dev/null +++ b/internal/discord/config.go @@ -0,0 +1,50 @@ +package discord + +import ( + "strings" + + "git.haelnorr.com/h/golib/env" + "github.com/pkg/errors" +) + +type Config struct { + ClientID string // ENV DISCORD_CLIENT_ID: Discord application client ID (required) + ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required) + OAuthScopes string // Authorisation scopes for OAuth + RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required) +} + +func ConfigFromEnv() (any, error) { + cfg := &Config{ + ClientID: env.String("DISCORD_CLIENT_ID", ""), + ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""), + OAuthScopes: getOAuthScopes(), + RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""), + } + + // Check required fields + if cfg.ClientID == "" { + return nil, errors.New("Envar not set: DISCORD_CLIENT_ID") + } + if cfg.ClientSecret == "" { + return nil, errors.New("Envar not set: DISCORD_CLIENT_SECRET") + } + if cfg.RedirectPath == "" { + return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH") + } + + return cfg, nil +} + +func getOAuthScopes() string { + list := []string{ + "connections", + "email", + "guilds", + "gdm.join", + "guilds.members.read", + "identify", + } + scopes := strings.Join(list, "+") + return scopes +} diff --git a/internal/discord/ezconf.go b/internal/discord/ezconf.go new file mode 100644 index 0000000..8442714 --- /dev/null +++ b/internal/discord/ezconf.go @@ -0,0 +1,41 @@ +package discord + +import ( + "runtime" + "strings" +) + +// EZConfIntegration provides integration with ezconf for automatic configuration +type EZConfIntegration struct { + configFunc func() (any, error) + name string +} + +// PackagePath returns the path to the config package for source parsing +func (e EZConfIntegration) PackagePath() string { + _, filename, _, _ := runtime.Caller(0) + // Return directory of this file + return filename[:len(filename)-len("/ezconf.go")] +} + +// ConfigFunc returns the ConfigFromEnv function for ezconf +func (e EZConfIntegration) ConfigFunc() func() (any, error) { + return func() (any, error) { + return e.configFunc() + } +} + +// Name returns the name to use when registering with ezconf +func (e EZConfIntegration) Name() string { + return strings.ToLower(e.name) +} + +// GroupName returns the display name for grouping environment variables +func (e EZConfIntegration) GroupName() string { + return e.name +} + +// NewEZConfIntegration creates a new EZConf integration helper +func NewEZConfIntegration() EZConfIntegration { + return EZConfIntegration{name: "Discord", configFunc: ConfigFromEnv} +} diff --git a/internal/discord/oauth.go b/internal/discord/oauth.go new file mode 100644 index 0000000..faff919 --- /dev/null +++ b/internal/discord/oauth.go @@ -0,0 +1,148 @@ +package discord + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/pkg/errors" +) + +// Token represents a response from the Discord OAuth API after a successful authorization request +type Token struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` +} + +const oauthurl string = "https://discord.com/oauth2/authorize" +const apiurl string = "https://discord.com/api/v10" + +// GetOAuthLink generates a new Discord OAuth2 link for user authentication +func (api *APIClient) GetOAuthLink(state string) (string, error) { + if state == "" { + return "", errors.New("state cannot be empty") + } + values := url.Values{} + values.Add("response_type", "code") + values.Add("client_id", api.cfg.ClientID) + values.Add("scope", api.cfg.OAuthScopes) + values.Add("state", state) + values.Add("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath)) + values.Add("prompt", "none") + + return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil +} + +// AuthorizeWithCode uses a users authorization token generated by OAuth2 to get a token for +// making requests to the API on behalf of the user +func (api *APIClient) AuthorizeWithCode(code string) (*Token, error) { + if code == "" { + return nil, errors.New("code cannot be empty") + } + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath)) + req, err := http.NewRequest( + "POST", + apiurl+"/oauth2/token", + strings.NewReader(data.Encode()), + ) + if err != nil { + return nil, errors.Wrap(err, "failed to create request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret) + resp, err := api.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body)) + } + var tokenResp Token + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "failed to parse token response") + } + return &tokenResp, nil +} + +// RefreshToken uses the refresh token to generate a new token pair +func (api *APIClient) RefreshToken(token *Token) (*Token, error) { + if token == nil { + return nil, errors.New("token cannot be nil") + } + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", token.RefreshToken) + req, err := http.NewRequest( + "POST", + apiurl+"/oauth2/token", + strings.NewReader(data.Encode()), + ) + if err != nil { + return nil, errors.Wrap(err, "failed to create request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret) + resp, err := api.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body)) + } + var tokenResp Token + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "failed to parse token response") + } + return &tokenResp, nil +} + +// RevokeToken sends a request to the Discord API to revoke the token pair +func (api *APIClient) RevokeToken(token *Token) error { + if token == nil { + return errors.New("token cannot be nil") + } + data := url.Values{} + data.Set("token", token.AccessToken) + data.Set("token_type_hint", "access_token") + req, err := http.NewRequest( + "POST", + apiurl+"/oauth2/token/revoke", + strings.NewReader(data.Encode()), + ) + if err != nil { + return errors.Wrap(err, "failed to create request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret) + resp, err := api.Do(req) + if err != nil { + return errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.Errorf("discord API returned status %d", resp.StatusCode) + } + return nil +} diff --git a/internal/discord/ratelimit.go b/internal/discord/ratelimit.go new file mode 100644 index 0000000..21f3888 --- /dev/null +++ b/internal/discord/ratelimit.go @@ -0,0 +1,216 @@ +package discord + +import ( + "net" + "net/http" + "strconv" + "time" + + "github.com/pkg/errors" +) + +// RateLimitState tracks rate limit information for a specific bucket +type RateLimitState struct { + Remaining int // Requests remaining in current window + Limit int // Total requests allowed in window + Reset time.Time // When the rate limit resets + Bucket string // Discord's bucket identifier +} + +// Do executes an HTTP request with automatic rate limit handling +// It will wait if rate limits are about to be exceeded and retry once if a 429 is received +func (c *APIClient) Do(req *http.Request) (*http.Response, error) { + if req == nil { + return nil, errors.New("request cannot be nil") + } + + // Step 1: Check if we need to wait before making request + bucket := c.getBucketFromRequest(req) + if err := c.waitIfNeeded(bucket); err != nil { + return nil, err + } + + // Step 2: Execute request + resp, err := c.client.Do(req) + if err != nil { + // Check if it's a network timeout + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, errors.Wrap(err, "request timed out") + } + return nil, errors.Wrap(err, "http request failed") + } + + // Step 3: Update rate limit state from response headers + c.updateRateLimit(resp.Header) + + // Step 4: Handle 429 (rate limited) + if resp.StatusCode == http.StatusTooManyRequests { + resp.Body.Close() // Close original response + + retryAfter := c.parseRetryAfter(resp.Header) + + // No Retry-After header, can't retry safely + if retryAfter == 0 { + c.logger.Warn(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Msg("Rate limited but no Retry-After header provided") + return nil, errors.New("discord API rate limited but no Retry-After header provided") + } + + // Retry-After exceeds 30 second cap + if retryAfter > 30*time.Second { + c.logger.Warn(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Dur("retry_after", retryAfter). + Msg("Rate limited with Retry-After exceeding 30s cap, not retrying") + return nil, errors.Errorf( + "discord API rate limited (retry after %s exceeds 30s cap)", + retryAfter, + ) + } + + // Wait and retry + c.logger.Warn(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Dur("retry_after", retryAfter). + Msg("Rate limited, waiting before retry") + + time.Sleep(retryAfter) + + // Retry the request + resp, err = c.client.Do(req) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, errors.Wrap(err, "retry request timed out") + } + return nil, errors.Wrap(err, "retry request failed") + } + + // Update rate limit again after retry + c.updateRateLimit(resp.Header) + + // If STILL rate limited after retry, return error + if resp.StatusCode == http.StatusTooManyRequests { + resp.Body.Close() + c.logger.Error(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Msg("Still rate limited after retry, Discord may be experiencing issues") + return nil, errors.Errorf( + "discord API still rate limited after retry (waited %s), Discord may be experiencing issues", + retryAfter, + ) + } + } + + return resp, nil +} + +// getBucketFromRequest extracts or generates bucket ID from request +// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers +func (c *APIClient) getBucketFromRequest(req *http.Request) string { + return req.Method + ":" + req.URL.Path +} + +// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits +func (c *APIClient) waitIfNeeded(bucket string) error { + c.mu.RLock() + state, exists := c.buckets[bucket] + c.mu.RUnlock() + + if !exists { + return nil // No state yet, proceed + } + + now := time.Now() + + // If we have no remaining requests and reset hasn't occurred, wait + if state.Remaining == 0 && now.Before(state.Reset) { + waitDuration := time.Until(state.Reset) + // Add small buffer (100ms) to ensure reset has occurred + waitDuration += 100 * time.Millisecond + + if waitDuration > 0 { + c.logger.Debug(). + Str("bucket", bucket). + Dur("wait_duration", waitDuration). + Msg("Proactively waiting for rate limit reset") + time.Sleep(waitDuration) + } + } + + return nil +} + +// updateRateLimit parses response headers and updates bucket state +func (c *APIClient) updateRateLimit(headers http.Header) { + bucket := headers.Get("X-RateLimit-Bucket") + if bucket == "" { + return // No bucket info, can't track + } + + // Parse headers + limit := c.parseInt(headers.Get("X-RateLimit-Limit")) + remaining := c.parseInt(headers.Get("X-RateLimit-Remaining")) + resetAfter := c.parseFloat(headers.Get("X-RateLimit-Reset-After")) + + state := &RateLimitState{ + Bucket: bucket, + Limit: limit, + Remaining: remaining, + Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))), + } + + c.mu.Lock() + c.buckets[bucket] = state + c.mu.Unlock() + + // Log rate limit state for debugging + c.logger.Debug(). + Str("bucket", bucket). + Int("remaining", remaining). + Int("limit", limit). + Dur("reset_in", time.Until(state.Reset)). + Msg("Rate limit state updated") +} + +// parseRetryAfter extracts retry delay from Retry-After header +func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration { + retryAfter := headers.Get("Retry-After") + if retryAfter == "" { + return 0 + } + + // Discord returns seconds as float + seconds := c.parseFloat(retryAfter) + if seconds <= 0 { + return 0 + } + + return time.Duration(seconds * float64(time.Second)) +} + +// parseInt parses an integer from a header value, returns 0 on error +func (c *APIClient) parseInt(s string) int { + if s == "" { + return 0 + } + i, _ := strconv.Atoi(s) + return i +} + +// parseFloat parses a float from a header value, returns 0 on error +func (c *APIClient) parseFloat(s string) float64 { + if s == "" { + return 0 + } + f, _ := strconv.ParseFloat(s, 64) + return f +} diff --git a/internal/discord/ratelimit_test.go b/internal/discord/ratelimit_test.go new file mode 100644 index 0000000..38dfb04 --- /dev/null +++ b/internal/discord/ratelimit_test.go @@ -0,0 +1,517 @@ +package discord + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "git.haelnorr.com/h/golib/hlog" +) + +// testLogger creates a test logger for testing +func testLogger(t *testing.T) *hlog.Logger { + level, _ := hlog.LogLevel("debug") + cfg := &hlog.Config{ + LogLevel: level, + LogOutput: "console", + } + logger, err := hlog.NewLogger(cfg, io.Discard) + if err != nil { + t.Fatalf("failed to create test logger: %v", err) + } + return logger +} + +// testConfig creates a test config for testing +func testConfig() *Config { + return &Config{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + OAuthScopes: "identify+email", + RedirectPath: "/oauth/callback", + } +} + +func TestNewRateLimitedClient(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + if client == nil { + t.Fatal("NewAPIClient returned nil") + } + if client.client == nil { + t.Error("client.client is nil") + } + if client.logger == nil { + t.Error("client.logger is nil") + } + if client.buckets == nil { + t.Error("client.buckets map is nil") + } + if client.cfg == nil { + t.Error("client.cfg is nil") + } + if client.trustedHost != "trusted-host.example.com" { + t.Errorf("expected trustedHost='trusted-host.example.com', got '%s'", client.trustedHost) + } +} + +func TestAPIClient_Do_Success(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + // Mock server that returns success with rate limit headers + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("X-RateLimit-Limit", "5") + w.Header().Set("X-RateLimit-Remaining", "3") + w.Header().Set("X-RateLimit-Reset-After", "2.5") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + req, err := http.NewRequest("GET", server.URL+"/test", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Do() returned error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } + + // Check that rate limit state was updated + client.mu.RLock() + state, exists := client.buckets["test-bucket"] + client.mu.RUnlock() + + if !exists { + t.Fatal("rate limit state not stored") + } + if state.Remaining != 3 { + t.Errorf("expected remaining=3, got %d", state.Remaining) + } + if state.Limit != 5 { + t.Errorf("expected limit=5, got %d", state.Limit) + } +} + +func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + attemptCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + if attemptCount == 1 { + // First request: return 429 + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("Retry-After", "0.1") // 100ms + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + "error_description": "You are being rate limited", + }) + return + } + // Second request: success + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("X-RateLimit-Limit", "5") + w.Header().Set("X-RateLimit-Remaining", "4") + w.Header().Set("X-RateLimit-Reset-After", "2.5") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + start := time.Now() + resp, err := client.Do(req) + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("Do() returned error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200 after retry, got %d", resp.StatusCode) + } + + if attemptCount != 2 { + t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount) + } + + // Should have waited approximately 100ms + if elapsed < 100*time.Millisecond { + t.Errorf("expected delay of ~100ms, but took %v", elapsed) + } +} + +func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + attemptCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // Always return 429 + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("Retry-After", "0.05") // 50ms + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + }) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + + if err == nil { + resp.Body.Close() + t.Fatal("Do() should have returned error after failed retry") + } + + if !strings.Contains(err.Error(), "still rate limited after retry") { + t.Errorf("expected 'still rate limited after retry' error, got: %v", err) + } + + if attemptCount != 2 { + t.Errorf("expected 2 attempts, got %d", attemptCount) + } +} + +func TestAPIClient_Do_RateLimitTooLong(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + }) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + start := time.Now() + resp, err := client.Do(req) + elapsed := time.Since(start) + + if err == nil { + resp.Body.Close() + t.Fatal("Do() should have returned error for Retry-After > 30s") + } + + if !strings.Contains(err.Error(), "exceeds 30s cap") { + t.Errorf("expected 'exceeds 30s cap' error, got: %v", err) + } + + // Should NOT have waited (immediate error) + if elapsed > 1*time.Second { + t.Errorf("should return immediately, but took %v", elapsed) + } +} + +func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return 429 but NO Retry-After header + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + }) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + + if err == nil { + resp.Body.Close() + t.Fatal("Do() should have returned error when no Retry-After header") + } + + if !strings.Contains(err.Error(), "no Retry-After header") { + t.Errorf("expected 'no Retry-After header' error, got: %v", err) + } +} + +func TestAPIClient_UpdateRateLimit(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + headers := http.Header{} + headers.Set("X-RateLimit-Bucket", "global") + headers.Set("X-RateLimit-Limit", "10") + headers.Set("X-RateLimit-Remaining", "7") + headers.Set("X-RateLimit-Reset-After", "5.5") + + client.updateRateLimit(headers) + + client.mu.RLock() + state, exists := client.buckets["global"] + client.mu.RUnlock() + + if !exists { + t.Fatal("bucket state not created") + } + + if state.Bucket != "global" { + t.Errorf("expected bucket='global', got '%s'", state.Bucket) + } + if state.Limit != 10 { + t.Errorf("expected limit=10, got %d", state.Limit) + } + if state.Remaining != 7 { + t.Errorf("expected remaining=7, got %d", state.Remaining) + } + + // Check reset time is approximately 5.5 seconds from now + resetIn := time.Until(state.Reset) + if resetIn < 5*time.Second || resetIn > 6*time.Second { + t.Errorf("expected reset in ~5.5s, got %v", resetIn) + } +} + +func TestAPIClient_WaitIfNeeded(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + // Set up a bucket with 0 remaining and reset in future + bucket := "test-bucket" + client.mu.Lock() + client.buckets[bucket] = &RateLimitState{ + Bucket: bucket, + Limit: 5, + Remaining: 0, + Reset: time.Now().Add(200 * time.Millisecond), + } + client.mu.Unlock() + + start := time.Now() + err = client.waitIfNeeded(bucket) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("waitIfNeeded returned error: %v", err) + } + + // Should have waited ~200ms + 100ms buffer + if elapsed < 200*time.Millisecond { + t.Errorf("expected wait of ~300ms, but took %v", elapsed) + } + if elapsed > 500*time.Millisecond { + t.Errorf("waited too long: %v", elapsed) + } +} + +func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + // Set up a bucket with remaining requests + bucket := "test-bucket" + client.mu.Lock() + client.buckets[bucket] = &RateLimitState{ + Bucket: bucket, + Limit: 5, + Remaining: 3, + Reset: time.Now().Add(5 * time.Second), + } + client.mu.Unlock() + + start := time.Now() + err = client.waitIfNeeded(bucket) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("waitIfNeeded returned error: %v", err) + } + + // Should NOT wait (has remaining requests) + if elapsed > 10*time.Millisecond { + t.Errorf("should not wait when remaining > 0, but took %v", elapsed) + } +} + +func TestAPIClient_Do_Concurrent(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + requestCount := 0 + var mu sync.Mutex + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requestCount++ + count := requestCount + mu.Unlock() + + w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket") + w.Header().Set("X-RateLimit-Limit", "10") + w.Header().Set("X-RateLimit-Remaining", "5") + w.Header().Set("X-RateLimit-Reset-After", "1.0") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))}) + })) + defer server.Close() + + // Launch 10 concurrent requests + var wg sync.WaitGroup + errors := make(chan error, 10) + + for range 10 { + wg.Go( + func() { + req, err := http.NewRequest("GET", server.URL+"/test", nil) + if err != nil { + errors <- err + return + } + + resp, err := client.Do(req) + if err != nil { + errors <- err + return + } + resp.Body.Close() + }) + } + + wg.Wait() + close(errors) + + // Check for any errors + for err := range errors { + t.Errorf("concurrent request failed: %v", err) + } + + // All requests should have completed + mu.Lock() + finalCount := requestCount + mu.Unlock() + + if finalCount != 10 { + t.Errorf("expected 10 requests, got %d", finalCount) + } + + // Check rate limit state is consistent (no data races) + client.mu.RLock() + state, exists := client.buckets["concurrent-bucket"] + client.mu.RUnlock() + + if !exists { + t.Fatal("bucket state not found after concurrent requests") + } + + // State should exist and be valid + if state.Limit != 10 { + t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit) + } +} + +func TestAPIClient_ParseRetryAfter(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + tests := []struct { + name string + header string + expected time.Duration + }{ + {"integer seconds", "2", 2 * time.Second}, + {"float seconds", "2.5", 2500 * time.Millisecond}, + {"zero", "0", 0}, + {"empty", "", 0}, + {"invalid", "abc", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers := http.Header{} + headers.Set("Retry-After", tt.header) + + result := client.parseRetryAfter(headers) + + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go new file mode 100644 index 0000000..12082a6 --- /dev/null +++ b/internal/handlers/callback.go @@ -0,0 +1,205 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/cookies" + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/hwsauth" + "github.com/pkg/errors" + "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/pkg/oauth" +) + +func Callback( + server *hws.Server, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + conn *bun.DB, + cfg *config.Config, + store *store.Store, + discordAPI *discord.APIClient, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5) + + if exceeded { + err := errors.Errorf( + "callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", + attempts, + track.IP, + track.UserAgent, + track.Path, + track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + ) + + store.ClearRedirectTrack(r, "/callback") + + throwError( + server, + w, + r, + http.StatusBadRequest, + "OAuth callback failed: Too many redirect attempts. Please try logging in again.", + err, + "warn", + ) + return + } + + state := r.URL.Query().Get("state") + code := r.URL.Query().Get("code") + if state == "" && code == "" { + http.Redirect(w, r, "/", http.StatusBadRequest) + return + } + data, err := verifyState(cfg.OAuth, w, r, state) + if err != nil { + if vsErr, ok := err.(*verifyStateError); ok { + if vsErr.IsCookieError() { + throwUnauthorized(server, w, r, "OAuth session not found or expired", err) + } else { + throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) + } + } else { + throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) + } + return + } + store.ClearRedirectTrack(r, "/callback") + + switch data { + case "login": + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "DB Transaction failed to start", err) + return + } + defer tx.Rollback() + redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI) + if err != nil { + throwInternalServiceError(server, w, r, "OAuth login failed", err) + return + } + tx.Commit() + redirect() + return + } + }, + ) +} + +type verifyStateError struct { + err error + cookieError bool +} + +func (e *verifyStateError) Error() string { + return e.err.Error() +} + +func (e *verifyStateError) IsCookieError() bool { + return e.cookieError +} + +func verifyState( + cfg *oauth.Config, + w http.ResponseWriter, + r *http.Request, + state string, +) (string, error) { + if r == nil { + return "", errors.New("request cannot be nil") + } + if state == "" { + return "", errors.New("state param field is empty") + } + + uak, err := oauth.GetStateCookie(r) + if err != nil { + return "", &verifyStateError{ + err: errors.Wrap(err, "oauth.GetStateCookie"), + cookieError: true, + } + } + + data, err := oauth.VerifyState(cfg, state, uak) + if err != nil { + return "", &verifyStateError{ + err: errors.Wrap(err, "oauth.VerifyState"), + cookieError: false, + } + } + + oauth.DeleteStateCookie(w) + return data, nil +} + +func login( + ctx context.Context, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + tx bun.Tx, + cfg *config.Config, + w http.ResponseWriter, + r *http.Request, + code string, + store *store.Store, + discordAPI *discord.APIClient, +) (func(), error) { + token, err := discordAPI.AuthorizeWithCode(code) + if err != nil { + return nil, errors.Wrap(err, "discordAPI.AuthorizeWithCode") + } + session, err := discord.NewOAuthSession(token) + if err != nil { + return nil, errors.Wrap(err, "discord.NewOAuthSession") + } + discorduser, err := session.GetUser() + if err != nil { + return nil, errors.Wrap(err, "session.GetUser") + } + + user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID) + if err != nil { + return nil, errors.Wrap(err, "db.GetUserByDiscordID") + } + var redirect string + if user == nil { + sessionID, err := store.CreateRegistrationSession(discorduser, token) + if err != nil { + return nil, errors.Wrap(err, "store.CreateRegistrationSession") + } + http.SetCookie(w, &http.Cookie{ + Name: "registration_session", + Path: "/", + Value: sessionID, + MaxAge: 300, // 5 minutes + HttpOnly: true, + Secure: cfg.HWSAuth.SSL, + SameSite: http.SameSiteLaxMode, + }) + redirect = "/register" + } else { + err = user.UpdateDiscordToken(ctx, tx, token) + if err != nil { + return nil, errors.Wrap(err, "user.UpdateDiscordToken") + } + err := auth.Login(w, r, user, true) + if err != nil { + return nil, errors.Wrap(err, "auth.Login") + } + redirect = cookies.CheckPageFrom(w, r) + } + return func() { + http.Redirect(w, r, redirect, http.StatusSeeOther) + }, nil +} diff --git a/internal/handlers/errorpage.go b/internal/handlers/errorpage.go index 2fa7699..d6e4a25 100644 --- a/internal/handlers/errorpage.go +++ b/internal/handlers/errorpage.go @@ -5,24 +5,93 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/view/page" - "github.com/pkg/errors" ) -func ErrorPage( - errorCode int, -) (hws.ErrorPage, error) { +// func ErrorPage( +// error hws.HWSError, +// ) (hws.ErrorPage, error) { +// messages := map[int]string{ +// 400: "The request you made was malformed or unexpected.", +// 401: "You need to login to view this page.", +// 403: "You do not have permission to view this page.", +// 404: "The page or resource you have requested does not exist.", +// 500: `An error occured on the server. Please try again, and if this +// continues to happen contact an administrator.`, +// 503: "The server is currently down for maintenance and should be back soon. =)", +// } +// msg, exists := messages[error.StatusCode] +// if !exists { +// return nil, errors.New("No valid message for the given code") +// } +// return page.Error(error.StatusCode, http.StatusText(error.StatusCode), msg), nil +// } + +func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) { + // Determine if this status code should show technical details + showDetails := shouldShowDetails(hwsError.StatusCode) + + // Get the user-friendly message + message := hwsError.Message + if message == "" { + // Fallback to default messages if no custom message provided + message = getDefaultMessage(hwsError.StatusCode) + } + + // Get technical details if applicable + var details string + if showDetails && hwsError.Error != nil { + details = hwsError.Error.Error() + } + + // Render appropriate template + if details != "" { + return page.ErrorWithDetails( + hwsError.StatusCode, + http.StatusText(hwsError.StatusCode), + message, + details, + ), nil + } + + return page.Error( + hwsError.StatusCode, + http.StatusText(hwsError.StatusCode), + message, + ), nil +} + +// shouldShowDetails determines if a status code should display technical details +func shouldShowDetails(statusCode int) bool { + switch statusCode { + case 400, 500, 503: // Bad Request, Internal Server Error, Service Unavailable + return true + case 401, 403, 404: // Unauthorized, Forbidden, Not Found + return false + default: + // For unknown codes, show details for 5xx errors + return statusCode >= 500 + } +} + +// getDefaultMessage provides fallback messages for status codes +func getDefaultMessage(statusCode int) string { messages := map[int]string{ 400: "The request you made was malformed or unexpected.", 401: "You need to login to view this page.", 403: "You do not have permission to view this page.", 404: "The page or resource you have requested does not exist.", - 500: `An error occured on the server. Please try again, and if this - continues to happen contact an administrator.`, + 500: `An error occurred on the server. Please try again, and if this + continues to happen contact an administrator.`, 503: "The server is currently down for maintenance and should be back soon. =)", } - msg, exists := messages[errorCode] + + msg, exists := messages[statusCode] if !exists { - return nil, errors.New("No valid message for the given code") + if statusCode >= 500 { + return "A server error occurred. Please try again later." + } + return "An error occurred while processing your request." } - return page.Error(errorCode, http.StatusText(errorCode), msg), nil + + return msg } diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go new file mode 100644 index 0000000..4e26611 --- /dev/null +++ b/internal/handlers/errors.go @@ -0,0 +1,109 @@ +package handlers + +import ( + "fmt" + "net/http" + + "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" +) + +// throwError is a generic helper that all throw* functions use internally +func throwError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + statusCode int, + msg string, + err error, + level string, +) { + err = s.ThrowError(w, r, hws.HWSError{ + StatusCode: statusCode, + Message: msg, + Error: err, + Level: hws.ErrorLevel(level), + RenderErrorPage: true, // throw* family always renders error pages + }) + if err != nil { + s.ThrowFatal(w, err) + } +} + +// throwInternalServiceError handles 500 errors (server failures) +func throwInternalServiceError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusInternalServerError, msg, err, "error") +} + +// throwBadRequest handles 400 errors (malformed requests) +func throwBadRequest( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusBadRequest, msg, err, "debug") +} + +// throwForbidden handles 403 errors (normal permission denials) +func throwForbidden( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusForbidden, msg, err, "debug") +} + +// throwForbiddenSecurity handles 403 errors for security events (uses WARN level) +func throwForbiddenSecurity( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusForbidden, msg, err, "warn") +} + +// throwUnauthorized handles 401 errors (not authenticated) +func throwUnauthorized( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusUnauthorized, msg, err, "debug") +} + +// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level) +func throwUnauthorizedSecurity( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusUnauthorized, msg, err, "warn") +} + +// throwNotFound handles 404 errors +func throwNotFound( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + path string, +) { + msg := fmt.Sprintf("The requested resource was not found: %s", path) + err := errors.New("Resource not found") + throwError(s, w, r, http.StatusNotFound, msg, err, "debug") +} diff --git a/internal/handlers/index.go b/internal/handlers/index.go index 3274302..25e7ac5 100644 --- a/internal/handlers/index.go +++ b/internal/handlers/index.go @@ -14,34 +14,7 @@ func Index(server *hws.Server) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - page, err := ErrorPage(http.StatusNotFound) - if err != nil { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to generate the error page", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: false, - }) - if err != nil { - server.ThrowFatal(w, err) - } - return - } - err = page.Render(r.Context(), w) - if err != nil { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to render the error page", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: false, - }) - if err != nil { - server.ThrowFatal(w, err) - } - return - } + throwNotFound(server, w, r, r.URL.Path) } page.Index().Render(r.Context(), w) }, diff --git a/internal/handlers/isusernameunique.go b/internal/handlers/isusernameunique.go new file mode 100644 index 0000000..f441784 --- /dev/null +++ b/internal/handlers/isusernameunique.go @@ -0,0 +1,45 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/store" + "github.com/uptrace/bun" +) + +func IsUsernameUnique( + server *hws.Server, + conn *bun.DB, + cfg *config.Config, + store *store.Store, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + username := r.FormValue("username") + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database transaction failed", err) + return + } + defer tx.Rollback() + unique, err := db.IsUsernameUnique(ctx, tx, username) + if err != nil { + throwInternalServiceError(server, w, r, "Database query failed", err) + return + } + tx.Commit() + if !unique { + w.WriteHeader(http.StatusConflict) + } else { + w.WriteHeader(http.StatusOK) + } + }, + ) +} diff --git a/internal/handlers/login.go b/internal/handlers/login.go new file mode 100644 index 0000000..93bd6c8 --- /dev/null +++ b/internal/handlers/login.go @@ -0,0 +1,63 @@ +package handlers + +import ( + "net/http" + + "git.haelnorr.com/h/golib/cookies" + "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/pkg/oauth" +) + +func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost) + attempts, exceeded, track := st.TrackRedirect(r, "/login", 5) + + if exceeded { + err := errors.Errorf( + "login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", + attempts, + track.IP, + track.UserAgent, + track.Path, + track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + ) + + st.ClearRedirectTrack(r, "/login") + + throwError( + server, + w, + r, + http.StatusBadRequest, + "Login failed: Too many redirect attempts. Please clear your browser cookies and try again.", + err, + "warn", + ) + return + } + + state, uak, err := oauth.GenerateState(cfg.OAuth, "login") + if err != nil { + throwInternalServiceError(server, w, r, "Failed to generate state token", err) + return + } + oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL) + + link, err := discordAPI.GetOAuthLink(state) + if err != nil { + throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err) + return + } + st.ClearRedirectTrack(r, "/login") + + http.Redirect(w, r, link, http.StatusSeeOther) + }, + ) +} diff --git a/internal/handlers/logout.go b/internal/handlers/logout.go new file mode 100644 index 0000000..02b9a9d --- /dev/null +++ b/internal/handlers/logout.go @@ -0,0 +1,59 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/hwsauth" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +func Logout( + server *hws.Server, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + conn *bun.DB, + discordAPI *discord.APIClient, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) + return + } + defer tx.Rollback() + + user := db.CurrentUser(r.Context()) + if user == nil { + // JIC - should be impossible to get here if route is protected by LoginReq + w.Header().Set("HX-Redirect", "/") + return + } + token, err := user.DeleteDiscordTokens(ctx, tx) + if err != nil { + throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens")) + return + } + err = discordAPI.RevokeToken(token.Convert()) + if err != nil { + throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken")) + return + } + err = auth.Logout(tx, w, r) + if err != nil { + throwInternalServiceError(server, w, r, "Logout failed", err) + return + } + tx.Commit() + w.Header().Set("HX-Redirect", "/") + }, + ) +} diff --git a/internal/handlers/register.go b/internal/handlers/register.go new file mode 100644 index 0000000..25ded6d --- /dev/null +++ b/internal/handlers/register.go @@ -0,0 +1,129 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/cookies" + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/hwsauth" + "github.com/pkg/errors" + "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/internal/view/page" +) + +func Register( + server *hws.Server, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + conn *bun.DB, + cfg *config.Config, + store *store.Store, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + attempts, exceeded, track := store.TrackRedirect(r, "/register", 3) + + if exceeded { + err := errors.Errorf( + "registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t", + attempts, + track.IP, + track.UserAgent, + track.Path, + track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + cfg.HWSAuth.SSL, + ) + + store.ClearRedirectTrack(r, "/register") + + throwError( + server, + w, + r, + http.StatusBadRequest, + "Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.", + err, + "warn", + ) + return + } + + sessionCookie, err := r.Cookie("registration_session") + if err != nil { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + details, ok := store.GetRegistrationSession(sessionCookie.Value) + if !ok { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + + store.ClearRedirectTrack(r, "/register") + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database transaction failed", err) + return + } + defer tx.Rollback() + method := r.Method + if method == "GET" { + tx.Commit() + page.Register(details.DiscordUser.Username).Render(r.Context(), w) + return + } + if method == "POST" { + username := r.FormValue("username") + user, err := registerUser(ctx, tx, username, details) + if err != nil { + throwInternalServiceError(server, w, r, "Registration failed", err) + return + } + tx.Commit() + if user == nil { + w.WriteHeader(http.StatusConflict) + } else { + err = auth.Login(w, r, user, true) + if err != nil { + throwInternalServiceError(server, w, r, "Login failed", err) + return + } + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) + } + return + } + }, + ) +} + +func registerUser( + ctx context.Context, + tx bun.Tx, + username string, + details *store.RegistrationSession, +) (*db.User, error) { + unique, err := db.IsUsernameUnique(ctx, tx, username) + if err != nil { + return nil, errors.Wrap(err, "db.IsUsernameUnique") + } + if !unique { + return nil, nil + } + user, err := db.CreateUser(ctx, tx, username, details.DiscordUser) + if err != nil { + return nil, errors.Wrap(err, "db.CreateUser") + } + err = user.UpdateDiscordToken(ctx, tx, details.Token) + if err != nil { + return nil, errors.Wrap(err, "db.UpdateDiscordToken") + } + return user, nil +} diff --git a/internal/handlers/static.go b/internal/handlers/static.go index fb444c9..3176129 100644 --- a/internal/handlers/static.go +++ b/internal/handlers/static.go @@ -15,16 +15,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler { if err != nil { // If we can't create the file server, return a handler that always errors return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to load the file system", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: true, - }) - if err != nil { - server.ThrowFatal(w, err) - } + throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err) }) } diff --git a/internal/store/newlogin.go b/internal/store/newlogin.go new file mode 100644 index 0000000..0e62529 --- /dev/null +++ b/internal/store/newlogin.go @@ -0,0 +1,46 @@ +package store + +import ( + "errors" + "time" + + "git.haelnorr.com/h/oslstats/internal/discord" + "github.com/bwmarrin/discordgo" +) + +type RegistrationSession struct { + DiscordUser *discordgo.User + Token *discord.Token + ExpiresAt time.Time +} + +func (s *Store) CreateRegistrationSession(user *discordgo.User, token *discord.Token) (string, error) { + if user == nil { + return "", errors.New("user cannot be nil") + } + if token == nil { + return "", errors.New("token cannot be nil") + } + id := generateID() + s.sessions.Store(id, &RegistrationSession{ + DiscordUser: user, + Token: token, + ExpiresAt: time.Now().Add(5 * time.Minute), + }) + return id, nil +} + +func (s *Store) GetRegistrationSession(id string) (*RegistrationSession, bool) { + val, ok := s.sessions.Load(id) + if !ok { + return nil, false + } + + session := val.(*RegistrationSession) + if time.Now().After(session.ExpiresAt) { + s.sessions.Delete(id) + return nil, false + } + + return session, true +} diff --git a/internal/store/redirects.go b/internal/store/redirects.go new file mode 100644 index 0000000..e2e7852 --- /dev/null +++ b/internal/store/redirects.go @@ -0,0 +1,95 @@ +package store + +import ( + "net" + "net/http" + "strings" + "time" +) + +// getClientIP extracts the client IP address, checking X-Forwarded-For first +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header (comma-separated list, first is client) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the list + ips := strings.Split(xff, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // Fall back to RemoteAddr (format: "IP:port" or "[IPv6]:port") + // Use net.SplitHostPort to properly handle both IPv4 and IPv6 + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // If SplitHostPort fails, return as-is (shouldn't happen with valid RemoteAddr) + return r.RemoteAddr + } + return host +} + +// TrackRedirect increments the redirect counter for this IP+UA+Path combination +// Returns the current attempt count, whether limit was exceeded, and the track details +func (s *Store) TrackRedirect(r *http.Request, path string, maxAttempts int) (attempts int, exceeded bool, track *RedirectTrack) { + if r == nil { + return 0, false, nil + } + + ip := getClientIP(r) + userAgent := r.UserAgent() + key := redirectKey(ip, userAgent, path) + + now := time.Now() + expiresAt := now.Add(5 * time.Minute) + + // Try to load existing track + val, exists := s.redirectTracks.Load(key) + if exists { + track = val.(*RedirectTrack) + + // Check if expired + if now.After(track.ExpiresAt) { + // Expired, start fresh + track = &RedirectTrack{ + IP: ip, + UserAgent: userAgent, + Path: path, + Attempts: 1, + FirstSeen: now, + ExpiresAt: expiresAt, + } + s.redirectTracks.Store(key, track) + return 1, false, track + } + + // Increment existing + track.Attempts++ + track.ExpiresAt = expiresAt // Extend expiry + exceeded = track.Attempts >= maxAttempts + return track.Attempts, exceeded, track + } + + // Create new track + track = &RedirectTrack{ + IP: ip, + UserAgent: userAgent, + Path: path, + Attempts: 1, + FirstSeen: now, + ExpiresAt: expiresAt, + } + s.redirectTracks.Store(key, track) + return 1, false, track +} + +// ClearRedirectTrack removes a redirect tracking entry (called after successful completion) +func (s *Store) ClearRedirectTrack(r *http.Request, path string) { + if r == nil { + return + } + + ip := getClientIP(r) + userAgent := r.UserAgent() + key := redirectKey(ip, userAgent, path) + s.redirectTracks.Delete(key) +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..5620e58 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,80 @@ +package store + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "sync" + "time" +) + +// RedirectTrack represents a single redirect attempt tracking entry +type RedirectTrack struct { + IP string // Client IP (X-Forwarded-For aware) + UserAgent string // Full User-Agent string for debugging + Path string // Request path (without query params) + Attempts int // Number of redirect attempts + FirstSeen time.Time // When first redirect was tracked + ExpiresAt time.Time // When to clean up this entry +} + +type Store struct { + sessions sync.Map // key: string, value: *RegistrationSession + redirectTracks sync.Map // key: string, value: *RedirectTrack + cleanup *time.Ticker +} + +func NewStore() *Store { + s := &Store{ + cleanup: time.NewTicker(1 * time.Minute), + } + + // Background cleanup of expired sessions + go func() { + for range s.cleanup.C { + s.cleanupExpired() + } + }() + + return s +} + +func (s *Store) Delete(id string) { + s.sessions.Delete(id) +} +func (s *Store) cleanupExpired() { + now := time.Now() + + // Clean up expired registration sessions + s.sessions.Range(func(key, value any) bool { + session := value.(*RegistrationSession) + if now.After(session.ExpiresAt) { + s.sessions.Delete(key) + } + return true + }) + + // Clean up expired redirect tracks + s.redirectTracks.Range(func(key, value any) bool { + track := value.(*RedirectTrack) + if now.After(track.ExpiresAt) { + s.redirectTracks.Delete(key) + } + return true + }) +} +func generateID() string { + b := make([]byte, 32) + rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} + +// redirectKey generates a unique key for tracking redirects +// Uses IP + first 100 chars of UA + path as key (not hashed for debugging) +func redirectKey(ip, userAgent, path string) string { + ua := userAgent + if len(ua) > 100 { + ua = ua[:100] + } + return fmt.Sprintf("%s:%s:%s", ip, ua, path) +} diff --git a/internal/view/component/form/register.templ b/internal/view/component/form/register.templ new file mode 100644 index 0000000..00af173 --- /dev/null +++ b/internal/view/component/form/register.templ @@ -0,0 +1,89 @@ +package form + +templ RegisterForm(username string) { +
+} diff --git a/internal/view/component/nav/navbarright.templ b/internal/view/component/nav/navbarright.templ index 5f1fc22..88a5531 100644 --- a/internal/view/component/nav/navbarright.templ +++ b/internal/view/component/nav/navbarright.templ @@ -1,6 +1,6 @@ package nav -import "git.haelnorr.com/h/oslstats/pkg/contexts" +import "git.haelnorr.com/h/oslstats/internal/db" type ProfileItem struct { name string // Label to display @@ -23,7 +23,7 @@ func getProfileItems() []ProfileItem { // Returns the right portion of the navbar templ navRight() { - {{ user := contexts.CurrentUser(ctx) }} + {{ user := db.CurrentUser(ctx) }} {{ items := getProfileItems() }}{ err }
-{ message }
- Go to homepage +{ err }
+ // Always show the message from hws.HWSError.Message +{ message }
+ // Conditionally show technical details in dropdown + if details != "" { +{ details }
+ + Select your display name. This must be unique, and cannot be changed. +
+