added discord api limiting
This commit is contained in:
327
AGENTS.md
Normal file
327
AGENTS.md
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
# AGENTS.md - Developer Guide for oslstats
|
||||||
|
|
||||||
|
This document provides guidelines for AI coding agents and developers working on the oslstats codebase.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
**Module**: `git.haelnorr.com/h/oslstats`
|
||||||
|
**Language**: Go 1.25.5
|
||||||
|
**Architecture**: Web application with Discord OAuth, PostgreSQL database, templ templates
|
||||||
|
**Key Technologies**: Bun ORM, templ, TailwindCSS, custom golib libraries
|
||||||
|
|
||||||
|
## Build, Test, and Development Commands
|
||||||
|
|
||||||
|
### Building
|
||||||
|
```bash
|
||||||
|
# Full production build (tailwind → templ → go generate → go build)
|
||||||
|
make build
|
||||||
|
|
||||||
|
# Build and run
|
||||||
|
make run
|
||||||
|
|
||||||
|
# Clean build artifacts
|
||||||
|
make clean
|
||||||
|
```
|
||||||
|
|
||||||
|
### Development Mode
|
||||||
|
```bash
|
||||||
|
# Watch mode with hot reload (templ, air, tailwindcss in parallel)
|
||||||
|
make dev
|
||||||
|
|
||||||
|
# Development server runs on:
|
||||||
|
# - Proxy: http://localhost:3000 (use this)
|
||||||
|
# - App: http://localhost:3333 (internal)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
go test ./...
|
||||||
|
|
||||||
|
# Run tests for a specific package
|
||||||
|
go test ./pkg/oauth
|
||||||
|
|
||||||
|
# Run a single test function
|
||||||
|
go test ./pkg/oauth -run TestGenerateState_Success
|
||||||
|
|
||||||
|
# Run tests with verbose output
|
||||||
|
go test -v ./pkg/oauth
|
||||||
|
|
||||||
|
# Run tests with coverage
|
||||||
|
go test -cover ./...
|
||||||
|
go test -coverprofile=coverage.out ./...
|
||||||
|
go tool cover -html=coverage.out
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database
|
||||||
|
```bash
|
||||||
|
# Run migrations
|
||||||
|
make migrate
|
||||||
|
# OR
|
||||||
|
./bin/oslstats --migrate
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Management
|
||||||
|
```bash
|
||||||
|
# Generate .env template file
|
||||||
|
make genenv
|
||||||
|
# OR with custom output: make genenv OUT=.env.example
|
||||||
|
|
||||||
|
# Show environment variable documentation
|
||||||
|
make envdoc
|
||||||
|
|
||||||
|
# Show current environment values
|
||||||
|
make showenv
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code Style Guidelines
|
||||||
|
|
||||||
|
### Import Organization
|
||||||
|
Organize imports in **3 groups** separated by blank lines:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
// 1. Standard library
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
// 2. External dependencies
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
|
// 3. Internal packages
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Naming Conventions
|
||||||
|
|
||||||
|
**Variables**:
|
||||||
|
- Local: `camelCase` (userAgentKey, httpServer, dbConn)
|
||||||
|
- Exported: `PascalCase` (Config, User, Token)
|
||||||
|
- Common abbreviations: `cfg`, `ctx`, `tx`, `db`, `err`, `w`, `r`
|
||||||
|
|
||||||
|
**Functions**:
|
||||||
|
- Exported: `PascalCase` (GetConfig, NewStore, GenerateState)
|
||||||
|
- Private: `camelCase` (throwError, shouldShowDetails, loadModels)
|
||||||
|
- HTTP handlers: Return `http.Handler`, use dependency injection pattern
|
||||||
|
- Database functions: Use `bun.Tx` as parameter for transactions
|
||||||
|
|
||||||
|
**Types**:
|
||||||
|
- Structs/Interfaces: `PascalCase` (Config, User, OAuthSession)
|
||||||
|
- Use `-er` suffix for interfaces (implied from usage)
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- Prefer single word: `config.go`, `oauth.go`, `errors.go`
|
||||||
|
- Use snake_case if needed: `discord_tokens.go`, `state_test.go`
|
||||||
|
- Test files: `*_test.go` alongside source files
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
|
||||||
|
**Always wrap errors** with context using `github.com/pkg/errors`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "operation_name")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Validate inputs at function start**:
|
||||||
|
```go
|
||||||
|
func DoSomething(cfg *Config, data string) error {
|
||||||
|
if cfg == nil {
|
||||||
|
return errors.New("cfg cannot be nil")
|
||||||
|
}
|
||||||
|
if data == "" {
|
||||||
|
return errors.New("data cannot be empty")
|
||||||
|
}
|
||||||
|
// ... rest of function
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**HTTP error helpers** (in handlers package):
|
||||||
|
- `throwInternalServiceError(s, w, r, msg, err)` - 500 errors
|
||||||
|
- `throwBadRequest(s, w, r, msg, err)` - 400 errors
|
||||||
|
- `throwForbidden(s, w, r, msg, err)` - 403 errors (normal)
|
||||||
|
- `throwForbiddenSecurity(s, w, r, msg, err)` - 403 security violations (WARN level)
|
||||||
|
- `throwUnauthorized(s, w, r, msg, err)` - 401 errors (normal)
|
||||||
|
- `throwUnauthorizedSecurity(s, w, r, msg, err)` - 401 security violations (WARN level)
|
||||||
|
- `throwNotFound(s, w, r, path)` - 404 errors
|
||||||
|
|
||||||
|
### Common Patterns
|
||||||
|
|
||||||
|
**HTTP Handler Pattern**:
|
||||||
|
```go
|
||||||
|
func HandlerName(server *hws.Server, deps ...) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Handler logic here
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Database Operation Pattern**:
|
||||||
|
```go
|
||||||
|
func GetSomething(ctx context.Context, tx bun.Tx, id int) (*Result, error) {
|
||||||
|
result := new(Result)
|
||||||
|
err := tx.NewSelect().
|
||||||
|
Model(result).
|
||||||
|
Where("id = ?", id).
|
||||||
|
Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == "sql: no rows in result set" {
|
||||||
|
return nil, nil // Return nil, nil for not found
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(err, "tx.Select")
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Setup Function Pattern** (returns instance, cleanup func, error):
|
||||||
|
```go
|
||||||
|
func setupSomething(ctx context.Context, cfg *Config) (*Type, func() error, error) {
|
||||||
|
instance := newInstance()
|
||||||
|
|
||||||
|
err := configure(instance)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "configure")
|
||||||
|
}
|
||||||
|
|
||||||
|
return instance, instance.Close, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Configuration Pattern** (using ezconf):
|
||||||
|
```go
|
||||||
|
type Config struct {
|
||||||
|
Field string // ENV FIELD_NAME: Description (required/default: value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConfigFromEnv() (any, error) {
|
||||||
|
cfg := &Config{
|
||||||
|
Field: env.String("FIELD_NAME", "default"),
|
||||||
|
}
|
||||||
|
// Validation here
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Formatting & Types
|
||||||
|
|
||||||
|
**Formatting**:
|
||||||
|
- Use `gofmt` (standard Go formatting)
|
||||||
|
- No tabs vs spaces debate - Go uses tabs
|
||||||
|
|
||||||
|
**Types**:
|
||||||
|
- Prefer explicit types over inference when it improves clarity
|
||||||
|
- Use struct tags for ORM and JSON marshaling:
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
bun.BaseModel `bun:"table:users,alias:u"`
|
||||||
|
ID int `bun:"id,pk,autoincrement"`
|
||||||
|
Username string `bun:"username,unique"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Comments**:
|
||||||
|
- Document exported functions and types
|
||||||
|
- Use inline comments for ENV var documentation in Config structs
|
||||||
|
- Explain security-critical code flows
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
**Test File Location**: Place `*_test.go` files alongside source files
|
||||||
|
|
||||||
|
**Test Naming**:
|
||||||
|
```go
|
||||||
|
func TestFunctionName_Scenario(t *testing.T)
|
||||||
|
func TestGenerateState_Success(t *testing.T)
|
||||||
|
func TestVerifyState_WrongUserAgentKey(t *testing.T)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Test Structure**:
|
||||||
|
- Use subtests with `t.Run()` for related scenarios
|
||||||
|
- Use table-driven tests for multiple similar cases
|
||||||
|
- Create helper functions for common setup (e.g., `testConfig()`)
|
||||||
|
- Test happy paths, error cases, edge cases, and security properties
|
||||||
|
|
||||||
|
**Test Categories** (from pkg/oauth/state_test.go example):
|
||||||
|
1. Happy path tests
|
||||||
|
2. Error handling (nil params, empty fields, malformed input)
|
||||||
|
3. Security tests (MITM, CSRF, replay attacks, tampering)
|
||||||
|
4. Edge cases (concurrency, constant-time comparison)
|
||||||
|
5. Integration tests (round-trip verification)
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
**Critical Practices**:
|
||||||
|
- Use `crypto/subtle.ConstantTimeCompare` for cryptographic comparisons
|
||||||
|
- Implement CSRF protection via state tokens
|
||||||
|
- Store sensitive cookies as HttpOnly
|
||||||
|
- Use separate logging levels for security violations (WARN)
|
||||||
|
- Validate all inputs at function boundaries
|
||||||
|
- Use parameterized queries (Bun ORM handles this)
|
||||||
|
- Never commit secrets (.env, keys/ are gitignored)
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
oslstats/
|
||||||
|
├── cmd/oslstats/ # Application entry point
|
||||||
|
│ ├── main.go # Entry point with flag parsing
|
||||||
|
│ ├── run.go # Server initialization & graceful shutdown
|
||||||
|
│ ├── httpserver.go # HTTP server setup
|
||||||
|
│ ├── routes.go # Route registration
|
||||||
|
│ ├── middleware.go # Middleware registration
|
||||||
|
│ ├── auth.go # Authentication setup
|
||||||
|
│ └── db.go # Database connection & migrations
|
||||||
|
├── internal/ # Private application code
|
||||||
|
│ ├── config/ # Configuration aggregation
|
||||||
|
│ ├── db/ # Database models & queries (Bun ORM)
|
||||||
|
│ ├── discord/ # Discord OAuth integration
|
||||||
|
│ ├── handlers/ # HTTP request handlers
|
||||||
|
│ ├── session/ # Session store (in-memory)
|
||||||
|
│ └── view/ # Templ templates
|
||||||
|
│ ├── component/ # Reusable UI components
|
||||||
|
│ ├── layout/ # Page layouts
|
||||||
|
│ └── page/ # Full pages
|
||||||
|
├── pkg/ # Reusable packages
|
||||||
|
│ ├── contexts/ # Context key definitions
|
||||||
|
│ ├── embedfs/ # Embedded static files
|
||||||
|
│ └── oauth/ # OAuth state management
|
||||||
|
├── bin/ # Compiled binaries (gitignored)
|
||||||
|
├── keys/ # Private keys (gitignored)
|
||||||
|
├── tmp/ # Air hot reload temp files (gitignored)
|
||||||
|
├── Makefile # Build automation
|
||||||
|
├── .air.toml # Hot reload configuration
|
||||||
|
└── go.mod # Go module definition
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Dependencies
|
||||||
|
|
||||||
|
- **git.haelnorr.com/h/golib/*** - Custom libraries (env, ezconf, hlog, hws, hwsauth, cookies, jwt)
|
||||||
|
- **github.com/a-h/templ** - Type-safe HTML templating
|
||||||
|
- **github.com/uptrace/bun** - PostgreSQL ORM
|
||||||
|
- **github.com/bwmarrin/discordgo** - Discord API client
|
||||||
|
- **github.com/pkg/errors** - Error wrapping (use this, not fmt.Errorf)
|
||||||
|
- **github.com/joho/godotenv** - .env file loading
|
||||||
|
|
||||||
|
## Notes for AI Agents
|
||||||
|
|
||||||
|
1. **Never commit** .env files, keys/, or generated files (*_templ.go, output.css)
|
||||||
|
2. **Database operations** should use `bun.Tx` for transaction safety
|
||||||
|
3. **Templates** are written in templ, not Go html/template - run `templ generate` after changes
|
||||||
|
4. **Static files** are embedded via `//go:embed` - check pkg/embedfs/
|
||||||
|
5. **Error messages** should be descriptive and use errors.Wrap for context
|
||||||
|
6. **Security is critical** - especially in OAuth flows (see pkg/oauth/state_test.go for examples)
|
||||||
|
7. **Air proxy** runs on port 3000 during development; app runs on 3333
|
||||||
|
8. **Test coverage** is currently limited - prioritize testing security-critical code
|
||||||
|
9. **Configuration** uses ezconf pattern - see internal/*/ezconf.go files for examples
|
||||||
|
10. **Graceful shutdown** is implemented in cmd/oslstats/run.go - follow this pattern
|
||||||
|
11. When in plan mode, always use the interactive question tool if available
|
||||||
@@ -4,14 +4,15 @@ import (
|
|||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/handlers"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/session"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hlog"
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/handlers"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupHttpServer(
|
func setupHttpServer(
|
||||||
@@ -19,7 +20,8 @@ func setupHttpServer(
|
|||||||
config *config.Config,
|
config *config.Config,
|
||||||
logger *hlog.Logger,
|
logger *hlog.Logger,
|
||||||
bun *bun.DB,
|
bun *bun.DB,
|
||||||
store *session.Store,
|
store *store.Store,
|
||||||
|
discordAPI *discord.APIClient,
|
||||||
) (server *hws.Server, err error) {
|
) (server *hws.Server, err error) {
|
||||||
if staticFS == nil {
|
if staticFS == nil {
|
||||||
return nil, errors.New("No filesystem provided")
|
return nil, errors.New("No filesystem provided")
|
||||||
@@ -55,7 +57,7 @@ func setupHttpServer(
|
|||||||
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
|
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = addRoutes(httpServer, &fs, config, bun, auth, store)
|
err = addRoutes(httpServer, &fs, config, bun, auth, store, discordAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "addRoutes")
|
return nil, errors.Wrap(err, "addRoutes")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,13 +5,14 @@ import (
|
|||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/golib/hwsauth"
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/handlers"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/session"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/handlers"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
func addRoutes(
|
func addRoutes(
|
||||||
@@ -20,7 +21,8 @@ func addRoutes(
|
|||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
store *session.Store,
|
store *store.Store,
|
||||||
|
discordAPI *discord.APIClient,
|
||||||
) error {
|
) error {
|
||||||
// Create the routes
|
// Create the routes
|
||||||
routes := []hws.Route{
|
routes := []hws.Route{
|
||||||
@@ -37,12 +39,12 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/login",
|
Path: "/login",
|
||||||
Method: hws.MethodGET,
|
Method: hws.MethodGET,
|
||||||
Handler: auth.LogoutReq(handlers.Login(server, cfg)),
|
Handler: auth.LogoutReq(handlers.Login(server, cfg, store, discordAPI)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/auth/callback",
|
Path: "/auth/callback",
|
||||||
Method: hws.MethodGET,
|
Method: hws.MethodGET,
|
||||||
Handler: auth.LogoutReq(handlers.Callback(server, conn, cfg, store)),
|
Handler: auth.LogoutReq(handlers.Callback(server, conn, cfg, store, discordAPI)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/register",
|
Path: "/register",
|
||||||
|
|||||||
@@ -9,10 +9,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hlog"
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/session"
|
|
||||||
"git.haelnorr.com/h/oslstats/pkg/embedfs"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
|
"git.haelnorr.com/h/oslstats/pkg/embedfs"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Initializes and runs the server
|
// Initializes and runs the server
|
||||||
@@ -44,10 +46,14 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error {
|
|||||||
|
|
||||||
// Setup session store
|
// Setup session store
|
||||||
logger.Debug().Msg("Setting up session store")
|
logger.Debug().Msg("Setting up session store")
|
||||||
store := session.NewStore()
|
store := store.NewStore()
|
||||||
|
|
||||||
|
// Setup Discord API client
|
||||||
|
logger.Debug().Msg("Setting up Discord API client")
|
||||||
|
discordAPI := discord.NewRateLimitedClient(logger)
|
||||||
|
|
||||||
logger.Debug().Msg("Setting up HTTP server")
|
logger.Debug().Msg("Setting up HTTP server")
|
||||||
httpServer, err := setupHttpServer(&staticFS, config, logger, bun, store)
|
httpServer, err := setupHttpServer(&staticFS, config, logger, bun, store, discordAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "setupHttpServer")
|
return errors.Wrap(err, "setupHttpServer")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func GetOAuthLink(cfg *Config, state, trustedHost string) (string, error) {
|
|||||||
return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil
|
return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
|
func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClient) (*Token, error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, errors.New("code cannot be empty")
|
return nil, errors.New("code cannot be empty")
|
||||||
}
|
}
|
||||||
@@ -53,6 +53,9 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
|
|||||||
if trustedHost == "" {
|
if trustedHost == "" {
|
||||||
return nil, errors.New("trustedHost cannot be empty")
|
return nil, errors.New("trustedHost cannot be empty")
|
||||||
}
|
}
|
||||||
|
if apiClient == nil {
|
||||||
|
return nil, errors.New("apiClient cannot be nil")
|
||||||
|
}
|
||||||
// Prepare form data
|
// Prepare form data
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Set("grant_type", "authorization_code")
|
data.Set("grant_type", "authorization_code")
|
||||||
@@ -72,9 +75,8 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
|
|||||||
|
|
||||||
// Set basic auth (client_id and client_secret)
|
// Set basic auth (client_id and client_secret)
|
||||||
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||||
// Execute request
|
// Execute request with rate limit handling
|
||||||
client := &http.Client{}
|
resp, err := apiClient.Do(req)
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to execute request")
|
return nil, errors.Wrap(err, "failed to execute request")
|
||||||
}
|
}
|
||||||
@@ -96,13 +98,16 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
|
|||||||
return &tokenResp, nil
|
return &tokenResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RefreshToken(cfg *Config, token *Token) (*Token, error) {
|
func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, error) {
|
||||||
if token == nil {
|
if token == nil {
|
||||||
return nil, errors.New("token cannot be nil")
|
return nil, errors.New("token cannot be nil")
|
||||||
}
|
}
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return nil, errors.New("config cannot be nil")
|
return nil, errors.New("config cannot be nil")
|
||||||
}
|
}
|
||||||
|
if apiClient == nil {
|
||||||
|
return nil, errors.New("apiClient cannot be nil")
|
||||||
|
}
|
||||||
// Prepare form data
|
// Prepare form data
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Set("grant_type", "refresh_token")
|
data.Set("grant_type", "refresh_token")
|
||||||
@@ -121,9 +126,8 @@ func RefreshToken(cfg *Config, token *Token) (*Token, error) {
|
|||||||
|
|
||||||
// Set basic auth (client_id and client_secret)
|
// Set basic auth (client_id and client_secret)
|
||||||
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||||
// Execute request
|
// Execute request with rate limit handling
|
||||||
client := &http.Client{}
|
resp, err := apiClient.Do(req)
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to execute request")
|
return nil, errors.Wrap(err, "failed to execute request")
|
||||||
}
|
}
|
||||||
@@ -145,13 +149,16 @@ func RefreshToken(cfg *Config, token *Token) (*Token, error) {
|
|||||||
return &tokenResp, nil
|
return &tokenResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RevokeToken(cfg *Config, token *Token) error {
|
func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error {
|
||||||
if token == nil {
|
if token == nil {
|
||||||
return errors.New("token cannot be nil")
|
return errors.New("token cannot be nil")
|
||||||
}
|
}
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return errors.New("config cannot be nil")
|
return errors.New("config cannot be nil")
|
||||||
}
|
}
|
||||||
|
if apiClient == nil {
|
||||||
|
return errors.New("apiClient cannot be nil")
|
||||||
|
}
|
||||||
// Prepare form data
|
// Prepare form data
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Set("token", token.AccessToken)
|
data.Set("token", token.AccessToken)
|
||||||
@@ -170,9 +177,8 @@ func RevokeToken(cfg *Config, token *Token) error {
|
|||||||
|
|
||||||
// Set basic auth (client_id and client_secret)
|
// Set basic auth (client_id and client_secret)
|
||||||
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||||
// Execute request
|
// Execute request with rate limit handling
|
||||||
client := &http.Client{}
|
resp, err := apiClient.Do(req)
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to execute request")
|
return errors.Wrap(err, "failed to execute request")
|
||||||
}
|
}
|
||||||
|
|||||||
235
internal/discord/ratelimit.go
Normal file
235
internal/discord/ratelimit.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package discord
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimitState tracks rate limit information for a specific bucket
|
||||||
|
type RateLimitState struct {
|
||||||
|
Remaining int // Requests remaining in current window
|
||||||
|
Limit int // Total requests allowed in window
|
||||||
|
Reset time.Time // When the rate limit resets
|
||||||
|
Bucket string // Discord's bucket identifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIClient is an HTTP client wrapper that handles Discord API rate limits
|
||||||
|
type APIClient struct {
|
||||||
|
client *http.Client
|
||||||
|
logger *hlog.Logger
|
||||||
|
mu sync.RWMutex
|
||||||
|
buckets map[string]*RateLimitState
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimitedClient creates a new Discord API client with rate limit handling
|
||||||
|
func NewRateLimitedClient(logger *hlog.Logger) *APIClient {
|
||||||
|
return &APIClient{
|
||||||
|
client: &http.Client{Timeout: 30 * time.Second},
|
||||||
|
logger: logger,
|
||||||
|
buckets: make(map[string]*RateLimitState),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do executes an HTTP request with automatic rate limit handling
|
||||||
|
// It will wait if rate limits are about to be exceeded and retry once if a 429 is received
|
||||||
|
func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, errors.New("request cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1: Check if we need to wait before making request
|
||||||
|
bucket := c.getBucketFromRequest(req)
|
||||||
|
if err := c.waitIfNeeded(bucket); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Execute request
|
||||||
|
resp, err := c.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
// Check if it's a network timeout
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
return nil, errors.Wrap(err, "request timed out")
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(err, "http request failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Update rate limit state from response headers
|
||||||
|
c.updateRateLimit(resp.Header)
|
||||||
|
|
||||||
|
// Step 4: Handle 429 (rate limited)
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
resp.Body.Close() // Close original response
|
||||||
|
|
||||||
|
retryAfter := c.parseRetryAfter(resp.Header)
|
||||||
|
|
||||||
|
// No Retry-After header, can't retry safely
|
||||||
|
if retryAfter == 0 {
|
||||||
|
c.logger.Warn().
|
||||||
|
Str("bucket", bucket).
|
||||||
|
Str("method", req.Method).
|
||||||
|
Str("path", req.URL.Path).
|
||||||
|
Msg("Rate limited but no Retry-After header provided")
|
||||||
|
return nil, errors.New("discord API rate limited but no Retry-After header provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry-After exceeds 30 second cap
|
||||||
|
if retryAfter > 30*time.Second {
|
||||||
|
c.logger.Warn().
|
||||||
|
Str("bucket", bucket).
|
||||||
|
Str("method", req.Method).
|
||||||
|
Str("path", req.URL.Path).
|
||||||
|
Dur("retry_after", retryAfter).
|
||||||
|
Msg("Rate limited with Retry-After exceeding 30s cap, not retrying")
|
||||||
|
return nil, errors.Errorf(
|
||||||
|
"discord API rate limited (retry after %s exceeds 30s cap)",
|
||||||
|
retryAfter,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait and retry
|
||||||
|
c.logger.Warn().
|
||||||
|
Str("bucket", bucket).
|
||||||
|
Str("method", req.Method).
|
||||||
|
Str("path", req.URL.Path).
|
||||||
|
Dur("retry_after", retryAfter).
|
||||||
|
Msg("Rate limited, waiting before retry")
|
||||||
|
|
||||||
|
time.Sleep(retryAfter)
|
||||||
|
|
||||||
|
// Retry the request
|
||||||
|
resp, err = c.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
return nil, errors.Wrap(err, "retry request timed out")
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(err, "retry request failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update rate limit again after retry
|
||||||
|
c.updateRateLimit(resp.Header)
|
||||||
|
|
||||||
|
// If STILL rate limited after retry, return error
|
||||||
|
if resp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
resp.Body.Close()
|
||||||
|
c.logger.Error().
|
||||||
|
Str("bucket", bucket).
|
||||||
|
Str("method", req.Method).
|
||||||
|
Str("path", req.URL.Path).
|
||||||
|
Msg("Still rate limited after retry, Discord may be experiencing issues")
|
||||||
|
return nil, errors.Errorf(
|
||||||
|
"discord API still rate limited after retry (waited %s), Discord may be experiencing issues",
|
||||||
|
retryAfter,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBucketFromRequest extracts or generates bucket ID from request
|
||||||
|
// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers
|
||||||
|
func (c *APIClient) getBucketFromRequest(req *http.Request) string {
|
||||||
|
return req.Method + ":" + req.URL.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits
|
||||||
|
func (c *APIClient) waitIfNeeded(bucket string) error {
|
||||||
|
c.mu.RLock()
|
||||||
|
state, exists := c.buckets[bucket]
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return nil // No state yet, proceed
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// If we have no remaining requests and reset hasn't occurred, wait
|
||||||
|
if state.Remaining == 0 && now.Before(state.Reset) {
|
||||||
|
waitDuration := time.Until(state.Reset)
|
||||||
|
// Add small buffer (100ms) to ensure reset has occurred
|
||||||
|
waitDuration += 100 * time.Millisecond
|
||||||
|
|
||||||
|
if waitDuration > 0 {
|
||||||
|
c.logger.Debug().
|
||||||
|
Str("bucket", bucket).
|
||||||
|
Dur("wait_duration", waitDuration).
|
||||||
|
Msg("Proactively waiting for rate limit reset")
|
||||||
|
time.Sleep(waitDuration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateRateLimit parses response headers and updates bucket state
|
||||||
|
func (c *APIClient) updateRateLimit(headers http.Header) {
|
||||||
|
bucket := headers.Get("X-RateLimit-Bucket")
|
||||||
|
if bucket == "" {
|
||||||
|
return // No bucket info, can't track
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse headers
|
||||||
|
limit := c.parseInt(headers.Get("X-RateLimit-Limit"))
|
||||||
|
remaining := c.parseInt(headers.Get("X-RateLimit-Remaining"))
|
||||||
|
resetAfter := c.parseFloat(headers.Get("X-RateLimit-Reset-After"))
|
||||||
|
|
||||||
|
state := &RateLimitState{
|
||||||
|
Bucket: bucket,
|
||||||
|
Limit: limit,
|
||||||
|
Remaining: remaining,
|
||||||
|
Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.buckets[bucket] = state
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
// Log rate limit state for debugging
|
||||||
|
c.logger.Debug().
|
||||||
|
Str("bucket", bucket).
|
||||||
|
Int("remaining", remaining).
|
||||||
|
Int("limit", limit).
|
||||||
|
Dur("reset_in", time.Until(state.Reset)).
|
||||||
|
Msg("Rate limit state updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRetryAfter extracts retry delay from Retry-After header
|
||||||
|
func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration {
|
||||||
|
retryAfter := headers.Get("Retry-After")
|
||||||
|
if retryAfter == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discord returns seconds as float
|
||||||
|
seconds := c.parseFloat(retryAfter)
|
||||||
|
if seconds <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Duration(seconds * float64(time.Second))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseInt parses an integer from a header value, returns 0 on error
|
||||||
|
func (c *APIClient) parseInt(s string) int {
|
||||||
|
if s == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
i, _ := strconv.Atoi(s)
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseFloat parses a float from a header value, returns 0 on error
|
||||||
|
func (c *APIClient) parseFloat(s string) float64 {
|
||||||
|
if s == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
f, _ := strconv.ParseFloat(s, 64)
|
||||||
|
return f
|
||||||
|
}
|
||||||
459
internal/discord/ratelimit_test.go
Normal file
459
internal/discord/ratelimit_test.go
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
package discord
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testLogger creates a test logger for testing
|
||||||
|
func testLogger(t *testing.T) *hlog.Logger {
|
||||||
|
level, _ := hlog.LogLevel("debug")
|
||||||
|
cfg := &hlog.Config{
|
||||||
|
LogLevel: level,
|
||||||
|
LogOutput: "console",
|
||||||
|
}
|
||||||
|
logger, err := hlog.NewLogger(cfg, io.Discard)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create test logger: %v", err)
|
||||||
|
}
|
||||||
|
return logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRateLimitedClient(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
t.Fatal("NewRateLimitedClient returned nil")
|
||||||
|
}
|
||||||
|
if client.client == nil {
|
||||||
|
t.Error("client.client is nil")
|
||||||
|
}
|
||||||
|
if client.logger == nil {
|
||||||
|
t.Error("client.logger is nil")
|
||||||
|
}
|
||||||
|
if client.buckets == nil {
|
||||||
|
t.Error("client.buckets map is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_Do_Success(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
// Mock server that returns success with rate limit headers
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||||
|
w.Header().Set("X-RateLimit-Limit", "5")
|
||||||
|
w.Header().Set("X-RateLimit-Remaining", "3")
|
||||||
|
w.Header().Set("X-RateLimit-Reset-After", "2.5")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() returned error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that rate limit state was updated
|
||||||
|
client.mu.RLock()
|
||||||
|
state, exists := client.buckets["test-bucket"]
|
||||||
|
client.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("rate limit state not stored")
|
||||||
|
}
|
||||||
|
if state.Remaining != 3 {
|
||||||
|
t.Errorf("expected remaining=3, got %d", state.Remaining)
|
||||||
|
}
|
||||||
|
if state.Limit != 5 {
|
||||||
|
t.Errorf("expected limit=5, got %d", state.Limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
attemptCount := 0
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attemptCount++
|
||||||
|
if attemptCount == 1 {
|
||||||
|
// First request: return 429
|
||||||
|
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||||
|
w.Header().Set("Retry-After", "0.1") // 100ms
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "rate_limited",
|
||||||
|
"error_description": "You are being rate limited",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Second request: success
|
||||||
|
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||||
|
w.Header().Set("X-RateLimit-Limit", "5")
|
||||||
|
w.Header().Set("X-RateLimit-Remaining", "4")
|
||||||
|
w.Header().Set("X-RateLimit-Reset-After", "2.5")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() returned error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200 after retry, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if attemptCount != 2 {
|
||||||
|
t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have waited approximately 100ms
|
||||||
|
if elapsed < 100*time.Millisecond {
|
||||||
|
t.Errorf("expected delay of ~100ms, but took %v", elapsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
attemptCount := 0
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attemptCount++
|
||||||
|
// Always return 429
|
||||||
|
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||||
|
w.Header().Set("Retry-After", "0.05") // 50ms
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "rate_limited",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
t.Fatal("Do() should have returned error after failed retry")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "still rate limited after retry") {
|
||||||
|
t.Errorf("expected 'still rate limited after retry' error, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if attemptCount != 2 {
|
||||||
|
t.Errorf("expected 2 attempts, got %d", attemptCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "rate_limited",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
t.Fatal("Do() should have returned error for Retry-After > 30s")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "exceeds 30s cap") {
|
||||||
|
t.Errorf("expected 'exceeds 30s cap' error, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should NOT have waited (immediate error)
|
||||||
|
if elapsed > 1*time.Second {
|
||||||
|
t.Errorf("should return immediately, but took %v", elapsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Return 429 but NO Retry-After header
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "rate_limited",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
t.Fatal("Do() should have returned error when no Retry-After header")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "no Retry-After header") {
|
||||||
|
t.Errorf("expected 'no Retry-After header' error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_UpdateRateLimit(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("X-RateLimit-Bucket", "global")
|
||||||
|
headers.Set("X-RateLimit-Limit", "10")
|
||||||
|
headers.Set("X-RateLimit-Remaining", "7")
|
||||||
|
headers.Set("X-RateLimit-Reset-After", "5.5")
|
||||||
|
|
||||||
|
client.updateRateLimit(headers)
|
||||||
|
|
||||||
|
client.mu.RLock()
|
||||||
|
state, exists := client.buckets["global"]
|
||||||
|
client.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("bucket state not created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Bucket != "global" {
|
||||||
|
t.Errorf("expected bucket='global', got '%s'", state.Bucket)
|
||||||
|
}
|
||||||
|
if state.Limit != 10 {
|
||||||
|
t.Errorf("expected limit=10, got %d", state.Limit)
|
||||||
|
}
|
||||||
|
if state.Remaining != 7 {
|
||||||
|
t.Errorf("expected remaining=7, got %d", state.Remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check reset time is approximately 5.5 seconds from now
|
||||||
|
resetIn := time.Until(state.Reset)
|
||||||
|
if resetIn < 5*time.Second || resetIn > 6*time.Second {
|
||||||
|
t.Errorf("expected reset in ~5.5s, got %v", resetIn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_WaitIfNeeded(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
// Set up a bucket with 0 remaining and reset in future
|
||||||
|
bucket := "test-bucket"
|
||||||
|
client.mu.Lock()
|
||||||
|
client.buckets[bucket] = &RateLimitState{
|
||||||
|
Bucket: bucket,
|
||||||
|
Limit: 5,
|
||||||
|
Remaining: 0,
|
||||||
|
Reset: time.Now().Add(200 * time.Millisecond),
|
||||||
|
}
|
||||||
|
client.mu.Unlock()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
err := client.waitIfNeeded(bucket)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("waitIfNeeded returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have waited ~200ms + 100ms buffer
|
||||||
|
if elapsed < 200*time.Millisecond {
|
||||||
|
t.Errorf("expected wait of ~300ms, but took %v", elapsed)
|
||||||
|
}
|
||||||
|
if elapsed > 500*time.Millisecond {
|
||||||
|
t.Errorf("waited too long: %v", elapsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
// Set up a bucket with remaining requests
|
||||||
|
bucket := "test-bucket"
|
||||||
|
client.mu.Lock()
|
||||||
|
client.buckets[bucket] = &RateLimitState{
|
||||||
|
Bucket: bucket,
|
||||||
|
Limit: 5,
|
||||||
|
Remaining: 3,
|
||||||
|
Reset: time.Now().Add(5 * time.Second),
|
||||||
|
}
|
||||||
|
client.mu.Unlock()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
err := client.waitIfNeeded(bucket)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("waitIfNeeded returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should NOT wait (has remaining requests)
|
||||||
|
if elapsed > 10*time.Millisecond {
|
||||||
|
t.Errorf("should not wait when remaining > 0, but took %v", elapsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_Do_Concurrent(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
requestCount := 0
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mu.Lock()
|
||||||
|
requestCount++
|
||||||
|
count := requestCount
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket")
|
||||||
|
w.Header().Set("X-RateLimit-Limit", "10")
|
||||||
|
w.Header().Set("X-RateLimit-Remaining", "5")
|
||||||
|
w.Header().Set("X-RateLimit-Reset-After", "1.0")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Launch 10 concurrent requests
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errors := make(chan error, 10)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errors)
|
||||||
|
|
||||||
|
// Check for any errors
|
||||||
|
for err := range errors {
|
||||||
|
t.Errorf("concurrent request failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All requests should have completed
|
||||||
|
mu.Lock()
|
||||||
|
finalCount := requestCount
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
if finalCount != 10 {
|
||||||
|
t.Errorf("expected 10 requests, got %d", finalCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check rate limit state is consistent (no data races)
|
||||||
|
client.mu.RLock()
|
||||||
|
state, exists := client.buckets["concurrent-bucket"]
|
||||||
|
client.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("bucket state not found after concurrent requests")
|
||||||
|
}
|
||||||
|
|
||||||
|
// State should exist and be valid
|
||||||
|
if state.Limit != 10 {
|
||||||
|
t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIClient_ParseRetryAfter(t *testing.T) {
|
||||||
|
logger := testLogger(t)
|
||||||
|
client := NewRateLimitedClient(logger)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
header string
|
||||||
|
expected time.Duration
|
||||||
|
}{
|
||||||
|
{"integer seconds", "2", 2 * time.Second},
|
||||||
|
{"float seconds", "2.5", 2500 * time.Millisecond},
|
||||||
|
{"zero", "0", 0},
|
||||||
|
{"empty", "", 0},
|
||||||
|
{"invalid", "abc", 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Retry-After", tt.header)
|
||||||
|
|
||||||
|
result := client.parseRetryAfter(headers)
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,18 +6,49 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
"git.haelnorr.com/h/oslstats/internal/session"
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *session.Store) http.Handler {
|
func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *store.Store, discordAPI *discord.APIClient) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Track callback redirect attempts
|
||||||
|
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
|
||||||
|
|
||||||
|
if exceeded {
|
||||||
|
// Build detailed error for logging
|
||||||
|
err := errors.Errorf(
|
||||||
|
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||||
|
attempts,
|
||||||
|
track.IP,
|
||||||
|
track.UserAgent,
|
||||||
|
track.Path,
|
||||||
|
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Clear the tracking entry
|
||||||
|
store.ClearRedirectTrack(r, "/callback")
|
||||||
|
|
||||||
|
// Show error page
|
||||||
|
throwError(
|
||||||
|
server,
|
||||||
|
w,
|
||||||
|
r,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"OAuth callback failed: Too many redirect attempts. Please try logging in again.",
|
||||||
|
err,
|
||||||
|
"warn",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
state := r.URL.Query().Get("state")
|
state := r.URL.Query().Get("state")
|
||||||
code := r.URL.Query().Get("code")
|
code := r.URL.Query().Get("code")
|
||||||
if state == "" && code == "" {
|
if state == "" && code == "" {
|
||||||
@@ -41,6 +72,10 @@ func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *sessi
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// SUCCESS POINT: State verified successfully
|
||||||
|
// Clear redirect tracking - OAuth callback completed successfully
|
||||||
|
store.ClearRedirectTrack(r, "/callback")
|
||||||
|
|
||||||
switch data {
|
switch data {
|
||||||
case "login":
|
case "login":
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
@@ -51,7 +86,7 @@ func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *sessi
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
redirect, err := login(ctx, tx, cfg, w, r, code, store)
|
redirect, err := login(ctx, tx, cfg, w, r, code, store, discordAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(server, w, r, "OAuth login failed", err)
|
throwInternalServiceError(server, w, r, "OAuth login failed", err)
|
||||||
return
|
return
|
||||||
@@ -122,9 +157,10 @@ func login(
|
|||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
code string,
|
code string,
|
||||||
store *session.Store,
|
store *store.Store,
|
||||||
|
discordAPI *discord.APIClient,
|
||||||
) (func(), error) {
|
) (func(), error) {
|
||||||
token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost)
|
token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost, discordAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "discord.AuthorizeWithCode")
|
return nil, errors.Wrap(err, "discord.AuthorizeWithCode")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,14 +4,47 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Login(server *hws.Server, cfg *config.Config) http.Handler {
|
func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Track login redirect attempts
|
||||||
|
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
|
||||||
|
|
||||||
|
if exceeded {
|
||||||
|
// Build detailed error for logging
|
||||||
|
err := errors.Errorf(
|
||||||
|
"login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||||
|
attempts,
|
||||||
|
track.IP,
|
||||||
|
track.UserAgent,
|
||||||
|
track.Path,
|
||||||
|
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Clear the tracking entry
|
||||||
|
st.ClearRedirectTrack(r, "/login")
|
||||||
|
|
||||||
|
// Show error page
|
||||||
|
throwError(
|
||||||
|
server,
|
||||||
|
w,
|
||||||
|
r,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"Login failed: Too many redirect attempts. Please clear your browser cookies and try again.",
|
||||||
|
err,
|
||||||
|
"warn",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
|
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(server, w, r, "Failed to generate state token", err)
|
throwInternalServiceError(server, w, r, "Failed to generate state token", err)
|
||||||
@@ -24,6 +57,11 @@ func Login(server *hws.Server, cfg *config.Config) http.Handler {
|
|||||||
throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err)
|
throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SUCCESS POINT: OAuth link generated, redirecting to Discord
|
||||||
|
// Clear redirect tracking - user successfully initiated OAuth
|
||||||
|
st.ClearRedirectTrack(r, "/login")
|
||||||
|
|
||||||
http.Redirect(w, r, link, http.StatusSeeOther)
|
http.Redirect(w, r, link, http.StatusSeeOther)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,31 +6,63 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/session"
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Register(
|
func Register(
|
||||||
server *hws.Server,
|
server *hws.Server,
|
||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
store *session.Store,
|
store *store.Store,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts, exceeded, track := store.TrackRedirect(r, "/register", 3)
|
||||||
|
|
||||||
|
if exceeded {
|
||||||
|
err := errors.Errorf(
|
||||||
|
"registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t",
|
||||||
|
attempts,
|
||||||
|
track.IP,
|
||||||
|
track.UserAgent,
|
||||||
|
track.Path,
|
||||||
|
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||||
|
cfg.HWSAuth.SSL,
|
||||||
|
)
|
||||||
|
|
||||||
|
store.ClearRedirectTrack(r, "/register")
|
||||||
|
|
||||||
|
throwError(
|
||||||
|
server,
|
||||||
|
w,
|
||||||
|
r,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.",
|
||||||
|
err,
|
||||||
|
"warn",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
sessionCookie, err := r.Cookie("registration_session")
|
sessionCookie, err := r.Cookie("registration_session")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
details, ok := store.GetRegistrationSession(sessionCookie.Value)
|
details, ok := store.GetRegistrationSession(sessionCookie.Value)
|
||||||
|
ok = false
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
store.ClearRedirectTrack(r, "/register")
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
@@ -65,12 +97,11 @@ func IsUsernameUnique(
|
|||||||
server *hws.Server,
|
server *hws.Server,
|
||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
store *session.Store,
|
store *store.Store,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
username := r.FormValue("username")
|
username := r.FormValue("username")
|
||||||
// check if its unique
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
package session
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Store struct {
|
|
||||||
sessions sync.Map
|
|
||||||
cleanup *time.Ticker
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStore() *Store {
|
|
||||||
s := &Store{
|
|
||||||
cleanup: time.NewTicker(1 * time.Minute),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Background cleanup of expired sessions
|
|
||||||
go func() {
|
|
||||||
for range s.cleanup.C {
|
|
||||||
s.cleanupExpired()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) Delete(id string) {
|
|
||||||
s.sessions.Delete(id)
|
|
||||||
}
|
|
||||||
func (s *Store) cleanupExpired() {
|
|
||||||
s.sessions.Range(func(key, value any) bool {
|
|
||||||
session := value.(*RegistrationSession)
|
|
||||||
if time.Now().After(session.ExpiresAt) {
|
|
||||||
s.sessions.Delete(key)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
func generateID() string {
|
|
||||||
b := make([]byte, 32)
|
|
||||||
rand.Read(b)
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b)
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package session
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
95
internal/store/redirects.go
Normal file
95
internal/store/redirects.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getClientIP extracts the client IP address, checking X-Forwarded-For first
|
||||||
|
func getClientIP(r *http.Request) string {
|
||||||
|
// Check X-Forwarded-For header (comma-separated list, first is client)
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
// Take the first IP in the list
|
||||||
|
ips := strings.Split(xff, ",")
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return strings.TrimSpace(ips[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to RemoteAddr (format: "IP:port" or "[IPv6]:port")
|
||||||
|
// Use net.SplitHostPort to properly handle both IPv4 and IPv6
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
// If SplitHostPort fails, return as-is (shouldn't happen with valid RemoteAddr)
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackRedirect increments the redirect counter for this IP+UA+Path combination
|
||||||
|
// Returns the current attempt count, whether limit was exceeded, and the track details
|
||||||
|
func (s *Store) TrackRedirect(r *http.Request, path string, maxAttempts int) (attempts int, exceeded bool, track *RedirectTrack) {
|
||||||
|
if r == nil {
|
||||||
|
return 0, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := getClientIP(r)
|
||||||
|
userAgent := r.UserAgent()
|
||||||
|
key := redirectKey(ip, userAgent, path)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
expiresAt := now.Add(5 * time.Minute)
|
||||||
|
|
||||||
|
// Try to load existing track
|
||||||
|
val, exists := s.redirectTracks.Load(key)
|
||||||
|
if exists {
|
||||||
|
track = val.(*RedirectTrack)
|
||||||
|
|
||||||
|
// Check if expired
|
||||||
|
if now.After(track.ExpiresAt) {
|
||||||
|
// Expired, start fresh
|
||||||
|
track = &RedirectTrack{
|
||||||
|
IP: ip,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
Path: path,
|
||||||
|
Attempts: 1,
|
||||||
|
FirstSeen: now,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
s.redirectTracks.Store(key, track)
|
||||||
|
return 1, false, track
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment existing
|
||||||
|
track.Attempts++
|
||||||
|
track.ExpiresAt = expiresAt // Extend expiry
|
||||||
|
exceeded = track.Attempts >= maxAttempts
|
||||||
|
return track.Attempts, exceeded, track
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new track
|
||||||
|
track = &RedirectTrack{
|
||||||
|
IP: ip,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
Path: path,
|
||||||
|
Attempts: 1,
|
||||||
|
FirstSeen: now,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
s.redirectTracks.Store(key, track)
|
||||||
|
return 1, false, track
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearRedirectTrack removes a redirect tracking entry (called after successful completion)
|
||||||
|
func (s *Store) ClearRedirectTrack(r *http.Request, path string) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := getClientIP(r)
|
||||||
|
userAgent := r.UserAgent()
|
||||||
|
key := redirectKey(ip, userAgent, path)
|
||||||
|
s.redirectTracks.Delete(key)
|
||||||
|
}
|
||||||
80
internal/store/store.go
Normal file
80
internal/store/store.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RedirectTrack represents a single redirect attempt tracking entry
|
||||||
|
type RedirectTrack struct {
|
||||||
|
IP string // Client IP (X-Forwarded-For aware)
|
||||||
|
UserAgent string // Full User-Agent string for debugging
|
||||||
|
Path string // Request path (without query params)
|
||||||
|
Attempts int // Number of redirect attempts
|
||||||
|
FirstSeen time.Time // When first redirect was tracked
|
||||||
|
ExpiresAt time.Time // When to clean up this entry
|
||||||
|
}
|
||||||
|
|
||||||
|
type Store struct {
|
||||||
|
sessions sync.Map // key: string, value: *RegistrationSession
|
||||||
|
redirectTracks sync.Map // key: string, value: *RedirectTrack
|
||||||
|
cleanup *time.Ticker
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStore() *Store {
|
||||||
|
s := &Store{
|
||||||
|
cleanup: time.NewTicker(1 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Background cleanup of expired sessions
|
||||||
|
go func() {
|
||||||
|
for range s.cleanup.C {
|
||||||
|
s.cleanupExpired()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Delete(id string) {
|
||||||
|
s.sessions.Delete(id)
|
||||||
|
}
|
||||||
|
func (s *Store) cleanupExpired() {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Clean up expired registration sessions
|
||||||
|
s.sessions.Range(func(key, value any) bool {
|
||||||
|
session := value.(*RegistrationSession)
|
||||||
|
if now.After(session.ExpiresAt) {
|
||||||
|
s.sessions.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Clean up expired redirect tracks
|
||||||
|
s.redirectTracks.Range(func(key, value any) bool {
|
||||||
|
track := value.(*RedirectTrack)
|
||||||
|
if now.After(track.ExpiresAt) {
|
||||||
|
s.redirectTracks.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func generateID() string {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
rand.Read(b)
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirectKey generates a unique key for tracking redirects
|
||||||
|
// Uses IP + first 100 chars of UA + path as key (not hashed for debugging)
|
||||||
|
func redirectKey(ip, userAgent, path string) string {
|
||||||
|
ua := userAgent
|
||||||
|
if len(ua) > 100 {
|
||||||
|
ua = ua[:100]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s:%s:%s", ip, ua, path)
|
||||||
|
}
|
||||||
@@ -232,18 +232,12 @@
|
|||||||
.top-0 {
|
.top-0 {
|
||||||
top: calc(var(--spacing) * 0);
|
top: calc(var(--spacing) * 0);
|
||||||
}
|
}
|
||||||
.top-2 {
|
|
||||||
top: calc(var(--spacing) * 2);
|
|
||||||
}
|
|
||||||
.top-4 {
|
.top-4 {
|
||||||
top: calc(var(--spacing) * 4);
|
top: calc(var(--spacing) * 4);
|
||||||
}
|
}
|
||||||
.right-0 {
|
.right-0 {
|
||||||
right: calc(var(--spacing) * 0);
|
right: calc(var(--spacing) * 0);
|
||||||
}
|
}
|
||||||
.right-2 {
|
|
||||||
right: calc(var(--spacing) * 2);
|
|
||||||
}
|
|
||||||
.bottom-0 {
|
.bottom-0 {
|
||||||
bottom: calc(var(--spacing) * 0);
|
bottom: calc(var(--spacing) * 0);
|
||||||
}
|
}
|
||||||
@@ -253,18 +247,9 @@
|
|||||||
.z-10 {
|
.z-10 {
|
||||||
z-index: 10;
|
z-index: 10;
|
||||||
}
|
}
|
||||||
.float-left {
|
|
||||||
float: left;
|
|
||||||
}
|
|
||||||
.m-0 {
|
|
||||||
margin: calc(var(--spacing) * 0);
|
|
||||||
}
|
|
||||||
.mx-auto {
|
.mx-auto {
|
||||||
margin-inline: auto;
|
margin-inline: auto;
|
||||||
}
|
}
|
||||||
.mt-1 {
|
|
||||||
margin-top: calc(var(--spacing) * 1);
|
|
||||||
}
|
|
||||||
.mt-1\.5 {
|
.mt-1\.5 {
|
||||||
margin-top: calc(var(--spacing) * 1.5);
|
margin-top: calc(var(--spacing) * 1.5);
|
||||||
}
|
}
|
||||||
@@ -298,21 +283,12 @@
|
|||||||
.mt-24 {
|
.mt-24 {
|
||||||
margin-top: calc(var(--spacing) * 24);
|
margin-top: calc(var(--spacing) * 24);
|
||||||
}
|
}
|
||||||
.mr-0 {
|
|
||||||
margin-right: calc(var(--spacing) * 0);
|
|
||||||
}
|
|
||||||
.mr-2 {
|
|
||||||
margin-right: calc(var(--spacing) * 2);
|
|
||||||
}
|
|
||||||
.mr-5 {
|
.mr-5 {
|
||||||
margin-right: calc(var(--spacing) * 5);
|
margin-right: calc(var(--spacing) * 5);
|
||||||
}
|
}
|
||||||
.mb-auto {
|
.mb-auto {
|
||||||
margin-bottom: auto;
|
margin-bottom: auto;
|
||||||
}
|
}
|
||||||
.ml-0 {
|
|
||||||
margin-left: calc(var(--spacing) * 0);
|
|
||||||
}
|
|
||||||
.ml-2 {
|
.ml-2 {
|
||||||
margin-left: calc(var(--spacing) * 2);
|
margin-left: calc(var(--spacing) * 2);
|
||||||
}
|
}
|
||||||
@@ -322,9 +298,6 @@
|
|||||||
.block {
|
.block {
|
||||||
display: block;
|
display: block;
|
||||||
}
|
}
|
||||||
.contents {
|
|
||||||
display: contents;
|
|
||||||
}
|
|
||||||
.flex {
|
.flex {
|
||||||
display: flex;
|
display: flex;
|
||||||
}
|
}
|
||||||
@@ -343,9 +316,6 @@
|
|||||||
.inline-flex {
|
.inline-flex {
|
||||||
display: inline-flex;
|
display: inline-flex;
|
||||||
}
|
}
|
||||||
.table {
|
|
||||||
display: table;
|
|
||||||
}
|
|
||||||
.size-5 {
|
.size-5 {
|
||||||
width: calc(var(--spacing) * 5);
|
width: calc(var(--spacing) * 5);
|
||||||
height: calc(var(--spacing) * 5);
|
height: calc(var(--spacing) * 5);
|
||||||
@@ -396,9 +366,6 @@
|
|||||||
.flex-1 {
|
.flex-1 {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
}
|
}
|
||||||
.border-collapse {
|
|
||||||
border-collapse: collapse;
|
|
||||||
}
|
|
||||||
.translate-x-0 {
|
.translate-x-0 {
|
||||||
--tw-translate-x: calc(var(--spacing) * 0);
|
--tw-translate-x: calc(var(--spacing) * 0);
|
||||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
translate: var(--tw-translate-x) var(--tw-translate-y);
|
||||||
@@ -413,9 +380,6 @@
|
|||||||
.cursor-pointer {
|
.cursor-pointer {
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
}
|
}
|
||||||
.resize {
|
|
||||||
resize: both;
|
|
||||||
}
|
|
||||||
.flex-col {
|
.flex-col {
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
}
|
}
|
||||||
@@ -670,9 +634,6 @@
|
|||||||
.text-text {
|
.text-text {
|
||||||
color: var(--text);
|
color: var(--text);
|
||||||
}
|
}
|
||||||
.underline {
|
|
||||||
text-decoration-line: underline;
|
|
||||||
}
|
|
||||||
.opacity-0 {
|
.opacity-0 {
|
||||||
opacity: 0%;
|
opacity: 0%;
|
||||||
}
|
}
|
||||||
@@ -687,10 +648,6 @@
|
|||||||
--tw-shadow: 0 1px 3px 0 var(--tw-shadow-color, rgb(0 0 0 / 0.1)), 0 1px 2px -1px var(--tw-shadow-color, rgb(0 0 0 / 0.1));
|
--tw-shadow: 0 1px 3px 0 var(--tw-shadow-color, rgb(0 0 0 / 0.1)), 0 1px 2px -1px var(--tw-shadow-color, rgb(0 0 0 / 0.1));
|
||||||
box-shadow: var(--tw-inset-shadow), var(--tw-inset-ring-shadow), var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow);
|
box-shadow: var(--tw-inset-shadow), var(--tw-inset-ring-shadow), var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow);
|
||||||
}
|
}
|
||||||
.outline {
|
|
||||||
outline-style: var(--tw-outline-style);
|
|
||||||
outline-width: 1px;
|
|
||||||
}
|
|
||||||
.transition {
|
.transition {
|
||||||
transition-property: color, background-color, border-color, outline-color, text-decoration-color, fill, stroke, --tw-gradient-from, --tw-gradient-via, --tw-gradient-to, opacity, box-shadow, transform, translate, scale, rotate, filter, -webkit-backdrop-filter, backdrop-filter, display, content-visibility, overlay, pointer-events;
|
transition-property: color, background-color, border-color, outline-color, text-decoration-color, fill, stroke, --tw-gradient-from, --tw-gradient-via, --tw-gradient-to, opacity, box-shadow, transform, translate, scale, rotate, filter, -webkit-backdrop-filter, backdrop-filter, display, content-visibility, overlay, pointer-events;
|
||||||
transition-timing-function: var(--tw-ease, var(--default-transition-timing-function));
|
transition-timing-function: var(--tw-ease, var(--default-transition-timing-function));
|
||||||
@@ -1165,11 +1122,6 @@
|
|||||||
inherits: false;
|
inherits: false;
|
||||||
initial-value: 0 0 #0000;
|
initial-value: 0 0 #0000;
|
||||||
}
|
}
|
||||||
@property --tw-outline-style {
|
|
||||||
syntax: "*";
|
|
||||||
inherits: false;
|
|
||||||
initial-value: solid;
|
|
||||||
}
|
|
||||||
@property --tw-duration {
|
@property --tw-duration {
|
||||||
syntax: "*";
|
syntax: "*";
|
||||||
inherits: false;
|
inherits: false;
|
||||||
@@ -1205,7 +1157,6 @@
|
|||||||
--tw-ring-offset-width: 0px;
|
--tw-ring-offset-width: 0px;
|
||||||
--tw-ring-offset-color: #fff;
|
--tw-ring-offset-color: #fff;
|
||||||
--tw-ring-offset-shadow: 0 0 #0000;
|
--tw-ring-offset-shadow: 0 0 #0000;
|
||||||
--tw-outline-style: solid;
|
|
||||||
--tw-duration: initial;
|
--tw-duration: initial;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user