diff --git a/.gitignore b/.gitignore index 33692cc..9d37f15 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.db* .logs/ server.log +keys/ bin/ tmp/ static/css/output.css diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..a888b37 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,327 @@ +# AGENTS.md - Developer Guide for oslstats + +This document provides guidelines for AI coding agents and developers working on the oslstats codebase. + +## Project Overview + +**Module**: `git.haelnorr.com/h/oslstats` +**Language**: Go 1.25.5 +**Architecture**: Web application with Discord OAuth, PostgreSQL database, templ templates +**Key Technologies**: Bun ORM, templ, TailwindCSS, custom golib libraries + +## Build, Test, and Development Commands + +### Building +```bash +# Full production build (tailwind → templ → go generate → go build) +make build + +# Build and run +make run + +# Clean build artifacts +make clean +``` + +### Development Mode +```bash +# Watch mode with hot reload (templ, air, tailwindcss in parallel) +make dev + +# Development server runs on: +# - Proxy: http://localhost:3000 (use this) +# - App: http://localhost:3333 (internal) +``` + +### Testing +```bash +# Run all tests +go test ./... + +# Run tests for a specific package +go test ./pkg/oauth + +# Run a single test function +go test ./pkg/oauth -run TestGenerateState_Success + +# Run tests with verbose output +go test -v ./pkg/oauth + +# Run tests with coverage +go test -cover ./... +go test -coverprofile=coverage.out ./... +go tool cover -html=coverage.out +``` + +### Database +```bash +# Run migrations +make migrate +# OR +./bin/oslstats --migrate +``` + +### Configuration Management +```bash +# Generate .env template file +make genenv +# OR with custom output: make genenv OUT=.env.example + +# Show environment variable documentation +make envdoc + +# Show current environment values +make showenv +``` + +## Code Style Guidelines + +### Import Organization +Organize imports in **3 groups** separated by blank lines: + +```go +import ( + // 1. Standard library + "context" + "net/http" + "fmt" + + // 2. External dependencies + "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" + "github.com/uptrace/bun" + + // 3. Internal packages + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/pkg/oauth" +) +``` + +### Naming Conventions + +**Variables**: +- Local: `camelCase` (userAgentKey, httpServer, dbConn) +- Exported: `PascalCase` (Config, User, Token) +- Common abbreviations: `cfg`, `ctx`, `tx`, `db`, `err`, `w`, `r` + +**Functions**: +- Exported: `PascalCase` (GetConfig, NewStore, GenerateState) +- Private: `camelCase` (throwError, shouldShowDetails, loadModels) +- HTTP handlers: Return `http.Handler`, use dependency injection pattern +- Database functions: Use `bun.Tx` as parameter for transactions + +**Types**: +- Structs/Interfaces: `PascalCase` (Config, User, OAuthSession) +- Use `-er` suffix for interfaces (implied from usage) + +**Files**: +- Prefer single word: `config.go`, `oauth.go`, `errors.go` +- Use snake_case if needed: `discord_tokens.go`, `state_test.go` +- Test files: `*_test.go` alongside source files + +### Error Handling + +**Always wrap errors** with context using `github.com/pkg/errors`: + +```go +if err != nil { + return errors.Wrap(err, "operation_name") +} +``` + +**Validate inputs at function start**: +```go +func DoSomething(cfg *Config, data string) error { + if cfg == nil { + return errors.New("cfg cannot be nil") + } + if data == "" { + return errors.New("data cannot be empty") + } + // ... rest of function +} +``` + +**HTTP error helpers** (in handlers package): +- `throwInternalServiceError(s, w, r, msg, err)` - 500 errors +- `throwBadRequest(s, w, r, msg, err)` - 400 errors +- `throwForbidden(s, w, r, msg, err)` - 403 errors (normal) +- `throwForbiddenSecurity(s, w, r, msg, err)` - 403 security violations (WARN level) +- `throwUnauthorized(s, w, r, msg, err)` - 401 errors (normal) +- `throwUnauthorizedSecurity(s, w, r, msg, err)` - 401 security violations (WARN level) +- `throwNotFound(s, w, r, path)` - 404 errors + +### Common Patterns + +**HTTP Handler Pattern**: +```go +func HandlerName(server *hws.Server, deps ...) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // Handler logic here + }, + ) +} +``` + +**Database Operation Pattern**: +```go +func GetSomething(ctx context.Context, tx bun.Tx, id int) (*Result, error) { + result := new(Result) + err := tx.NewSelect(). + Model(result). + Where("id = ?", id). + Scan(ctx) + if err != nil { + if err.Error() == "sql: no rows in result set" { + return nil, nil // Return nil, nil for not found + } + return nil, errors.Wrap(err, "tx.Select") + } + return result, nil +} +``` + +**Setup Function Pattern** (returns instance, cleanup func, error): +```go +func setupSomething(ctx context.Context, cfg *Config) (*Type, func() error, error) { + instance := newInstance() + + err := configure(instance) + if err != nil { + return nil, nil, errors.Wrap(err, "configure") + } + + return instance, instance.Close, nil +} +``` + +**Configuration Pattern** (using ezconf): +```go +type Config struct { + Field string // ENV FIELD_NAME: Description (required/default: value) +} + +func ConfigFromEnv() (any, error) { + cfg := &Config{ + Field: env.String("FIELD_NAME", "default"), + } + // Validation here + return cfg, nil +} +``` + +### Formatting & Types + +**Formatting**: +- Use `gofmt` (standard Go formatting) +- No tabs vs spaces debate - Go uses tabs + +**Types**: +- Prefer explicit types over inference when it improves clarity +- Use struct tags for ORM and JSON marshaling: + ```go + type User struct { + bun.BaseModel `bun:"table:users,alias:u"` + ID int `bun:"id,pk,autoincrement"` + Username string `bun:"username,unique"` + AccessToken string `json:"access_token"` + } + ``` + +**Comments**: +- Document exported functions and types +- Use inline comments for ENV var documentation in Config structs +- Explain security-critical code flows + +### Testing + +**Test File Location**: Place `*_test.go` files alongside source files + +**Test Naming**: +```go +func TestFunctionName_Scenario(t *testing.T) +func TestGenerateState_Success(t *testing.T) +func TestVerifyState_WrongUserAgentKey(t *testing.T) +``` + +**Test Structure**: +- Use subtests with `t.Run()` for related scenarios +- Use table-driven tests for multiple similar cases +- Create helper functions for common setup (e.g., `testConfig()`) +- Test happy paths, error cases, edge cases, and security properties + +**Test Categories** (from pkg/oauth/state_test.go example): +1. Happy path tests +2. Error handling (nil params, empty fields, malformed input) +3. Security tests (MITM, CSRF, replay attacks, tampering) +4. Edge cases (concurrency, constant-time comparison) +5. Integration tests (round-trip verification) + +### Security + +**Critical Practices**: +- Use `crypto/subtle.ConstantTimeCompare` for cryptographic comparisons +- Implement CSRF protection via state tokens +- Store sensitive cookies as HttpOnly +- Use separate logging levels for security violations (WARN) +- Validate all inputs at function boundaries +- Use parameterized queries (Bun ORM handles this) +- Never commit secrets (.env, keys/ are gitignored) + +## Project Structure + +``` +oslstats/ +├── cmd/oslstats/ # Application entry point +│ ├── main.go # Entry point with flag parsing +│ ├── run.go # Server initialization & graceful shutdown +│ ├── httpserver.go # HTTP server setup +│ ├── routes.go # Route registration +│ ├── middleware.go # Middleware registration +│ ├── auth.go # Authentication setup +│ └── db.go # Database connection & migrations +├── internal/ # Private application code +│ ├── config/ # Configuration aggregation +│ ├── db/ # Database models & queries (Bun ORM) +│ ├── discord/ # Discord OAuth integration +│ ├── handlers/ # HTTP request handlers +│ ├── session/ # Session store (in-memory) +│ └── view/ # Templ templates +│ ├── component/ # Reusable UI components +│ ├── layout/ # Page layouts +│ └── page/ # Full pages +├── pkg/ # Reusable packages +│ ├── contexts/ # Context key definitions +│ ├── embedfs/ # Embedded static files +│ └── oauth/ # OAuth state management +├── bin/ # Compiled binaries (gitignored) +├── keys/ # Private keys (gitignored) +├── tmp/ # Air hot reload temp files (gitignored) +├── Makefile # Build automation +├── .air.toml # Hot reload configuration +└── go.mod # Go module definition +``` + +## Key Dependencies + +- **git.haelnorr.com/h/golib/*** - Custom libraries (env, ezconf, hlog, hws, hwsauth, cookies, jwt) +- **github.com/a-h/templ** - Type-safe HTML templating +- **github.com/uptrace/bun** - PostgreSQL ORM +- **github.com/bwmarrin/discordgo** - Discord API client +- **github.com/pkg/errors** - Error wrapping (use this, not fmt.Errorf) +- **github.com/joho/godotenv** - .env file loading + +## Notes for AI Agents + +1. **Never commit** .env files, keys/, or generated files (*_templ.go, output.css) +2. **Database operations** should use `bun.Tx` for transaction safety +3. **Templates** are written in templ, not Go html/template - run `templ generate` after changes +4. **Static files** are embedded via `//go:embed` - check pkg/embedfs/ +5. **Error messages** should be descriptive and use errors.Wrap for context +6. **Security is critical** - especially in OAuth flows (see pkg/oauth/state_test.go for examples) +7. **Air proxy** runs on port 3000 during development; app runs on 3333 +8. **Test coverage** is currently limited - prioritize testing security-critical code +9. **Configuration** uses ezconf pattern - see internal/*/ezconf.go files for examples +10. **Graceful shutdown** is implemented in cmd/oslstats/run.go - follow this pattern +11. When in plan mode, always use the interactive question tool if available diff --git a/Makefile b/Makefile index f3f2a8b..8f59266 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,6 @@ BINARY_NAME=oslstats build: - ./scripts/generate-css-sources.sh && \ tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css && \ go mod tidy && \ templ generate && \ @@ -16,7 +15,6 @@ run: ./bin/${BINARY_NAME}${SUFFIX} dev: - ./scripts/generate-css-sources.sh && \ templ generate --watch &\ air &\ tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css --watch @@ -36,3 +34,6 @@ showenv: make build ./bin/${BINARY_NAME} --showenv +migrate: + make build + ./bin/${BINARY_NAME}${SUFFIX} --migrate diff --git a/cmd/oslstats/auth.go b/cmd/oslstats/auth.go index 4f57359..d2b6ff5 100644 --- a/cmd/oslstats/auth.go +++ b/cmd/oslstats/auth.go @@ -8,7 +8,6 @@ import ( "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/handlers" - "git.haelnorr.com/h/oslstats/pkg/contexts" "github.com/pkg/errors" "github.com/uptrace/bun" ) @@ -31,6 +30,7 @@ func setupAuth( beginTx, logger, handlers.ErrorPage, + conn.DB, ) if err != nil { return nil, errors.Wrap(err, "hwsauth.NewAuthenticator") @@ -38,7 +38,9 @@ func setupAuth( auth.IgnorePaths(ignoredPaths...) - contexts.CurrentUser = auth.CurrentModel + db.CurrentUser = auth.CurrentModel return auth, nil } + +// TODO: make a new getuser function that wraps db.GetUserByID and does OAuth refresh diff --git a/cmd/oslstats/db.go b/cmd/oslstats/db.go index 1df8f57..922b928 100644 --- a/cmd/oslstats/db.go +++ b/cmd/oslstats/db.go @@ -20,7 +20,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func conn = bun.NewDB(sqldb, pgdialect.New()) close = sqldb.Close - err = loadModels(ctx, conn, cfg.Flags.ResetDB) + err = loadModels(ctx, conn, cfg.Flags.MigrateDB) if err != nil { return nil, nil, errors.Wrap(err, "loadModels") } @@ -31,6 +31,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func func loadModels(ctx context.Context, conn *bun.DB, resetDB bool) error { models := []any{ (*db.User)(nil), + (*db.DiscordToken)(nil), } for _, model := range models { diff --git a/cmd/oslstats/httpserver.go b/cmd/oslstats/httpserver.go index fa0461d..bd2763f 100644 --- a/cmd/oslstats/httpserver.go +++ b/cmd/oslstats/httpserver.go @@ -4,13 +4,15 @@ import ( "io/fs" "net/http" - "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/internal/handlers" - "git.haelnorr.com/h/golib/hlog" + "git.haelnorr.com/h/golib/hws" "github.com/pkg/errors" "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/store" ) func setupHttpServer( @@ -18,6 +20,8 @@ func setupHttpServer( config *config.Config, logger *hlog.Logger, bun *bun.DB, + store *store.Store, + discordAPI *discord.APIClient, ) (server *hws.Server, err error) { if staticFS == nil { return nil, errors.New("No filesystem provided") @@ -53,7 +57,7 @@ func setupHttpServer( return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths") } - err = addRoutes(httpServer, &fs, config, logger, bun, auth) + err = addRoutes(httpServer, &fs, config, bun, auth, store, discordAPI) if err != nil { return nil, errors.Wrap(err, "addRoutes") } diff --git a/cmd/oslstats/main.go b/cmd/oslstats/main.go index e92d126..38e0eea 100644 --- a/cmd/oslstats/main.go +++ b/cmd/oslstats/main.go @@ -29,6 +29,16 @@ func main() { return } + if flags.MigrateDB { + _, closedb, err := setupBun(ctx, cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err) + os.Exit(1) + } + closedb() + return + } + if err := run(ctx, os.Stdout, cfg); err != nil { fmt.Fprintf(os.Stderr, "%s\n", err) os.Exit(1) diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index efea343..c19f6d0 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -5,22 +5,24 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hwsauth" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/internal/db" - "git.haelnorr.com/h/oslstats/internal/handlers" - - "git.haelnorr.com/h/golib/hlog" "github.com/pkg/errors" "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/store" ) func addRoutes( server *hws.Server, staticFS *http.FileSystem, - config *config.Config, - logger *hlog.Logger, + cfg *config.Config, conn *bun.DB, auth *hwsauth.Authenticator[*db.User, bun.Tx], + store *store.Store, + discordAPI *discord.APIClient, ) error { // Create the routes routes := []hws.Route{ @@ -34,10 +36,38 @@ func addRoutes( Method: hws.MethodGET, Handler: handlers.Index(server), }, + { + Path: "/login", + Method: hws.MethodGET, + Handler: auth.LogoutReq(handlers.Login(server, cfg, store, discordAPI)), + }, + { + Path: "/auth/callback", + Method: hws.MethodGET, + Handler: auth.LogoutReq(handlers.Callback(server, auth, conn, cfg, store, discordAPI)), + }, + { + Path: "/register", + Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, + Handler: auth.LogoutReq(handlers.Register(server, auth, conn, cfg, store)), + }, + { + Path: "/logout", + Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, + Handler: auth.LoginReq(handlers.Logout(server, auth, conn, discordAPI)), + }, + } + + htmxRoutes := []hws.Route{ + { + Path: "/htmx/isusernameunique", + Method: hws.MethodPOST, + Handler: handlers.IsUsernameUnique(server, conn, cfg, store), + }, } // Register the routes with the server - err := server.AddRoutes(routes...) + err := server.AddRoutes(append(routes, htmxRoutes...)...) if err != nil { return errors.Wrap(err, "server.AddRoutes") } diff --git a/cmd/oslstats/run.go b/cmd/oslstats/run.go index 7906c26..2d6f34f 100644 --- a/cmd/oslstats/run.go +++ b/cmd/oslstats/run.go @@ -9,18 +9,21 @@ import ( "time" "git.haelnorr.com/h/golib/hlog" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/pkg/embedfs" "github.com/pkg/errors" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/pkg/embedfs" ) // Initializes and runs the server -func run(ctx context.Context, w io.Writer, config *config.Config) error { +func run(ctx context.Context, w io.Writer, cfg *config.Config) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() // Setup the logger - logger, err := hlog.NewLogger(config.HLOG, w) + logger, err := hlog.NewLogger(cfg.HLOG, w) if err != nil { return errors.Wrap(err, "hlog.NewLogger") } @@ -28,7 +31,7 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error { // Setup the database connection logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Connecting to database") - bun, closedb, err := setupBun(ctx, config) + bun, closedb, err := setupBun(ctx, cfg) if err != nil { return errors.Wrap(err, "setupDBConn") } @@ -41,8 +44,19 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error { return errors.Wrap(err, "getStaticFiles") } + // Setup session store + logger.Debug().Msg("Setting up session store") + store := store.NewStore() + + // Setup Discord API client + logger.Debug().Msg("Setting up Discord API client") + discordAPI, err := discord.NewAPIClient(cfg.Discord, logger, cfg.HWSAuth.TrustedHost) + if err != nil { + return errors.Wrap(err, "discord.NewAPIClient") + } + logger.Debug().Msg("Setting up HTTP server") - httpServer, err := setupHttpServer(&staticFS, config, logger, bun) + httpServer, err := setupHttpServer(&staticFS, cfg, logger, bun, store, discordAPI) if err != nil { return errors.Wrap(err, "setupHttpServer") } diff --git a/go.mod b/go.mod index 26e14e9..e909fb8 100644 --- a/go.mod +++ b/go.mod @@ -6,20 +6,25 @@ require ( git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/hlog v0.10.4 - git.haelnorr.com/h/golib/hws v0.2.3 - git.haelnorr.com/h/golib/hwsauth v0.3.4 + git.haelnorr.com/h/golib/hws v0.3.1 + git.haelnorr.com/h/golib/hwsauth v0.5.2 github.com/a-h/templ v0.3.977 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 github.com/uptrace/bun v1.2.16 github.com/uptrace/bun/dialect/pgdialect v1.2.16 github.com/uptrace/bun/driver/pgdriver v1.2.16 - golang.org/x/crypto v0.45.0 ) require ( - git.haelnorr.com/h/golib/cookies v0.9.0 // indirect - git.haelnorr.com/h/golib/jwt v0.10.0 // indirect + github.com/gorilla/websocket v1.4.2 // indirect + golang.org/x/crypto v0.45.0 // indirect +) + +require ( + git.haelnorr.com/h/golib/cookies v0.9.0 + git.haelnorr.com/h/golib/jwt v0.10.1 // indirect + github.com/bwmarrin/discordgo v0.29.0 github.com/go-logr/logr v1.4.3 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/google/uuid v1.6.0 // indirect diff --git a/go.sum b/go.sum index 78c8ff7..afdf0fa 100644 --- a/go.sum +++ b/go.sum @@ -6,16 +6,18 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= -git.haelnorr.com/h/golib/hws v0.2.3 h1:gZQkBciXKh3jYw05vZncSR2lvIqi0H2MVfIWySySsmw= -git.haelnorr.com/h/golib/hws v0.2.3/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= -git.haelnorr.com/h/golib/hwsauth v0.3.4 h1:wwYBb6cQQ+x9hxmYuZBF4mVmCv/n4PjJV//e1+SgPOo= -git.haelnorr.com/h/golib/hwsauth v0.3.4/go.mod h1:LI7Qz68GPNIW8732Zwptb//ybjiFJOoXf4tgUuUEqHI= -git.haelnorr.com/h/golib/jwt v0.10.0 h1:8cI8mSnb8X+EmJtrBO/5UZwuBMtib0IE9dv85gkm94E= -git.haelnorr.com/h/golib/jwt v0.10.0/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= +git.haelnorr.com/h/golib/hws v0.3.1 h1:uFXAT8SuKs4VACBdrkmZ+dJjeBlSPgCKUPt8zGCcwrI= +git.haelnorr.com/h/golib/hws v0.3.1/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= +git.haelnorr.com/h/golib/hwsauth v0.5.2 h1:K4McXMEHtI5o4fAL3AZrmaMkwORNqSTV3MM6BExNKag= +git.haelnorr.com/h/golib/hwsauth v0.5.2/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ= +git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= +git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg= github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo= +github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= +github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -28,6 +30,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -66,13 +70,19 @@ go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8= diff --git a/internal/config/config.go b/internal/config/config.go index 87b9f15..112752a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,8 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/pkg/oauth" "github.com/joho/godotenv" "github.com/pkg/errors" ) @@ -15,6 +17,8 @@ type Config struct { HWS *hws.Config HWSAuth *hwsauth.Config HLOG *hlog.Config + Discord *discord.Config + OAuth *oauth.Config Flags *Flags } @@ -32,6 +36,8 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { hws.NewEZConfIntegration(), hwsauth.NewEZConfIntegration(), db.NewEZConfIntegration(), + discord.NewEZConfIntegration(), + oauth.NewEZConfIntegration(), ) if err := loader.ParseEnvVars(); err != nil { return nil, nil, errors.Wrap(err, "loader.ParseEnvVars") @@ -65,11 +71,23 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { return nil, nil, errors.New("DB Config not loaded") } + discordcfg, ok := loader.GetConfig("discord") + if !ok { + return nil, nil, errors.New("Dicord Config not loaded") + } + + oauthcfg, ok := loader.GetConfig("oauth") + if !ok { + return nil, nil, errors.New("OAuth Config not loaded") + } + config := &Config{ DB: dbcfg.(*db.Config), HWS: hwscfg.(*hws.Config), HWSAuth: hwsauthcfg.(*hwsauth.Config), HLOG: hlogcfg.(*hlog.Config), + Discord: discordcfg.(*discord.Config), + OAuth: oauthcfg.(*oauth.Config), Flags: flags, } diff --git a/internal/config/flags.go b/internal/config/flags.go index 7d89108..ba87131 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -5,16 +5,16 @@ import ( ) type Flags struct { - ResetDB bool - EnvDoc bool - ShowEnv bool - GenEnv string - EnvFile string + MigrateDB bool + EnvDoc bool + ShowEnv bool + GenEnv string + EnvFile string } func SetupFlags() *Flags { // Parse commandline args - resetDB := flag.Bool("resetdb", false, "Reset all the database tables with the updated models") + migrateDB := flag.Bool("migrate", false, "Reset all the database tables with the updated models") envDoc := flag.Bool("envdoc", false, "Print all environment variables and their documentation") showEnv := flag.Bool("showenv", false, "Print all environment variable values and their documentation") genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)") @@ -22,11 +22,11 @@ func SetupFlags() *Flags { flag.Parse() flags := &Flags{ - ResetDB: *resetDB, - EnvDoc: *envDoc, - ShowEnv: *showEnv, - GenEnv: *genEnv, - EnvFile: *envfile, + MigrateDB: *migrateDB, + EnvDoc: *envDoc, + ShowEnv: *showEnv, + GenEnv: *genEnv, + EnvFile: *envfile, } return flags } diff --git a/internal/db/discord_tokens.go b/internal/db/discord_tokens.go new file mode 100644 index 0000000..7cc09d9 --- /dev/null +++ b/internal/db/discord_tokens.go @@ -0,0 +1,95 @@ +package db + +import ( + "context" + "time" + + "git.haelnorr.com/h/oslstats/internal/discord" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type DiscordToken struct { + bun.BaseModel `bun:"table:discord_tokens,alias:dt"` + + DiscordID string `bun:"discord_id,pk,notnull"` + AccessToken string `bun:"access_token,notnull"` + RefreshToken string `bun:"refresh_token,notnull"` + ExpiresAt int64 `bun:"expires_at,notnull"` + Scope string `bun:"scope,notnull"` + TokenType string `bun:"token_type,notnull"` +} + +// UpdateDiscordToken adds the provided discord token to the database. +// If the user already has a token stored, it will replace that token instead. +func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error { + if token == nil { + return errors.New("token cannot be nil") + } + expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix() + + discordToken := &DiscordToken{ + DiscordID: user.DiscordID, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + ExpiresAt: expiresAt, + Scope: token.Scope, + TokenType: token.TokenType, + } + + _, err := tx.NewInsert(). + Model(discordToken). + On("CONFLICT (discord_id) DO UPDATE"). + Set("access_token = EXCLUDED.access_token"). + Set("refresh_token = EXCLUDED.refresh_token"). + Set("expires_at = EXCLUDED.expires_at"). + Exec(ctx) + + if err != nil { + return errors.Wrap(err, "tx.NewInsert") + } + return nil +} + +// DeleteDiscordTokens deletes a users discord OAuth tokens from the database. +// It returns the DiscordToken so that it can be revoked via the discord API +func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { + token, err := user.GetDiscordToken(ctx, tx) + if err != nil { + return nil, errors.Wrap(err, "user.GetDiscordToken") + } + _, err = tx.NewDelete(). + Model((*DiscordToken)(nil)). + Where("discord_id = ?", user.DiscordID). + Exec(ctx) + if err != nil { + return nil, errors.Wrap(err, "tx.NewDelete") + } + return token, nil +} + +// GetDiscordToken retrieves the users discord token from the database +func (user *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { + token := new(DiscordToken) + err := tx.NewSelect(). + Model(token). + Where("discord_id = ?", user.DiscordID). + Limit(1). + Scan(ctx) + if err != nil { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return token, nil +} + +// Convert reverts the token back into a *discord.Token +func (t *DiscordToken) Convert() *discord.Token { + token := &discord.Token{ + AccessToken: t.AccessToken, + RefreshToken: t.RefreshToken, + ExpiresIn: int(t.ExpiresAt - time.Now().Unix()), + Scope: t.Scope, + TokenType: t.TokenType, + } + return token +} diff --git a/internal/db/ezconf.go b/internal/db/ezconf.go index db4eac9..1cb08c5 100644 --- a/internal/db/ezconf.go +++ b/internal/db/ezconf.go @@ -37,5 +37,5 @@ func (e EZConfIntegration) GroupName() string { // NewEZConfIntegration creates a new EZConf integration helper func NewEZConfIntegration() EZConfIntegration { - return EZConfIntegration{name: "db", configFunc: ConfigFromEnv} + return EZConfIntegration{name: "DB", configFunc: ConfigFromEnv} } diff --git a/internal/db/user.go b/internal/db/user.go index 01764eb..40c5ca8 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -2,65 +2,30 @@ package db import ( "context" + "fmt" + "time" + "git.haelnorr.com/h/golib/hwsauth" + "github.com/bwmarrin/discordgo" "github.com/pkg/errors" "github.com/uptrace/bun" - "golang.org/x/crypto/bcrypt" ) +var CurrentUser hwsauth.ContextLoader[*User] + type User struct { bun.BaseModel `bun:"table:users,alias:u"` - ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key) - Username string `bun:"username,unique"` // Username (unique) - PasswordHash string `bun:"password_hash,nullzero"` // Bcrypt hashed password (not exported in JSON) - CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database - Bio string `bun:"bio"` // Short byline set by the user + ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key) + Username string `bun:"username,unique"` // Username (unique) + CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database + DiscordID string `bun:"discord_id,unique"` } func (user *User) GetID() int { return user.ID } -// Uses bcrypt to set the users password_hash from the given password -func (user *User) SetPassword(ctx context.Context, tx bun.Tx, password string) error { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "bcrypt.GenerateFromPassword") - } - newPassword := string(hashedPassword) - - _, err = tx.NewUpdate(). - Model(user). - Set("password_hash = ?", newPassword). - Where("id = ?", user.ID). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "tx.Update") - } - return nil -} - -// Uses bcrypt to check if the given password matches the users password_hash -func (user *User) CheckPassword(ctx context.Context, tx bun.Tx, password string) error { - var hashedPassword string - err := tx.NewSelect(). - Table("users"). - Column("password_hash"). - Where("id = ?", user.ID). - Limit(1). - Scan(ctx, &hashedPassword) - if err != nil { - return errors.Wrap(err, "tx.Select") - } - - err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) - if err != nil { - return errors.Wrap(err, "Username or password incorrect") - } - return nil -} - // Change the user's username func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername string) error { _, err := tx.NewUpdate(). @@ -75,35 +40,18 @@ func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername str return nil } -// Change the user's bio -func (user *User) ChangeBio(ctx context.Context, tx bun.Tx, newBio string) error { - _, err := tx.NewUpdate(). - Model(user). - Set("bio = ?", newBio). - Where("id = ?", user.ID). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "tx.Update") - } - user.Bio = newBio - return nil -} - // CreateUser creates a new user with the given username and password -func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*User, error) { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return nil, errors.Wrap(err, "bcrypt.GenerateFromPassword") +func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) { + if discorduser == nil { + return nil, errors.New("user cannot be nil") } - user := &User{ - Username: username, - PasswordHash: string(hashedPassword), - CreatedAt: 0, // You may want to set this to time.Now().Unix() - Bio: "", + Username: username, + CreatedAt: time.Now().Unix(), + DiscordID: discorduser.ID, } - _, err = tx.NewInsert(). + _, err := tx.NewInsert(). Model(user). Exec(ctx) if err != nil { @@ -116,6 +64,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*Use // GetUserByID queries the database for a user matching the given ID // Returns nil, nil if no user is found func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { + fmt.Printf("user id requested: %v", id) user := new(User) err := tx.NewSelect(). Model(user). @@ -149,6 +98,24 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, return user, nil } +// GetUserByDiscordID queries the database for a user matching the given discord id +// Returns nil, nil if no user is found +func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) { + user := new(User) + err := tx.NewSelect(). + Model(user). + Where("discord_id = ?", discordID). + Limit(1). + Scan(ctx) + if err != nil { + if err.Error() == "sql: no rows in result set" { + return nil, nil + } + return nil, errors.Wrap(err, "tx.Select") + } + return user, nil +} + // IsUsernameUnique checks if the given username is unique (not already taken) // Returns true if the username is available, false if it's taken func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) { diff --git a/internal/discord/api.go b/internal/discord/api.go new file mode 100644 index 0000000..aa490b5 --- /dev/null +++ b/internal/discord/api.go @@ -0,0 +1,61 @@ +package discord + +import ( + "net/http" + "sync" + "time" + + "git.haelnorr.com/h/golib/hlog" + "github.com/bwmarrin/discordgo" + "github.com/pkg/errors" +) + +type OAuthSession struct { + *discordgo.Session +} + +func NewOAuthSession(token *Token) (*OAuthSession, error) { + session, err := discordgo.New("Bearer " + token.AccessToken) + if err != nil { + return nil, errors.Wrap(err, "discordgo.New") + } + return &OAuthSession{Session: session}, nil +} + +func (s *OAuthSession) GetUser() (*discordgo.User, error) { + user, err := s.User("@me") + if err != nil { + return nil, errors.Wrap(err, "s.User") + } + return user, nil +} + +// APIClient is an HTTP client wrapper that handles Discord API rate limits +type APIClient struct { + cfg *Config + client *http.Client + logger *hlog.Logger + mu sync.RWMutex + buckets map[string]*RateLimitState + trustedHost string +} + +// NewAPIClient creates a new Discord API client with rate limit handling +func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APIClient, error) { + if cfg == nil { + return nil, errors.New("config cannot be nil") + } + if logger == nil { + return nil, errors.New("logger cannot be nil") + } + if trustedhost == "" { + return nil, errors.New("trustedhost cannot be empty") + } + return &APIClient{ + client: &http.Client{Timeout: 30 * time.Second}, + logger: logger, + buckets: make(map[string]*RateLimitState), + cfg: cfg, + trustedHost: trustedhost, + }, nil +} diff --git a/internal/discord/config.go b/internal/discord/config.go new file mode 100644 index 0000000..cc086ff --- /dev/null +++ b/internal/discord/config.go @@ -0,0 +1,50 @@ +package discord + +import ( + "strings" + + "git.haelnorr.com/h/golib/env" + "github.com/pkg/errors" +) + +type Config struct { + ClientID string // ENV DISCORD_CLIENT_ID: Discord application client ID (required) + ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required) + OAuthScopes string // Authorisation scopes for OAuth + RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required) +} + +func ConfigFromEnv() (any, error) { + cfg := &Config{ + ClientID: env.String("DISCORD_CLIENT_ID", ""), + ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""), + OAuthScopes: getOAuthScopes(), + RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""), + } + + // Check required fields + if cfg.ClientID == "" { + return nil, errors.New("Envar not set: DISCORD_CLIENT_ID") + } + if cfg.ClientSecret == "" { + return nil, errors.New("Envar not set: DISCORD_CLIENT_SECRET") + } + if cfg.RedirectPath == "" { + return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH") + } + + return cfg, nil +} + +func getOAuthScopes() string { + list := []string{ + "connections", + "email", + "guilds", + "gdm.join", + "guilds.members.read", + "identify", + } + scopes := strings.Join(list, "+") + return scopes +} diff --git a/internal/discord/ezconf.go b/internal/discord/ezconf.go new file mode 100644 index 0000000..8442714 --- /dev/null +++ b/internal/discord/ezconf.go @@ -0,0 +1,41 @@ +package discord + +import ( + "runtime" + "strings" +) + +// EZConfIntegration provides integration with ezconf for automatic configuration +type EZConfIntegration struct { + configFunc func() (any, error) + name string +} + +// PackagePath returns the path to the config package for source parsing +func (e EZConfIntegration) PackagePath() string { + _, filename, _, _ := runtime.Caller(0) + // Return directory of this file + return filename[:len(filename)-len("/ezconf.go")] +} + +// ConfigFunc returns the ConfigFromEnv function for ezconf +func (e EZConfIntegration) ConfigFunc() func() (any, error) { + return func() (any, error) { + return e.configFunc() + } +} + +// Name returns the name to use when registering with ezconf +func (e EZConfIntegration) Name() string { + return strings.ToLower(e.name) +} + +// GroupName returns the display name for grouping environment variables +func (e EZConfIntegration) GroupName() string { + return e.name +} + +// NewEZConfIntegration creates a new EZConf integration helper +func NewEZConfIntegration() EZConfIntegration { + return EZConfIntegration{name: "Discord", configFunc: ConfigFromEnv} +} diff --git a/internal/discord/oauth.go b/internal/discord/oauth.go new file mode 100644 index 0000000..faff919 --- /dev/null +++ b/internal/discord/oauth.go @@ -0,0 +1,148 @@ +package discord + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/pkg/errors" +) + +// Token represents a response from the Discord OAuth API after a successful authorization request +type Token struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` +} + +const oauthurl string = "https://discord.com/oauth2/authorize" +const apiurl string = "https://discord.com/api/v10" + +// GetOAuthLink generates a new Discord OAuth2 link for user authentication +func (api *APIClient) GetOAuthLink(state string) (string, error) { + if state == "" { + return "", errors.New("state cannot be empty") + } + values := url.Values{} + values.Add("response_type", "code") + values.Add("client_id", api.cfg.ClientID) + values.Add("scope", api.cfg.OAuthScopes) + values.Add("state", state) + values.Add("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath)) + values.Add("prompt", "none") + + return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil +} + +// AuthorizeWithCode uses a users authorization token generated by OAuth2 to get a token for +// making requests to the API on behalf of the user +func (api *APIClient) AuthorizeWithCode(code string) (*Token, error) { + if code == "" { + return nil, errors.New("code cannot be empty") + } + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath)) + req, err := http.NewRequest( + "POST", + apiurl+"/oauth2/token", + strings.NewReader(data.Encode()), + ) + if err != nil { + return nil, errors.Wrap(err, "failed to create request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret) + resp, err := api.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body)) + } + var tokenResp Token + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "failed to parse token response") + } + return &tokenResp, nil +} + +// RefreshToken uses the refresh token to generate a new token pair +func (api *APIClient) RefreshToken(token *Token) (*Token, error) { + if token == nil { + return nil, errors.New("token cannot be nil") + } + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", token.RefreshToken) + req, err := http.NewRequest( + "POST", + apiurl+"/oauth2/token", + strings.NewReader(data.Encode()), + ) + if err != nil { + return nil, errors.Wrap(err, "failed to create request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret) + resp, err := api.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body)) + } + var tokenResp Token + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "failed to parse token response") + } + return &tokenResp, nil +} + +// RevokeToken sends a request to the Discord API to revoke the token pair +func (api *APIClient) RevokeToken(token *Token) error { + if token == nil { + return errors.New("token cannot be nil") + } + data := url.Values{} + data.Set("token", token.AccessToken) + data.Set("token_type_hint", "access_token") + req, err := http.NewRequest( + "POST", + apiurl+"/oauth2/token/revoke", + strings.NewReader(data.Encode()), + ) + if err != nil { + return errors.Wrap(err, "failed to create request") + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret) + resp, err := api.Do(req) + if err != nil { + return errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.Errorf("discord API returned status %d", resp.StatusCode) + } + return nil +} diff --git a/internal/discord/ratelimit.go b/internal/discord/ratelimit.go new file mode 100644 index 0000000..21f3888 --- /dev/null +++ b/internal/discord/ratelimit.go @@ -0,0 +1,216 @@ +package discord + +import ( + "net" + "net/http" + "strconv" + "time" + + "github.com/pkg/errors" +) + +// RateLimitState tracks rate limit information for a specific bucket +type RateLimitState struct { + Remaining int // Requests remaining in current window + Limit int // Total requests allowed in window + Reset time.Time // When the rate limit resets + Bucket string // Discord's bucket identifier +} + +// Do executes an HTTP request with automatic rate limit handling +// It will wait if rate limits are about to be exceeded and retry once if a 429 is received +func (c *APIClient) Do(req *http.Request) (*http.Response, error) { + if req == nil { + return nil, errors.New("request cannot be nil") + } + + // Step 1: Check if we need to wait before making request + bucket := c.getBucketFromRequest(req) + if err := c.waitIfNeeded(bucket); err != nil { + return nil, err + } + + // Step 2: Execute request + resp, err := c.client.Do(req) + if err != nil { + // Check if it's a network timeout + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, errors.Wrap(err, "request timed out") + } + return nil, errors.Wrap(err, "http request failed") + } + + // Step 3: Update rate limit state from response headers + c.updateRateLimit(resp.Header) + + // Step 4: Handle 429 (rate limited) + if resp.StatusCode == http.StatusTooManyRequests { + resp.Body.Close() // Close original response + + retryAfter := c.parseRetryAfter(resp.Header) + + // No Retry-After header, can't retry safely + if retryAfter == 0 { + c.logger.Warn(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Msg("Rate limited but no Retry-After header provided") + return nil, errors.New("discord API rate limited but no Retry-After header provided") + } + + // Retry-After exceeds 30 second cap + if retryAfter > 30*time.Second { + c.logger.Warn(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Dur("retry_after", retryAfter). + Msg("Rate limited with Retry-After exceeding 30s cap, not retrying") + return nil, errors.Errorf( + "discord API rate limited (retry after %s exceeds 30s cap)", + retryAfter, + ) + } + + // Wait and retry + c.logger.Warn(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Dur("retry_after", retryAfter). + Msg("Rate limited, waiting before retry") + + time.Sleep(retryAfter) + + // Retry the request + resp, err = c.client.Do(req) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, errors.Wrap(err, "retry request timed out") + } + return nil, errors.Wrap(err, "retry request failed") + } + + // Update rate limit again after retry + c.updateRateLimit(resp.Header) + + // If STILL rate limited after retry, return error + if resp.StatusCode == http.StatusTooManyRequests { + resp.Body.Close() + c.logger.Error(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Msg("Still rate limited after retry, Discord may be experiencing issues") + return nil, errors.Errorf( + "discord API still rate limited after retry (waited %s), Discord may be experiencing issues", + retryAfter, + ) + } + } + + return resp, nil +} + +// getBucketFromRequest extracts or generates bucket ID from request +// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers +func (c *APIClient) getBucketFromRequest(req *http.Request) string { + return req.Method + ":" + req.URL.Path +} + +// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits +func (c *APIClient) waitIfNeeded(bucket string) error { + c.mu.RLock() + state, exists := c.buckets[bucket] + c.mu.RUnlock() + + if !exists { + return nil // No state yet, proceed + } + + now := time.Now() + + // If we have no remaining requests and reset hasn't occurred, wait + if state.Remaining == 0 && now.Before(state.Reset) { + waitDuration := time.Until(state.Reset) + // Add small buffer (100ms) to ensure reset has occurred + waitDuration += 100 * time.Millisecond + + if waitDuration > 0 { + c.logger.Debug(). + Str("bucket", bucket). + Dur("wait_duration", waitDuration). + Msg("Proactively waiting for rate limit reset") + time.Sleep(waitDuration) + } + } + + return nil +} + +// updateRateLimit parses response headers and updates bucket state +func (c *APIClient) updateRateLimit(headers http.Header) { + bucket := headers.Get("X-RateLimit-Bucket") + if bucket == "" { + return // No bucket info, can't track + } + + // Parse headers + limit := c.parseInt(headers.Get("X-RateLimit-Limit")) + remaining := c.parseInt(headers.Get("X-RateLimit-Remaining")) + resetAfter := c.parseFloat(headers.Get("X-RateLimit-Reset-After")) + + state := &RateLimitState{ + Bucket: bucket, + Limit: limit, + Remaining: remaining, + Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))), + } + + c.mu.Lock() + c.buckets[bucket] = state + c.mu.Unlock() + + // Log rate limit state for debugging + c.logger.Debug(). + Str("bucket", bucket). + Int("remaining", remaining). + Int("limit", limit). + Dur("reset_in", time.Until(state.Reset)). + Msg("Rate limit state updated") +} + +// parseRetryAfter extracts retry delay from Retry-After header +func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration { + retryAfter := headers.Get("Retry-After") + if retryAfter == "" { + return 0 + } + + // Discord returns seconds as float + seconds := c.parseFloat(retryAfter) + if seconds <= 0 { + return 0 + } + + return time.Duration(seconds * float64(time.Second)) +} + +// parseInt parses an integer from a header value, returns 0 on error +func (c *APIClient) parseInt(s string) int { + if s == "" { + return 0 + } + i, _ := strconv.Atoi(s) + return i +} + +// parseFloat parses a float from a header value, returns 0 on error +func (c *APIClient) parseFloat(s string) float64 { + if s == "" { + return 0 + } + f, _ := strconv.ParseFloat(s, 64) + return f +} diff --git a/internal/discord/ratelimit_test.go b/internal/discord/ratelimit_test.go new file mode 100644 index 0000000..38dfb04 --- /dev/null +++ b/internal/discord/ratelimit_test.go @@ -0,0 +1,517 @@ +package discord + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "git.haelnorr.com/h/golib/hlog" +) + +// testLogger creates a test logger for testing +func testLogger(t *testing.T) *hlog.Logger { + level, _ := hlog.LogLevel("debug") + cfg := &hlog.Config{ + LogLevel: level, + LogOutput: "console", + } + logger, err := hlog.NewLogger(cfg, io.Discard) + if err != nil { + t.Fatalf("failed to create test logger: %v", err) + } + return logger +} + +// testConfig creates a test config for testing +func testConfig() *Config { + return &Config{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + OAuthScopes: "identify+email", + RedirectPath: "/oauth/callback", + } +} + +func TestNewRateLimitedClient(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + if client == nil { + t.Fatal("NewAPIClient returned nil") + } + if client.client == nil { + t.Error("client.client is nil") + } + if client.logger == nil { + t.Error("client.logger is nil") + } + if client.buckets == nil { + t.Error("client.buckets map is nil") + } + if client.cfg == nil { + t.Error("client.cfg is nil") + } + if client.trustedHost != "trusted-host.example.com" { + t.Errorf("expected trustedHost='trusted-host.example.com', got '%s'", client.trustedHost) + } +} + +func TestAPIClient_Do_Success(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + // Mock server that returns success with rate limit headers + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("X-RateLimit-Limit", "5") + w.Header().Set("X-RateLimit-Remaining", "3") + w.Header().Set("X-RateLimit-Reset-After", "2.5") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + req, err := http.NewRequest("GET", server.URL+"/test", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Do() returned error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } + + // Check that rate limit state was updated + client.mu.RLock() + state, exists := client.buckets["test-bucket"] + client.mu.RUnlock() + + if !exists { + t.Fatal("rate limit state not stored") + } + if state.Remaining != 3 { + t.Errorf("expected remaining=3, got %d", state.Remaining) + } + if state.Limit != 5 { + t.Errorf("expected limit=5, got %d", state.Limit) + } +} + +func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + attemptCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + if attemptCount == 1 { + // First request: return 429 + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("Retry-After", "0.1") // 100ms + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + "error_description": "You are being rate limited", + }) + return + } + // Second request: success + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("X-RateLimit-Limit", "5") + w.Header().Set("X-RateLimit-Remaining", "4") + w.Header().Set("X-RateLimit-Reset-After", "2.5") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + start := time.Now() + resp, err := client.Do(req) + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("Do() returned error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200 after retry, got %d", resp.StatusCode) + } + + if attemptCount != 2 { + t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount) + } + + // Should have waited approximately 100ms + if elapsed < 100*time.Millisecond { + t.Errorf("expected delay of ~100ms, but took %v", elapsed) + } +} + +func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + attemptCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // Always return 429 + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("Retry-After", "0.05") // 50ms + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + }) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + + if err == nil { + resp.Body.Close() + t.Fatal("Do() should have returned error after failed retry") + } + + if !strings.Contains(err.Error(), "still rate limited after retry") { + t.Errorf("expected 'still rate limited after retry' error, got: %v", err) + } + + if attemptCount != 2 { + t.Errorf("expected 2 attempts, got %d", attemptCount) + } +} + +func TestAPIClient_Do_RateLimitTooLong(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + }) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + start := time.Now() + resp, err := client.Do(req) + elapsed := time.Since(start) + + if err == nil { + resp.Body.Close() + t.Fatal("Do() should have returned error for Retry-After > 30s") + } + + if !strings.Contains(err.Error(), "exceeds 30s cap") { + t.Errorf("expected 'exceeds 30s cap' error, got: %v", err) + } + + // Should NOT have waited (immediate error) + if elapsed > 1*time.Second { + t.Errorf("should return immediately, but took %v", elapsed) + } +} + +func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return 429 but NO Retry-After header + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + }) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + + if err == nil { + resp.Body.Close() + t.Fatal("Do() should have returned error when no Retry-After header") + } + + if !strings.Contains(err.Error(), "no Retry-After header") { + t.Errorf("expected 'no Retry-After header' error, got: %v", err) + } +} + +func TestAPIClient_UpdateRateLimit(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + headers := http.Header{} + headers.Set("X-RateLimit-Bucket", "global") + headers.Set("X-RateLimit-Limit", "10") + headers.Set("X-RateLimit-Remaining", "7") + headers.Set("X-RateLimit-Reset-After", "5.5") + + client.updateRateLimit(headers) + + client.mu.RLock() + state, exists := client.buckets["global"] + client.mu.RUnlock() + + if !exists { + t.Fatal("bucket state not created") + } + + if state.Bucket != "global" { + t.Errorf("expected bucket='global', got '%s'", state.Bucket) + } + if state.Limit != 10 { + t.Errorf("expected limit=10, got %d", state.Limit) + } + if state.Remaining != 7 { + t.Errorf("expected remaining=7, got %d", state.Remaining) + } + + // Check reset time is approximately 5.5 seconds from now + resetIn := time.Until(state.Reset) + if resetIn < 5*time.Second || resetIn > 6*time.Second { + t.Errorf("expected reset in ~5.5s, got %v", resetIn) + } +} + +func TestAPIClient_WaitIfNeeded(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + // Set up a bucket with 0 remaining and reset in future + bucket := "test-bucket" + client.mu.Lock() + client.buckets[bucket] = &RateLimitState{ + Bucket: bucket, + Limit: 5, + Remaining: 0, + Reset: time.Now().Add(200 * time.Millisecond), + } + client.mu.Unlock() + + start := time.Now() + err = client.waitIfNeeded(bucket) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("waitIfNeeded returned error: %v", err) + } + + // Should have waited ~200ms + 100ms buffer + if elapsed < 200*time.Millisecond { + t.Errorf("expected wait of ~300ms, but took %v", elapsed) + } + if elapsed > 500*time.Millisecond { + t.Errorf("waited too long: %v", elapsed) + } +} + +func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + // Set up a bucket with remaining requests + bucket := "test-bucket" + client.mu.Lock() + client.buckets[bucket] = &RateLimitState{ + Bucket: bucket, + Limit: 5, + Remaining: 3, + Reset: time.Now().Add(5 * time.Second), + } + client.mu.Unlock() + + start := time.Now() + err = client.waitIfNeeded(bucket) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("waitIfNeeded returned error: %v", err) + } + + // Should NOT wait (has remaining requests) + if elapsed > 10*time.Millisecond { + t.Errorf("should not wait when remaining > 0, but took %v", elapsed) + } +} + +func TestAPIClient_Do_Concurrent(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + requestCount := 0 + var mu sync.Mutex + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requestCount++ + count := requestCount + mu.Unlock() + + w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket") + w.Header().Set("X-RateLimit-Limit", "10") + w.Header().Set("X-RateLimit-Remaining", "5") + w.Header().Set("X-RateLimit-Reset-After", "1.0") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))}) + })) + defer server.Close() + + // Launch 10 concurrent requests + var wg sync.WaitGroup + errors := make(chan error, 10) + + for range 10 { + wg.Go( + func() { + req, err := http.NewRequest("GET", server.URL+"/test", nil) + if err != nil { + errors <- err + return + } + + resp, err := client.Do(req) + if err != nil { + errors <- err + return + } + resp.Body.Close() + }) + } + + wg.Wait() + close(errors) + + // Check for any errors + for err := range errors { + t.Errorf("concurrent request failed: %v", err) + } + + // All requests should have completed + mu.Lock() + finalCount := requestCount + mu.Unlock() + + if finalCount != 10 { + t.Errorf("expected 10 requests, got %d", finalCount) + } + + // Check rate limit state is consistent (no data races) + client.mu.RLock() + state, exists := client.buckets["concurrent-bucket"] + client.mu.RUnlock() + + if !exists { + t.Fatal("bucket state not found after concurrent requests") + } + + // State should exist and be valid + if state.Limit != 10 { + t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit) + } +} + +func TestAPIClient_ParseRetryAfter(t *testing.T) { + logger := testLogger(t) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } + + tests := []struct { + name string + header string + expected time.Duration + }{ + {"integer seconds", "2", 2 * time.Second}, + {"float seconds", "2.5", 2500 * time.Millisecond}, + {"zero", "0", 0}, + {"empty", "", 0}, + {"invalid", "abc", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers := http.Header{} + headers.Set("Retry-After", tt.header) + + result := client.parseRetryAfter(headers) + + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go new file mode 100644 index 0000000..12082a6 --- /dev/null +++ b/internal/handlers/callback.go @@ -0,0 +1,205 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/cookies" + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/hwsauth" + "github.com/pkg/errors" + "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/pkg/oauth" +) + +func Callback( + server *hws.Server, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + conn *bun.DB, + cfg *config.Config, + store *store.Store, + discordAPI *discord.APIClient, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5) + + if exceeded { + err := errors.Errorf( + "callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", + attempts, + track.IP, + track.UserAgent, + track.Path, + track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + ) + + store.ClearRedirectTrack(r, "/callback") + + throwError( + server, + w, + r, + http.StatusBadRequest, + "OAuth callback failed: Too many redirect attempts. Please try logging in again.", + err, + "warn", + ) + return + } + + state := r.URL.Query().Get("state") + code := r.URL.Query().Get("code") + if state == "" && code == "" { + http.Redirect(w, r, "/", http.StatusBadRequest) + return + } + data, err := verifyState(cfg.OAuth, w, r, state) + if err != nil { + if vsErr, ok := err.(*verifyStateError); ok { + if vsErr.IsCookieError() { + throwUnauthorized(server, w, r, "OAuth session not found or expired", err) + } else { + throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) + } + } else { + throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) + } + return + } + store.ClearRedirectTrack(r, "/callback") + + switch data { + case "login": + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "DB Transaction failed to start", err) + return + } + defer tx.Rollback() + redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI) + if err != nil { + throwInternalServiceError(server, w, r, "OAuth login failed", err) + return + } + tx.Commit() + redirect() + return + } + }, + ) +} + +type verifyStateError struct { + err error + cookieError bool +} + +func (e *verifyStateError) Error() string { + return e.err.Error() +} + +func (e *verifyStateError) IsCookieError() bool { + return e.cookieError +} + +func verifyState( + cfg *oauth.Config, + w http.ResponseWriter, + r *http.Request, + state string, +) (string, error) { + if r == nil { + return "", errors.New("request cannot be nil") + } + if state == "" { + return "", errors.New("state param field is empty") + } + + uak, err := oauth.GetStateCookie(r) + if err != nil { + return "", &verifyStateError{ + err: errors.Wrap(err, "oauth.GetStateCookie"), + cookieError: true, + } + } + + data, err := oauth.VerifyState(cfg, state, uak) + if err != nil { + return "", &verifyStateError{ + err: errors.Wrap(err, "oauth.VerifyState"), + cookieError: false, + } + } + + oauth.DeleteStateCookie(w) + return data, nil +} + +func login( + ctx context.Context, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + tx bun.Tx, + cfg *config.Config, + w http.ResponseWriter, + r *http.Request, + code string, + store *store.Store, + discordAPI *discord.APIClient, +) (func(), error) { + token, err := discordAPI.AuthorizeWithCode(code) + if err != nil { + return nil, errors.Wrap(err, "discordAPI.AuthorizeWithCode") + } + session, err := discord.NewOAuthSession(token) + if err != nil { + return nil, errors.Wrap(err, "discord.NewOAuthSession") + } + discorduser, err := session.GetUser() + if err != nil { + return nil, errors.Wrap(err, "session.GetUser") + } + + user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID) + if err != nil { + return nil, errors.Wrap(err, "db.GetUserByDiscordID") + } + var redirect string + if user == nil { + sessionID, err := store.CreateRegistrationSession(discorduser, token) + if err != nil { + return nil, errors.Wrap(err, "store.CreateRegistrationSession") + } + http.SetCookie(w, &http.Cookie{ + Name: "registration_session", + Path: "/", + Value: sessionID, + MaxAge: 300, // 5 minutes + HttpOnly: true, + Secure: cfg.HWSAuth.SSL, + SameSite: http.SameSiteLaxMode, + }) + redirect = "/register" + } else { + err = user.UpdateDiscordToken(ctx, tx, token) + if err != nil { + return nil, errors.Wrap(err, "user.UpdateDiscordToken") + } + err := auth.Login(w, r, user, true) + if err != nil { + return nil, errors.Wrap(err, "auth.Login") + } + redirect = cookies.CheckPageFrom(w, r) + } + return func() { + http.Redirect(w, r, redirect, http.StatusSeeOther) + }, nil +} diff --git a/internal/handlers/errorpage.go b/internal/handlers/errorpage.go index 2fa7699..d6e4a25 100644 --- a/internal/handlers/errorpage.go +++ b/internal/handlers/errorpage.go @@ -5,24 +5,93 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/view/page" - "github.com/pkg/errors" ) -func ErrorPage( - errorCode int, -) (hws.ErrorPage, error) { +// func ErrorPage( +// error hws.HWSError, +// ) (hws.ErrorPage, error) { +// messages := map[int]string{ +// 400: "The request you made was malformed or unexpected.", +// 401: "You need to login to view this page.", +// 403: "You do not have permission to view this page.", +// 404: "The page or resource you have requested does not exist.", +// 500: `An error occured on the server. Please try again, and if this +// continues to happen contact an administrator.`, +// 503: "The server is currently down for maintenance and should be back soon. =)", +// } +// msg, exists := messages[error.StatusCode] +// if !exists { +// return nil, errors.New("No valid message for the given code") +// } +// return page.Error(error.StatusCode, http.StatusText(error.StatusCode), msg), nil +// } + +func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) { + // Determine if this status code should show technical details + showDetails := shouldShowDetails(hwsError.StatusCode) + + // Get the user-friendly message + message := hwsError.Message + if message == "" { + // Fallback to default messages if no custom message provided + message = getDefaultMessage(hwsError.StatusCode) + } + + // Get technical details if applicable + var details string + if showDetails && hwsError.Error != nil { + details = hwsError.Error.Error() + } + + // Render appropriate template + if details != "" { + return page.ErrorWithDetails( + hwsError.StatusCode, + http.StatusText(hwsError.StatusCode), + message, + details, + ), nil + } + + return page.Error( + hwsError.StatusCode, + http.StatusText(hwsError.StatusCode), + message, + ), nil +} + +// shouldShowDetails determines if a status code should display technical details +func shouldShowDetails(statusCode int) bool { + switch statusCode { + case 400, 500, 503: // Bad Request, Internal Server Error, Service Unavailable + return true + case 401, 403, 404: // Unauthorized, Forbidden, Not Found + return false + default: + // For unknown codes, show details for 5xx errors + return statusCode >= 500 + } +} + +// getDefaultMessage provides fallback messages for status codes +func getDefaultMessage(statusCode int) string { messages := map[int]string{ 400: "The request you made was malformed or unexpected.", 401: "You need to login to view this page.", 403: "You do not have permission to view this page.", 404: "The page or resource you have requested does not exist.", - 500: `An error occured on the server. Please try again, and if this - continues to happen contact an administrator.`, + 500: `An error occurred on the server. Please try again, and if this + continues to happen contact an administrator.`, 503: "The server is currently down for maintenance and should be back soon. =)", } - msg, exists := messages[errorCode] + + msg, exists := messages[statusCode] if !exists { - return nil, errors.New("No valid message for the given code") + if statusCode >= 500 { + return "A server error occurred. Please try again later." + } + return "An error occurred while processing your request." } - return page.Error(errorCode, http.StatusText(errorCode), msg), nil + + return msg } diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go new file mode 100644 index 0000000..4e26611 --- /dev/null +++ b/internal/handlers/errors.go @@ -0,0 +1,109 @@ +package handlers + +import ( + "fmt" + "net/http" + + "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" +) + +// throwError is a generic helper that all throw* functions use internally +func throwError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + statusCode int, + msg string, + err error, + level string, +) { + err = s.ThrowError(w, r, hws.HWSError{ + StatusCode: statusCode, + Message: msg, + Error: err, + Level: hws.ErrorLevel(level), + RenderErrorPage: true, // throw* family always renders error pages + }) + if err != nil { + s.ThrowFatal(w, err) + } +} + +// throwInternalServiceError handles 500 errors (server failures) +func throwInternalServiceError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusInternalServerError, msg, err, "error") +} + +// throwBadRequest handles 400 errors (malformed requests) +func throwBadRequest( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusBadRequest, msg, err, "debug") +} + +// throwForbidden handles 403 errors (normal permission denials) +func throwForbidden( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusForbidden, msg, err, "debug") +} + +// throwForbiddenSecurity handles 403 errors for security events (uses WARN level) +func throwForbiddenSecurity( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusForbidden, msg, err, "warn") +} + +// throwUnauthorized handles 401 errors (not authenticated) +func throwUnauthorized( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusUnauthorized, msg, err, "debug") +} + +// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level) +func throwUnauthorizedSecurity( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusUnauthorized, msg, err, "warn") +} + +// throwNotFound handles 404 errors +func throwNotFound( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + path string, +) { + msg := fmt.Sprintf("The requested resource was not found: %s", path) + err := errors.New("Resource not found") + throwError(s, w, r, http.StatusNotFound, msg, err, "debug") +} diff --git a/internal/handlers/index.go b/internal/handlers/index.go index 3274302..25e7ac5 100644 --- a/internal/handlers/index.go +++ b/internal/handlers/index.go @@ -14,34 +14,7 @@ func Index(server *hws.Server) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - page, err := ErrorPage(http.StatusNotFound) - if err != nil { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to generate the error page", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: false, - }) - if err != nil { - server.ThrowFatal(w, err) - } - return - } - err = page.Render(r.Context(), w) - if err != nil { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to render the error page", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: false, - }) - if err != nil { - server.ThrowFatal(w, err) - } - return - } + throwNotFound(server, w, r, r.URL.Path) } page.Index().Render(r.Context(), w) }, diff --git a/internal/handlers/isusernameunique.go b/internal/handlers/isusernameunique.go new file mode 100644 index 0000000..f441784 --- /dev/null +++ b/internal/handlers/isusernameunique.go @@ -0,0 +1,45 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/store" + "github.com/uptrace/bun" +) + +func IsUsernameUnique( + server *hws.Server, + conn *bun.DB, + cfg *config.Config, + store *store.Store, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + username := r.FormValue("username") + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database transaction failed", err) + return + } + defer tx.Rollback() + unique, err := db.IsUsernameUnique(ctx, tx, username) + if err != nil { + throwInternalServiceError(server, w, r, "Database query failed", err) + return + } + tx.Commit() + if !unique { + w.WriteHeader(http.StatusConflict) + } else { + w.WriteHeader(http.StatusOK) + } + }, + ) +} diff --git a/internal/handlers/login.go b/internal/handlers/login.go new file mode 100644 index 0000000..93bd6c8 --- /dev/null +++ b/internal/handlers/login.go @@ -0,0 +1,63 @@ +package handlers + +import ( + "net/http" + + "git.haelnorr.com/h/golib/cookies" + "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/pkg/oauth" +) + +func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost) + attempts, exceeded, track := st.TrackRedirect(r, "/login", 5) + + if exceeded { + err := errors.Errorf( + "login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", + attempts, + track.IP, + track.UserAgent, + track.Path, + track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + ) + + st.ClearRedirectTrack(r, "/login") + + throwError( + server, + w, + r, + http.StatusBadRequest, + "Login failed: Too many redirect attempts. Please clear your browser cookies and try again.", + err, + "warn", + ) + return + } + + state, uak, err := oauth.GenerateState(cfg.OAuth, "login") + if err != nil { + throwInternalServiceError(server, w, r, "Failed to generate state token", err) + return + } + oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL) + + link, err := discordAPI.GetOAuthLink(state) + if err != nil { + throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err) + return + } + st.ClearRedirectTrack(r, "/login") + + http.Redirect(w, r, link, http.StatusSeeOther) + }, + ) +} diff --git a/internal/handlers/logout.go b/internal/handlers/logout.go new file mode 100644 index 0000000..02b9a9d --- /dev/null +++ b/internal/handlers/logout.go @@ -0,0 +1,59 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/hwsauth" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +func Logout( + server *hws.Server, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + conn *bun.DB, + discordAPI *discord.APIClient, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) + return + } + defer tx.Rollback() + + user := db.CurrentUser(r.Context()) + if user == nil { + // JIC - should be impossible to get here if route is protected by LoginReq + w.Header().Set("HX-Redirect", "/") + return + } + token, err := user.DeleteDiscordTokens(ctx, tx) + if err != nil { + throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens")) + return + } + err = discordAPI.RevokeToken(token.Convert()) + if err != nil { + throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken")) + return + } + err = auth.Logout(tx, w, r) + if err != nil { + throwInternalServiceError(server, w, r, "Logout failed", err) + return + } + tx.Commit() + w.Header().Set("HX-Redirect", "/") + }, + ) +} diff --git a/internal/handlers/register.go b/internal/handlers/register.go new file mode 100644 index 0000000..25ded6d --- /dev/null +++ b/internal/handlers/register.go @@ -0,0 +1,129 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/cookies" + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/hwsauth" + "github.com/pkg/errors" + "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/internal/view/page" +) + +func Register( + server *hws.Server, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + conn *bun.DB, + cfg *config.Config, + store *store.Store, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + attempts, exceeded, track := store.TrackRedirect(r, "/register", 3) + + if exceeded { + err := errors.Errorf( + "registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t", + attempts, + track.IP, + track.UserAgent, + track.Path, + track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + cfg.HWSAuth.SSL, + ) + + store.ClearRedirectTrack(r, "/register") + + throwError( + server, + w, + r, + http.StatusBadRequest, + "Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.", + err, + "warn", + ) + return + } + + sessionCookie, err := r.Cookie("registration_session") + if err != nil { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + details, ok := store.GetRegistrationSession(sessionCookie.Value) + if !ok { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + + store.ClearRedirectTrack(r, "/register") + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database transaction failed", err) + return + } + defer tx.Rollback() + method := r.Method + if method == "GET" { + tx.Commit() + page.Register(details.DiscordUser.Username).Render(r.Context(), w) + return + } + if method == "POST" { + username := r.FormValue("username") + user, err := registerUser(ctx, tx, username, details) + if err != nil { + throwInternalServiceError(server, w, r, "Registration failed", err) + return + } + tx.Commit() + if user == nil { + w.WriteHeader(http.StatusConflict) + } else { + err = auth.Login(w, r, user, true) + if err != nil { + throwInternalServiceError(server, w, r, "Login failed", err) + return + } + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) + } + return + } + }, + ) +} + +func registerUser( + ctx context.Context, + tx bun.Tx, + username string, + details *store.RegistrationSession, +) (*db.User, error) { + unique, err := db.IsUsernameUnique(ctx, tx, username) + if err != nil { + return nil, errors.Wrap(err, "db.IsUsernameUnique") + } + if !unique { + return nil, nil + } + user, err := db.CreateUser(ctx, tx, username, details.DiscordUser) + if err != nil { + return nil, errors.Wrap(err, "db.CreateUser") + } + err = user.UpdateDiscordToken(ctx, tx, details.Token) + if err != nil { + return nil, errors.Wrap(err, "db.UpdateDiscordToken") + } + return user, nil +} diff --git a/internal/handlers/static.go b/internal/handlers/static.go index fb444c9..3176129 100644 --- a/internal/handlers/static.go +++ b/internal/handlers/static.go @@ -15,16 +15,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler { if err != nil { // If we can't create the file server, return a handler that always errors return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to load the file system", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: true, - }) - if err != nil { - server.ThrowFatal(w, err) - } + throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err) }) } diff --git a/internal/store/newlogin.go b/internal/store/newlogin.go new file mode 100644 index 0000000..0e62529 --- /dev/null +++ b/internal/store/newlogin.go @@ -0,0 +1,46 @@ +package store + +import ( + "errors" + "time" + + "git.haelnorr.com/h/oslstats/internal/discord" + "github.com/bwmarrin/discordgo" +) + +type RegistrationSession struct { + DiscordUser *discordgo.User + Token *discord.Token + ExpiresAt time.Time +} + +func (s *Store) CreateRegistrationSession(user *discordgo.User, token *discord.Token) (string, error) { + if user == nil { + return "", errors.New("user cannot be nil") + } + if token == nil { + return "", errors.New("token cannot be nil") + } + id := generateID() + s.sessions.Store(id, &RegistrationSession{ + DiscordUser: user, + Token: token, + ExpiresAt: time.Now().Add(5 * time.Minute), + }) + return id, nil +} + +func (s *Store) GetRegistrationSession(id string) (*RegistrationSession, bool) { + val, ok := s.sessions.Load(id) + if !ok { + return nil, false + } + + session := val.(*RegistrationSession) + if time.Now().After(session.ExpiresAt) { + s.sessions.Delete(id) + return nil, false + } + + return session, true +} diff --git a/internal/store/redirects.go b/internal/store/redirects.go new file mode 100644 index 0000000..e2e7852 --- /dev/null +++ b/internal/store/redirects.go @@ -0,0 +1,95 @@ +package store + +import ( + "net" + "net/http" + "strings" + "time" +) + +// getClientIP extracts the client IP address, checking X-Forwarded-For first +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header (comma-separated list, first is client) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the list + ips := strings.Split(xff, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // Fall back to RemoteAddr (format: "IP:port" or "[IPv6]:port") + // Use net.SplitHostPort to properly handle both IPv4 and IPv6 + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // If SplitHostPort fails, return as-is (shouldn't happen with valid RemoteAddr) + return r.RemoteAddr + } + return host +} + +// TrackRedirect increments the redirect counter for this IP+UA+Path combination +// Returns the current attempt count, whether limit was exceeded, and the track details +func (s *Store) TrackRedirect(r *http.Request, path string, maxAttempts int) (attempts int, exceeded bool, track *RedirectTrack) { + if r == nil { + return 0, false, nil + } + + ip := getClientIP(r) + userAgent := r.UserAgent() + key := redirectKey(ip, userAgent, path) + + now := time.Now() + expiresAt := now.Add(5 * time.Minute) + + // Try to load existing track + val, exists := s.redirectTracks.Load(key) + if exists { + track = val.(*RedirectTrack) + + // Check if expired + if now.After(track.ExpiresAt) { + // Expired, start fresh + track = &RedirectTrack{ + IP: ip, + UserAgent: userAgent, + Path: path, + Attempts: 1, + FirstSeen: now, + ExpiresAt: expiresAt, + } + s.redirectTracks.Store(key, track) + return 1, false, track + } + + // Increment existing + track.Attempts++ + track.ExpiresAt = expiresAt // Extend expiry + exceeded = track.Attempts >= maxAttempts + return track.Attempts, exceeded, track + } + + // Create new track + track = &RedirectTrack{ + IP: ip, + UserAgent: userAgent, + Path: path, + Attempts: 1, + FirstSeen: now, + ExpiresAt: expiresAt, + } + s.redirectTracks.Store(key, track) + return 1, false, track +} + +// ClearRedirectTrack removes a redirect tracking entry (called after successful completion) +func (s *Store) ClearRedirectTrack(r *http.Request, path string) { + if r == nil { + return + } + + ip := getClientIP(r) + userAgent := r.UserAgent() + key := redirectKey(ip, userAgent, path) + s.redirectTracks.Delete(key) +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..5620e58 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,80 @@ +package store + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "sync" + "time" +) + +// RedirectTrack represents a single redirect attempt tracking entry +type RedirectTrack struct { + IP string // Client IP (X-Forwarded-For aware) + UserAgent string // Full User-Agent string for debugging + Path string // Request path (without query params) + Attempts int // Number of redirect attempts + FirstSeen time.Time // When first redirect was tracked + ExpiresAt time.Time // When to clean up this entry +} + +type Store struct { + sessions sync.Map // key: string, value: *RegistrationSession + redirectTracks sync.Map // key: string, value: *RedirectTrack + cleanup *time.Ticker +} + +func NewStore() *Store { + s := &Store{ + cleanup: time.NewTicker(1 * time.Minute), + } + + // Background cleanup of expired sessions + go func() { + for range s.cleanup.C { + s.cleanupExpired() + } + }() + + return s +} + +func (s *Store) Delete(id string) { + s.sessions.Delete(id) +} +func (s *Store) cleanupExpired() { + now := time.Now() + + // Clean up expired registration sessions + s.sessions.Range(func(key, value any) bool { + session := value.(*RegistrationSession) + if now.After(session.ExpiresAt) { + s.sessions.Delete(key) + } + return true + }) + + // Clean up expired redirect tracks + s.redirectTracks.Range(func(key, value any) bool { + track := value.(*RedirectTrack) + if now.After(track.ExpiresAt) { + s.redirectTracks.Delete(key) + } + return true + }) +} +func generateID() string { + b := make([]byte, 32) + rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} + +// redirectKey generates a unique key for tracking redirects +// Uses IP + first 100 chars of UA + path as key (not hashed for debugging) +func redirectKey(ip, userAgent, path string) string { + ua := userAgent + if len(ua) > 100 { + ua = ua[:100] + } + return fmt.Sprintf("%s:%s:%s", ip, ua, path) +} diff --git a/internal/view/component/form/register.templ b/internal/view/component/form/register.templ new file mode 100644 index 0000000..00af173 --- /dev/null +++ b/internal/view/component/form/register.templ @@ -0,0 +1,89 @@ +package form + +templ RegisterForm(username string) { +
+ +
+
+
+ + +

+
+
+ +
+
+} diff --git a/internal/view/component/nav/navbarright.templ b/internal/view/component/nav/navbarright.templ index 5f1fc22..88a5531 100644 --- a/internal/view/component/nav/navbarright.templ +++ b/internal/view/component/nav/navbarright.templ @@ -1,6 +1,6 @@ package nav -import "git.haelnorr.com/h/oslstats/pkg/contexts" +import "git.haelnorr.com/h/oslstats/internal/db" type ProfileItem struct { name string // Label to display @@ -23,7 +23,7 @@ func getProfileItems() []ProfileItem { // Returns the right portion of the navbar templ navRight() { - {{ user := contexts.CurrentUser(ctx) }} + {{ user := db.CurrentUser(ctx) }} {{ items := getProfileItems() }}
diff --git a/internal/view/component/nav/sidenav.templ b/internal/view/component/nav/sidenav.templ index 23fb9f0..5b50e70 100644 --- a/internal/view/component/nav/sidenav.templ +++ b/internal/view/component/nav/sidenav.templ @@ -1,10 +1,10 @@ package nav -import "git.haelnorr.com/h/oslstats/pkg/contexts" +import "git.haelnorr.com/h/oslstats/internal/db" // Returns the mobile version of the navbar thats only visible when activated templ sideNav(navItems []NavItem) { - {{ user := contexts.CurrentUser(ctx) }} + {{ user := db.CurrentUser(ctx) }}
- + { title } - - - @@ -38,19 +35,6 @@ templ Global(title string) { const bodyData = { showError500: false, showError503: false, - showConfirmPasswordModal: false, - handleHtmxBeforeOnLoad(event) { - const requestPath = event.detail.pathInfo.requestPath; - if (requestPath === "/reauthenticate") { - // handle password incorrect on refresh attempt - if (event.detail.xhr.status === 445) { - event.detail.shouldSwap = true; - event.detail.isError = false; - } else if (event.detail.xhr.status === 200) { - this.showConfirmPasswordModal = false; - } - } - }, // handle errors from the server on HTMX requests handleHtmxError(event) { const errorCode = event.detail.errorInfo.error; @@ -65,11 +49,6 @@ templ Global(title string) { this.showError503 = true; setTimeout(() => (this.showError503 = false), 6000); } - - // user is authorized but needs to refresh their login - if (errorCode.includes("Code 444")) { - this.showConfirmPasswordModal = true; - } }, }; @@ -78,7 +57,6 @@ templ Global(title string) { class="bg-base text-text ubuntu-mono-regular overflow-x-hidden" x-data="bodyData" x-on:htmx:error="handleHtmxError($event)" - x-on:htmx:before-on-load="handleHtmxBeforeOnLoad($event)" > @popup.Error500Popup() @popup.Error503Popup() diff --git a/internal/view/page/error.templ b/internal/view/page/error.templ index 1561478..1612e91 100644 --- a/internal/view/page/error.templ +++ b/internal/view/page/error.templ @@ -3,32 +3,66 @@ package page import "git.haelnorr.com/h/oslstats/internal/view/layout" import "strconv" -// Page template for Error pages. Error code should be a HTTP status code as -// a string, and err should be the corresponding response title. -// Message is a custom error message displayed below the code and error. +// Original Error template (keep for backwards compatibility where needed) templ Error(code int, err string, message string) { + @ErrorWithDetails(code, err, message, "") +} + +// Enhanced Error template with optional details section +templ ErrorWithDetails(code int, err string, message string, details string) { @layout.Global(err) { -
-
-

{ strconv.Itoa(code) }

-

{ err }

-

{ message }

- Go to homepage +
+
+

{ strconv.Itoa(code) }

+

{ err }

+ // Always show the message from hws.HWSError.Message +

{ message }

+ // Conditionally show technical details in dropdown + if details != "" { +
+
+ + Details + (click to expand) + +
+
{ details }
+
+ +
+
+ } + + Go to homepage +
+ if details != "" { + + } } } diff --git a/internal/view/page/register.templ b/internal/view/page/register.templ new file mode 100644 index 0000000..cfd6517 --- /dev/null +++ b/internal/view/page/register.templ @@ -0,0 +1,29 @@ +package page + +import "git.haelnorr.com/h/oslstats/internal/view/layout" +import "git.haelnorr.com/h/oslstats/internal/view/component/form" + +// Returns the login page +templ Register(username string) { + @layout.Global("Register") { +
+
+
+
+

Set your display name

+

+ Select your display name. This must be unique, and cannot be changed. +

+
+
+ @form.RegisterForm(username) +
+
+
+
+ } +} diff --git a/pkg/contexts/currentuser.go b/pkg/contexts/currentuser.go deleted file mode 100644 index f522563..0000000 --- a/pkg/contexts/currentuser.go +++ /dev/null @@ -1,8 +0,0 @@ -package contexts - -import ( - "git.haelnorr.com/h/golib/hwsauth" - "git.haelnorr.com/h/oslstats/internal/db" -) - -var CurrentUser hwsauth.ContextLoader[*db.User] diff --git a/pkg/contexts/keys.go b/pkg/contexts/keys.go index 996cba7..e089934 100644 --- a/pkg/contexts/keys.go +++ b/pkg/contexts/keys.go @@ -1,7 +1,7 @@ package contexts -type contextKey string +type Key string -func (c contextKey) String() string { +func (c Key) String() string { return "oslstats context key " + string(c) } diff --git a/pkg/embedfs/files/css/input.css b/pkg/embedfs/files/css/input.css index 8839df9..a673dd5 100644 --- a/pkg/embedfs/files/css/input.css +++ b/pkg/embedfs/files/css/input.css @@ -1,19 +1,10 @@ +@import url("https://fonts.googleapis.com/css2?family=Ubuntu+Mono:ital,wght@0,400;0,700;1,400;1,700&display=swap"); @import "tailwindcss"; -@source "../../../../internal/view/component/footer/footer.templ"; -@source "../../../../internal/view/component/nav/navbarleft.templ"; -@source "../../../../internal/view/component/nav/navbarright.templ"; -@source "../../../../internal/view/component/nav/navbar.templ"; -@source "../../../../internal/view/component/nav/sidenav.templ"; -@source "../../../../internal/view/component/popup/error500Popup.templ"; -@source "../../../../internal/view/component/popup/error503Popup.templ"; -@source "../../../../internal/view/layout/global.templ"; -@source "../../../../internal/view/page/error.templ"; -@source "../../../../internal/view/page/index.templ"; - [x-cloak] { display: none !important; } + @theme inline { --color-rosewater: var(--rosewater); --color-flamingo: var(--flamingo); @@ -43,6 +34,7 @@ --color-mantle: var(--mantle); --color-crust: var(--crust); } + :root { --rosewater: hsl(11, 59%, 67%); --flamingo: hsl(0, 60%, 67%); @@ -102,6 +94,7 @@ --mantle: hsl(240, 21%, 12%); --crust: hsl(240, 23%, 9%); } + .ubuntu-mono-regular { font-family: "Ubuntu Mono", serif; font-weight: 400; diff --git a/pkg/embedfs/files/css/output.css b/pkg/embedfs/files/css/output.css index c506704..1f8ff44 100644 --- a/pkg/embedfs/files/css/output.css +++ b/pkg/embedfs/files/css/output.css @@ -1,4 +1,5 @@ /*! tailwindcss v4.1.18 | MIT License | https://tailwindcss.com */ +@import url("https://fonts.googleapis.com/css2?family=Ubuntu+Mono:ital,wght@0,400;0,700;1,400;1,700&display=swap"); @layer properties; @layer theme, base, components, utilities; @layer theme { @@ -10,7 +11,10 @@ --spacing: 0.25rem; --breakpoint-xl: 80rem; --container-md: 28rem; + --container-2xl: 42rem; --container-7xl: 80rem; + --text-xs: 0.75rem; + --text-xs--line-height: calc(1 / 0.75); --text-sm: 0.875rem; --text-sm--line-height: calc(1.25 / 0.875); --text-lg: 1.125rem; @@ -28,11 +32,13 @@ --text-9xl: 8rem; --text-9xl--line-height: 1; --font-weight-medium: 500; + --font-weight-semibold: 600; --font-weight-bold: 700; --tracking-tight: -0.025em; --leading-relaxed: 1.625; --radius-sm: 0.25rem; --radius-lg: 0.5rem; + --radius-xl: 0.75rem; --default-transition-duration: 150ms; --default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); --default-font-family: var(--font-sans); @@ -208,6 +214,9 @@ .relative { position: relative; } + .static { + position: static; + } .end-0 { inset-inline-end: calc(var(--spacing) * 0); } @@ -244,9 +253,18 @@ .mt-4 { margin-top: calc(var(--spacing) * 4); } + .mt-5 { + margin-top: calc(var(--spacing) * 5); + } .mt-6 { margin-top: calc(var(--spacing) * 6); } + .mt-7 { + margin-top: calc(var(--spacing) * 7); + } + .mt-8 { + margin-top: calc(var(--spacing) * 8); + } .mt-10 { margin-top: calc(var(--spacing) * 10); } @@ -265,6 +283,9 @@ .mb-auto { margin-bottom: auto; } + .ml-2 { + margin-left: calc(var(--spacing) * 2); + } .ml-auto { margin-left: auto; } @@ -321,9 +342,15 @@ .w-full { width: 100%; } + .max-w-2xl { + max-width: var(--container-2xl); + } .max-w-7xl { max-width: var(--container-7xl); } + .max-w-100 { + max-width: calc(var(--spacing) * 100); + } .max-w-md { max-width: var(--container-md); } @@ -344,6 +371,9 @@ .transform { transform: var(--tw-rotate-x,) var(--tw-rotate-y,) var(--tw-rotate-z,) var(--tw-skew-x,) var(--tw-skew-y,); } + .cursor-pointer { + cursor: pointer; + } .flex-col { flex-direction: column; } @@ -381,6 +411,12 @@ margin-block-end: calc(calc(var(--spacing) * 1) * calc(1 - var(--tw-space-y-reverse))); } } + .gap-x-2 { + column-gap: calc(var(--spacing) * 2); + } + .gap-y-4 { + row-gap: calc(var(--spacing) * 4); + } .divide-y { :where(& > :not(:last-child)) { --tw-divide-y-reverse: 0; @@ -398,9 +434,15 @@ .overflow-hidden { overflow: hidden; } + .overflow-x-auto { + overflow-x: auto; + } .overflow-x-hidden { overflow-x: hidden; } + .rounded { + border-radius: 0.25rem; + } .rounded-full { border-radius: calc(infinity * 1px); } @@ -410,19 +452,35 @@ .rounded-sm { border-radius: var(--radius-sm); } + .rounded-xl { + border-radius: var(--radius-xl); + } .border { border-style: var(--tw-border-style); border-width: 1px; } + .border-2 { + border-style: var(--tw-border-style); + border-width: 2px; + } + .border-green { + border-color: var(--green); + } + .border-overlay0 { + border-color: var(--overlay0); + } + .border-red { + border-color: var(--red); + } .border-surface1 { border-color: var(--surface1); } + .border-transparent { + border-color: transparent; + } .bg-base { background-color: var(--base); } - .bg-blue { - background-color: var(--blue); - } .bg-crust { background-color: var(--crust); } @@ -456,12 +514,21 @@ .p-4 { padding: calc(var(--spacing) * 4); } + .px-2 { + padding-inline: calc(var(--spacing) * 2); + } + .px-3 { + padding-inline: calc(var(--spacing) * 3); + } .px-4 { padding-inline: calc(var(--spacing) * 4); } .px-5 { padding-inline: calc(var(--spacing) * 5); } + .py-1 { + padding-block: calc(var(--spacing) * 1); + } .py-2 { padding-block: calc(var(--spacing) * 2); } @@ -480,6 +547,15 @@ .text-center { text-align: center; } + .text-left { + text-align: left; + } + .text-right { + text-align: right; + } + .font-mono { + font-family: var(--font-mono); + } .text-2xl { font-size: var(--text-2xl); line-height: var(--tw-leading, var(--text-2xl--line-height)); @@ -508,6 +584,10 @@ font-size: var(--text-xl); line-height: var(--tw-leading, var(--text-xl--line-height)); } + .text-xs { + font-size: var(--text-xs); + line-height: var(--tw-leading, var(--text-xs--line-height)); + } .leading-relaxed { --tw-leading: var(--leading-relaxed); line-height: var(--leading-relaxed); @@ -520,10 +600,20 @@ --tw-font-weight: var(--font-weight-medium); font-weight: var(--font-weight-medium); } + .font-semibold { + --tw-font-weight: var(--font-weight-semibold); + font-weight: var(--font-weight-semibold); + } .tracking-tight { --tw-tracking: var(--tracking-tight); letter-spacing: var(--tracking-tight); } + .break-all { + word-break: break-all; + } + .whitespace-pre-wrap { + white-space: pre-wrap; + } .text-crust { color: var(--crust); } @@ -568,6 +658,14 @@ --tw-duration: 200ms; transition-duration: 200ms; } + .outline-none { + --tw-outline-style: none; + outline-style: none; + } + .select-none { + -webkit-user-select: none; + user-select: none; + } .hover\:cursor-pointer { &:hover { @media (hover: hover) { @@ -575,16 +673,6 @@ } } } - .hover\:bg-blue\/75 { - &:hover { - @media (hover: hover) { - background-color: var(--blue); - @supports (color: color-mix(in lab, red, red)) { - background-color: color-mix(in oklab, var(--blue) 75%, transparent); - } - } - } - } .hover\:bg-crust { &:hover { @media (hover: hover) { @@ -673,6 +761,51 @@ } } } + .hover\:text-text { + &:hover { + @media (hover: hover) { + color: var(--text); + } + } + } + .focus\:border-blue { + &:focus { + border-color: var(--blue); + } + } + .focus\:border-green { + &:focus { + border-color: var(--green); + } + } + .focus\:border-red { + &:focus { + border-color: var(--red); + } + } + .disabled\:pointer-events-none { + &:disabled { + pointer-events: none; + } + } + .disabled\:cursor-default { + &:disabled { + cursor: default; + } + } + .disabled\:bg-green\/60 { + &:disabled { + background-color: var(--green); + @supports (color: color-mix(in lab, red, red)) { + background-color: color-mix(in oklab, var(--green) 60%, transparent); + } + } + } + .disabled\:opacity-50 { + &:disabled { + opacity: 50%; + } + } .sm\:end-6 { @media (width >= 40rem) { inset-inline-end: calc(var(--spacing) * 6); @@ -693,11 +826,6 @@ display: none; } } - .sm\:inline { - @media (width >= 40rem) { - display: inline; - } - } .sm\:justify-between { @media (width >= 40rem) { justify-content: space-between; @@ -708,6 +836,11 @@ gap: calc(var(--spacing) * 2); } } + .sm\:p-7 { + @media (width >= 40rem) { + padding: calc(var(--spacing) * 7); + } + } .sm\:px-6 { @media (width >= 40rem) { padding-inline: calc(var(--spacing) * 6); diff --git a/pkg/embedfs/files/js/theme.js b/pkg/embedfs/files/js/theme.js index 291c41f..a2a11c7 100644 --- a/pkg/embedfs/files/js/theme.js +++ b/pkg/embedfs/files/js/theme.js @@ -1,3 +1,5 @@ +// This function prevents the 'flash of unstyled content' +// Include it at the top of (function() { let theme = localStorage.getItem("theme") || "system"; if (theme === "system") { diff --git a/pkg/oauth/config.go b/pkg/oauth/config.go new file mode 100644 index 0000000..c37ef21 --- /dev/null +++ b/pkg/oauth/config.go @@ -0,0 +1,23 @@ +package oauth + +import ( + "git.haelnorr.com/h/golib/env" + "github.com/pkg/errors" +) + +type Config struct { + PrivateKey string // ENV OAUTH_PRIVATE_KEY: Private key for signing OAuth state tokens (required) +} + +func ConfigFromEnv() (any, error) { + cfg := &Config{ + PrivateKey: env.String("OAUTH_PRIVATE_KEY", ""), + } + + // Check required fields + if cfg.PrivateKey == "" { + return nil, errors.New("Envar not set: OAUTH_PRIVATE_KEY") + } + + return cfg, nil +} diff --git a/pkg/oauth/cookies.go b/pkg/oauth/cookies.go new file mode 100644 index 0000000..adf0f54 --- /dev/null +++ b/pkg/oauth/cookies.go @@ -0,0 +1,45 @@ +package oauth + +import ( + "encoding/base64" + "net/http" + + "github.com/pkg/errors" +) + +func SetStateCookie(w http.ResponseWriter, uak []byte, ssl bool) { + encodedUak := base64.RawURLEncoding.EncodeToString(uak) + http.SetCookie(w, &http.Cookie{ + Name: "oauth_uak", + Value: encodedUak, + Path: "/", + MaxAge: 300, + HttpOnly: true, + Secure: ssl, + SameSite: http.SameSiteLaxMode, + }) +} + +func GetStateCookie(r *http.Request) ([]byte, error) { + if r == nil { + return nil, errors.New("Request cannot be nil") + } + cookie, err := r.Cookie("oauth_uak") + if err != nil { + return nil, err + } + uak, err := base64.RawURLEncoding.DecodeString(cookie.Value) + if err != nil { + return nil, errors.Wrap(err, "failed to decode userAgentKey from cookie") + } + return uak, nil +} + +func DeleteStateCookie(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: "oauth_uak", + Value: "", + Path: "/", + MaxAge: -1, + }) +} diff --git a/pkg/oauth/ezconf.go b/pkg/oauth/ezconf.go new file mode 100644 index 0000000..e8a87ca --- /dev/null +++ b/pkg/oauth/ezconf.go @@ -0,0 +1,41 @@ +package oauth + +import ( + "runtime" + "strings" +) + +// EZConfIntegration provides integration with ezconf for automatic configuration +type EZConfIntegration struct { + configFunc func() (any, error) + name string +} + +// PackagePath returns the path to the config package for source parsing +func (e EZConfIntegration) PackagePath() string { + _, filename, _, _ := runtime.Caller(0) + // Return directory of this file + return filename[:len(filename)-len("/ezconf.go")] +} + +// ConfigFunc returns the ConfigFromEnv function for ezconf +func (e EZConfIntegration) ConfigFunc() func() (any, error) { + return func() (any, error) { + return e.configFunc() + } +} + +// Name returns the name to use when registering with ezconf +func (e EZConfIntegration) Name() string { + return strings.ToLower(e.name) +} + +// GroupName returns the display name for grouping environment variables +func (e EZConfIntegration) GroupName() string { + return e.name +} + +// NewEZConfIntegration creates a new EZConf integration helper +func NewEZConfIntegration() EZConfIntegration { + return EZConfIntegration{name: "OAuth", configFunc: ConfigFromEnv} +} diff --git a/pkg/oauth/state.go b/pkg/oauth/state.go new file mode 100644 index 0000000..5c97ab5 --- /dev/null +++ b/pkg/oauth/state.go @@ -0,0 +1,117 @@ +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "slices" + "strings" + + "github.com/pkg/errors" +) + +// STATE FLOW: +// data provided at call time to be retrieved later +// random value generated on the spot +// userAgentKey - nonce used to prevent MITM, stored as lax cookie on client +// privateKey - from config + +func GenerateState(cfg *Config, data string) (state string, userAgentKey []byte, err error) { + // signature = BASE64_SHA256(data + "." + random + userAgentKey + privateKey) + // state = data + "." + random + "." + signature + if cfg == nil { + return "", nil, errors.New("cfg cannot be nil") + } + if cfg.PrivateKey == "" { + return "", nil, errors.New("private key cannot be empty") + } + if data == "" { + return "", nil, errors.New("data cannot be empty") + } + + // Generate 32 random bytes for random component + randomBytes := make([]byte, 32) + _, err = rand.Read(randomBytes) + if err != nil { + return "", nil, errors.Wrap(err, "failed to generate random bytes") + } + + // Generate 32 random bytes for userAgentKey + userAgentKey = make([]byte, 32) + _, err = rand.Read(userAgentKey) + if err != nil { + return "", nil, errors.Wrap(err, "failed to generate userAgentKey bytes") + } + + // Encode random and userAgentKey to base64 + randomEncoded := base64.RawURLEncoding.EncodeToString(randomBytes) + userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey) + + // Create payload for signing: data + "." + random + userAgentKey + privateKey + // Note: userAgentKey is concatenated directly with privateKey (no separator) + payload := data + "." + randomEncoded + userAgentKeyEncoded + cfg.PrivateKey + + // Generate signature + hash := sha256.Sum256([]byte(payload)) + signature := base64.RawURLEncoding.EncodeToString(hash[:]) + + // Construct state: data + "." + random + "." + signature + state = data + "." + randomEncoded + "." + signature + + return state, userAgentKey, nil +} + +func VerifyState(cfg *Config, state string, userAgentKey []byte) (data string, err error) { + // Validate inputs + if cfg == nil { + return "", errors.New("cfg cannot be nil") + } + if cfg.PrivateKey == "" { + return "", errors.New("private key cannot be empty") + } + if state == "" { + return "", errors.New("state cannot be empty") + } + if len(userAgentKey) == 0 { + return "", errors.New("userAgentKey cannot be empty") + } + + // Split state into parts + parts := strings.Split(state, ".") + if len(parts) != 3 { + return "", errors.Errorf("state must have exactly 3 parts (data.random.signature), got %d parts", len(parts)) + } + + // Check for empty parts + if slices.Contains(parts, "") { + return "", errors.New("state parts cannot be empty") + } + + data = parts[0] + random := parts[1] + receivedSignature := parts[2] + + // Encode userAgentKey to base64 for payload reconstruction + userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey) + + // Reconstruct payload (same as generation): data + "." + random + userAgentKeyEncoded + privateKey + payload := data + "." + random + userAgentKeyEncoded + cfg.PrivateKey + + // Generate expected hash + hash := sha256.Sum256([]byte(payload)) + + // Decode received signature to bytes + receivedBytes, err := base64.RawURLEncoding.DecodeString(receivedSignature) + if err != nil { + return "", errors.Wrap(err, "failed to decode received signature") + } + + // Compare hash bytes directly with decoded signature using constant-time comparison + // This is more efficient than encoding hash and then decoding both for comparison + if subtle.ConstantTimeCompare(hash[:], receivedBytes) == 1 { + return data, nil + } + + return "", errors.New("invalid state signature") +} diff --git a/pkg/oauth/state_test.go b/pkg/oauth/state_test.go new file mode 100644 index 0000000..c647e66 --- /dev/null +++ b/pkg/oauth/state_test.go @@ -0,0 +1,817 @@ +package oauth + +import ( + "crypto/sha256" + "encoding/base64" + "strings" + "testing" +) + +// Helper function to create a test config +func testConfig() *Config { + return &Config{ + PrivateKey: "test_private_key_for_testing_12345", + } +} + +// TestGenerateState_Success tests the happy path of state generation +func TestGenerateState_Success(t *testing.T) { + cfg := testConfig() + data := "test_data_payload" + + state, userAgentKey, err := GenerateState(cfg, data) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if state == "" { + t.Error("Expected non-empty state") + } + + if len(userAgentKey) != 32 { + t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey)) + } + + // Verify state format: data.random.signature + parts := strings.Split(state, ".") + if len(parts) != 3 { + t.Errorf("Expected state to have 3 parts, got %d", len(parts)) + } + + // Verify data is preserved + if parts[0] != data { + t.Errorf("Expected data to be '%s', got '%s'", data, parts[0]) + } + + // Verify random part is base64 encoded + randomBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Errorf("Expected random part to be valid base64: %v", err) + } + if len(randomBytes) != 32 { + t.Errorf("Expected random to be 32 bytes when decoded, got %d", len(randomBytes)) + } + + // Verify signature part is base64 encoded + sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + t.Errorf("Expected signature part to be valid base64: %v", err) + } + if len(sigBytes) != 32 { + t.Errorf("Expected signature to be 32 bytes (SHA256), got %d", len(sigBytes)) + } +} + +// TestGenerateState_NilConfig tests that nil config returns error +func TestGenerateState_NilConfig(t *testing.T) { + _, _, err := GenerateState(nil, "test_data") + + if err == nil { + t.Fatal("Expected error for nil config, got nil") + } + + if !strings.Contains(err.Error(), "cfg cannot be nil") { + t.Errorf("Expected error message about nil config, got: %v", err) + } +} + +// TestGenerateState_EmptyPrivateKey tests that empty private key returns error +func TestGenerateState_EmptyPrivateKey(t *testing.T) { + cfg := &Config{PrivateKey: ""} + _, _, err := GenerateState(cfg, "test_data") + + if err == nil { + t.Fatal("Expected error for empty private key, got nil") + } + + if !strings.Contains(err.Error(), "private key cannot be empty") { + t.Errorf("Expected error message about empty private key, got: %v", err) + } +} + +// TestGenerateState_EmptyData tests that empty data returns error +func TestGenerateState_EmptyData(t *testing.T) { + cfg := testConfig() + _, _, err := GenerateState(cfg, "") + + if err == nil { + t.Fatal("Expected error for empty data, got nil") + } + + if !strings.Contains(err.Error(), "data cannot be empty") { + t.Errorf("Expected error message about empty data, got: %v", err) + } +} + +// TestGenerateState_Randomness tests that multiple calls generate different states +func TestGenerateState_Randomness(t *testing.T) { + cfg := testConfig() + data := "same_data" + + state1, _, err1 := GenerateState(cfg, data) + state2, _, err2 := GenerateState(cfg, data) + + if err1 != nil || err2 != nil { + t.Fatalf("Unexpected errors: %v, %v", err1, err2) + } + + if state1 == state2 { + t.Error("Expected different states for multiple calls, got identical states") + } +} + +// TestGenerateState_DifferentData tests states with different data payloads +func TestGenerateState_DifferentData(t *testing.T) { + cfg := testConfig() + + testCases := []string{ + "simple", + "with-dashes", + "with_underscores", + "123456789", + "MixedCase123", + } + + for _, data := range testCases { + t.Run(data, func(t *testing.T) { + state, userAgentKey, err := GenerateState(cfg, data) + + if err != nil { + t.Fatalf("Unexpected error for data '%s': %v", data, err) + } + + if !strings.HasPrefix(state, data+".") { + t.Errorf("Expected state to start with '%s.', got: %s", data, state) + } + + if len(userAgentKey) != 32 { + t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey)) + } + }) + } +} + +// TestVerifyState_Success tests the happy path of state verification +func TestVerifyState_Success(t *testing.T) { + cfg := testConfig() + data := "test_data" + + // Generate state + state, userAgentKey, err := GenerateState(cfg, data) + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Verify state + extractedData, err := VerifyState(cfg, state, userAgentKey) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if extractedData != data { + t.Errorf("Expected extracted data to be '%s', got '%s'", data, extractedData) + } +} + +// TestVerifyState_NilConfig tests that nil config returns error +func TestVerifyState_NilConfig(t *testing.T) { + _, err := VerifyState(nil, "state", []byte("key")) + + if err == nil { + t.Fatal("Expected error for nil config, got nil") + } + + if !strings.Contains(err.Error(), "cfg cannot be nil") { + t.Errorf("Expected error message about nil config, got: %v", err) + } +} + +// TestVerifyState_EmptyPrivateKey tests that empty private key returns error +func TestVerifyState_EmptyPrivateKey(t *testing.T) { + cfg := &Config{PrivateKey: ""} + _, err := VerifyState(cfg, "state", []byte("key")) + + if err == nil { + t.Fatal("Expected error for empty private key, got nil") + } + + if !strings.Contains(err.Error(), "private key cannot be empty") { + t.Errorf("Expected error message about empty private key, got: %v", err) + } +} + +// TestVerifyState_EmptyState tests that empty state returns error +func TestVerifyState_EmptyState(t *testing.T) { + cfg := testConfig() + _, err := VerifyState(cfg, "", []byte("key")) + + if err == nil { + t.Fatal("Expected error for empty state, got nil") + } + + if !strings.Contains(err.Error(), "state cannot be empty") { + t.Errorf("Expected error message about empty state, got: %v", err) + } +} + +// TestVerifyState_EmptyUserAgentKey tests that empty userAgentKey returns error +func TestVerifyState_EmptyUserAgentKey(t *testing.T) { + cfg := testConfig() + _, err := VerifyState(cfg, "data.random.signature", []byte{}) + + if err == nil { + t.Fatal("Expected error for empty userAgentKey, got nil") + } + + if !strings.Contains(err.Error(), "userAgentKey cannot be empty") { + t.Errorf("Expected error message about empty userAgentKey, got: %v", err) + } +} + +// TestVerifyState_WrongUserAgentKey tests MITM protection +func TestVerifyState_WrongUserAgentKey(t *testing.T) { + cfg := testConfig() + + // Generate first state + state, _, err := GenerateState(cfg, "test_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Generate a different userAgentKey + _, wrongKey, err := GenerateState(cfg, "other_data") + if err != nil { + t.Fatalf("Failed to generate second state: %v", err) + } + + // Try to verify with wrong key + _, err = VerifyState(cfg, state, wrongKey) + + if err == nil { + t.Error("Expected error for invalid signature") + } + + if !strings.Contains(err.Error(), "invalid state signature") { + t.Errorf("Expected error about invalid signature, got: %v", err) + } +} + +// TestVerifyState_TamperedData tests tampering detection +func TestVerifyState_TamperedData(t *testing.T) { + cfg := testConfig() + + // Generate state + state, userAgentKey, err := GenerateState(cfg, "original_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Tamper with the data portion + parts := strings.Split(state, ".") + parts[0] = "tampered_data" + tamperedState := strings.Join(parts, ".") + + // Try to verify tampered state + _, err = VerifyState(cfg, tamperedState, userAgentKey) + + if err == nil { + t.Error("Expected error for tampered state") + } +} + +// TestVerifyState_TamperedRandom tests tampering with random portion +func TestVerifyState_TamperedRandom(t *testing.T) { + cfg := testConfig() + + // Generate state + state, userAgentKey, err := GenerateState(cfg, "test_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Tamper with the random portion + parts := strings.Split(state, ".") + parts[1] = base64.RawURLEncoding.EncodeToString([]byte("tampered_random_value_here12")) + tamperedState := strings.Join(parts, ".") + + // Try to verify tampered state + _, err = VerifyState(cfg, tamperedState, userAgentKey) + + if err == nil { + t.Error("Expected error for tampered state") + } +} + +// TestVerifyState_TamperedSignature tests tampering with signature +func TestVerifyState_TamperedSignature(t *testing.T) { + cfg := testConfig() + + // Generate state + state, userAgentKey, err := GenerateState(cfg, "test_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Tamper with the signature portion + parts := strings.Split(state, ".") + // Create a different valid base64 string + parts[2] = base64.RawURLEncoding.EncodeToString(sha256.New().Sum([]byte("tampered"))) + tamperedState := strings.Join(parts, ".") + + // Try to verify tampered state + _, err = VerifyState(cfg, tamperedState, userAgentKey) + + if err == nil { + t.Error("Expected error for tampered signature") + } +} + +// TestVerifyState_MalformedState_TwoParts tests state with only 2 parts +func TestVerifyState_MalformedState_TwoParts(t *testing.T) { + cfg := testConfig() + malformedState := "data.random" + + _, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890")) + + if err == nil { + t.Fatal("Expected error for malformed state") + } + + if !strings.Contains(err.Error(), "must have exactly 3 parts") { + t.Errorf("Expected error about incorrect number of parts, got: %v", err) + } +} + +// TestVerifyState_MalformedState_FourParts tests state with 4 parts +func TestVerifyState_MalformedState_FourParts(t *testing.T) { + cfg := testConfig() + malformedState := "data.random.signature.extra" + + _, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890")) + + if err == nil { + t.Fatal("Expected error for malformed state") + } + + if !strings.Contains(err.Error(), "must have exactly 3 parts") { + t.Errorf("Expected error about incorrect number of parts, got: %v", err) + } +} + +// TestVerifyState_EmptyStateParts tests state with empty parts +func TestVerifyState_EmptyStateParts(t *testing.T) { + cfg := testConfig() + testCases := []struct { + name string + state string + }{ + {"empty data", ".random.signature"}, + {"empty random", "data..signature"}, + {"empty signature", "data.random."}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := VerifyState(cfg, tc.state, []byte("key123456789012345678901234567890")) + + if err == nil { + t.Fatal("Expected error for state with empty parts") + } + + if !strings.Contains(err.Error(), "state parts cannot be empty") { + t.Errorf("Expected error about empty parts, got: %v", err) + } + }) + } +} + +// TestVerifyState_InvalidBase64Signature tests state with invalid base64 in signature +func TestVerifyState_InvalidBase64Signature(t *testing.T) { + cfg := testConfig() + invalidState := "data.random.invalid@base64!" + + _, err := VerifyState(cfg, invalidState, []byte("key123456789012345678901234567890")) + + if err == nil { + t.Fatal("Expected error for invalid base64 signature") + } + + if !strings.Contains(err.Error(), "failed to decode") { + t.Errorf("Expected error about decoding signature, got: %v", err) + } +} + +// TestVerifyState_DifferentPrivateKey tests that different private keys fail verification +func TestVerifyState_DifferentPrivateKey(t *testing.T) { + cfg1 := &Config{PrivateKey: "private_key_1"} + cfg2 := &Config{PrivateKey: "private_key_2"} + + // Generate with first config + state, userAgentKey, err := GenerateState(cfg1, "test_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Try to verify with second config + _, err = VerifyState(cfg2, state, userAgentKey) + + if err == nil { + t.Error("Expected error for mismatched private key") + } +} + +// TestRoundTrip tests complete round trip with various data payloads +func TestRoundTrip(t *testing.T) { + cfg := testConfig() + + testCases := []string{ + "simple", + "with-dashes-and-numbers-123", + "MixedCaseData", + "user_token_abc123", + "link_resource_xyz789", + } + + for _, data := range testCases { + t.Run(data, func(t *testing.T) { + // Generate + state, userAgentKey, err := GenerateState(cfg, data) + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Verify + extractedData, err := VerifyState(cfg, state, userAgentKey) + if err != nil { + t.Fatalf("Failed to verify state: %v", err) + } + + if extractedData != data { + t.Errorf("Expected extracted data '%s', got '%s'", data, extractedData) + } + }) + } +} + +// TestConcurrentGeneration tests that concurrent state generation works correctly +func TestConcurrentGeneration(t *testing.T) { + cfg := testConfig() + data := "concurrent_test" + + const numGoroutines = 10 + results := make(chan string, numGoroutines) + errors := make(chan error, numGoroutines) + + // Generate states concurrently + for range numGoroutines { + go func() { + state, userAgentKey, err := GenerateState(cfg, data) + if err != nil { + errors <- err + return + } + + // Verify immediately + _, verifyErr := VerifyState(cfg, state, userAgentKey) + if verifyErr != nil { + errors <- verifyErr + return + } + + results <- state + }() + } + + // Collect results + states := make(map[string]bool) + for range numGoroutines { + select { + case state := <-results: + if states[state] { + t.Errorf("Duplicate state generated: %s", state) + } + states[state] = true + case err := <-errors: + t.Errorf("Concurrent generation error: %v", err) + } + } + + if len(states) != numGoroutines { + t.Errorf("Expected %d unique states, got %d", numGoroutines, len(states)) + } +} + +// TestStateFormatCompatibility ensures state is URL-safe +func TestStateFormatCompatibility(t *testing.T) { + cfg := testConfig() + data := "url_safe_test" + + state, _, err := GenerateState(cfg, data) + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Check that state doesn't contain characters that need URL encoding + unsafeChars := []string{"+", "/", "=", " ", "&", "?", "#"} + for _, char := range unsafeChars { + if strings.Contains(state, char) { + t.Errorf("State contains URL-unsafe character '%s': %s", char, state) + } + } +} + +// TestMITM_AttackerCannotSubstituteState verifies MITM protection actually works +// An attacker obtains their own valid state but tries to use it with victim's session +func TestMITM_AttackerCannotSubstituteState(t *testing.T) { + cfg := testConfig() + + // Victim generates a state for their login + victimState, victimKey, err := GenerateState(cfg, "victim_data") + if err != nil { + t.Fatalf("Failed to generate victim state: %v", err) + } + + // Attacker generates their own valid state (they can request this from the server) + attackerState, attackerKey, err := GenerateState(cfg, "attacker_data") + if err != nil { + t.Fatalf("Failed to generate attacker state: %v", err) + } + + // Both states should be valid on their own + _, err = VerifyState(cfg, victimState, victimKey) + if err != nil { + t.Fatalf("Victim state should be valid: err=%v", err) + } + + _, err = VerifyState(cfg, attackerState, attackerKey) + if err != nil { + t.Fatalf("Attacker state should be valid: err=%v", err) + } + + // MITM Attack Scenario 1: Attacker substitutes their state but victim has their cookie + // This should FAIL because attackerState was signed with attackerKey, not victimKey + _, err = VerifyState(cfg, attackerState, victimKey) + if err == nil { + t.Error("Expected error when attacker substitutes state") + } + + // MITM Attack Scenario 2: Attacker uses victim's state but has their own cookie + // This should also FAIL + _, err = VerifyState(cfg, victimState, attackerKey) + if err == nil { + t.Error("Expected error when attacker uses victim's state") + } + + // The key insight: even though both states are "valid", they are bound to their respective cookies + // An attacker cannot mix and match states and cookies + t.Log("✓ MITM protection verified: States are cryptographically bound to their userAgentKey cookies") +} + +// TestCSRF_AttackerCannotForgeState verifies CSRF protection +// An attacker tries to forge a state parameter without knowing the private key +func TestCSRF_AttackerCannotForgeState(t *testing.T) { + cfg := testConfig() + + // Attacker doesn't know the private key, but tries to forge a state + // They might try to construct: "malicious_data.random.signature" + + // Attempt 1: Use a random signature + randomSig := base64.RawURLEncoding.EncodeToString([]byte("random_signature_attempt_12345678")) + forgedState1 := "malicious_data.somefakerandom." + randomSig + + // Generate a real userAgentKey (attacker might try to get this) + _, realKey, err := GenerateState(cfg, "legitimate_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Try to verify forged state + _, err = VerifyState(cfg, forgedState1, realKey) + if err == nil { + t.Error("CSRF VULNERABILITY: Attacker forged a valid state without private key!") + } + + // Attempt 2: Attacker tries to compute signature without private key + // They use: SHA256(data + "." + random + userAgentKey) - missing privateKey + attackerPayload := "malicious_data.fakerandom" + base64.RawURLEncoding.EncodeToString(realKey) + hash := sha256.Sum256([]byte(attackerPayload)) + attackerSig := base64.RawURLEncoding.EncodeToString(hash[:]) + forgedState2 := "malicious_data.fakerandom." + attackerSig + + _, err = VerifyState(cfg, forgedState2, realKey) + if err == nil { + t.Error("CSRF VULNERABILITY: Attacker forged valid state without private key!") + } + + t.Log("✓ CSRF protection verified: Cannot forge state without private key") +} + +// TestTampering_SignatureDetectsAllModifications verifies tamper detection +func TestTampering_SignatureDetectsAllModifications(t *testing.T) { + cfg := testConfig() + + // Generate a valid state + originalState, userAgentKey, err := GenerateState(cfg, "original_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Verify original is valid + data, err := VerifyState(cfg, originalState, userAgentKey) + if err != nil || data != "original_data" { + t.Fatalf("Original state should be valid") + } + + parts := strings.Split(originalState, ".") + + // Test 1: Attacker modifies data but keeps signature + tamperedState := "modified_data." + parts[1] + "." + parts[2] + _, err = VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Error("TAMPER VULNERABILITY: Modified data not detected!") + } + + // Test 2: Attacker modifies random but keeps signature + newRandom := base64.RawURLEncoding.EncodeToString([]byte("new_random_value_32bytes_long!!")) + tamperedState = parts[0] + "." + newRandom + "." + parts[2] + _, err = VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Error("TAMPER VULNERABILITY: Modified random not detected!") + } + + // Test 3: Attacker tries to recompute signature but doesn't have privateKey + // They compute: SHA256(modified_data + "." + random + userAgentKey) + attackerPayload := "modified_data." + parts[1] + base64.RawURLEncoding.EncodeToString(userAgentKey) + hash := sha256.Sum256([]byte(attackerPayload)) + attackerSig := base64.RawURLEncoding.EncodeToString(hash[:]) + tamperedState = "modified_data." + parts[1] + "." + attackerSig + _, err = VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Error("TAMPER VULNERABILITY: Attacker recomputed signature without private key!") + } + + // Test 4: Single bit flip in signature + sigBytes, _ := base64.RawURLEncoding.DecodeString(parts[2]) + sigBytes[0] ^= 0x01 // Flip one bit + flippedSig := base64.RawURLEncoding.EncodeToString(sigBytes) + tamperedState = parts[0] + "." + parts[1] + "." + flippedSig + _, err = VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Error("TAMPER VULNERABILITY: Single bit flip in signature not detected!") + } + + t.Log("✓ Tamper detection verified: All modifications to state are detected") +} + +// TestReplay_DifferentSessionsCannotReuseState verifies replay protection +func TestReplay_DifferentSessionsCannotReuseState(t *testing.T) { + cfg := testConfig() + + // Session 1: User initiates OAuth flow + state1, key1, err := GenerateState(cfg, "session1_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // State is valid for session 1 + _, err = VerifyState(cfg, state1, key1) + if err != nil { + t.Fatalf("State should be valid for session 1") + } + + // Session 2: Same user (or attacker) initiates a new OAuth flow + state2, key2, err := GenerateState(cfg, "session1_data") // same data + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Replay Attack: Try to use state1 with key2 + _, err = VerifyState(cfg, state1, key2) + if err == nil { + t.Error("REPLAY VULNERABILITY: State from session 1 was accepted in session 2!") + } + + // Even with same data, each session should have unique state+key binding + if state1 == state2 { + t.Error("REPLAY VULNERABILITY: Same data produces identical states!") + } + + t.Log("✓ Replay protection verified: States are bound to specific session cookies") +} + +// TestConstantTimeComparison verifies that signature comparison is timing-safe +// This is a behavioral test - we can't easily test timing, but we can verify the function is used +func TestConstantTimeComparison_IsUsed(t *testing.T) { + cfg := testConfig() + + // Generate valid state + state, userAgentKey, err := GenerateState(cfg, "test") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Create states with signatures that differ at different positions + parts := strings.Split(state, ".") + originalSig, _ := base64.RawURLEncoding.DecodeString(parts[2]) + + testCases := []struct { + name string + position int + }{ + {"first byte differs", 0}, + {"middle byte differs", 16}, + {"last byte differs", 31}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create signature that differs at specific position + tamperedSig := make([]byte, len(originalSig)) + copy(tamperedSig, originalSig) + tamperedSig[tc.position] ^= 0xFF // Flip all bits + + tamperedSigStr := base64.RawURLEncoding.EncodeToString(tamperedSig) + tamperedState := parts[0] + "." + parts[1] + "." + tamperedSigStr + + // All should fail verification + _, err := VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Errorf("Tampered signature at position %d should be invalid", tc.position) + } + + // If constant-time comparison is NOT used, early differences would return faster + // While we can't easily test timing here, we verify all positions fail equally + }) + } + + t.Log("✓ Constant-time comparison: All signature positions validated equally") + t.Log(" Note: crypto/subtle.ConstantTimeCompare is used in implementation") +} + +// TestPrivateKey_IsCriticalToSecurity verifies private key is essential +func TestPrivateKey_IsCriticalToSecurity(t *testing.T) { + cfg1 := &Config{PrivateKey: "secret_key_1"} + cfg2 := &Config{PrivateKey: "secret_key_2"} + + // Generate state with key1 + state, userAgentKey, err := GenerateState(cfg1, "data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Should verify with key1 + _, err = VerifyState(cfg1, state, userAgentKey) + if err != nil { + t.Fatalf("State should be valid with correct private key") + } + + // Should NOT verify with key2 (different private key) + _, err = VerifyState(cfg2, state, userAgentKey) + if err == nil { + t.Error("SECURITY VULNERABILITY: State verified with different private key!") + } + + // This proves that the private key is cryptographically involved in the signature + t.Log("✓ Private key security verified: Different keys produce incompatible signatures") +} + +// TestUserAgentKey_ProperlyIntegratedInSignature verifies userAgentKey is in payload +func TestUserAgentKey_ProperlyIntegratedInSignature(t *testing.T) { + cfg := testConfig() + + // Generate two states with same data but different userAgentKeys (implicit) + state1, key1, err := GenerateState(cfg, "same_data") + if err != nil { + t.Fatalf("Failed to generate state1: %v", err) + } + + state2, key2, err := GenerateState(cfg, "same_data") + if err != nil { + t.Fatalf("Failed to generate state2: %v", err) + } + + // The states should be different even with same data (different random and keys) + if state1 == state2 { + t.Error("States should differ due to different random values") + } + + // Each state should only verify with its own key + _, err1 := VerifyState(cfg, state1, key1) + _, err2 := VerifyState(cfg, state2, key2) + + if err1 != nil || err2 != nil { + t.Fatal("States should be valid with their own keys") + } + + // Cross-verification should fail + _, err1 = VerifyState(cfg, state1, key2) + _, err2 = VerifyState(cfg, state2, key1) + + if err1 == nil || err2 == nil { + t.Error("SECURITY VULNERABILITY: userAgentKey not properly integrated in signature!") + } + + t.Log("✓ UserAgentKey integration verified: Each state bound to its specific key") +} diff --git a/scripts/README.md b/scripts/README.md deleted file mode 100644 index ae6b33d..0000000 --- a/scripts/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# Scripts - -## generate-css-sources.sh - -Automatically generates the `pkg/embedfs/files/css/input.css` file with `@source` directives for all `.templ` files in the project. - -### Why is this needed? - -Tailwind CSS v4 requires explicit `@source` directives to know which files to scan for utility classes. Glob patterns like `**/*.templ` don't work in `@source` directives, so each file must be listed individually. - -This script: -1. Finds all `.templ` files in the `internal/` directory -2. Generates `@source` directives with relative paths from the CSS file location -3. Adds your custom theme and utility classes - -### When does it run? - -The script runs automatically as part of: -- `make build` - Before building the CSS -- `make dev` - Before starting watch mode - -### Manual usage - -If you need to regenerate the sources manually: - -```bash -./scripts/generate-css-sources.sh -``` - -### Adding new template files - -When you add a new `.templ` file, you don't need to do anything special - just run `make build` or `make dev` and the script will automatically pick up the new file. diff --git a/scripts/generate-css-sources.sh b/scripts/generate-css-sources.sh deleted file mode 100755 index 4704473..0000000 --- a/scripts/generate-css-sources.sh +++ /dev/null @@ -1,140 +0,0 @@ -#!/bin/bash - -# Generate @source directives for all .templ files -# Paths are relative to pkg/embedfs/files/css/input.css - -INPUT_CSS="pkg/embedfs/files/css/input.css" - -# Start with the base imports -cat > "$INPUT_CSS" <<'CSSHEAD' -@import "tailwindcss"; - -CSSHEAD - -# Find all .templ files and add @source directives -find internal -name "*.templ" -type f | sort | while read -r file; do - # Convert to relative path from pkg/embedfs/files/css/ - rel_path="../../../../$file" - echo "@source \"$rel_path\";" >> "$INPUT_CSS" -done - -# Add the custom theme and utility classes -cat >> "$INPUT_CSS" <<'CSSBODY' - -[x-cloak] { - display: none !important; -} -@theme inline { - --color-rosewater: var(--rosewater); - --color-flamingo: var(--flamingo); - --color-pink: var(--pink); - --color-mauve: var(--mauve); - --color-red: var(--red); - --color-dark-red: var(--dark-red); - --color-maroon: var(--maroon); - --color-peach: var(--peach); - --color-yellow: var(--yellow); - --color-green: var(--green); - --color-teal: var(--teal); - --color-sky: var(--sky); - --color-sapphire: var(--sapphire); - --color-blue: var(--blue); - --color-lavender: var(--lavender); - --color-text: var(--text); - --color-subtext1: var(--subtext1); - --color-subtext0: var(--subtext0); - --color-overlay2: var(--overlay2); - --color-overlay1: var(--overlay1); - --color-overlay0: var(--overlay0); - --color-surface2: var(--surface2); - --color-surface1: var(--surface1); - --color-surface0: var(--surface0); - --color-base: var(--base); - --color-mantle: var(--mantle); - --color-crust: var(--crust); -} -:root { - --rosewater: hsl(11, 59%, 67%); - --flamingo: hsl(0, 60%, 67%); - --pink: hsl(316, 73%, 69%); - --mauve: hsl(266, 85%, 58%); - --red: hsl(347, 87%, 44%); - --dark-red: hsl(343, 50%, 82%); - --maroon: hsl(355, 76%, 59%); - --peach: hsl(22, 99%, 52%); - --yellow: hsl(35, 77%, 49%); - --green: hsl(109, 58%, 40%); - --teal: hsl(183, 74%, 35%); - --sky: hsl(197, 97%, 46%); - --sapphire: hsl(189, 70%, 42%); - --blue: hsl(220, 91%, 54%); - --lavender: hsl(231, 97%, 72%); - --text: hsl(234, 16%, 35%); - --subtext1: hsl(233, 13%, 41%); - --subtext0: hsl(233, 10%, 47%); - --overlay2: hsl(232, 10%, 53%); - --overlay1: hsl(231, 10%, 59%); - --overlay0: hsl(228, 11%, 65%); - --surface2: hsl(227, 12%, 71%); - --surface1: hsl(225, 14%, 77%); - --surface0: hsl(223, 16%, 83%); - --base: hsl(220, 23%, 95%); - --mantle: hsl(220, 22%, 92%); - --crust: hsl(220, 21%, 89%); -} - -.dark { - --rosewater: hsl(10, 56%, 91%); - --flamingo: hsl(0, 59%, 88%); - --pink: hsl(316, 72%, 86%); - --mauve: hsl(267, 84%, 81%); - --red: hsl(343, 81%, 75%); - --dark-red: hsl(316, 19%, 27%); - --maroon: hsl(350, 65%, 77%); - --peach: hsl(23, 92%, 75%); - --yellow: hsl(41, 86%, 83%); - --green: hsl(115, 54%, 76%); - --teal: hsl(170, 57%, 73%); - --sky: hsl(189, 71%, 73%); - --sapphire: hsl(199, 76%, 69%); - --blue: hsl(217, 92%, 76%); - --lavender: hsl(232, 97%, 85%); - --text: hsl(226, 64%, 88%); - --subtext1: hsl(227, 35%, 80%); - --subtext0: hsl(228, 24%, 72%); - --overlay2: hsl(228, 17%, 64%); - --overlay1: hsl(230, 13%, 55%); - --overlay0: hsl(231, 11%, 47%); - --surface2: hsl(233, 12%, 39%); - --surface1: hsl(234, 13%, 31%); - --surface0: hsl(237, 16%, 23%); - --base: hsl(240, 21%, 15%); - --mantle: hsl(240, 21%, 12%); - --crust: hsl(240, 23%, 9%); -} -.ubuntu-mono-regular { - font-family: "Ubuntu Mono", serif; - font-weight: 400; - font-style: normal; -} - -.ubuntu-mono-bold { - font-family: "Ubuntu Mono", serif; - font-weight: 700; - font-style: normal; -} - -.ubuntu-mono-regular-italic { - font-family: "Ubuntu Mono", serif; - font-weight: 400; - font-style: italic; -} - -.ubuntu-mono-bold-italic { - font-family: "Ubuntu Mono", serif; - font-weight: 700; - font-style: italic; -} -CSSBODY - -echo "Generated $INPUT_CSS with $(grep -c '@source' "$INPUT_CSS") source files"