diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..a888b37 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,327 @@ +# AGENTS.md - Developer Guide for oslstats + +This document provides guidelines for AI coding agents and developers working on the oslstats codebase. + +## Project Overview + +**Module**: `git.haelnorr.com/h/oslstats` +**Language**: Go 1.25.5 +**Architecture**: Web application with Discord OAuth, PostgreSQL database, templ templates +**Key Technologies**: Bun ORM, templ, TailwindCSS, custom golib libraries + +## Build, Test, and Development Commands + +### Building +```bash +# Full production build (tailwind → templ → go generate → go build) +make build + +# Build and run +make run + +# Clean build artifacts +make clean +``` + +### Development Mode +```bash +# Watch mode with hot reload (templ, air, tailwindcss in parallel) +make dev + +# Development server runs on: +# - Proxy: http://localhost:3000 (use this) +# - App: http://localhost:3333 (internal) +``` + +### Testing +```bash +# Run all tests +go test ./... + +# Run tests for a specific package +go test ./pkg/oauth + +# Run a single test function +go test ./pkg/oauth -run TestGenerateState_Success + +# Run tests with verbose output +go test -v ./pkg/oauth + +# Run tests with coverage +go test -cover ./... +go test -coverprofile=coverage.out ./... +go tool cover -html=coverage.out +``` + +### Database +```bash +# Run migrations +make migrate +# OR +./bin/oslstats --migrate +``` + +### Configuration Management +```bash +# Generate .env template file +make genenv +# OR with custom output: make genenv OUT=.env.example + +# Show environment variable documentation +make envdoc + +# Show current environment values +make showenv +``` + +## Code Style Guidelines + +### Import Organization +Organize imports in **3 groups** separated by blank lines: + +```go +import ( + // 1. Standard library + "context" + "net/http" + "fmt" + + // 2. External dependencies + "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" + "github.com/uptrace/bun" + + // 3. Internal packages + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/pkg/oauth" +) +``` + +### Naming Conventions + +**Variables**: +- Local: `camelCase` (userAgentKey, httpServer, dbConn) +- Exported: `PascalCase` (Config, User, Token) +- Common abbreviations: `cfg`, `ctx`, `tx`, `db`, `err`, `w`, `r` + +**Functions**: +- Exported: `PascalCase` (GetConfig, NewStore, GenerateState) +- Private: `camelCase` (throwError, shouldShowDetails, loadModels) +- HTTP handlers: Return `http.Handler`, use dependency injection pattern +- Database functions: Use `bun.Tx` as parameter for transactions + +**Types**: +- Structs/Interfaces: `PascalCase` (Config, User, OAuthSession) +- Use `-er` suffix for interfaces (implied from usage) + +**Files**: +- Prefer single word: `config.go`, `oauth.go`, `errors.go` +- Use snake_case if needed: `discord_tokens.go`, `state_test.go` +- Test files: `*_test.go` alongside source files + +### Error Handling + +**Always wrap errors** with context using `github.com/pkg/errors`: + +```go +if err != nil { + return errors.Wrap(err, "operation_name") +} +``` + +**Validate inputs at function start**: +```go +func DoSomething(cfg *Config, data string) error { + if cfg == nil { + return errors.New("cfg cannot be nil") + } + if data == "" { + return errors.New("data cannot be empty") + } + // ... rest of function +} +``` + +**HTTP error helpers** (in handlers package): +- `throwInternalServiceError(s, w, r, msg, err)` - 500 errors +- `throwBadRequest(s, w, r, msg, err)` - 400 errors +- `throwForbidden(s, w, r, msg, err)` - 403 errors (normal) +- `throwForbiddenSecurity(s, w, r, msg, err)` - 403 security violations (WARN level) +- `throwUnauthorized(s, w, r, msg, err)` - 401 errors (normal) +- `throwUnauthorizedSecurity(s, w, r, msg, err)` - 401 security violations (WARN level) +- `throwNotFound(s, w, r, path)` - 404 errors + +### Common Patterns + +**HTTP Handler Pattern**: +```go +func HandlerName(server *hws.Server, deps ...) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // Handler logic here + }, + ) +} +``` + +**Database Operation Pattern**: +```go +func GetSomething(ctx context.Context, tx bun.Tx, id int) (*Result, error) { + result := new(Result) + err := tx.NewSelect(). + Model(result). + Where("id = ?", id). + Scan(ctx) + if err != nil { + if err.Error() == "sql: no rows in result set" { + return nil, nil // Return nil, nil for not found + } + return nil, errors.Wrap(err, "tx.Select") + } + return result, nil +} +``` + +**Setup Function Pattern** (returns instance, cleanup func, error): +```go +func setupSomething(ctx context.Context, cfg *Config) (*Type, func() error, error) { + instance := newInstance() + + err := configure(instance) + if err != nil { + return nil, nil, errors.Wrap(err, "configure") + } + + return instance, instance.Close, nil +} +``` + +**Configuration Pattern** (using ezconf): +```go +type Config struct { + Field string // ENV FIELD_NAME: Description (required/default: value) +} + +func ConfigFromEnv() (any, error) { + cfg := &Config{ + Field: env.String("FIELD_NAME", "default"), + } + // Validation here + return cfg, nil +} +``` + +### Formatting & Types + +**Formatting**: +- Use `gofmt` (standard Go formatting) +- No tabs vs spaces debate - Go uses tabs + +**Types**: +- Prefer explicit types over inference when it improves clarity +- Use struct tags for ORM and JSON marshaling: + ```go + type User struct { + bun.BaseModel `bun:"table:users,alias:u"` + ID int `bun:"id,pk,autoincrement"` + Username string `bun:"username,unique"` + AccessToken string `json:"access_token"` + } + ``` + +**Comments**: +- Document exported functions and types +- Use inline comments for ENV var documentation in Config structs +- Explain security-critical code flows + +### Testing + +**Test File Location**: Place `*_test.go` files alongside source files + +**Test Naming**: +```go +func TestFunctionName_Scenario(t *testing.T) +func TestGenerateState_Success(t *testing.T) +func TestVerifyState_WrongUserAgentKey(t *testing.T) +``` + +**Test Structure**: +- Use subtests with `t.Run()` for related scenarios +- Use table-driven tests for multiple similar cases +- Create helper functions for common setup (e.g., `testConfig()`) +- Test happy paths, error cases, edge cases, and security properties + +**Test Categories** (from pkg/oauth/state_test.go example): +1. Happy path tests +2. Error handling (nil params, empty fields, malformed input) +3. Security tests (MITM, CSRF, replay attacks, tampering) +4. Edge cases (concurrency, constant-time comparison) +5. Integration tests (round-trip verification) + +### Security + +**Critical Practices**: +- Use `crypto/subtle.ConstantTimeCompare` for cryptographic comparisons +- Implement CSRF protection via state tokens +- Store sensitive cookies as HttpOnly +- Use separate logging levels for security violations (WARN) +- Validate all inputs at function boundaries +- Use parameterized queries (Bun ORM handles this) +- Never commit secrets (.env, keys/ are gitignored) + +## Project Structure + +``` +oslstats/ +├── cmd/oslstats/ # Application entry point +│ ├── main.go # Entry point with flag parsing +│ ├── run.go # Server initialization & graceful shutdown +│ ├── httpserver.go # HTTP server setup +│ ├── routes.go # Route registration +│ ├── middleware.go # Middleware registration +│ ├── auth.go # Authentication setup +│ └── db.go # Database connection & migrations +├── internal/ # Private application code +│ ├── config/ # Configuration aggregation +│ ├── db/ # Database models & queries (Bun ORM) +│ ├── discord/ # Discord OAuth integration +│ ├── handlers/ # HTTP request handlers +│ ├── session/ # Session store (in-memory) +│ └── view/ # Templ templates +│ ├── component/ # Reusable UI components +│ ├── layout/ # Page layouts +│ └── page/ # Full pages +├── pkg/ # Reusable packages +│ ├── contexts/ # Context key definitions +│ ├── embedfs/ # Embedded static files +│ └── oauth/ # OAuth state management +├── bin/ # Compiled binaries (gitignored) +├── keys/ # Private keys (gitignored) +├── tmp/ # Air hot reload temp files (gitignored) +├── Makefile # Build automation +├── .air.toml # Hot reload configuration +└── go.mod # Go module definition +``` + +## Key Dependencies + +- **git.haelnorr.com/h/golib/*** - Custom libraries (env, ezconf, hlog, hws, hwsauth, cookies, jwt) +- **github.com/a-h/templ** - Type-safe HTML templating +- **github.com/uptrace/bun** - PostgreSQL ORM +- **github.com/bwmarrin/discordgo** - Discord API client +- **github.com/pkg/errors** - Error wrapping (use this, not fmt.Errorf) +- **github.com/joho/godotenv** - .env file loading + +## Notes for AI Agents + +1. **Never commit** .env files, keys/, or generated files (*_templ.go, output.css) +2. **Database operations** should use `bun.Tx` for transaction safety +3. **Templates** are written in templ, not Go html/template - run `templ generate` after changes +4. **Static files** are embedded via `//go:embed` - check pkg/embedfs/ +5. **Error messages** should be descriptive and use errors.Wrap for context +6. **Security is critical** - especially in OAuth flows (see pkg/oauth/state_test.go for examples) +7. **Air proxy** runs on port 3000 during development; app runs on 3333 +8. **Test coverage** is currently limited - prioritize testing security-critical code +9. **Configuration** uses ezconf pattern - see internal/*/ezconf.go files for examples +10. **Graceful shutdown** is implemented in cmd/oslstats/run.go - follow this pattern +11. When in plan mode, always use the interactive question tool if available diff --git a/cmd/oslstats/httpserver.go b/cmd/oslstats/httpserver.go index 8c0ff26..bd2763f 100644 --- a/cmd/oslstats/httpserver.go +++ b/cmd/oslstats/httpserver.go @@ -4,14 +4,15 @@ import ( "io/fs" "net/http" - "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/internal/handlers" - "git.haelnorr.com/h/oslstats/internal/session" - "git.haelnorr.com/h/golib/hlog" + "git.haelnorr.com/h/golib/hws" "github.com/pkg/errors" "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/store" ) func setupHttpServer( @@ -19,7 +20,8 @@ func setupHttpServer( config *config.Config, logger *hlog.Logger, bun *bun.DB, - store *session.Store, + store *store.Store, + discordAPI *discord.APIClient, ) (server *hws.Server, err error) { if staticFS == nil { return nil, errors.New("No filesystem provided") @@ -55,7 +57,7 @@ func setupHttpServer( return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths") } - err = addRoutes(httpServer, &fs, config, bun, auth, store) + err = addRoutes(httpServer, &fs, config, bun, auth, store, discordAPI) if err != nil { return nil, errors.Wrap(err, "addRoutes") } diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index e247234..e7161b7 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -5,13 +5,14 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hwsauth" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/internal/db" - "git.haelnorr.com/h/oslstats/internal/handlers" - "git.haelnorr.com/h/oslstats/internal/session" - "github.com/pkg/errors" "github.com/uptrace/bun" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/store" ) func addRoutes( @@ -20,7 +21,8 @@ func addRoutes( cfg *config.Config, conn *bun.DB, auth *hwsauth.Authenticator[*db.User, bun.Tx], - store *session.Store, + store *store.Store, + discordAPI *discord.APIClient, ) error { // Create the routes routes := []hws.Route{ @@ -37,12 +39,12 @@ func addRoutes( { Path: "/login", Method: hws.MethodGET, - Handler: auth.LogoutReq(handlers.Login(server, cfg)), + Handler: auth.LogoutReq(handlers.Login(server, cfg, store, discordAPI)), }, { Path: "/auth/callback", Method: hws.MethodGET, - Handler: auth.LogoutReq(handlers.Callback(server, conn, cfg, store)), + Handler: auth.LogoutReq(handlers.Callback(server, conn, cfg, store, discordAPI)), }, { Path: "/register", diff --git a/cmd/oslstats/run.go b/cmd/oslstats/run.go index 5dd1922..32395e8 100644 --- a/cmd/oslstats/run.go +++ b/cmd/oslstats/run.go @@ -9,10 +9,12 @@ import ( "time" "git.haelnorr.com/h/golib/hlog" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/internal/session" - "git.haelnorr.com/h/oslstats/pkg/embedfs" "github.com/pkg/errors" + + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/pkg/embedfs" ) // Initializes and runs the server @@ -44,10 +46,14 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error { // Setup session store logger.Debug().Msg("Setting up session store") - store := session.NewStore() + store := store.NewStore() + + // Setup Discord API client + logger.Debug().Msg("Setting up Discord API client") + discordAPI := discord.NewRateLimitedClient(logger) logger.Debug().Msg("Setting up HTTP server") - httpServer, err := setupHttpServer(&staticFS, config, logger, bun, store) + httpServer, err := setupHttpServer(&staticFS, config, logger, bun, store, discordAPI) if err != nil { return errors.Wrap(err, "setupHttpServer") } diff --git a/internal/discord/oauth.go b/internal/discord/oauth.go index 73644ea..780a71b 100644 --- a/internal/discord/oauth.go +++ b/internal/discord/oauth.go @@ -43,7 +43,7 @@ func GetOAuthLink(cfg *Config, state, trustedHost string) (string, error) { return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil } -func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) { +func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClient) (*Token, error) { if code == "" { return nil, errors.New("code cannot be empty") } @@ -53,6 +53,9 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) { if trustedHost == "" { return nil, errors.New("trustedHost cannot be empty") } + if apiClient == nil { + return nil, errors.New("apiClient cannot be nil") + } // Prepare form data data := url.Values{} data.Set("grant_type", "authorization_code") @@ -72,9 +75,8 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) { // Set basic auth (client_id and client_secret) req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) - // Execute request - client := &http.Client{} - resp, err := client.Do(req) + // Execute request with rate limit handling + resp, err := apiClient.Do(req) if err != nil { return nil, errors.Wrap(err, "failed to execute request") } @@ -96,13 +98,16 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) { return &tokenResp, nil } -func RefreshToken(cfg *Config, token *Token) (*Token, error) { +func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, error) { if token == nil { return nil, errors.New("token cannot be nil") } if cfg == nil { return nil, errors.New("config cannot be nil") } + if apiClient == nil { + return nil, errors.New("apiClient cannot be nil") + } // Prepare form data data := url.Values{} data.Set("grant_type", "refresh_token") @@ -121,9 +126,8 @@ func RefreshToken(cfg *Config, token *Token) (*Token, error) { // Set basic auth (client_id and client_secret) req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) - // Execute request - client := &http.Client{} - resp, err := client.Do(req) + // Execute request with rate limit handling + resp, err := apiClient.Do(req) if err != nil { return nil, errors.Wrap(err, "failed to execute request") } @@ -145,13 +149,16 @@ func RefreshToken(cfg *Config, token *Token) (*Token, error) { return &tokenResp, nil } -func RevokeToken(cfg *Config, token *Token) error { +func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error { if token == nil { return errors.New("token cannot be nil") } if cfg == nil { return errors.New("config cannot be nil") } + if apiClient == nil { + return errors.New("apiClient cannot be nil") + } // Prepare form data data := url.Values{} data.Set("token", token.AccessToken) @@ -170,9 +177,8 @@ func RevokeToken(cfg *Config, token *Token) error { // Set basic auth (client_id and client_secret) req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) - // Execute request - client := &http.Client{} - resp, err := client.Do(req) + // Execute request with rate limit handling + resp, err := apiClient.Do(req) if err != nil { return errors.Wrap(err, "failed to execute request") } diff --git a/internal/discord/ratelimit.go b/internal/discord/ratelimit.go new file mode 100644 index 0000000..4cde26d --- /dev/null +++ b/internal/discord/ratelimit.go @@ -0,0 +1,235 @@ +package discord + +import ( + "net" + "net/http" + "strconv" + "sync" + "time" + + "git.haelnorr.com/h/golib/hlog" + "github.com/pkg/errors" +) + +// RateLimitState tracks rate limit information for a specific bucket +type RateLimitState struct { + Remaining int // Requests remaining in current window + Limit int // Total requests allowed in window + Reset time.Time // When the rate limit resets + Bucket string // Discord's bucket identifier +} + +// APIClient is an HTTP client wrapper that handles Discord API rate limits +type APIClient struct { + client *http.Client + logger *hlog.Logger + mu sync.RWMutex + buckets map[string]*RateLimitState +} + +// NewRateLimitedClient creates a new Discord API client with rate limit handling +func NewRateLimitedClient(logger *hlog.Logger) *APIClient { + return &APIClient{ + client: &http.Client{Timeout: 30 * time.Second}, + logger: logger, + buckets: make(map[string]*RateLimitState), + } +} + +// Do executes an HTTP request with automatic rate limit handling +// It will wait if rate limits are about to be exceeded and retry once if a 429 is received +func (c *APIClient) Do(req *http.Request) (*http.Response, error) { + if req == nil { + return nil, errors.New("request cannot be nil") + } + + // Step 1: Check if we need to wait before making request + bucket := c.getBucketFromRequest(req) + if err := c.waitIfNeeded(bucket); err != nil { + return nil, err + } + + // Step 2: Execute request + resp, err := c.client.Do(req) + if err != nil { + // Check if it's a network timeout + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, errors.Wrap(err, "request timed out") + } + return nil, errors.Wrap(err, "http request failed") + } + + // Step 3: Update rate limit state from response headers + c.updateRateLimit(resp.Header) + + // Step 4: Handle 429 (rate limited) + if resp.StatusCode == http.StatusTooManyRequests { + resp.Body.Close() // Close original response + + retryAfter := c.parseRetryAfter(resp.Header) + + // No Retry-After header, can't retry safely + if retryAfter == 0 { + c.logger.Warn(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Msg("Rate limited but no Retry-After header provided") + return nil, errors.New("discord API rate limited but no Retry-After header provided") + } + + // Retry-After exceeds 30 second cap + if retryAfter > 30*time.Second { + c.logger.Warn(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Dur("retry_after", retryAfter). + Msg("Rate limited with Retry-After exceeding 30s cap, not retrying") + return nil, errors.Errorf( + "discord API rate limited (retry after %s exceeds 30s cap)", + retryAfter, + ) + } + + // Wait and retry + c.logger.Warn(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Dur("retry_after", retryAfter). + Msg("Rate limited, waiting before retry") + + time.Sleep(retryAfter) + + // Retry the request + resp, err = c.client.Do(req) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, errors.Wrap(err, "retry request timed out") + } + return nil, errors.Wrap(err, "retry request failed") + } + + // Update rate limit again after retry + c.updateRateLimit(resp.Header) + + // If STILL rate limited after retry, return error + if resp.StatusCode == http.StatusTooManyRequests { + resp.Body.Close() + c.logger.Error(). + Str("bucket", bucket). + Str("method", req.Method). + Str("path", req.URL.Path). + Msg("Still rate limited after retry, Discord may be experiencing issues") + return nil, errors.Errorf( + "discord API still rate limited after retry (waited %s), Discord may be experiencing issues", + retryAfter, + ) + } + } + + return resp, nil +} + +// getBucketFromRequest extracts or generates bucket ID from request +// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers +func (c *APIClient) getBucketFromRequest(req *http.Request) string { + return req.Method + ":" + req.URL.Path +} + +// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits +func (c *APIClient) waitIfNeeded(bucket string) error { + c.mu.RLock() + state, exists := c.buckets[bucket] + c.mu.RUnlock() + + if !exists { + return nil // No state yet, proceed + } + + now := time.Now() + + // If we have no remaining requests and reset hasn't occurred, wait + if state.Remaining == 0 && now.Before(state.Reset) { + waitDuration := time.Until(state.Reset) + // Add small buffer (100ms) to ensure reset has occurred + waitDuration += 100 * time.Millisecond + + if waitDuration > 0 { + c.logger.Debug(). + Str("bucket", bucket). + Dur("wait_duration", waitDuration). + Msg("Proactively waiting for rate limit reset") + time.Sleep(waitDuration) + } + } + + return nil +} + +// updateRateLimit parses response headers and updates bucket state +func (c *APIClient) updateRateLimit(headers http.Header) { + bucket := headers.Get("X-RateLimit-Bucket") + if bucket == "" { + return // No bucket info, can't track + } + + // Parse headers + limit := c.parseInt(headers.Get("X-RateLimit-Limit")) + remaining := c.parseInt(headers.Get("X-RateLimit-Remaining")) + resetAfter := c.parseFloat(headers.Get("X-RateLimit-Reset-After")) + + state := &RateLimitState{ + Bucket: bucket, + Limit: limit, + Remaining: remaining, + Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))), + } + + c.mu.Lock() + c.buckets[bucket] = state + c.mu.Unlock() + + // Log rate limit state for debugging + c.logger.Debug(). + Str("bucket", bucket). + Int("remaining", remaining). + Int("limit", limit). + Dur("reset_in", time.Until(state.Reset)). + Msg("Rate limit state updated") +} + +// parseRetryAfter extracts retry delay from Retry-After header +func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration { + retryAfter := headers.Get("Retry-After") + if retryAfter == "" { + return 0 + } + + // Discord returns seconds as float + seconds := c.parseFloat(retryAfter) + if seconds <= 0 { + return 0 + } + + return time.Duration(seconds * float64(time.Second)) +} + +// parseInt parses an integer from a header value, returns 0 on error +func (c *APIClient) parseInt(s string) int { + if s == "" { + return 0 + } + i, _ := strconv.Atoi(s) + return i +} + +// parseFloat parses a float from a header value, returns 0 on error +func (c *APIClient) parseFloat(s string) float64 { + if s == "" { + return 0 + } + f, _ := strconv.ParseFloat(s, 64) + return f +} diff --git a/internal/discord/ratelimit_test.go b/internal/discord/ratelimit_test.go new file mode 100644 index 0000000..33f6b57 --- /dev/null +++ b/internal/discord/ratelimit_test.go @@ -0,0 +1,459 @@ +package discord + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "git.haelnorr.com/h/golib/hlog" +) + +// testLogger creates a test logger for testing +func testLogger(t *testing.T) *hlog.Logger { + level, _ := hlog.LogLevel("debug") + cfg := &hlog.Config{ + LogLevel: level, + LogOutput: "console", + } + logger, err := hlog.NewLogger(cfg, io.Discard) + if err != nil { + t.Fatalf("failed to create test logger: %v", err) + } + return logger +} + +func TestNewRateLimitedClient(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + if client == nil { + t.Fatal("NewRateLimitedClient returned nil") + } + if client.client == nil { + t.Error("client.client is nil") + } + if client.logger == nil { + t.Error("client.logger is nil") + } + if client.buckets == nil { + t.Error("client.buckets map is nil") + } +} + +func TestAPIClient_Do_Success(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + // Mock server that returns success with rate limit headers + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("X-RateLimit-Limit", "5") + w.Header().Set("X-RateLimit-Remaining", "3") + w.Header().Set("X-RateLimit-Reset-After", "2.5") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + req, err := http.NewRequest("GET", server.URL+"/test", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Do() returned error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } + + // Check that rate limit state was updated + client.mu.RLock() + state, exists := client.buckets["test-bucket"] + client.mu.RUnlock() + + if !exists { + t.Fatal("rate limit state not stored") + } + if state.Remaining != 3 { + t.Errorf("expected remaining=3, got %d", state.Remaining) + } + if state.Limit != 5 { + t.Errorf("expected limit=5, got %d", state.Limit) + } +} + +func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + attemptCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + if attemptCount == 1 { + // First request: return 429 + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("Retry-After", "0.1") // 100ms + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + "error_description": "You are being rate limited", + }) + return + } + // Second request: success + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("X-RateLimit-Limit", "5") + w.Header().Set("X-RateLimit-Remaining", "4") + w.Header().Set("X-RateLimit-Reset-After", "2.5") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + start := time.Now() + resp, err := client.Do(req) + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("Do() returned error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200 after retry, got %d", resp.StatusCode) + } + + if attemptCount != 2 { + t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount) + } + + // Should have waited approximately 100ms + if elapsed < 100*time.Millisecond { + t.Errorf("expected delay of ~100ms, but took %v", elapsed) + } +} + +func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + attemptCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // Always return 429 + w.Header().Set("X-RateLimit-Bucket", "test-bucket") + w.Header().Set("Retry-After", "0.05") // 50ms + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + }) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + + if err == nil { + resp.Body.Close() + t.Fatal("Do() should have returned error after failed retry") + } + + if !strings.Contains(err.Error(), "still rate limited after retry") { + t.Errorf("expected 'still rate limited after retry' error, got: %v", err) + } + + if attemptCount != 2 { + t.Errorf("expected 2 attempts, got %d", attemptCount) + } +} + +func TestAPIClient_Do_RateLimitTooLong(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + }) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + start := time.Now() + resp, err := client.Do(req) + elapsed := time.Since(start) + + if err == nil { + resp.Body.Close() + t.Fatal("Do() should have returned error for Retry-After > 30s") + } + + if !strings.Contains(err.Error(), "exceeds 30s cap") { + t.Errorf("expected 'exceeds 30s cap' error, got: %v", err) + } + + // Should NOT have waited (immediate error) + if elapsed > 1*time.Second { + t.Errorf("should return immediately, but took %v", elapsed) + } +} + +func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return 429 but NO Retry-After header + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{ + "error": "rate_limited", + }) + })) + defer server.Close() + + req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + + if err == nil { + resp.Body.Close() + t.Fatal("Do() should have returned error when no Retry-After header") + } + + if !strings.Contains(err.Error(), "no Retry-After header") { + t.Errorf("expected 'no Retry-After header' error, got: %v", err) + } +} + +func TestAPIClient_UpdateRateLimit(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + headers := http.Header{} + headers.Set("X-RateLimit-Bucket", "global") + headers.Set("X-RateLimit-Limit", "10") + headers.Set("X-RateLimit-Remaining", "7") + headers.Set("X-RateLimit-Reset-After", "5.5") + + client.updateRateLimit(headers) + + client.mu.RLock() + state, exists := client.buckets["global"] + client.mu.RUnlock() + + if !exists { + t.Fatal("bucket state not created") + } + + if state.Bucket != "global" { + t.Errorf("expected bucket='global', got '%s'", state.Bucket) + } + if state.Limit != 10 { + t.Errorf("expected limit=10, got %d", state.Limit) + } + if state.Remaining != 7 { + t.Errorf("expected remaining=7, got %d", state.Remaining) + } + + // Check reset time is approximately 5.5 seconds from now + resetIn := time.Until(state.Reset) + if resetIn < 5*time.Second || resetIn > 6*time.Second { + t.Errorf("expected reset in ~5.5s, got %v", resetIn) + } +} + +func TestAPIClient_WaitIfNeeded(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + // Set up a bucket with 0 remaining and reset in future + bucket := "test-bucket" + client.mu.Lock() + client.buckets[bucket] = &RateLimitState{ + Bucket: bucket, + Limit: 5, + Remaining: 0, + Reset: time.Now().Add(200 * time.Millisecond), + } + client.mu.Unlock() + + start := time.Now() + err := client.waitIfNeeded(bucket) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("waitIfNeeded returned error: %v", err) + } + + // Should have waited ~200ms + 100ms buffer + if elapsed < 200*time.Millisecond { + t.Errorf("expected wait of ~300ms, but took %v", elapsed) + } + if elapsed > 500*time.Millisecond { + t.Errorf("waited too long: %v", elapsed) + } +} + +func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + // Set up a bucket with remaining requests + bucket := "test-bucket" + client.mu.Lock() + client.buckets[bucket] = &RateLimitState{ + Bucket: bucket, + Limit: 5, + Remaining: 3, + Reset: time.Now().Add(5 * time.Second), + } + client.mu.Unlock() + + start := time.Now() + err := client.waitIfNeeded(bucket) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("waitIfNeeded returned error: %v", err) + } + + // Should NOT wait (has remaining requests) + if elapsed > 10*time.Millisecond { + t.Errorf("should not wait when remaining > 0, but took %v", elapsed) + } +} + +func TestAPIClient_Do_Concurrent(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + requestCount := 0 + var mu sync.Mutex + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requestCount++ + count := requestCount + mu.Unlock() + + w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket") + w.Header().Set("X-RateLimit-Limit", "10") + w.Header().Set("X-RateLimit-Remaining", "5") + w.Header().Set("X-RateLimit-Reset-After", "1.0") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))}) + })) + defer server.Close() + + // Launch 10 concurrent requests + var wg sync.WaitGroup + errors := make(chan error, 10) + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + req, err := http.NewRequest("GET", server.URL+"/test", nil) + if err != nil { + errors <- err + return + } + + resp, err := client.Do(req) + if err != nil { + errors <- err + return + } + resp.Body.Close() + }() + } + + wg.Wait() + close(errors) + + // Check for any errors + for err := range errors { + t.Errorf("concurrent request failed: %v", err) + } + + // All requests should have completed + mu.Lock() + finalCount := requestCount + mu.Unlock() + + if finalCount != 10 { + t.Errorf("expected 10 requests, got %d", finalCount) + } + + // Check rate limit state is consistent (no data races) + client.mu.RLock() + state, exists := client.buckets["concurrent-bucket"] + client.mu.RUnlock() + + if !exists { + t.Fatal("bucket state not found after concurrent requests") + } + + // State should exist and be valid + if state.Limit != 10 { + t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit) + } +} + +func TestAPIClient_ParseRetryAfter(t *testing.T) { + logger := testLogger(t) + client := NewRateLimitedClient(logger) + + tests := []struct { + name string + header string + expected time.Duration + }{ + {"integer seconds", "2", 2 * time.Second}, + {"float seconds", "2.5", 2500 * time.Millisecond}, + {"zero", "0", 0}, + {"empty", "", 0}, + {"invalid", "abc", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers := http.Header{} + headers.Set("Retry-After", tt.header) + + result := client.parseRetryAfter(headers) + + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go index d0677aa..47239c5 100644 --- a/internal/handlers/callback.go +++ b/internal/handlers/callback.go @@ -6,18 +6,49 @@ import ( "time" "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" + "github.com/uptrace/bun" + "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" - "git.haelnorr.com/h/oslstats/internal/session" + "git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/pkg/oauth" - "github.com/pkg/errors" - "github.com/uptrace/bun" ) -func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *session.Store) http.Handler { +func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *store.Store, discordAPI *discord.APIClient) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + // Track callback redirect attempts + attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5) + + if exceeded { + // Build detailed error for logging + err := errors.Errorf( + "callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", + attempts, + track.IP, + track.UserAgent, + track.Path, + track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + ) + + // Clear the tracking entry + store.ClearRedirectTrack(r, "/callback") + + // Show error page + throwError( + server, + w, + r, + http.StatusBadRequest, + "OAuth callback failed: Too many redirect attempts. Please try logging in again.", + err, + "warn", + ) + return + } + state := r.URL.Query().Get("state") code := r.URL.Query().Get("code") if state == "" && code == "" { @@ -41,6 +72,10 @@ func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *sessi } return } + // SUCCESS POINT: State verified successfully + // Clear redirect tracking - OAuth callback completed successfully + store.ClearRedirectTrack(r, "/callback") + switch data { case "login": ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) @@ -51,7 +86,7 @@ func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *sessi return } defer tx.Rollback() - redirect, err := login(ctx, tx, cfg, w, r, code, store) + redirect, err := login(ctx, tx, cfg, w, r, code, store, discordAPI) if err != nil { throwInternalServiceError(server, w, r, "OAuth login failed", err) return @@ -122,9 +157,10 @@ func login( w http.ResponseWriter, r *http.Request, code string, - store *session.Store, + store *store.Store, + discordAPI *discord.APIClient, ) (func(), error) { - token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost) + token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost, discordAPI) if err != nil { return nil, errors.Wrap(err, "discord.AuthorizeWithCode") } diff --git a/internal/handlers/login.go b/internal/handlers/login.go index 3ff126f..e46312a 100644 --- a/internal/handlers/login.go +++ b/internal/handlers/login.go @@ -4,14 +4,47 @@ import ( "net/http" "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" + "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/pkg/oauth" ) -func Login(server *hws.Server, cfg *config.Config) http.Handler { +func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + // Track login redirect attempts + attempts, exceeded, track := st.TrackRedirect(r, "/login", 5) + + if exceeded { + // Build detailed error for logging + err := errors.Errorf( + "login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", + attempts, + track.IP, + track.UserAgent, + track.Path, + track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + ) + + // Clear the tracking entry + st.ClearRedirectTrack(r, "/login") + + // Show error page + throwError( + server, + w, + r, + http.StatusBadRequest, + "Login failed: Too many redirect attempts. Please clear your browser cookies and try again.", + err, + "warn", + ) + return + } + state, uak, err := oauth.GenerateState(cfg.OAuth, "login") if err != nil { throwInternalServiceError(server, w, r, "Failed to generate state token", err) @@ -24,6 +57,11 @@ func Login(server *hws.Server, cfg *config.Config) http.Handler { throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err) return } + + // SUCCESS POINT: OAuth link generated, redirecting to Discord + // Clear redirect tracking - user successfully initiated OAuth + st.ClearRedirectTrack(r, "/login") + http.Redirect(w, r, link, http.StatusSeeOther) }, ) diff --git a/internal/handlers/register.go b/internal/handlers/register.go index b871e45..d70691f 100644 --- a/internal/handlers/register.go +++ b/internal/handlers/register.go @@ -6,31 +6,63 @@ import ( "time" "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" + "github.com/uptrace/bun" + "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" - "git.haelnorr.com/h/oslstats/internal/session" + "git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/view/page" - "github.com/uptrace/bun" ) func Register( server *hws.Server, conn *bun.DB, cfg *config.Config, - store *session.Store, + store *store.Store, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + attempts, exceeded, track := store.TrackRedirect(r, "/register", 3) + + if exceeded { + err := errors.Errorf( + "registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t", + attempts, + track.IP, + track.UserAgent, + track.Path, + track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + cfg.HWSAuth.SSL, + ) + + store.ClearRedirectTrack(r, "/register") + + throwError( + server, + w, + r, + http.StatusBadRequest, + "Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.", + err, + "warn", + ) + return + } + sessionCookie, err := r.Cookie("registration_session") if err != nil { http.Redirect(w, r, "/login", http.StatusSeeOther) return } details, ok := store.GetRegistrationSession(sessionCookie.Value) + ok = false if !ok { http.Redirect(w, r, "/login", http.StatusSeeOther) return } + + store.ClearRedirectTrack(r, "/register") ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) defer cancel() tx, err := conn.BeginTx(ctx, nil) @@ -65,12 +97,11 @@ func IsUsernameUnique( server *hws.Server, conn *bun.DB, cfg *config.Config, - store *session.Store, + store *store.Store, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { username := r.FormValue("username") - // check if its unique ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) defer cancel() tx, err := conn.BeginTx(ctx, nil) diff --git a/internal/session/store.go b/internal/session/store.go deleted file mode 100644 index fe77e62..0000000 --- a/internal/session/store.go +++ /dev/null @@ -1,46 +0,0 @@ -package session - -import ( - "crypto/rand" - "encoding/base64" - "sync" - "time" -) - -type Store struct { - sessions sync.Map - cleanup *time.Ticker -} - -func NewStore() *Store { - s := &Store{ - cleanup: time.NewTicker(1 * time.Minute), - } - - // Background cleanup of expired sessions - go func() { - for range s.cleanup.C { - s.cleanupExpired() - } - }() - - return s -} - -func (s *Store) Delete(id string) { - s.sessions.Delete(id) -} -func (s *Store) cleanupExpired() { - s.sessions.Range(func(key, value any) bool { - session := value.(*RegistrationSession) - if time.Now().After(session.ExpiresAt) { - s.sessions.Delete(key) - } - return true - }) -} -func generateID() string { - b := make([]byte, 32) - rand.Read(b) - return base64.RawURLEncoding.EncodeToString(b) -} diff --git a/internal/session/newlogin.go b/internal/store/newlogin.go similarity index 98% rename from internal/session/newlogin.go rename to internal/store/newlogin.go index d1c60ba..0e62529 100644 --- a/internal/session/newlogin.go +++ b/internal/store/newlogin.go @@ -1,4 +1,4 @@ -package session +package store import ( "errors" diff --git a/internal/store/redirects.go b/internal/store/redirects.go new file mode 100644 index 0000000..e2e7852 --- /dev/null +++ b/internal/store/redirects.go @@ -0,0 +1,95 @@ +package store + +import ( + "net" + "net/http" + "strings" + "time" +) + +// getClientIP extracts the client IP address, checking X-Forwarded-For first +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header (comma-separated list, first is client) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the list + ips := strings.Split(xff, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // Fall back to RemoteAddr (format: "IP:port" or "[IPv6]:port") + // Use net.SplitHostPort to properly handle both IPv4 and IPv6 + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // If SplitHostPort fails, return as-is (shouldn't happen with valid RemoteAddr) + return r.RemoteAddr + } + return host +} + +// TrackRedirect increments the redirect counter for this IP+UA+Path combination +// Returns the current attempt count, whether limit was exceeded, and the track details +func (s *Store) TrackRedirect(r *http.Request, path string, maxAttempts int) (attempts int, exceeded bool, track *RedirectTrack) { + if r == nil { + return 0, false, nil + } + + ip := getClientIP(r) + userAgent := r.UserAgent() + key := redirectKey(ip, userAgent, path) + + now := time.Now() + expiresAt := now.Add(5 * time.Minute) + + // Try to load existing track + val, exists := s.redirectTracks.Load(key) + if exists { + track = val.(*RedirectTrack) + + // Check if expired + if now.After(track.ExpiresAt) { + // Expired, start fresh + track = &RedirectTrack{ + IP: ip, + UserAgent: userAgent, + Path: path, + Attempts: 1, + FirstSeen: now, + ExpiresAt: expiresAt, + } + s.redirectTracks.Store(key, track) + return 1, false, track + } + + // Increment existing + track.Attempts++ + track.ExpiresAt = expiresAt // Extend expiry + exceeded = track.Attempts >= maxAttempts + return track.Attempts, exceeded, track + } + + // Create new track + track = &RedirectTrack{ + IP: ip, + UserAgent: userAgent, + Path: path, + Attempts: 1, + FirstSeen: now, + ExpiresAt: expiresAt, + } + s.redirectTracks.Store(key, track) + return 1, false, track +} + +// ClearRedirectTrack removes a redirect tracking entry (called after successful completion) +func (s *Store) ClearRedirectTrack(r *http.Request, path string) { + if r == nil { + return + } + + ip := getClientIP(r) + userAgent := r.UserAgent() + key := redirectKey(ip, userAgent, path) + s.redirectTracks.Delete(key) +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..5620e58 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,80 @@ +package store + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "sync" + "time" +) + +// RedirectTrack represents a single redirect attempt tracking entry +type RedirectTrack struct { + IP string // Client IP (X-Forwarded-For aware) + UserAgent string // Full User-Agent string for debugging + Path string // Request path (without query params) + Attempts int // Number of redirect attempts + FirstSeen time.Time // When first redirect was tracked + ExpiresAt time.Time // When to clean up this entry +} + +type Store struct { + sessions sync.Map // key: string, value: *RegistrationSession + redirectTracks sync.Map // key: string, value: *RedirectTrack + cleanup *time.Ticker +} + +func NewStore() *Store { + s := &Store{ + cleanup: time.NewTicker(1 * time.Minute), + } + + // Background cleanup of expired sessions + go func() { + for range s.cleanup.C { + s.cleanupExpired() + } + }() + + return s +} + +func (s *Store) Delete(id string) { + s.sessions.Delete(id) +} +func (s *Store) cleanupExpired() { + now := time.Now() + + // Clean up expired registration sessions + s.sessions.Range(func(key, value any) bool { + session := value.(*RegistrationSession) + if now.After(session.ExpiresAt) { + s.sessions.Delete(key) + } + return true + }) + + // Clean up expired redirect tracks + s.redirectTracks.Range(func(key, value any) bool { + track := value.(*RedirectTrack) + if now.After(track.ExpiresAt) { + s.redirectTracks.Delete(key) + } + return true + }) +} +func generateID() string { + b := make([]byte, 32) + rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} + +// redirectKey generates a unique key for tracking redirects +// Uses IP + first 100 chars of UA + path as key (not hashed for debugging) +func redirectKey(ip, userAgent, path string) string { + ua := userAgent + if len(ua) > 100 { + ua = ua[:100] + } + return fmt.Sprintf("%s:%s:%s", ip, ua, path) +} diff --git a/pkg/embedfs/files/css/output.css b/pkg/embedfs/files/css/output.css index 022341c..0643362 100644 --- a/pkg/embedfs/files/css/output.css +++ b/pkg/embedfs/files/css/output.css @@ -232,18 +232,12 @@ .top-0 { top: calc(var(--spacing) * 0); } - .top-2 { - top: calc(var(--spacing) * 2); - } .top-4 { top: calc(var(--spacing) * 4); } .right-0 { right: calc(var(--spacing) * 0); } - .right-2 { - right: calc(var(--spacing) * 2); - } .bottom-0 { bottom: calc(var(--spacing) * 0); } @@ -253,18 +247,9 @@ .z-10 { z-index: 10; } - .float-left { - float: left; - } - .m-0 { - margin: calc(var(--spacing) * 0); - } .mx-auto { margin-inline: auto; } - .mt-1 { - margin-top: calc(var(--spacing) * 1); - } .mt-1\.5 { margin-top: calc(var(--spacing) * 1.5); } @@ -298,21 +283,12 @@ .mt-24 { margin-top: calc(var(--spacing) * 24); } - .mr-0 { - margin-right: calc(var(--spacing) * 0); - } - .mr-2 { - margin-right: calc(var(--spacing) * 2); - } .mr-5 { margin-right: calc(var(--spacing) * 5); } .mb-auto { margin-bottom: auto; } - .ml-0 { - margin-left: calc(var(--spacing) * 0); - } .ml-2 { margin-left: calc(var(--spacing) * 2); } @@ -322,9 +298,6 @@ .block { display: block; } - .contents { - display: contents; - } .flex { display: flex; } @@ -343,9 +316,6 @@ .inline-flex { display: inline-flex; } - .table { - display: table; - } .size-5 { width: calc(var(--spacing) * 5); height: calc(var(--spacing) * 5); @@ -396,9 +366,6 @@ .flex-1 { flex: 1; } - .border-collapse { - border-collapse: collapse; - } .translate-x-0 { --tw-translate-x: calc(var(--spacing) * 0); translate: var(--tw-translate-x) var(--tw-translate-y); @@ -413,9 +380,6 @@ .cursor-pointer { cursor: pointer; } - .resize { - resize: both; - } .flex-col { flex-direction: column; } @@ -670,9 +634,6 @@ .text-text { color: var(--text); } - .underline { - text-decoration-line: underline; - } .opacity-0 { opacity: 0%; } @@ -687,10 +648,6 @@ --tw-shadow: 0 1px 3px 0 var(--tw-shadow-color, rgb(0 0 0 / 0.1)), 0 1px 2px -1px var(--tw-shadow-color, rgb(0 0 0 / 0.1)); box-shadow: var(--tw-inset-shadow), var(--tw-inset-ring-shadow), var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow); } - .outline { - outline-style: var(--tw-outline-style); - outline-width: 1px; - } .transition { transition-property: color, background-color, border-color, outline-color, text-decoration-color, fill, stroke, --tw-gradient-from, --tw-gradient-via, --tw-gradient-to, opacity, box-shadow, transform, translate, scale, rotate, filter, -webkit-backdrop-filter, backdrop-filter, display, content-visibility, overlay, pointer-events; transition-timing-function: var(--tw-ease, var(--default-transition-timing-function)); @@ -1165,11 +1122,6 @@ inherits: false; initial-value: 0 0 #0000; } -@property --tw-outline-style { - syntax: "*"; - inherits: false; - initial-value: solid; -} @property --tw-duration { syntax: "*"; inherits: false; @@ -1205,7 +1157,6 @@ --tw-ring-offset-width: 0px; --tw-ring-offset-color: #fff; --tw-ring-offset-shadow: 0 0 #0000; - --tw-outline-style: solid; --tw-duration: initial; } }