Merge branch 'discord-oauth'

This commit is contained in:
2026-01-24 15:37:03 +11:00
52 changed files with 3927 additions and 424 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
*.db*
.logs/
server.log
keys/
bin/
tmp/
static/css/output.css

327
AGENTS.md Normal file
View File

@@ -0,0 +1,327 @@
# AGENTS.md - Developer Guide for oslstats
This document provides guidelines for AI coding agents and developers working on the oslstats codebase.
## Project Overview
**Module**: `git.haelnorr.com/h/oslstats`
**Language**: Go 1.25.5
**Architecture**: Web application with Discord OAuth, PostgreSQL database, templ templates
**Key Technologies**: Bun ORM, templ, TailwindCSS, custom golib libraries
## Build, Test, and Development Commands
### Building
```bash
# Full production build (tailwind → templ → go generate → go build)
make build
# Build and run
make run
# Clean build artifacts
make clean
```
### Development Mode
```bash
# Watch mode with hot reload (templ, air, tailwindcss in parallel)
make dev
# Development server runs on:
# - Proxy: http://localhost:3000 (use this)
# - App: http://localhost:3333 (internal)
```
### Testing
```bash
# Run all tests
go test ./...
# Run tests for a specific package
go test ./pkg/oauth
# Run a single test function
go test ./pkg/oauth -run TestGenerateState_Success
# Run tests with verbose output
go test -v ./pkg/oauth
# Run tests with coverage
go test -cover ./...
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out
```
### Database
```bash
# Run migrations
make migrate
# OR
./bin/oslstats --migrate
```
### Configuration Management
```bash
# Generate .env template file
make genenv
# OR with custom output: make genenv OUT=.env.example
# Show environment variable documentation
make envdoc
# Show current environment values
make showenv
```
## Code Style Guidelines
### Import Organization
Organize imports in **3 groups** separated by blank lines:
```go
import (
// 1. Standard library
"context"
"net/http"
"fmt"
// 2. External dependencies
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"github.com/uptrace/bun"
// 3. Internal packages
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/pkg/oauth"
)
```
### Naming Conventions
**Variables**:
- Local: `camelCase` (userAgentKey, httpServer, dbConn)
- Exported: `PascalCase` (Config, User, Token)
- Common abbreviations: `cfg`, `ctx`, `tx`, `db`, `err`, `w`, `r`
**Functions**:
- Exported: `PascalCase` (GetConfig, NewStore, GenerateState)
- Private: `camelCase` (throwError, shouldShowDetails, loadModels)
- HTTP handlers: Return `http.Handler`, use dependency injection pattern
- Database functions: Use `bun.Tx` as parameter for transactions
**Types**:
- Structs/Interfaces: `PascalCase` (Config, User, OAuthSession)
- Use `-er` suffix for interfaces (implied from usage)
**Files**:
- Prefer single word: `config.go`, `oauth.go`, `errors.go`
- Use snake_case if needed: `discord_tokens.go`, `state_test.go`
- Test files: `*_test.go` alongside source files
### Error Handling
**Always wrap errors** with context using `github.com/pkg/errors`:
```go
if err != nil {
return errors.Wrap(err, "operation_name")
}
```
**Validate inputs at function start**:
```go
func DoSomething(cfg *Config, data string) error {
if cfg == nil {
return errors.New("cfg cannot be nil")
}
if data == "" {
return errors.New("data cannot be empty")
}
// ... rest of function
}
```
**HTTP error helpers** (in handlers package):
- `throwInternalServiceError(s, w, r, msg, err)` - 500 errors
- `throwBadRequest(s, w, r, msg, err)` - 400 errors
- `throwForbidden(s, w, r, msg, err)` - 403 errors (normal)
- `throwForbiddenSecurity(s, w, r, msg, err)` - 403 security violations (WARN level)
- `throwUnauthorized(s, w, r, msg, err)` - 401 errors (normal)
- `throwUnauthorizedSecurity(s, w, r, msg, err)` - 401 security violations (WARN level)
- `throwNotFound(s, w, r, path)` - 404 errors
### Common Patterns
**HTTP Handler Pattern**:
```go
func HandlerName(server *hws.Server, deps ...) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// Handler logic here
},
)
}
```
**Database Operation Pattern**:
```go
func GetSomething(ctx context.Context, tx bun.Tx, id int) (*Result, error) {
result := new(Result)
err := tx.NewSelect().
Model(result).
Where("id = ?", id).
Scan(ctx)
if err != nil {
if err.Error() == "sql: no rows in result set" {
return nil, nil // Return nil, nil for not found
}
return nil, errors.Wrap(err, "tx.Select")
}
return result, nil
}
```
**Setup Function Pattern** (returns instance, cleanup func, error):
```go
func setupSomething(ctx context.Context, cfg *Config) (*Type, func() error, error) {
instance := newInstance()
err := configure(instance)
if err != nil {
return nil, nil, errors.Wrap(err, "configure")
}
return instance, instance.Close, nil
}
```
**Configuration Pattern** (using ezconf):
```go
type Config struct {
Field string // ENV FIELD_NAME: Description (required/default: value)
}
func ConfigFromEnv() (any, error) {
cfg := &Config{
Field: env.String("FIELD_NAME", "default"),
}
// Validation here
return cfg, nil
}
```
### Formatting & Types
**Formatting**:
- Use `gofmt` (standard Go formatting)
- No tabs vs spaces debate - Go uses tabs
**Types**:
- Prefer explicit types over inference when it improves clarity
- Use struct tags for ORM and JSON marshaling:
```go
type User struct {
bun.BaseModel `bun:"table:users,alias:u"`
ID int `bun:"id,pk,autoincrement"`
Username string `bun:"username,unique"`
AccessToken string `json:"access_token"`
}
```
**Comments**:
- Document exported functions and types
- Use inline comments for ENV var documentation in Config structs
- Explain security-critical code flows
### Testing
**Test File Location**: Place `*_test.go` files alongside source files
**Test Naming**:
```go
func TestFunctionName_Scenario(t *testing.T)
func TestGenerateState_Success(t *testing.T)
func TestVerifyState_WrongUserAgentKey(t *testing.T)
```
**Test Structure**:
- Use subtests with `t.Run()` for related scenarios
- Use table-driven tests for multiple similar cases
- Create helper functions for common setup (e.g., `testConfig()`)
- Test happy paths, error cases, edge cases, and security properties
**Test Categories** (from pkg/oauth/state_test.go example):
1. Happy path tests
2. Error handling (nil params, empty fields, malformed input)
3. Security tests (MITM, CSRF, replay attacks, tampering)
4. Edge cases (concurrency, constant-time comparison)
5. Integration tests (round-trip verification)
### Security
**Critical Practices**:
- Use `crypto/subtle.ConstantTimeCompare` for cryptographic comparisons
- Implement CSRF protection via state tokens
- Store sensitive cookies as HttpOnly
- Use separate logging levels for security violations (WARN)
- Validate all inputs at function boundaries
- Use parameterized queries (Bun ORM handles this)
- Never commit secrets (.env, keys/ are gitignored)
## Project Structure
```
oslstats/
├── cmd/oslstats/ # Application entry point
│ ├── main.go # Entry point with flag parsing
│ ├── run.go # Server initialization & graceful shutdown
│ ├── httpserver.go # HTTP server setup
│ ├── routes.go # Route registration
│ ├── middleware.go # Middleware registration
│ ├── auth.go # Authentication setup
│ └── db.go # Database connection & migrations
├── internal/ # Private application code
│ ├── config/ # Configuration aggregation
│ ├── db/ # Database models & queries (Bun ORM)
│ ├── discord/ # Discord OAuth integration
│ ├── handlers/ # HTTP request handlers
│ ├── session/ # Session store (in-memory)
│ └── view/ # Templ templates
│ ├── component/ # Reusable UI components
│ ├── layout/ # Page layouts
│ └── page/ # Full pages
├── pkg/ # Reusable packages
│ ├── contexts/ # Context key definitions
│ ├── embedfs/ # Embedded static files
│ └── oauth/ # OAuth state management
├── bin/ # Compiled binaries (gitignored)
├── keys/ # Private keys (gitignored)
├── tmp/ # Air hot reload temp files (gitignored)
├── Makefile # Build automation
├── .air.toml # Hot reload configuration
└── go.mod # Go module definition
```
## Key Dependencies
- **git.haelnorr.com/h/golib/*** - Custom libraries (env, ezconf, hlog, hws, hwsauth, cookies, jwt)
- **github.com/a-h/templ** - Type-safe HTML templating
- **github.com/uptrace/bun** - PostgreSQL ORM
- **github.com/bwmarrin/discordgo** - Discord API client
- **github.com/pkg/errors** - Error wrapping (use this, not fmt.Errorf)
- **github.com/joho/godotenv** - .env file loading
## Notes for AI Agents
1. **Never commit** .env files, keys/, or generated files (*_templ.go, output.css)
2. **Database operations** should use `bun.Tx` for transaction safety
3. **Templates** are written in templ, not Go html/template - run `templ generate` after changes
4. **Static files** are embedded via `//go:embed` - check pkg/embedfs/
5. **Error messages** should be descriptive and use errors.Wrap for context
6. **Security is critical** - especially in OAuth flows (see pkg/oauth/state_test.go for examples)
7. **Air proxy** runs on port 3000 during development; app runs on 3333
8. **Test coverage** is currently limited - prioritize testing security-critical code
9. **Configuration** uses ezconf pattern - see internal/*/ezconf.go files for examples
10. **Graceful shutdown** is implemented in cmd/oslstats/run.go - follow this pattern
11. When in plan mode, always use the interactive question tool if available

View File

@@ -4,7 +4,6 @@
BINARY_NAME=oslstats
build:
./scripts/generate-css-sources.sh && \
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css && \
go mod tidy && \
templ generate && \
@@ -16,7 +15,6 @@ run:
./bin/${BINARY_NAME}${SUFFIX}
dev:
./scripts/generate-css-sources.sh && \
templ generate --watch &\
air &\
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css --watch
@@ -36,3 +34,6 @@ showenv:
make build
./bin/${BINARY_NAME} --showenv
migrate:
make build
./bin/${BINARY_NAME}${SUFFIX} --migrate

View File

@@ -8,7 +8,6 @@ import (
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/oslstats/pkg/contexts"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
@@ -31,6 +30,7 @@ func setupAuth(
beginTx,
logger,
handlers.ErrorPage,
conn.DB,
)
if err != nil {
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
@@ -38,7 +38,9 @@ func setupAuth(
auth.IgnorePaths(ignoredPaths...)
contexts.CurrentUser = auth.CurrentModel
db.CurrentUser = auth.CurrentModel
return auth, nil
}
// TODO: make a new getuser function that wraps db.GetUserByID and does OAuth refresh

View File

@@ -20,7 +20,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func
conn = bun.NewDB(sqldb, pgdialect.New())
close = sqldb.Close
err = loadModels(ctx, conn, cfg.Flags.ResetDB)
err = loadModels(ctx, conn, cfg.Flags.MigrateDB)
if err != nil {
return nil, nil, errors.Wrap(err, "loadModels")
}
@@ -31,6 +31,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func
func loadModels(ctx context.Context, conn *bun.DB, resetDB bool) error {
models := []any{
(*db.User)(nil),
(*db.DiscordToken)(nil),
}
for _, model := range models {

View File

@@ -4,13 +4,15 @@ import (
"io/fs"
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/oslstats/internal/store"
)
func setupHttpServer(
@@ -18,6 +20,8 @@ func setupHttpServer(
config *config.Config,
logger *hlog.Logger,
bun *bun.DB,
store *store.Store,
discordAPI *discord.APIClient,
) (server *hws.Server, err error) {
if staticFS == nil {
return nil, errors.New("No filesystem provided")
@@ -53,7 +57,7 @@ func setupHttpServer(
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
}
err = addRoutes(httpServer, &fs, config, logger, bun, auth)
err = addRoutes(httpServer, &fs, config, bun, auth, store, discordAPI)
if err != nil {
return nil, errors.Wrap(err, "addRoutes")
}

View File

@@ -29,6 +29,16 @@ func main() {
return
}
if flags.MigrateDB {
_, closedb, err := setupBun(ctx, cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
closedb()
return
}
if err := run(ctx, os.Stdout, cfg); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)

View File

@@ -5,22 +5,24 @@ import (
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/golib/hlog"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/oslstats/internal/store"
)
func addRoutes(
server *hws.Server,
staticFS *http.FileSystem,
config *config.Config,
logger *hlog.Logger,
cfg *config.Config,
conn *bun.DB,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
store *store.Store,
discordAPI *discord.APIClient,
) error {
// Create the routes
routes := []hws.Route{
@@ -34,10 +36,38 @@ func addRoutes(
Method: hws.MethodGET,
Handler: handlers.Index(server),
},
{
Path: "/login",
Method: hws.MethodGET,
Handler: auth.LogoutReq(handlers.Login(server, cfg, store, discordAPI)),
},
{
Path: "/auth/callback",
Method: hws.MethodGET,
Handler: auth.LogoutReq(handlers.Callback(server, auth, conn, cfg, store, discordAPI)),
},
{
Path: "/register",
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: auth.LogoutReq(handlers.Register(server, auth, conn, cfg, store)),
},
{
Path: "/logout",
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: auth.LoginReq(handlers.Logout(server, auth, conn, discordAPI)),
},
}
htmxRoutes := []hws.Route{
{
Path: "/htmx/isusernameunique",
Method: hws.MethodPOST,
Handler: handlers.IsUsernameUnique(server, conn, cfg, store),
},
}
// Register the routes with the server
err := server.AddRoutes(routes...)
err := server.AddRoutes(append(routes, htmxRoutes...)...)
if err != nil {
return errors.Wrap(err, "server.AddRoutes")
}

View File

@@ -9,18 +9,21 @@ import (
"time"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/pkg/embedfs"
"github.com/pkg/errors"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/pkg/embedfs"
)
// Initializes and runs the server
func run(ctx context.Context, w io.Writer, config *config.Config) error {
func run(ctx context.Context, w io.Writer, cfg *config.Config) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
// Setup the logger
logger, err := hlog.NewLogger(config.HLOG, w)
logger, err := hlog.NewLogger(cfg.HLOG, w)
if err != nil {
return errors.Wrap(err, "hlog.NewLogger")
}
@@ -28,7 +31,7 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error {
// Setup the database connection
logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database")
bun, closedb, err := setupBun(ctx, config)
bun, closedb, err := setupBun(ctx, cfg)
if err != nil {
return errors.Wrap(err, "setupDBConn")
}
@@ -41,8 +44,19 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error {
return errors.Wrap(err, "getStaticFiles")
}
// Setup session store
logger.Debug().Msg("Setting up session store")
store := store.NewStore()
// Setup Discord API client
logger.Debug().Msg("Setting up Discord API client")
discordAPI, err := discord.NewAPIClient(cfg.Discord, logger, cfg.HWSAuth.TrustedHost)
if err != nil {
return errors.Wrap(err, "discord.NewAPIClient")
}
logger.Debug().Msg("Setting up HTTP server")
httpServer, err := setupHttpServer(&staticFS, config, logger, bun)
httpServer, err := setupHttpServer(&staticFS, cfg, logger, bun, store, discordAPI)
if err != nil {
return errors.Wrap(err, "setupHttpServer")
}

15
go.mod
View File

@@ -6,20 +6,25 @@ require (
git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/ezconf v0.1.1
git.haelnorr.com/h/golib/hlog v0.10.4
git.haelnorr.com/h/golib/hws v0.2.3
git.haelnorr.com/h/golib/hwsauth v0.3.4
git.haelnorr.com/h/golib/hws v0.3.1
git.haelnorr.com/h/golib/hwsauth v0.5.2
github.com/a-h/templ v0.3.977
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
github.com/uptrace/bun v1.2.16
github.com/uptrace/bun/dialect/pgdialect v1.2.16
github.com/uptrace/bun/driver/pgdriver v1.2.16
golang.org/x/crypto v0.45.0
)
require (
git.haelnorr.com/h/golib/cookies v0.9.0 // indirect
git.haelnorr.com/h/golib/jwt v0.10.0 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
golang.org/x/crypto v0.45.0 // indirect
)
require (
git.haelnorr.com/h/golib/cookies v0.9.0
git.haelnorr.com/h/golib/jwt v0.10.1 // indirect
github.com/bwmarrin/discordgo v0.29.0
github.com/go-logr/logr v1.4.3 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect

22
go.sum
View File

@@ -6,16 +6,18 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI
git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
git.haelnorr.com/h/golib/hws v0.2.3 h1:gZQkBciXKh3jYw05vZncSR2lvIqi0H2MVfIWySySsmw=
git.haelnorr.com/h/golib/hws v0.2.3/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
git.haelnorr.com/h/golib/hwsauth v0.3.4 h1:wwYBb6cQQ+x9hxmYuZBF4mVmCv/n4PjJV//e1+SgPOo=
git.haelnorr.com/h/golib/hwsauth v0.3.4/go.mod h1:LI7Qz68GPNIW8732Zwptb//ybjiFJOoXf4tgUuUEqHI=
git.haelnorr.com/h/golib/jwt v0.10.0 h1:8cI8mSnb8X+EmJtrBO/5UZwuBMtib0IE9dv85gkm94E=
git.haelnorr.com/h/golib/jwt v0.10.0/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/hws v0.3.1 h1:uFXAT8SuKs4VACBdrkmZ+dJjeBlSPgCKUPt8zGCcwrI=
git.haelnorr.com/h/golib/hws v0.3.1/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
git.haelnorr.com/h/golib/hwsauth v0.5.2 h1:K4McXMEHtI5o4fAL3AZrmaMkwORNqSTV3MM6BExNKag=
git.haelnorr.com/h/golib/hwsauth v0.5.2/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=
github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -28,6 +30,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
@@ -66,13 +70,19 @@ go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8=

View File

@@ -6,6 +6,8 @@ import (
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/pkg/oauth"
"github.com/joho/godotenv"
"github.com/pkg/errors"
)
@@ -15,6 +17,8 @@ type Config struct {
HWS *hws.Config
HWSAuth *hwsauth.Config
HLOG *hlog.Config
Discord *discord.Config
OAuth *oauth.Config
Flags *Flags
}
@@ -32,6 +36,8 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
hws.NewEZConfIntegration(),
hwsauth.NewEZConfIntegration(),
db.NewEZConfIntegration(),
discord.NewEZConfIntegration(),
oauth.NewEZConfIntegration(),
)
if err := loader.ParseEnvVars(); err != nil {
return nil, nil, errors.Wrap(err, "loader.ParseEnvVars")
@@ -65,11 +71,23 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
return nil, nil, errors.New("DB Config not loaded")
}
discordcfg, ok := loader.GetConfig("discord")
if !ok {
return nil, nil, errors.New("Dicord Config not loaded")
}
oauthcfg, ok := loader.GetConfig("oauth")
if !ok {
return nil, nil, errors.New("OAuth Config not loaded")
}
config := &Config{
DB: dbcfg.(*db.Config),
HWS: hwscfg.(*hws.Config),
HWSAuth: hwsauthcfg.(*hwsauth.Config),
HLOG: hlogcfg.(*hlog.Config),
Discord: discordcfg.(*discord.Config),
OAuth: oauthcfg.(*oauth.Config),
Flags: flags,
}

View File

@@ -5,16 +5,16 @@ import (
)
type Flags struct {
ResetDB bool
EnvDoc bool
ShowEnv bool
GenEnv string
EnvFile string
MigrateDB bool
EnvDoc bool
ShowEnv bool
GenEnv string
EnvFile string
}
func SetupFlags() *Flags {
// Parse commandline args
resetDB := flag.Bool("resetdb", false, "Reset all the database tables with the updated models")
migrateDB := flag.Bool("migrate", false, "Reset all the database tables with the updated models")
envDoc := flag.Bool("envdoc", false, "Print all environment variables and their documentation")
showEnv := flag.Bool("showenv", false, "Print all environment variable values and their documentation")
genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)")
@@ -22,11 +22,11 @@ func SetupFlags() *Flags {
flag.Parse()
flags := &Flags{
ResetDB: *resetDB,
EnvDoc: *envDoc,
ShowEnv: *showEnv,
GenEnv: *genEnv,
EnvFile: *envfile,
MigrateDB: *migrateDB,
EnvDoc: *envDoc,
ShowEnv: *showEnv,
GenEnv: *genEnv,
EnvFile: *envfile,
}
return flags
}

View File

@@ -0,0 +1,95 @@
package db
import (
"context"
"time"
"git.haelnorr.com/h/oslstats/internal/discord"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type DiscordToken struct {
bun.BaseModel `bun:"table:discord_tokens,alias:dt"`
DiscordID string `bun:"discord_id,pk,notnull"`
AccessToken string `bun:"access_token,notnull"`
RefreshToken string `bun:"refresh_token,notnull"`
ExpiresAt int64 `bun:"expires_at,notnull"`
Scope string `bun:"scope,notnull"`
TokenType string `bun:"token_type,notnull"`
}
// UpdateDiscordToken adds the provided discord token to the database.
// If the user already has a token stored, it will replace that token instead.
func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error {
if token == nil {
return errors.New("token cannot be nil")
}
expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix()
discordToken := &DiscordToken{
DiscordID: user.DiscordID,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
ExpiresAt: expiresAt,
Scope: token.Scope,
TokenType: token.TokenType,
}
_, err := tx.NewInsert().
Model(discordToken).
On("CONFLICT (discord_id) DO UPDATE").
Set("access_token = EXCLUDED.access_token").
Set("refresh_token = EXCLUDED.refresh_token").
Set("expires_at = EXCLUDED.expires_at").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
}
return nil
}
// DeleteDiscordTokens deletes a users discord OAuth tokens from the database.
// It returns the DiscordToken so that it can be revoked via the discord API
func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token, err := user.GetDiscordToken(ctx, tx)
if err != nil {
return nil, errors.Wrap(err, "user.GetDiscordToken")
}
_, err = tx.NewDelete().
Model((*DiscordToken)(nil)).
Where("discord_id = ?", user.DiscordID).
Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewDelete")
}
return token, nil
}
// GetDiscordToken retrieves the users discord token from the database
func (user *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token := new(DiscordToken)
err := tx.NewSelect().
Model(token).
Where("discord_id = ?", user.DiscordID).
Limit(1).
Scan(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return token, nil
}
// Convert reverts the token back into a *discord.Token
func (t *DiscordToken) Convert() *discord.Token {
token := &discord.Token{
AccessToken: t.AccessToken,
RefreshToken: t.RefreshToken,
ExpiresIn: int(t.ExpiresAt - time.Now().Unix()),
Scope: t.Scope,
TokenType: t.TokenType,
}
return token
}

View File

@@ -37,5 +37,5 @@ func (e EZConfIntegration) GroupName() string {
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{name: "db", configFunc: ConfigFromEnv}
return EZConfIntegration{name: "DB", configFunc: ConfigFromEnv}
}

View File

@@ -2,65 +2,30 @@ package db
import (
"context"
"fmt"
"time"
"git.haelnorr.com/h/golib/hwsauth"
"github.com/bwmarrin/discordgo"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
var CurrentUser hwsauth.ContextLoader[*User]
type User struct {
bun.BaseModel `bun:"table:users,alias:u"`
ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key)
Username string `bun:"username,unique"` // Username (unique)
PasswordHash string `bun:"password_hash,nullzero"` // Bcrypt hashed password (not exported in JSON)
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
Bio string `bun:"bio"` // Short byline set by the user
ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key)
Username string `bun:"username,unique"` // Username (unique)
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
DiscordID string `bun:"discord_id,unique"`
}
func (user *User) GetID() int {
return user.ID
}
// Uses bcrypt to set the users password_hash from the given password
func (user *User) SetPassword(ctx context.Context, tx bun.Tx, password string) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
}
newPassword := string(hashedPassword)
_, err = tx.NewUpdate().
Model(user).
Set("password_hash = ?", newPassword).
Where("id = ?", user.ID).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.Update")
}
return nil
}
// Uses bcrypt to check if the given password matches the users password_hash
func (user *User) CheckPassword(ctx context.Context, tx bun.Tx, password string) error {
var hashedPassword string
err := tx.NewSelect().
Table("users").
Column("password_hash").
Where("id = ?", user.ID).
Limit(1).
Scan(ctx, &hashedPassword)
if err != nil {
return errors.Wrap(err, "tx.Select")
}
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
if err != nil {
return errors.Wrap(err, "Username or password incorrect")
}
return nil
}
// Change the user's username
func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername string) error {
_, err := tx.NewUpdate().
@@ -75,35 +40,18 @@ func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername str
return nil
}
// Change the user's bio
func (user *User) ChangeBio(ctx context.Context, tx bun.Tx, newBio string) error {
_, err := tx.NewUpdate().
Model(user).
Set("bio = ?", newBio).
Where("id = ?", user.ID).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.Update")
}
user.Bio = newBio
return nil
}
// CreateUser creates a new user with the given username and password
func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*User, error) {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, errors.Wrap(err, "bcrypt.GenerateFromPassword")
func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) {
if discorduser == nil {
return nil, errors.New("user cannot be nil")
}
user := &User{
Username: username,
PasswordHash: string(hashedPassword),
CreatedAt: 0, // You may want to set this to time.Now().Unix()
Bio: "",
Username: username,
CreatedAt: time.Now().Unix(),
DiscordID: discorduser.ID,
}
_, err = tx.NewInsert().
_, err := tx.NewInsert().
Model(user).
Exec(ctx)
if err != nil {
@@ -116,6 +64,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*Use
// GetUserByID queries the database for a user matching the given ID
// Returns nil, nil if no user is found
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
fmt.Printf("user id requested: %v", id)
user := new(User)
err := tx.NewSelect().
Model(user).
@@ -149,6 +98,24 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User,
return user, nil
}
// GetUserByDiscordID queries the database for a user matching the given discord id
// Returns nil, nil if no user is found
func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) {
user := new(User)
err := tx.NewSelect().
Model(user).
Where("discord_id = ?", discordID).
Limit(1).
Scan(ctx)
if err != nil {
if err.Error() == "sql: no rows in result set" {
return nil, nil
}
return nil, errors.Wrap(err, "tx.Select")
}
return user, nil
}
// IsUsernameUnique checks if the given username is unique (not already taken)
// Returns true if the username is available, false if it's taken
func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) {

61
internal/discord/api.go Normal file
View File

@@ -0,0 +1,61 @@
package discord
import (
"net/http"
"sync"
"time"
"git.haelnorr.com/h/golib/hlog"
"github.com/bwmarrin/discordgo"
"github.com/pkg/errors"
)
type OAuthSession struct {
*discordgo.Session
}
func NewOAuthSession(token *Token) (*OAuthSession, error) {
session, err := discordgo.New("Bearer " + token.AccessToken)
if err != nil {
return nil, errors.Wrap(err, "discordgo.New")
}
return &OAuthSession{Session: session}, nil
}
func (s *OAuthSession) GetUser() (*discordgo.User, error) {
user, err := s.User("@me")
if err != nil {
return nil, errors.Wrap(err, "s.User")
}
return user, nil
}
// APIClient is an HTTP client wrapper that handles Discord API rate limits
type APIClient struct {
cfg *Config
client *http.Client
logger *hlog.Logger
mu sync.RWMutex
buckets map[string]*RateLimitState
trustedHost string
}
// NewAPIClient creates a new Discord API client with rate limit handling
func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APIClient, error) {
if cfg == nil {
return nil, errors.New("config cannot be nil")
}
if logger == nil {
return nil, errors.New("logger cannot be nil")
}
if trustedhost == "" {
return nil, errors.New("trustedhost cannot be empty")
}
return &APIClient{
client: &http.Client{Timeout: 30 * time.Second},
logger: logger,
buckets: make(map[string]*RateLimitState),
cfg: cfg,
trustedHost: trustedhost,
}, nil
}

View File

@@ -0,0 +1,50 @@
package discord
import (
"strings"
"git.haelnorr.com/h/golib/env"
"github.com/pkg/errors"
)
type Config struct {
ClientID string // ENV DISCORD_CLIENT_ID: Discord application client ID (required)
ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required)
OAuthScopes string // Authorisation scopes for OAuth
RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required)
}
func ConfigFromEnv() (any, error) {
cfg := &Config{
ClientID: env.String("DISCORD_CLIENT_ID", ""),
ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""),
OAuthScopes: getOAuthScopes(),
RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""),
}
// Check required fields
if cfg.ClientID == "" {
return nil, errors.New("Envar not set: DISCORD_CLIENT_ID")
}
if cfg.ClientSecret == "" {
return nil, errors.New("Envar not set: DISCORD_CLIENT_SECRET")
}
if cfg.RedirectPath == "" {
return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH")
}
return cfg, nil
}
func getOAuthScopes() string {
list := []string{
"connections",
"email",
"guilds",
"gdm.join",
"guilds.members.read",
"identify",
}
scopes := strings.Join(list, "+")
return scopes
}

View File

@@ -0,0 +1,41 @@
package discord
import (
"runtime"
"strings"
)
// EZConfIntegration provides integration with ezconf for automatic configuration
type EZConfIntegration struct {
configFunc func() (any, error)
name string
}
// PackagePath returns the path to the config package for source parsing
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
return func() (any, error) {
return e.configFunc()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return strings.ToLower(e.name)
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return e.name
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{name: "Discord", configFunc: ConfigFromEnv}
}

148
internal/discord/oauth.go Normal file
View File

@@ -0,0 +1,148 @@
package discord
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/pkg/errors"
)
// Token represents a response from the Discord OAuth API after a successful authorization request
type Token struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
}
const oauthurl string = "https://discord.com/oauth2/authorize"
const apiurl string = "https://discord.com/api/v10"
// GetOAuthLink generates a new Discord OAuth2 link for user authentication
func (api *APIClient) GetOAuthLink(state string) (string, error) {
if state == "" {
return "", errors.New("state cannot be empty")
}
values := url.Values{}
values.Add("response_type", "code")
values.Add("client_id", api.cfg.ClientID)
values.Add("scope", api.cfg.OAuthScopes)
values.Add("state", state)
values.Add("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath))
values.Add("prompt", "none")
return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil
}
// AuthorizeWithCode uses a users authorization token generated by OAuth2 to get a token for
// making requests to the API on behalf of the user
func (api *APIClient) AuthorizeWithCode(code string) (*Token, error) {
if code == "" {
return nil, errors.New("code cannot be empty")
}
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath))
req, err := http.NewRequest(
"POST",
apiurl+"/oauth2/token",
strings.NewReader(data.Encode()),
)
if err != nil {
return nil, errors.Wrap(err, "failed to create request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
resp, err := api.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to execute request")
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read response body")
}
if resp.StatusCode != http.StatusOK {
return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body))
}
var tokenResp Token
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.Wrap(err, "failed to parse token response")
}
return &tokenResp, nil
}
// RefreshToken uses the refresh token to generate a new token pair
func (api *APIClient) RefreshToken(token *Token) (*Token, error) {
if token == nil {
return nil, errors.New("token cannot be nil")
}
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", token.RefreshToken)
req, err := http.NewRequest(
"POST",
apiurl+"/oauth2/token",
strings.NewReader(data.Encode()),
)
if err != nil {
return nil, errors.Wrap(err, "failed to create request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
resp, err := api.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to execute request")
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read response body")
}
if resp.StatusCode != http.StatusOK {
return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body))
}
var tokenResp Token
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.Wrap(err, "failed to parse token response")
}
return &tokenResp, nil
}
// RevokeToken sends a request to the Discord API to revoke the token pair
func (api *APIClient) RevokeToken(token *Token) error {
if token == nil {
return errors.New("token cannot be nil")
}
data := url.Values{}
data.Set("token", token.AccessToken)
data.Set("token_type_hint", "access_token")
req, err := http.NewRequest(
"POST",
apiurl+"/oauth2/token/revoke",
strings.NewReader(data.Encode()),
)
if err != nil {
return errors.Wrap(err, "failed to create request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
resp, err := api.Do(req)
if err != nil {
return errors.Wrap(err, "failed to execute request")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.Errorf("discord API returned status %d", resp.StatusCode)
}
return nil
}

View File

@@ -0,0 +1,216 @@
package discord
import (
"net"
"net/http"
"strconv"
"time"
"github.com/pkg/errors"
)
// RateLimitState tracks rate limit information for a specific bucket
type RateLimitState struct {
Remaining int // Requests remaining in current window
Limit int // Total requests allowed in window
Reset time.Time // When the rate limit resets
Bucket string // Discord's bucket identifier
}
// Do executes an HTTP request with automatic rate limit handling
// It will wait if rate limits are about to be exceeded and retry once if a 429 is received
func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
if req == nil {
return nil, errors.New("request cannot be nil")
}
// Step 1: Check if we need to wait before making request
bucket := c.getBucketFromRequest(req)
if err := c.waitIfNeeded(bucket); err != nil {
return nil, err
}
// Step 2: Execute request
resp, err := c.client.Do(req)
if err != nil {
// Check if it's a network timeout
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil, errors.Wrap(err, "request timed out")
}
return nil, errors.Wrap(err, "http request failed")
}
// Step 3: Update rate limit state from response headers
c.updateRateLimit(resp.Header)
// Step 4: Handle 429 (rate limited)
if resp.StatusCode == http.StatusTooManyRequests {
resp.Body.Close() // Close original response
retryAfter := c.parseRetryAfter(resp.Header)
// No Retry-After header, can't retry safely
if retryAfter == 0 {
c.logger.Warn().
Str("bucket", bucket).
Str("method", req.Method).
Str("path", req.URL.Path).
Msg("Rate limited but no Retry-After header provided")
return nil, errors.New("discord API rate limited but no Retry-After header provided")
}
// Retry-After exceeds 30 second cap
if retryAfter > 30*time.Second {
c.logger.Warn().
Str("bucket", bucket).
Str("method", req.Method).
Str("path", req.URL.Path).
Dur("retry_after", retryAfter).
Msg("Rate limited with Retry-After exceeding 30s cap, not retrying")
return nil, errors.Errorf(
"discord API rate limited (retry after %s exceeds 30s cap)",
retryAfter,
)
}
// Wait and retry
c.logger.Warn().
Str("bucket", bucket).
Str("method", req.Method).
Str("path", req.URL.Path).
Dur("retry_after", retryAfter).
Msg("Rate limited, waiting before retry")
time.Sleep(retryAfter)
// Retry the request
resp, err = c.client.Do(req)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil, errors.Wrap(err, "retry request timed out")
}
return nil, errors.Wrap(err, "retry request failed")
}
// Update rate limit again after retry
c.updateRateLimit(resp.Header)
// If STILL rate limited after retry, return error
if resp.StatusCode == http.StatusTooManyRequests {
resp.Body.Close()
c.logger.Error().
Str("bucket", bucket).
Str("method", req.Method).
Str("path", req.URL.Path).
Msg("Still rate limited after retry, Discord may be experiencing issues")
return nil, errors.Errorf(
"discord API still rate limited after retry (waited %s), Discord may be experiencing issues",
retryAfter,
)
}
}
return resp, nil
}
// getBucketFromRequest extracts or generates bucket ID from request
// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers
func (c *APIClient) getBucketFromRequest(req *http.Request) string {
return req.Method + ":" + req.URL.Path
}
// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits
func (c *APIClient) waitIfNeeded(bucket string) error {
c.mu.RLock()
state, exists := c.buckets[bucket]
c.mu.RUnlock()
if !exists {
return nil // No state yet, proceed
}
now := time.Now()
// If we have no remaining requests and reset hasn't occurred, wait
if state.Remaining == 0 && now.Before(state.Reset) {
waitDuration := time.Until(state.Reset)
// Add small buffer (100ms) to ensure reset has occurred
waitDuration += 100 * time.Millisecond
if waitDuration > 0 {
c.logger.Debug().
Str("bucket", bucket).
Dur("wait_duration", waitDuration).
Msg("Proactively waiting for rate limit reset")
time.Sleep(waitDuration)
}
}
return nil
}
// updateRateLimit parses response headers and updates bucket state
func (c *APIClient) updateRateLimit(headers http.Header) {
bucket := headers.Get("X-RateLimit-Bucket")
if bucket == "" {
return // No bucket info, can't track
}
// Parse headers
limit := c.parseInt(headers.Get("X-RateLimit-Limit"))
remaining := c.parseInt(headers.Get("X-RateLimit-Remaining"))
resetAfter := c.parseFloat(headers.Get("X-RateLimit-Reset-After"))
state := &RateLimitState{
Bucket: bucket,
Limit: limit,
Remaining: remaining,
Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))),
}
c.mu.Lock()
c.buckets[bucket] = state
c.mu.Unlock()
// Log rate limit state for debugging
c.logger.Debug().
Str("bucket", bucket).
Int("remaining", remaining).
Int("limit", limit).
Dur("reset_in", time.Until(state.Reset)).
Msg("Rate limit state updated")
}
// parseRetryAfter extracts retry delay from Retry-After header
func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration {
retryAfter := headers.Get("Retry-After")
if retryAfter == "" {
return 0
}
// Discord returns seconds as float
seconds := c.parseFloat(retryAfter)
if seconds <= 0 {
return 0
}
return time.Duration(seconds * float64(time.Second))
}
// parseInt parses an integer from a header value, returns 0 on error
func (c *APIClient) parseInt(s string) int {
if s == "" {
return 0
}
i, _ := strconv.Atoi(s)
return i
}
// parseFloat parses a float from a header value, returns 0 on error
func (c *APIClient) parseFloat(s string) float64 {
if s == "" {
return 0
}
f, _ := strconv.ParseFloat(s, 64)
return f
}

View File

@@ -0,0 +1,517 @@
package discord
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"git.haelnorr.com/h/golib/hlog"
)
// testLogger creates a test logger for testing
func testLogger(t *testing.T) *hlog.Logger {
level, _ := hlog.LogLevel("debug")
cfg := &hlog.Config{
LogLevel: level,
LogOutput: "console",
}
logger, err := hlog.NewLogger(cfg, io.Discard)
if err != nil {
t.Fatalf("failed to create test logger: %v", err)
}
return logger
}
// testConfig creates a test config for testing
func testConfig() *Config {
return &Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
OAuthScopes: "identify+email",
RedirectPath: "/oauth/callback",
}
}
func TestNewRateLimitedClient(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
if client == nil {
t.Fatal("NewAPIClient returned nil")
}
if client.client == nil {
t.Error("client.client is nil")
}
if client.logger == nil {
t.Error("client.logger is nil")
}
if client.buckets == nil {
t.Error("client.buckets map is nil")
}
if client.cfg == nil {
t.Error("client.cfg is nil")
}
if client.trustedHost != "trusted-host.example.com" {
t.Errorf("expected trustedHost='trusted-host.example.com', got '%s'", client.trustedHost)
}
}
func TestAPIClient_Do_Success(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
// Mock server that returns success with rate limit headers
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
w.Header().Set("X-RateLimit-Limit", "5")
w.Header().Set("X-RateLimit-Remaining", "3")
w.Header().Set("X-RateLimit-Reset-After", "2.5")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}))
defer server.Close()
req, err := http.NewRequest("GET", server.URL+"/test", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Do() returned error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
// Check that rate limit state was updated
client.mu.RLock()
state, exists := client.buckets["test-bucket"]
client.mu.RUnlock()
if !exists {
t.Fatal("rate limit state not stored")
}
if state.Remaining != 3 {
t.Errorf("expected remaining=3, got %d", state.Remaining)
}
if state.Limit != 5 {
t.Errorf("expected limit=5, got %d", state.Limit)
}
}
func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
attemptCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount++
if attemptCount == 1 {
// First request: return 429
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
w.Header().Set("Retry-After", "0.1") // 100ms
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limited",
"error_description": "You are being rate limited",
})
return
}
// Second request: success
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
w.Header().Set("X-RateLimit-Limit", "5")
w.Header().Set("X-RateLimit-Remaining", "4")
w.Header().Set("X-RateLimit-Reset-After", "2.5")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}))
defer server.Close()
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
start := time.Now()
resp, err := client.Do(req)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Do() returned error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200 after retry, got %d", resp.StatusCode)
}
if attemptCount != 2 {
t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount)
}
// Should have waited approximately 100ms
if elapsed < 100*time.Millisecond {
t.Errorf("expected delay of ~100ms, but took %v", elapsed)
}
}
func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
attemptCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount++
// Always return 429
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
w.Header().Set("Retry-After", "0.05") // 50ms
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limited",
})
}))
defer server.Close()
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
t.Fatal("Do() should have returned error after failed retry")
}
if !strings.Contains(err.Error(), "still rate limited after retry") {
t.Errorf("expected 'still rate limited after retry' error, got: %v", err)
}
if attemptCount != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount)
}
}
func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limited",
})
}))
defer server.Close()
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
start := time.Now()
resp, err := client.Do(req)
elapsed := time.Since(start)
if err == nil {
resp.Body.Close()
t.Fatal("Do() should have returned error for Retry-After > 30s")
}
if !strings.Contains(err.Error(), "exceeds 30s cap") {
t.Errorf("expected 'exceeds 30s cap' error, got: %v", err)
}
// Should NOT have waited (immediate error)
if elapsed > 1*time.Second {
t.Errorf("should return immediately, but took %v", elapsed)
}
}
func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return 429 but NO Retry-After header
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limited",
})
}))
defer server.Close()
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
t.Fatal("Do() should have returned error when no Retry-After header")
}
if !strings.Contains(err.Error(), "no Retry-After header") {
t.Errorf("expected 'no Retry-After header' error, got: %v", err)
}
}
func TestAPIClient_UpdateRateLimit(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
headers := http.Header{}
headers.Set("X-RateLimit-Bucket", "global")
headers.Set("X-RateLimit-Limit", "10")
headers.Set("X-RateLimit-Remaining", "7")
headers.Set("X-RateLimit-Reset-After", "5.5")
client.updateRateLimit(headers)
client.mu.RLock()
state, exists := client.buckets["global"]
client.mu.RUnlock()
if !exists {
t.Fatal("bucket state not created")
}
if state.Bucket != "global" {
t.Errorf("expected bucket='global', got '%s'", state.Bucket)
}
if state.Limit != 10 {
t.Errorf("expected limit=10, got %d", state.Limit)
}
if state.Remaining != 7 {
t.Errorf("expected remaining=7, got %d", state.Remaining)
}
// Check reset time is approximately 5.5 seconds from now
resetIn := time.Until(state.Reset)
if resetIn < 5*time.Second || resetIn > 6*time.Second {
t.Errorf("expected reset in ~5.5s, got %v", resetIn)
}
}
func TestAPIClient_WaitIfNeeded(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
// Set up a bucket with 0 remaining and reset in future
bucket := "test-bucket"
client.mu.Lock()
client.buckets[bucket] = &RateLimitState{
Bucket: bucket,
Limit: 5,
Remaining: 0,
Reset: time.Now().Add(200 * time.Millisecond),
}
client.mu.Unlock()
start := time.Now()
err = client.waitIfNeeded(bucket)
elapsed := time.Since(start)
if err != nil {
t.Errorf("waitIfNeeded returned error: %v", err)
}
// Should have waited ~200ms + 100ms buffer
if elapsed < 200*time.Millisecond {
t.Errorf("expected wait of ~300ms, but took %v", elapsed)
}
if elapsed > 500*time.Millisecond {
t.Errorf("waited too long: %v", elapsed)
}
}
func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
// Set up a bucket with remaining requests
bucket := "test-bucket"
client.mu.Lock()
client.buckets[bucket] = &RateLimitState{
Bucket: bucket,
Limit: 5,
Remaining: 3,
Reset: time.Now().Add(5 * time.Second),
}
client.mu.Unlock()
start := time.Now()
err = client.waitIfNeeded(bucket)
elapsed := time.Since(start)
if err != nil {
t.Errorf("waitIfNeeded returned error: %v", err)
}
// Should NOT wait (has remaining requests)
if elapsed > 10*time.Millisecond {
t.Errorf("should not wait when remaining > 0, but took %v", elapsed)
}
}
func TestAPIClient_Do_Concurrent(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
requestCount := 0
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
requestCount++
count := requestCount
mu.Unlock()
w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket")
w.Header().Set("X-RateLimit-Limit", "10")
w.Header().Set("X-RateLimit-Remaining", "5")
w.Header().Set("X-RateLimit-Reset-After", "1.0")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))})
}))
defer server.Close()
// Launch 10 concurrent requests
var wg sync.WaitGroup
errors := make(chan error, 10)
for range 10 {
wg.Go(
func() {
req, err := http.NewRequest("GET", server.URL+"/test", nil)
if err != nil {
errors <- err
return
}
resp, err := client.Do(req)
if err != nil {
errors <- err
return
}
resp.Body.Close()
})
}
wg.Wait()
close(errors)
// Check for any errors
for err := range errors {
t.Errorf("concurrent request failed: %v", err)
}
// All requests should have completed
mu.Lock()
finalCount := requestCount
mu.Unlock()
if finalCount != 10 {
t.Errorf("expected 10 requests, got %d", finalCount)
}
// Check rate limit state is consistent (no data races)
client.mu.RLock()
state, exists := client.buckets["concurrent-bucket"]
client.mu.RUnlock()
if !exists {
t.Fatal("bucket state not found after concurrent requests")
}
// State should exist and be valid
if state.Limit != 10 {
t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit)
}
}
func TestAPIClient_ParseRetryAfter(t *testing.T) {
logger := testLogger(t)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
tests := []struct {
name string
header string
expected time.Duration
}{
{"integer seconds", "2", 2 * time.Second},
{"float seconds", "2.5", 2500 * time.Millisecond},
{"zero", "0", 0},
{"empty", "", 0},
{"invalid", "abc", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
headers := http.Header{}
headers.Set("Retry-After", tt.header)
result := client.parseRetryAfter(headers)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

View File

@@ -0,0 +1,205 @@
package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/pkg/oauth"
)
func Callback(
server *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB,
cfg *config.Config,
store *store.Store,
discordAPI *discord.APIClient,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
if exceeded {
err := errors.Errorf(
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
)
store.ClearRedirectTrack(r, "/callback")
throwError(
server,
w,
r,
http.StatusBadRequest,
"OAuth callback failed: Too many redirect attempts. Please try logging in again.",
err,
"warn",
)
return
}
state := r.URL.Query().Get("state")
code := r.URL.Query().Get("code")
if state == "" && code == "" {
http.Redirect(w, r, "/", http.StatusBadRequest)
return
}
data, err := verifyState(cfg.OAuth, w, r, state)
if err != nil {
if vsErr, ok := err.(*verifyStateError); ok {
if vsErr.IsCookieError() {
throwUnauthorized(server, w, r, "OAuth session not found or expired", err)
} else {
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
}
} else {
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
}
return
}
store.ClearRedirectTrack(r, "/callback")
switch data {
case "login":
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(server, w, r, "DB Transaction failed to start", err)
return
}
defer tx.Rollback()
redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
if err != nil {
throwInternalServiceError(server, w, r, "OAuth login failed", err)
return
}
tx.Commit()
redirect()
return
}
},
)
}
type verifyStateError struct {
err error
cookieError bool
}
func (e *verifyStateError) Error() string {
return e.err.Error()
}
func (e *verifyStateError) IsCookieError() bool {
return e.cookieError
}
func verifyState(
cfg *oauth.Config,
w http.ResponseWriter,
r *http.Request,
state string,
) (string, error) {
if r == nil {
return "", errors.New("request cannot be nil")
}
if state == "" {
return "", errors.New("state param field is empty")
}
uak, err := oauth.GetStateCookie(r)
if err != nil {
return "", &verifyStateError{
err: errors.Wrap(err, "oauth.GetStateCookie"),
cookieError: true,
}
}
data, err := oauth.VerifyState(cfg, state, uak)
if err != nil {
return "", &verifyStateError{
err: errors.Wrap(err, "oauth.VerifyState"),
cookieError: false,
}
}
oauth.DeleteStateCookie(w)
return data, nil
}
func login(
ctx context.Context,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
tx bun.Tx,
cfg *config.Config,
w http.ResponseWriter,
r *http.Request,
code string,
store *store.Store,
discordAPI *discord.APIClient,
) (func(), error) {
token, err := discordAPI.AuthorizeWithCode(code)
if err != nil {
return nil, errors.Wrap(err, "discordAPI.AuthorizeWithCode")
}
session, err := discord.NewOAuthSession(token)
if err != nil {
return nil, errors.Wrap(err, "discord.NewOAuthSession")
}
discorduser, err := session.GetUser()
if err != nil {
return nil, errors.Wrap(err, "session.GetUser")
}
user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserByDiscordID")
}
var redirect string
if user == nil {
sessionID, err := store.CreateRegistrationSession(discorduser, token)
if err != nil {
return nil, errors.Wrap(err, "store.CreateRegistrationSession")
}
http.SetCookie(w, &http.Cookie{
Name: "registration_session",
Path: "/",
Value: sessionID,
MaxAge: 300, // 5 minutes
HttpOnly: true,
Secure: cfg.HWSAuth.SSL,
SameSite: http.SameSiteLaxMode,
})
redirect = "/register"
} else {
err = user.UpdateDiscordToken(ctx, tx, token)
if err != nil {
return nil, errors.Wrap(err, "user.UpdateDiscordToken")
}
err := auth.Login(w, r, user, true)
if err != nil {
return nil, errors.Wrap(err, "auth.Login")
}
redirect = cookies.CheckPageFrom(w, r)
}
return func() {
http.Redirect(w, r, redirect, http.StatusSeeOther)
}, nil
}

View File

@@ -5,24 +5,93 @@ import (
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/view/page"
"github.com/pkg/errors"
)
func ErrorPage(
errorCode int,
) (hws.ErrorPage, error) {
// func ErrorPage(
// error hws.HWSError,
// ) (hws.ErrorPage, error) {
// messages := map[int]string{
// 400: "The request you made was malformed or unexpected.",
// 401: "You need to login to view this page.",
// 403: "You do not have permission to view this page.",
// 404: "The page or resource you have requested does not exist.",
// 500: `An error occured on the server. Please try again, and if this
// continues to happen contact an administrator.`,
// 503: "The server is currently down for maintenance and should be back soon. =)",
// }
// msg, exists := messages[error.StatusCode]
// if !exists {
// return nil, errors.New("No valid message for the given code")
// }
// return page.Error(error.StatusCode, http.StatusText(error.StatusCode), msg), nil
// }
func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
// Determine if this status code should show technical details
showDetails := shouldShowDetails(hwsError.StatusCode)
// Get the user-friendly message
message := hwsError.Message
if message == "" {
// Fallback to default messages if no custom message provided
message = getDefaultMessage(hwsError.StatusCode)
}
// Get technical details if applicable
var details string
if showDetails && hwsError.Error != nil {
details = hwsError.Error.Error()
}
// Render appropriate template
if details != "" {
return page.ErrorWithDetails(
hwsError.StatusCode,
http.StatusText(hwsError.StatusCode),
message,
details,
), nil
}
return page.Error(
hwsError.StatusCode,
http.StatusText(hwsError.StatusCode),
message,
), nil
}
// shouldShowDetails determines if a status code should display technical details
func shouldShowDetails(statusCode int) bool {
switch statusCode {
case 400, 500, 503: // Bad Request, Internal Server Error, Service Unavailable
return true
case 401, 403, 404: // Unauthorized, Forbidden, Not Found
return false
default:
// For unknown codes, show details for 5xx errors
return statusCode >= 500
}
}
// getDefaultMessage provides fallback messages for status codes
func getDefaultMessage(statusCode int) string {
messages := map[int]string{
400: "The request you made was malformed or unexpected.",
401: "You need to login to view this page.",
403: "You do not have permission to view this page.",
404: "The page or resource you have requested does not exist.",
500: `An error occured on the server. Please try again, and if this
continues to happen contact an administrator.`,
500: `An error occurred on the server. Please try again, and if this
continues to happen contact an administrator.`,
503: "The server is currently down for maintenance and should be back soon. =)",
}
msg, exists := messages[errorCode]
msg, exists := messages[statusCode]
if !exists {
return nil, errors.New("No valid message for the given code")
if statusCode >= 500 {
return "A server error occurred. Please try again later."
}
return "An error occurred while processing your request."
}
return page.Error(errorCode, http.StatusText(errorCode), msg), nil
return msg
}

109
internal/handlers/errors.go Normal file
View File

@@ -0,0 +1,109 @@
package handlers
import (
"fmt"
"net/http"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
)
// throwError is a generic helper that all throw* functions use internally
func throwError(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
statusCode int,
msg string,
err error,
level string,
) {
err = s.ThrowError(w, r, hws.HWSError{
StatusCode: statusCode,
Message: msg,
Error: err,
Level: hws.ErrorLevel(level),
RenderErrorPage: true, // throw* family always renders error pages
})
if err != nil {
s.ThrowFatal(w, err)
}
}
// throwInternalServiceError handles 500 errors (server failures)
func throwInternalServiceError(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusInternalServerError, msg, err, "error")
}
// throwBadRequest handles 400 errors (malformed requests)
func throwBadRequest(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusBadRequest, msg, err, "debug")
}
// throwForbidden handles 403 errors (normal permission denials)
func throwForbidden(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusForbidden, msg, err, "debug")
}
// throwForbiddenSecurity handles 403 errors for security events (uses WARN level)
func throwForbiddenSecurity(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusForbidden, msg, err, "warn")
}
// throwUnauthorized handles 401 errors (not authenticated)
func throwUnauthorized(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusUnauthorized, msg, err, "debug")
}
// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level)
func throwUnauthorizedSecurity(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusUnauthorized, msg, err, "warn")
}
// throwNotFound handles 404 errors
func throwNotFound(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
path string,
) {
msg := fmt.Sprintf("The requested resource was not found: %s", path)
err := errors.New("Resource not found")
throwError(s, w, r, http.StatusNotFound, msg, err, "debug")
}

View File

@@ -14,34 +14,7 @@ func Index(server *hws.Server) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
page, err := ErrorPage(http.StatusNotFound)
if err != nil {
err = server.ThrowError(w, r, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured trying to generate the error page",
Error: err,
Level: hws.ErrorLevel("error"),
RenderErrorPage: false,
})
if err != nil {
server.ThrowFatal(w, err)
}
return
}
err = page.Render(r.Context(), w)
if err != nil {
err = server.ThrowError(w, r, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured trying to render the error page",
Error: err,
Level: hws.ErrorLevel("error"),
RenderErrorPage: false,
})
if err != nil {
server.ThrowFatal(w, err)
}
return
}
throwNotFound(server, w, r, r.URL.Path)
}
page.Index().Render(r.Context(), w)
},

View File

@@ -0,0 +1,45 @@
package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/store"
"github.com/uptrace/bun"
)
func IsUsernameUnique(
server *hws.Server,
conn *bun.DB,
cfg *config.Config,
store *store.Store,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("username")
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(server, w, r, "Database transaction failed", err)
return
}
defer tx.Rollback()
unique, err := db.IsUsernameUnique(ctx, tx, username)
if err != nil {
throwInternalServiceError(server, w, r, "Database query failed", err)
return
}
tx.Commit()
if !unique {
w.WriteHeader(http.StatusConflict)
} else {
w.WriteHeader(http.StatusOK)
}
},
)
}

View File

@@ -0,0 +1,63 @@
package handlers
import (
"net/http"
"git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/pkg/oauth"
)
func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
if exceeded {
err := errors.Errorf(
"login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
)
st.ClearRedirectTrack(r, "/login")
throwError(
server,
w,
r,
http.StatusBadRequest,
"Login failed: Too many redirect attempts. Please clear your browser cookies and try again.",
err,
"warn",
)
return
}
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
if err != nil {
throwInternalServiceError(server, w, r, "Failed to generate state token", err)
return
}
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
link, err := discordAPI.GetOAuthLink(state)
if err != nil {
throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err)
return
}
st.ClearRedirectTrack(r, "/login")
http.Redirect(w, r, link, http.StatusSeeOther)
},
)
}

View File

@@ -0,0 +1,59 @@
package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func Logout(
server *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB,
discordAPI *discord.APIClient,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
return
}
defer tx.Rollback()
user := db.CurrentUser(r.Context())
if user == nil {
// JIC - should be impossible to get here if route is protected by LoginReq
w.Header().Set("HX-Redirect", "/")
return
}
token, err := user.DeleteDiscordTokens(ctx, tx)
if err != nil {
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
return
}
err = discordAPI.RevokeToken(token.Convert())
if err != nil {
throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
return
}
err = auth.Logout(tx, w, r)
if err != nil {
throwInternalServiceError(server, w, r, "Logout failed", err)
return
}
tx.Commit()
w.Header().Set("HX-Redirect", "/")
},
)
}

View File

@@ -0,0 +1,129 @@
package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/view/page"
)
func Register(
server *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB,
cfg *config.Config,
store *store.Store,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
attempts, exceeded, track := store.TrackRedirect(r, "/register", 3)
if exceeded {
err := errors.Errorf(
"registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
cfg.HWSAuth.SSL,
)
store.ClearRedirectTrack(r, "/register")
throwError(
server,
w,
r,
http.StatusBadRequest,
"Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.",
err,
"warn",
)
return
}
sessionCookie, err := r.Cookie("registration_session")
if err != nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
details, ok := store.GetRegistrationSession(sessionCookie.Value)
if !ok {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
store.ClearRedirectTrack(r, "/register")
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(server, w, r, "Database transaction failed", err)
return
}
defer tx.Rollback()
method := r.Method
if method == "GET" {
tx.Commit()
page.Register(details.DiscordUser.Username).Render(r.Context(), w)
return
}
if method == "POST" {
username := r.FormValue("username")
user, err := registerUser(ctx, tx, username, details)
if err != nil {
throwInternalServiceError(server, w, r, "Registration failed", err)
return
}
tx.Commit()
if user == nil {
w.WriteHeader(http.StatusConflict)
} else {
err = auth.Login(w, r, user, true)
if err != nil {
throwInternalServiceError(server, w, r, "Login failed", err)
return
}
pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom)
}
return
}
},
)
}
func registerUser(
ctx context.Context,
tx bun.Tx,
username string,
details *store.RegistrationSession,
) (*db.User, error) {
unique, err := db.IsUsernameUnique(ctx, tx, username)
if err != nil {
return nil, errors.Wrap(err, "db.IsUsernameUnique")
}
if !unique {
return nil, nil
}
user, err := db.CreateUser(ctx, tx, username, details.DiscordUser)
if err != nil {
return nil, errors.Wrap(err, "db.CreateUser")
}
err = user.UpdateDiscordToken(ctx, tx, details.Token)
if err != nil {
return nil, errors.Wrap(err, "db.UpdateDiscordToken")
}
return user, nil
}

View File

@@ -15,16 +15,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler {
if err != nil {
// If we can't create the file server, return a handler that always errors
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err = server.ThrowError(w, r, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured trying to load the file system",
Error: err,
Level: hws.ErrorLevel("error"),
RenderErrorPage: true,
})
if err != nil {
server.ThrowFatal(w, err)
}
throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err)
})
}

View File

@@ -0,0 +1,46 @@
package store
import (
"errors"
"time"
"git.haelnorr.com/h/oslstats/internal/discord"
"github.com/bwmarrin/discordgo"
)
type RegistrationSession struct {
DiscordUser *discordgo.User
Token *discord.Token
ExpiresAt time.Time
}
func (s *Store) CreateRegistrationSession(user *discordgo.User, token *discord.Token) (string, error) {
if user == nil {
return "", errors.New("user cannot be nil")
}
if token == nil {
return "", errors.New("token cannot be nil")
}
id := generateID()
s.sessions.Store(id, &RegistrationSession{
DiscordUser: user,
Token: token,
ExpiresAt: time.Now().Add(5 * time.Minute),
})
return id, nil
}
func (s *Store) GetRegistrationSession(id string) (*RegistrationSession, bool) {
val, ok := s.sessions.Load(id)
if !ok {
return nil, false
}
session := val.(*RegistrationSession)
if time.Now().After(session.ExpiresAt) {
s.sessions.Delete(id)
return nil, false
}
return session, true
}

View File

@@ -0,0 +1,95 @@
package store
import (
"net"
"net/http"
"strings"
"time"
)
// getClientIP extracts the client IP address, checking X-Forwarded-For first
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (comma-separated list, first is client)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
ips := strings.Split(xff, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// Fall back to RemoteAddr (format: "IP:port" or "[IPv6]:port")
// Use net.SplitHostPort to properly handle both IPv4 and IPv6
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// If SplitHostPort fails, return as-is (shouldn't happen with valid RemoteAddr)
return r.RemoteAddr
}
return host
}
// TrackRedirect increments the redirect counter for this IP+UA+Path combination
// Returns the current attempt count, whether limit was exceeded, and the track details
func (s *Store) TrackRedirect(r *http.Request, path string, maxAttempts int) (attempts int, exceeded bool, track *RedirectTrack) {
if r == nil {
return 0, false, nil
}
ip := getClientIP(r)
userAgent := r.UserAgent()
key := redirectKey(ip, userAgent, path)
now := time.Now()
expiresAt := now.Add(5 * time.Minute)
// Try to load existing track
val, exists := s.redirectTracks.Load(key)
if exists {
track = val.(*RedirectTrack)
// Check if expired
if now.After(track.ExpiresAt) {
// Expired, start fresh
track = &RedirectTrack{
IP: ip,
UserAgent: userAgent,
Path: path,
Attempts: 1,
FirstSeen: now,
ExpiresAt: expiresAt,
}
s.redirectTracks.Store(key, track)
return 1, false, track
}
// Increment existing
track.Attempts++
track.ExpiresAt = expiresAt // Extend expiry
exceeded = track.Attempts >= maxAttempts
return track.Attempts, exceeded, track
}
// Create new track
track = &RedirectTrack{
IP: ip,
UserAgent: userAgent,
Path: path,
Attempts: 1,
FirstSeen: now,
ExpiresAt: expiresAt,
}
s.redirectTracks.Store(key, track)
return 1, false, track
}
// ClearRedirectTrack removes a redirect tracking entry (called after successful completion)
func (s *Store) ClearRedirectTrack(r *http.Request, path string) {
if r == nil {
return
}
ip := getClientIP(r)
userAgent := r.UserAgent()
key := redirectKey(ip, userAgent, path)
s.redirectTracks.Delete(key)
}

80
internal/store/store.go Normal file
View File

@@ -0,0 +1,80 @@
package store
import (
"crypto/rand"
"encoding/base64"
"fmt"
"sync"
"time"
)
// RedirectTrack represents a single redirect attempt tracking entry
type RedirectTrack struct {
IP string // Client IP (X-Forwarded-For aware)
UserAgent string // Full User-Agent string for debugging
Path string // Request path (without query params)
Attempts int // Number of redirect attempts
FirstSeen time.Time // When first redirect was tracked
ExpiresAt time.Time // When to clean up this entry
}
type Store struct {
sessions sync.Map // key: string, value: *RegistrationSession
redirectTracks sync.Map // key: string, value: *RedirectTrack
cleanup *time.Ticker
}
func NewStore() *Store {
s := &Store{
cleanup: time.NewTicker(1 * time.Minute),
}
// Background cleanup of expired sessions
go func() {
for range s.cleanup.C {
s.cleanupExpired()
}
}()
return s
}
func (s *Store) Delete(id string) {
s.sessions.Delete(id)
}
func (s *Store) cleanupExpired() {
now := time.Now()
// Clean up expired registration sessions
s.sessions.Range(func(key, value any) bool {
session := value.(*RegistrationSession)
if now.After(session.ExpiresAt) {
s.sessions.Delete(key)
}
return true
})
// Clean up expired redirect tracks
s.redirectTracks.Range(func(key, value any) bool {
track := value.(*RedirectTrack)
if now.After(track.ExpiresAt) {
s.redirectTracks.Delete(key)
}
return true
})
}
func generateID() string {
b := make([]byte, 32)
rand.Read(b)
return base64.RawURLEncoding.EncodeToString(b)
}
// redirectKey generates a unique key for tracking redirects
// Uses IP + first 100 chars of UA + path as key (not hashed for debugging)
func redirectKey(ip, userAgent, path string) string {
ua := userAgent
if len(ua) > 100 {
ua = ua[:100]
}
return fmt.Sprintf("%s:%s:%s", ip, ua, path)
}

View File

@@ -0,0 +1,89 @@
package form
templ RegisterForm(username string) {
<form
hx-post="/register"
hx-swap="none"
x-data={ templ.JSFuncCall("registerFormData").CallInline }
@submit="handleSubmit()"
@htmx:after-request="if(submitTimeout) clearTimeout(submitTimeout); const redirect = $event.detail.xhr.getResponseHeader('HX-Redirect'); if(redirect) return; if(!$event.detail.successful) { isSubmitting=false; buttontext='Register'; if($event.detail.xhr.status === 409) { errorMessage='Username is already taken'; isUnique=false; } else { errorMessage='An error occurred. Please try again.'; } }"
>
<script>
function registerFormData() {
return {
canSubmit: false,
buttontext: "Register",
errorMessage: "",
isChecking: false,
isUnique: false,
isEmpty: true,
isSubmitting: false,
submitTimeout: null,
resetErr() {
this.errorMessage = "";
this.isChecking = false;
this.isUnique = false;
},
enableSubmit() {
this.canSubmit = true;
},
handleSubmit() {
this.isSubmitting = true;
this.buttontext = 'Loading...';
// Set timeout for 10 seconds
this.submitTimeout = setTimeout(() => {
this.isSubmitting = false;
this.buttontext = 'Register';
this.errorMessage = 'Request timed out. Please try again.';
}, 10000);
}
};
}
</script>
<div
class="grid gap-y-4"
>
<div>
<div class="relative">
<input
type="text"
id="username"
name="username"
x-bind:class="{
'py-3 px-4 block w-full rounded-lg text-sm bg-base disabled:opacity-50 disabled:pointer-events-none border-2 outline-none': true,
'border-overlay0 focus:border-blue': !isUnique && !errorMessage,
'border-green focus:border-green': isUnique && !isChecking && !errorMessage,
'border-red focus:border-red': errorMessage && !isChecking && !isSubmitting
}"
required
aria-describedby="username-error"
value={ username }
@input="resetErr(); isEmpty = $el.value.trim() === ''; if(isEmpty) { errorMessage='Username is required'; isUnique=false; }"
hx-post="/htmx/isusernameunique"
hx-trigger="load delay:100ms, input changed delay:500ms"
hx-swap="none"
@htmx:before-request="if($el.value.trim() === '') { isEmpty=true; return; } isEmpty=false; isChecking=true; isUnique=false; errorMessage=''"
@htmx:after-request="isChecking=false; if($event.detail.successful) { isUnique=true; canSubmit=true; } else if($event.detail.xhr.status === 409) { errorMessage='Username is already taken'; isUnique=false; canSubmit=false; }"
/>
<p
class="text-center text-xs text-red mt-2"
id="username-error"
x-show="errorMessage && !isSubmitting"
x-cloak
x-text="errorMessage"
></p>
</div>
</div>
<button
x-bind:disabled="isEmpty || !isUnique || isChecking || isSubmitting"
x-text="buttontext"
type="submit"
class="w-full py-3 px-4 inline-flex justify-center items-center
gap-x-2 rounded-lg border border-transparent transition
bg-green hover:bg-green/75 text-mantle hover:cursor-pointer
disabled:bg-green/60 disabled:cursor-default"
></button>
</div>
</form>
}

View File

@@ -1,6 +1,6 @@
package nav
import "git.haelnorr.com/h/oslstats/pkg/contexts"
import "git.haelnorr.com/h/oslstats/internal/db"
type ProfileItem struct {
name string // Label to display
@@ -23,7 +23,7 @@ func getProfileItems() []ProfileItem {
// Returns the right portion of the navbar
templ navRight() {
{{ user := contexts.CurrentUser(ctx) }}
{{ user := db.CurrentUser(ctx) }}
{{ items := getProfileItems() }}
<div class="flex items-center gap-2">
<div class="sm:flex sm:gap-2">

View File

@@ -1,10 +1,10 @@
package nav
import "git.haelnorr.com/h/oslstats/pkg/contexts"
import "git.haelnorr.com/h/oslstats/internal/db"
// Returns the mobile version of the navbar thats only visible when activated
templ sideNav(navItems []NavItem) {
{{ user := contexts.CurrentUser(ctx) }}
{{ user := db.CurrentUser(ctx) }}
<div
x-show="open"
x-transition

View File

@@ -18,14 +18,11 @@ templ Global(title string) {
window.matchMedia('(prefers-color-scheme: dark)').matches)}"
>
<head>
<!-- <script src="/static/js/theme.js"></script> -->
<script src="/static/js/theme.js"></script>
<meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<title>{ title }</title>
<link rel="icon" type="image/x-icon" href="/static/favicon.ico"/>
<link rel="preconnect" href="https://fonts.googleapis.com"/>
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin="anonymous"/>
<link href="https://fonts.googleapis.com/css2?family=Ubuntu+Mono:ital,wght@0,400;0,700;1,400;1,700&display=swap" rel="stylesheet"/>
<link href="/static/css/output.css" rel="stylesheet"/>
<script src="https://unpkg.com/htmx.org@2.0.4" integrity="sha384-HGfztofotfshcF7+8n44JQL2oJmowVChPTg48S+jvZoztPfvwD79OC/LTtG6dMp+" crossorigin="anonymous"></script>
<script defer src="https://cdn.jsdelivr.net/npm/@alpinejs/persist@3.x.x/dist/cdn.min.js"></script>
@@ -38,19 +35,6 @@ templ Global(title string) {
const bodyData = {
showError500: false,
showError503: false,
showConfirmPasswordModal: false,
handleHtmxBeforeOnLoad(event) {
const requestPath = event.detail.pathInfo.requestPath;
if (requestPath === "/reauthenticate") {
// handle password incorrect on refresh attempt
if (event.detail.xhr.status === 445) {
event.detail.shouldSwap = true;
event.detail.isError = false;
} else if (event.detail.xhr.status === 200) {
this.showConfirmPasswordModal = false;
}
}
},
// handle errors from the server on HTMX requests
handleHtmxError(event) {
const errorCode = event.detail.errorInfo.error;
@@ -65,11 +49,6 @@ templ Global(title string) {
this.showError503 = true;
setTimeout(() => (this.showError503 = false), 6000);
}
// user is authorized but needs to refresh their login
if (errorCode.includes("Code 444")) {
this.showConfirmPasswordModal = true;
}
},
};
</script>
@@ -78,7 +57,6 @@ templ Global(title string) {
class="bg-base text-text ubuntu-mono-regular overflow-x-hidden"
x-data="bodyData"
x-on:htmx:error="handleHtmxError($event)"
x-on:htmx:before-on-load="handleHtmxBeforeOnLoad($event)"
>
@popup.Error500Popup()
@popup.Error503Popup()

View File

@@ -3,32 +3,66 @@ package page
import "git.haelnorr.com/h/oslstats/internal/view/layout"
import "strconv"
// Page template for Error pages. Error code should be a HTTP status code as
// a string, and err should be the corresponding response title.
// Message is a custom error message displayed below the code and error.
// Original Error template (keep for backwards compatibility where needed)
templ Error(code int, err string, message string) {
@ErrorWithDetails(code, err, message, "")
}
// Enhanced Error template with optional details section
templ ErrorWithDetails(code int, err string, message string, details string) {
@layout.Global(err) {
<div
class="grid mt-24 left-0 right-0 top-0 bottom-0
place-content-center bg-base px-4"
>
<div class="text-center">
<h1
class="text-9xl text-text"
>{ strconv.Itoa(code) }</h1>
<p
class="text-2xl font-bold tracking-tight text-subtext1
sm:text-4xl"
>{ err }</p>
<p
class="mt-4 text-subtext0"
>{ message }</p>
<a
href="/"
class="mt-6 inline-block rounded-lg bg-mauve px-5 py-3
text-sm text-crust transition hover:bg-mauve/75"
>Go to homepage</a>
<div class="grid mt-24 left-0 right-0 top-0 bottom-0 place-content-center bg-base px-4">
<div class="text-center max-w-2xl mx-auto">
<h1 class="text-9xl text-text">{ strconv.Itoa(code) }</h1>
<p class="text-2xl font-bold tracking-tight text-subtext1 sm:text-4xl">{ err }</p>
// Always show the message from hws.HWSError.Message
<p class="mt-4 text-subtext0">{ message }</p>
// Conditionally show technical details in dropdown
if details != "" {
<div class="mt-8 text-left">
<details class="bg-surface0 rounded-lg p-4 text-right">
<summary class="text-left cursor-pointer text-subtext1 font-semibold select-none hover:text-text">
Details
<span class="text-xs text-subtext0 ml-2">(click to expand)</span>
</summary>
<div class="text-left mt-4 relative">
<pre id="details" class="text-xs text-subtext0 font-mono whitespace-pre-wrap break-all bg-mantle p-4 rounded overflow-x-auto">{ details }</pre>
</div>
<button
onclick="copyToClipboard('details')"
id="copyButton"
class="mt-2 bg-mauve text-crust px-3 py-1 rounded text-xs hover:bg-mauve/75 transition hover:cursor-pointer"
title="Copy to clipboard"
>
Copy
</button>
</details>
</div>
}
<a href="/" class="mt-6 inline-block rounded-lg bg-mauve px-5 py-3 text-sm text-crust transition hover:bg-mauve/75">
Go to homepage
</a>
</div>
</div>
if details != "" {
<script>
function copyToClipboard(id) {
var details = document.getElementById(id).innerText;
var button = document.getElementById("copyButton");
navigator.clipboard
.writeText(details)
.then(function () {
button.innerText = "Copied!";
setTimeout(function () {
button.innerText = "Copy";
}, 2000);
})
.catch(function (err) {
console.error("Failed to copy:", err);
button.innerText = "Failed";
});
}
</script>
}
}
}

View File

@@ -0,0 +1,29 @@
package page
import "git.haelnorr.com/h/oslstats/internal/view/layout"
import "git.haelnorr.com/h/oslstats/internal/view/component/form"
// Returns the login page
templ Register(username string) {
@layout.Global("Register") {
<div class="max-w-100 mx-auto px-2">
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
<div class="p-4 sm:p-7">
<div class="text-center">
<h1
class="block text-2xl font-bold"
>Set your display name</h1>
<p
class="mt-2 text-sm text-subtext0"
>
Select your display name. This must be unique, and cannot be changed.
</p>
</div>
<div class="mt-5">
@form.RegisterForm(username)
</div>
</div>
</div>
</div>
}
}

View File

@@ -1,8 +0,0 @@
package contexts
import (
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db"
)
var CurrentUser hwsauth.ContextLoader[*db.User]

View File

@@ -1,7 +1,7 @@
package contexts
type contextKey string
type Key string
func (c contextKey) String() string {
func (c Key) String() string {
return "oslstats context key " + string(c)
}

View File

@@ -1,19 +1,10 @@
@import url("https://fonts.googleapis.com/css2?family=Ubuntu+Mono:ital,wght@0,400;0,700;1,400;1,700&display=swap");
@import "tailwindcss";
@source "../../../../internal/view/component/footer/footer.templ";
@source "../../../../internal/view/component/nav/navbarleft.templ";
@source "../../../../internal/view/component/nav/navbarright.templ";
@source "../../../../internal/view/component/nav/navbar.templ";
@source "../../../../internal/view/component/nav/sidenav.templ";
@source "../../../../internal/view/component/popup/error500Popup.templ";
@source "../../../../internal/view/component/popup/error503Popup.templ";
@source "../../../../internal/view/layout/global.templ";
@source "../../../../internal/view/page/error.templ";
@source "../../../../internal/view/page/index.templ";
[x-cloak] {
display: none !important;
}
@theme inline {
--color-rosewater: var(--rosewater);
--color-flamingo: var(--flamingo);
@@ -43,6 +34,7 @@
--color-mantle: var(--mantle);
--color-crust: var(--crust);
}
:root {
--rosewater: hsl(11, 59%, 67%);
--flamingo: hsl(0, 60%, 67%);
@@ -102,6 +94,7 @@
--mantle: hsl(240, 21%, 12%);
--crust: hsl(240, 23%, 9%);
}
.ubuntu-mono-regular {
font-family: "Ubuntu Mono", serif;
font-weight: 400;

View File

@@ -1,4 +1,5 @@
/*! tailwindcss v4.1.18 | MIT License | https://tailwindcss.com */
@import url("https://fonts.googleapis.com/css2?family=Ubuntu+Mono:ital,wght@0,400;0,700;1,400;1,700&display=swap");
@layer properties;
@layer theme, base, components, utilities;
@layer theme {
@@ -10,7 +11,10 @@
--spacing: 0.25rem;
--breakpoint-xl: 80rem;
--container-md: 28rem;
--container-2xl: 42rem;
--container-7xl: 80rem;
--text-xs: 0.75rem;
--text-xs--line-height: calc(1 / 0.75);
--text-sm: 0.875rem;
--text-sm--line-height: calc(1.25 / 0.875);
--text-lg: 1.125rem;
@@ -28,11 +32,13 @@
--text-9xl: 8rem;
--text-9xl--line-height: 1;
--font-weight-medium: 500;
--font-weight-semibold: 600;
--font-weight-bold: 700;
--tracking-tight: -0.025em;
--leading-relaxed: 1.625;
--radius-sm: 0.25rem;
--radius-lg: 0.5rem;
--radius-xl: 0.75rem;
--default-transition-duration: 150ms;
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
--default-font-family: var(--font-sans);
@@ -208,6 +214,9 @@
.relative {
position: relative;
}
.static {
position: static;
}
.end-0 {
inset-inline-end: calc(var(--spacing) * 0);
}
@@ -244,9 +253,18 @@
.mt-4 {
margin-top: calc(var(--spacing) * 4);
}
.mt-5 {
margin-top: calc(var(--spacing) * 5);
}
.mt-6 {
margin-top: calc(var(--spacing) * 6);
}
.mt-7 {
margin-top: calc(var(--spacing) * 7);
}
.mt-8 {
margin-top: calc(var(--spacing) * 8);
}
.mt-10 {
margin-top: calc(var(--spacing) * 10);
}
@@ -265,6 +283,9 @@
.mb-auto {
margin-bottom: auto;
}
.ml-2 {
margin-left: calc(var(--spacing) * 2);
}
.ml-auto {
margin-left: auto;
}
@@ -321,9 +342,15 @@
.w-full {
width: 100%;
}
.max-w-2xl {
max-width: var(--container-2xl);
}
.max-w-7xl {
max-width: var(--container-7xl);
}
.max-w-100 {
max-width: calc(var(--spacing) * 100);
}
.max-w-md {
max-width: var(--container-md);
}
@@ -344,6 +371,9 @@
.transform {
transform: var(--tw-rotate-x,) var(--tw-rotate-y,) var(--tw-rotate-z,) var(--tw-skew-x,) var(--tw-skew-y,);
}
.cursor-pointer {
cursor: pointer;
}
.flex-col {
flex-direction: column;
}
@@ -381,6 +411,12 @@
margin-block-end: calc(calc(var(--spacing) * 1) * calc(1 - var(--tw-space-y-reverse)));
}
}
.gap-x-2 {
column-gap: calc(var(--spacing) * 2);
}
.gap-y-4 {
row-gap: calc(var(--spacing) * 4);
}
.divide-y {
:where(& > :not(:last-child)) {
--tw-divide-y-reverse: 0;
@@ -398,9 +434,15 @@
.overflow-hidden {
overflow: hidden;
}
.overflow-x-auto {
overflow-x: auto;
}
.overflow-x-hidden {
overflow-x: hidden;
}
.rounded {
border-radius: 0.25rem;
}
.rounded-full {
border-radius: calc(infinity * 1px);
}
@@ -410,19 +452,35 @@
.rounded-sm {
border-radius: var(--radius-sm);
}
.rounded-xl {
border-radius: var(--radius-xl);
}
.border {
border-style: var(--tw-border-style);
border-width: 1px;
}
.border-2 {
border-style: var(--tw-border-style);
border-width: 2px;
}
.border-green {
border-color: var(--green);
}
.border-overlay0 {
border-color: var(--overlay0);
}
.border-red {
border-color: var(--red);
}
.border-surface1 {
border-color: var(--surface1);
}
.border-transparent {
border-color: transparent;
}
.bg-base {
background-color: var(--base);
}
.bg-blue {
background-color: var(--blue);
}
.bg-crust {
background-color: var(--crust);
}
@@ -456,12 +514,21 @@
.p-4 {
padding: calc(var(--spacing) * 4);
}
.px-2 {
padding-inline: calc(var(--spacing) * 2);
}
.px-3 {
padding-inline: calc(var(--spacing) * 3);
}
.px-4 {
padding-inline: calc(var(--spacing) * 4);
}
.px-5 {
padding-inline: calc(var(--spacing) * 5);
}
.py-1 {
padding-block: calc(var(--spacing) * 1);
}
.py-2 {
padding-block: calc(var(--spacing) * 2);
}
@@ -480,6 +547,15 @@
.text-center {
text-align: center;
}
.text-left {
text-align: left;
}
.text-right {
text-align: right;
}
.font-mono {
font-family: var(--font-mono);
}
.text-2xl {
font-size: var(--text-2xl);
line-height: var(--tw-leading, var(--text-2xl--line-height));
@@ -508,6 +584,10 @@
font-size: var(--text-xl);
line-height: var(--tw-leading, var(--text-xl--line-height));
}
.text-xs {
font-size: var(--text-xs);
line-height: var(--tw-leading, var(--text-xs--line-height));
}
.leading-relaxed {
--tw-leading: var(--leading-relaxed);
line-height: var(--leading-relaxed);
@@ -520,10 +600,20 @@
--tw-font-weight: var(--font-weight-medium);
font-weight: var(--font-weight-medium);
}
.font-semibold {
--tw-font-weight: var(--font-weight-semibold);
font-weight: var(--font-weight-semibold);
}
.tracking-tight {
--tw-tracking: var(--tracking-tight);
letter-spacing: var(--tracking-tight);
}
.break-all {
word-break: break-all;
}
.whitespace-pre-wrap {
white-space: pre-wrap;
}
.text-crust {
color: var(--crust);
}
@@ -568,6 +658,14 @@
--tw-duration: 200ms;
transition-duration: 200ms;
}
.outline-none {
--tw-outline-style: none;
outline-style: none;
}
.select-none {
-webkit-user-select: none;
user-select: none;
}
.hover\:cursor-pointer {
&:hover {
@media (hover: hover) {
@@ -575,16 +673,6 @@
}
}
}
.hover\:bg-blue\/75 {
&:hover {
@media (hover: hover) {
background-color: var(--blue);
@supports (color: color-mix(in lab, red, red)) {
background-color: color-mix(in oklab, var(--blue) 75%, transparent);
}
}
}
}
.hover\:bg-crust {
&:hover {
@media (hover: hover) {
@@ -673,6 +761,51 @@
}
}
}
.hover\:text-text {
&:hover {
@media (hover: hover) {
color: var(--text);
}
}
}
.focus\:border-blue {
&:focus {
border-color: var(--blue);
}
}
.focus\:border-green {
&:focus {
border-color: var(--green);
}
}
.focus\:border-red {
&:focus {
border-color: var(--red);
}
}
.disabled\:pointer-events-none {
&:disabled {
pointer-events: none;
}
}
.disabled\:cursor-default {
&:disabled {
cursor: default;
}
}
.disabled\:bg-green\/60 {
&:disabled {
background-color: var(--green);
@supports (color: color-mix(in lab, red, red)) {
background-color: color-mix(in oklab, var(--green) 60%, transparent);
}
}
}
.disabled\:opacity-50 {
&:disabled {
opacity: 50%;
}
}
.sm\:end-6 {
@media (width >= 40rem) {
inset-inline-end: calc(var(--spacing) * 6);
@@ -693,11 +826,6 @@
display: none;
}
}
.sm\:inline {
@media (width >= 40rem) {
display: inline;
}
}
.sm\:justify-between {
@media (width >= 40rem) {
justify-content: space-between;
@@ -708,6 +836,11 @@
gap: calc(var(--spacing) * 2);
}
}
.sm\:p-7 {
@media (width >= 40rem) {
padding: calc(var(--spacing) * 7);
}
}
.sm\:px-6 {
@media (width >= 40rem) {
padding-inline: calc(var(--spacing) * 6);

View File

@@ -1,3 +1,5 @@
// This function prevents the 'flash of unstyled content'
// Include it at the top of <head>
(function() {
let theme = localStorage.getItem("theme") || "system";
if (theme === "system") {

23
pkg/oauth/config.go Normal file
View File

@@ -0,0 +1,23 @@
package oauth
import (
"git.haelnorr.com/h/golib/env"
"github.com/pkg/errors"
)
type Config struct {
PrivateKey string // ENV OAUTH_PRIVATE_KEY: Private key for signing OAuth state tokens (required)
}
func ConfigFromEnv() (any, error) {
cfg := &Config{
PrivateKey: env.String("OAUTH_PRIVATE_KEY", ""),
}
// Check required fields
if cfg.PrivateKey == "" {
return nil, errors.New("Envar not set: OAUTH_PRIVATE_KEY")
}
return cfg, nil
}

45
pkg/oauth/cookies.go Normal file
View File

@@ -0,0 +1,45 @@
package oauth
import (
"encoding/base64"
"net/http"
"github.com/pkg/errors"
)
func SetStateCookie(w http.ResponseWriter, uak []byte, ssl bool) {
encodedUak := base64.RawURLEncoding.EncodeToString(uak)
http.SetCookie(w, &http.Cookie{
Name: "oauth_uak",
Value: encodedUak,
Path: "/",
MaxAge: 300,
HttpOnly: true,
Secure: ssl,
SameSite: http.SameSiteLaxMode,
})
}
func GetStateCookie(r *http.Request) ([]byte, error) {
if r == nil {
return nil, errors.New("Request cannot be nil")
}
cookie, err := r.Cookie("oauth_uak")
if err != nil {
return nil, err
}
uak, err := base64.RawURLEncoding.DecodeString(cookie.Value)
if err != nil {
return nil, errors.Wrap(err, "failed to decode userAgentKey from cookie")
}
return uak, nil
}
func DeleteStateCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: "oauth_uak",
Value: "",
Path: "/",
MaxAge: -1,
})
}

41
pkg/oauth/ezconf.go Normal file
View File

@@ -0,0 +1,41 @@
package oauth
import (
"runtime"
"strings"
)
// EZConfIntegration provides integration with ezconf for automatic configuration
type EZConfIntegration struct {
configFunc func() (any, error)
name string
}
// PackagePath returns the path to the config package for source parsing
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
return func() (any, error) {
return e.configFunc()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return strings.ToLower(e.name)
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return e.name
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{name: "OAuth", configFunc: ConfigFromEnv}
}

117
pkg/oauth/state.go Normal file
View File

@@ -0,0 +1,117 @@
package oauth
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"slices"
"strings"
"github.com/pkg/errors"
)
// STATE FLOW:
// data provided at call time to be retrieved later
// random value generated on the spot
// userAgentKey - nonce used to prevent MITM, stored as lax cookie on client
// privateKey - from config
func GenerateState(cfg *Config, data string) (state string, userAgentKey []byte, err error) {
// signature = BASE64_SHA256(data + "." + random + userAgentKey + privateKey)
// state = data + "." + random + "." + signature
if cfg == nil {
return "", nil, errors.New("cfg cannot be nil")
}
if cfg.PrivateKey == "" {
return "", nil, errors.New("private key cannot be empty")
}
if data == "" {
return "", nil, errors.New("data cannot be empty")
}
// Generate 32 random bytes for random component
randomBytes := make([]byte, 32)
_, err = rand.Read(randomBytes)
if err != nil {
return "", nil, errors.Wrap(err, "failed to generate random bytes")
}
// Generate 32 random bytes for userAgentKey
userAgentKey = make([]byte, 32)
_, err = rand.Read(userAgentKey)
if err != nil {
return "", nil, errors.Wrap(err, "failed to generate userAgentKey bytes")
}
// Encode random and userAgentKey to base64
randomEncoded := base64.RawURLEncoding.EncodeToString(randomBytes)
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
// Create payload for signing: data + "." + random + userAgentKey + privateKey
// Note: userAgentKey is concatenated directly with privateKey (no separator)
payload := data + "." + randomEncoded + userAgentKeyEncoded + cfg.PrivateKey
// Generate signature
hash := sha256.Sum256([]byte(payload))
signature := base64.RawURLEncoding.EncodeToString(hash[:])
// Construct state: data + "." + random + "." + signature
state = data + "." + randomEncoded + "." + signature
return state, userAgentKey, nil
}
func VerifyState(cfg *Config, state string, userAgentKey []byte) (data string, err error) {
// Validate inputs
if cfg == nil {
return "", errors.New("cfg cannot be nil")
}
if cfg.PrivateKey == "" {
return "", errors.New("private key cannot be empty")
}
if state == "" {
return "", errors.New("state cannot be empty")
}
if len(userAgentKey) == 0 {
return "", errors.New("userAgentKey cannot be empty")
}
// Split state into parts
parts := strings.Split(state, ".")
if len(parts) != 3 {
return "", errors.Errorf("state must have exactly 3 parts (data.random.signature), got %d parts", len(parts))
}
// Check for empty parts
if slices.Contains(parts, "") {
return "", errors.New("state parts cannot be empty")
}
data = parts[0]
random := parts[1]
receivedSignature := parts[2]
// Encode userAgentKey to base64 for payload reconstruction
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
// Reconstruct payload (same as generation): data + "." + random + userAgentKeyEncoded + privateKey
payload := data + "." + random + userAgentKeyEncoded + cfg.PrivateKey
// Generate expected hash
hash := sha256.Sum256([]byte(payload))
// Decode received signature to bytes
receivedBytes, err := base64.RawURLEncoding.DecodeString(receivedSignature)
if err != nil {
return "", errors.Wrap(err, "failed to decode received signature")
}
// Compare hash bytes directly with decoded signature using constant-time comparison
// This is more efficient than encoding hash and then decoding both for comparison
if subtle.ConstantTimeCompare(hash[:], receivedBytes) == 1 {
return data, nil
}
return "", errors.New("invalid state signature")
}

817
pkg/oauth/state_test.go Normal file
View File

@@ -0,0 +1,817 @@
package oauth
import (
"crypto/sha256"
"encoding/base64"
"strings"
"testing"
)
// Helper function to create a test config
func testConfig() *Config {
return &Config{
PrivateKey: "test_private_key_for_testing_12345",
}
}
// TestGenerateState_Success tests the happy path of state generation
func TestGenerateState_Success(t *testing.T) {
cfg := testConfig()
data := "test_data_payload"
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if state == "" {
t.Error("Expected non-empty state")
}
if len(userAgentKey) != 32 {
t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey))
}
// Verify state format: data.random.signature
parts := strings.Split(state, ".")
if len(parts) != 3 {
t.Errorf("Expected state to have 3 parts, got %d", len(parts))
}
// Verify data is preserved
if parts[0] != data {
t.Errorf("Expected data to be '%s', got '%s'", data, parts[0])
}
// Verify random part is base64 encoded
randomBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Errorf("Expected random part to be valid base64: %v", err)
}
if len(randomBytes) != 32 {
t.Errorf("Expected random to be 32 bytes when decoded, got %d", len(randomBytes))
}
// Verify signature part is base64 encoded
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
t.Errorf("Expected signature part to be valid base64: %v", err)
}
if len(sigBytes) != 32 {
t.Errorf("Expected signature to be 32 bytes (SHA256), got %d", len(sigBytes))
}
}
// TestGenerateState_NilConfig tests that nil config returns error
func TestGenerateState_NilConfig(t *testing.T) {
_, _, err := GenerateState(nil, "test_data")
if err == nil {
t.Fatal("Expected error for nil config, got nil")
}
if !strings.Contains(err.Error(), "cfg cannot be nil") {
t.Errorf("Expected error message about nil config, got: %v", err)
}
}
// TestGenerateState_EmptyPrivateKey tests that empty private key returns error
func TestGenerateState_EmptyPrivateKey(t *testing.T) {
cfg := &Config{PrivateKey: ""}
_, _, err := GenerateState(cfg, "test_data")
if err == nil {
t.Fatal("Expected error for empty private key, got nil")
}
if !strings.Contains(err.Error(), "private key cannot be empty") {
t.Errorf("Expected error message about empty private key, got: %v", err)
}
}
// TestGenerateState_EmptyData tests that empty data returns error
func TestGenerateState_EmptyData(t *testing.T) {
cfg := testConfig()
_, _, err := GenerateState(cfg, "")
if err == nil {
t.Fatal("Expected error for empty data, got nil")
}
if !strings.Contains(err.Error(), "data cannot be empty") {
t.Errorf("Expected error message about empty data, got: %v", err)
}
}
// TestGenerateState_Randomness tests that multiple calls generate different states
func TestGenerateState_Randomness(t *testing.T) {
cfg := testConfig()
data := "same_data"
state1, _, err1 := GenerateState(cfg, data)
state2, _, err2 := GenerateState(cfg, data)
if err1 != nil || err2 != nil {
t.Fatalf("Unexpected errors: %v, %v", err1, err2)
}
if state1 == state2 {
t.Error("Expected different states for multiple calls, got identical states")
}
}
// TestGenerateState_DifferentData tests states with different data payloads
func TestGenerateState_DifferentData(t *testing.T) {
cfg := testConfig()
testCases := []string{
"simple",
"with-dashes",
"with_underscores",
"123456789",
"MixedCase123",
}
for _, data := range testCases {
t.Run(data, func(t *testing.T) {
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Unexpected error for data '%s': %v", data, err)
}
if !strings.HasPrefix(state, data+".") {
t.Errorf("Expected state to start with '%s.', got: %s", data, state)
}
if len(userAgentKey) != 32 {
t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey))
}
})
}
}
// TestVerifyState_Success tests the happy path of state verification
func TestVerifyState_Success(t *testing.T) {
cfg := testConfig()
data := "test_data"
// Generate state
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Verify state
extractedData, err := VerifyState(cfg, state, userAgentKey)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if extractedData != data {
t.Errorf("Expected extracted data to be '%s', got '%s'", data, extractedData)
}
}
// TestVerifyState_NilConfig tests that nil config returns error
func TestVerifyState_NilConfig(t *testing.T) {
_, err := VerifyState(nil, "state", []byte("key"))
if err == nil {
t.Fatal("Expected error for nil config, got nil")
}
if !strings.Contains(err.Error(), "cfg cannot be nil") {
t.Errorf("Expected error message about nil config, got: %v", err)
}
}
// TestVerifyState_EmptyPrivateKey tests that empty private key returns error
func TestVerifyState_EmptyPrivateKey(t *testing.T) {
cfg := &Config{PrivateKey: ""}
_, err := VerifyState(cfg, "state", []byte("key"))
if err == nil {
t.Fatal("Expected error for empty private key, got nil")
}
if !strings.Contains(err.Error(), "private key cannot be empty") {
t.Errorf("Expected error message about empty private key, got: %v", err)
}
}
// TestVerifyState_EmptyState tests that empty state returns error
func TestVerifyState_EmptyState(t *testing.T) {
cfg := testConfig()
_, err := VerifyState(cfg, "", []byte("key"))
if err == nil {
t.Fatal("Expected error for empty state, got nil")
}
if !strings.Contains(err.Error(), "state cannot be empty") {
t.Errorf("Expected error message about empty state, got: %v", err)
}
}
// TestVerifyState_EmptyUserAgentKey tests that empty userAgentKey returns error
func TestVerifyState_EmptyUserAgentKey(t *testing.T) {
cfg := testConfig()
_, err := VerifyState(cfg, "data.random.signature", []byte{})
if err == nil {
t.Fatal("Expected error for empty userAgentKey, got nil")
}
if !strings.Contains(err.Error(), "userAgentKey cannot be empty") {
t.Errorf("Expected error message about empty userAgentKey, got: %v", err)
}
}
// TestVerifyState_WrongUserAgentKey tests MITM protection
func TestVerifyState_WrongUserAgentKey(t *testing.T) {
cfg := testConfig()
// Generate first state
state, _, err := GenerateState(cfg, "test_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Generate a different userAgentKey
_, wrongKey, err := GenerateState(cfg, "other_data")
if err != nil {
t.Fatalf("Failed to generate second state: %v", err)
}
// Try to verify with wrong key
_, err = VerifyState(cfg, state, wrongKey)
if err == nil {
t.Error("Expected error for invalid signature")
}
if !strings.Contains(err.Error(), "invalid state signature") {
t.Errorf("Expected error about invalid signature, got: %v", err)
}
}
// TestVerifyState_TamperedData tests tampering detection
func TestVerifyState_TamperedData(t *testing.T) {
cfg := testConfig()
// Generate state
state, userAgentKey, err := GenerateState(cfg, "original_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Tamper with the data portion
parts := strings.Split(state, ".")
parts[0] = "tampered_data"
tamperedState := strings.Join(parts, ".")
// Try to verify tampered state
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("Expected error for tampered state")
}
}
// TestVerifyState_TamperedRandom tests tampering with random portion
func TestVerifyState_TamperedRandom(t *testing.T) {
cfg := testConfig()
// Generate state
state, userAgentKey, err := GenerateState(cfg, "test_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Tamper with the random portion
parts := strings.Split(state, ".")
parts[1] = base64.RawURLEncoding.EncodeToString([]byte("tampered_random_value_here12"))
tamperedState := strings.Join(parts, ".")
// Try to verify tampered state
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("Expected error for tampered state")
}
}
// TestVerifyState_TamperedSignature tests tampering with signature
func TestVerifyState_TamperedSignature(t *testing.T) {
cfg := testConfig()
// Generate state
state, userAgentKey, err := GenerateState(cfg, "test_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Tamper with the signature portion
parts := strings.Split(state, ".")
// Create a different valid base64 string
parts[2] = base64.RawURLEncoding.EncodeToString(sha256.New().Sum([]byte("tampered")))
tamperedState := strings.Join(parts, ".")
// Try to verify tampered state
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("Expected error for tampered signature")
}
}
// TestVerifyState_MalformedState_TwoParts tests state with only 2 parts
func TestVerifyState_MalformedState_TwoParts(t *testing.T) {
cfg := testConfig()
malformedState := "data.random"
_, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890"))
if err == nil {
t.Fatal("Expected error for malformed state")
}
if !strings.Contains(err.Error(), "must have exactly 3 parts") {
t.Errorf("Expected error about incorrect number of parts, got: %v", err)
}
}
// TestVerifyState_MalformedState_FourParts tests state with 4 parts
func TestVerifyState_MalformedState_FourParts(t *testing.T) {
cfg := testConfig()
malformedState := "data.random.signature.extra"
_, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890"))
if err == nil {
t.Fatal("Expected error for malformed state")
}
if !strings.Contains(err.Error(), "must have exactly 3 parts") {
t.Errorf("Expected error about incorrect number of parts, got: %v", err)
}
}
// TestVerifyState_EmptyStateParts tests state with empty parts
func TestVerifyState_EmptyStateParts(t *testing.T) {
cfg := testConfig()
testCases := []struct {
name string
state string
}{
{"empty data", ".random.signature"},
{"empty random", "data..signature"},
{"empty signature", "data.random."},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := VerifyState(cfg, tc.state, []byte("key123456789012345678901234567890"))
if err == nil {
t.Fatal("Expected error for state with empty parts")
}
if !strings.Contains(err.Error(), "state parts cannot be empty") {
t.Errorf("Expected error about empty parts, got: %v", err)
}
})
}
}
// TestVerifyState_InvalidBase64Signature tests state with invalid base64 in signature
func TestVerifyState_InvalidBase64Signature(t *testing.T) {
cfg := testConfig()
invalidState := "data.random.invalid@base64!"
_, err := VerifyState(cfg, invalidState, []byte("key123456789012345678901234567890"))
if err == nil {
t.Fatal("Expected error for invalid base64 signature")
}
if !strings.Contains(err.Error(), "failed to decode") {
t.Errorf("Expected error about decoding signature, got: %v", err)
}
}
// TestVerifyState_DifferentPrivateKey tests that different private keys fail verification
func TestVerifyState_DifferentPrivateKey(t *testing.T) {
cfg1 := &Config{PrivateKey: "private_key_1"}
cfg2 := &Config{PrivateKey: "private_key_2"}
// Generate with first config
state, userAgentKey, err := GenerateState(cfg1, "test_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Try to verify with second config
_, err = VerifyState(cfg2, state, userAgentKey)
if err == nil {
t.Error("Expected error for mismatched private key")
}
}
// TestRoundTrip tests complete round trip with various data payloads
func TestRoundTrip(t *testing.T) {
cfg := testConfig()
testCases := []string{
"simple",
"with-dashes-and-numbers-123",
"MixedCaseData",
"user_token_abc123",
"link_resource_xyz789",
}
for _, data := range testCases {
t.Run(data, func(t *testing.T) {
// Generate
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Verify
extractedData, err := VerifyState(cfg, state, userAgentKey)
if err != nil {
t.Fatalf("Failed to verify state: %v", err)
}
if extractedData != data {
t.Errorf("Expected extracted data '%s', got '%s'", data, extractedData)
}
})
}
}
// TestConcurrentGeneration tests that concurrent state generation works correctly
func TestConcurrentGeneration(t *testing.T) {
cfg := testConfig()
data := "concurrent_test"
const numGoroutines = 10
results := make(chan string, numGoroutines)
errors := make(chan error, numGoroutines)
// Generate states concurrently
for range numGoroutines {
go func() {
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
errors <- err
return
}
// Verify immediately
_, verifyErr := VerifyState(cfg, state, userAgentKey)
if verifyErr != nil {
errors <- verifyErr
return
}
results <- state
}()
}
// Collect results
states := make(map[string]bool)
for range numGoroutines {
select {
case state := <-results:
if states[state] {
t.Errorf("Duplicate state generated: %s", state)
}
states[state] = true
case err := <-errors:
t.Errorf("Concurrent generation error: %v", err)
}
}
if len(states) != numGoroutines {
t.Errorf("Expected %d unique states, got %d", numGoroutines, len(states))
}
}
// TestStateFormatCompatibility ensures state is URL-safe
func TestStateFormatCompatibility(t *testing.T) {
cfg := testConfig()
data := "url_safe_test"
state, _, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Check that state doesn't contain characters that need URL encoding
unsafeChars := []string{"+", "/", "=", " ", "&", "?", "#"}
for _, char := range unsafeChars {
if strings.Contains(state, char) {
t.Errorf("State contains URL-unsafe character '%s': %s", char, state)
}
}
}
// TestMITM_AttackerCannotSubstituteState verifies MITM protection actually works
// An attacker obtains their own valid state but tries to use it with victim's session
func TestMITM_AttackerCannotSubstituteState(t *testing.T) {
cfg := testConfig()
// Victim generates a state for their login
victimState, victimKey, err := GenerateState(cfg, "victim_data")
if err != nil {
t.Fatalf("Failed to generate victim state: %v", err)
}
// Attacker generates their own valid state (they can request this from the server)
attackerState, attackerKey, err := GenerateState(cfg, "attacker_data")
if err != nil {
t.Fatalf("Failed to generate attacker state: %v", err)
}
// Both states should be valid on their own
_, err = VerifyState(cfg, victimState, victimKey)
if err != nil {
t.Fatalf("Victim state should be valid: err=%v", err)
}
_, err = VerifyState(cfg, attackerState, attackerKey)
if err != nil {
t.Fatalf("Attacker state should be valid: err=%v", err)
}
// MITM Attack Scenario 1: Attacker substitutes their state but victim has their cookie
// This should FAIL because attackerState was signed with attackerKey, not victimKey
_, err = VerifyState(cfg, attackerState, victimKey)
if err == nil {
t.Error("Expected error when attacker substitutes state")
}
// MITM Attack Scenario 2: Attacker uses victim's state but has their own cookie
// This should also FAIL
_, err = VerifyState(cfg, victimState, attackerKey)
if err == nil {
t.Error("Expected error when attacker uses victim's state")
}
// The key insight: even though both states are "valid", they are bound to their respective cookies
// An attacker cannot mix and match states and cookies
t.Log("✓ MITM protection verified: States are cryptographically bound to their userAgentKey cookies")
}
// TestCSRF_AttackerCannotForgeState verifies CSRF protection
// An attacker tries to forge a state parameter without knowing the private key
func TestCSRF_AttackerCannotForgeState(t *testing.T) {
cfg := testConfig()
// Attacker doesn't know the private key, but tries to forge a state
// They might try to construct: "malicious_data.random.signature"
// Attempt 1: Use a random signature
randomSig := base64.RawURLEncoding.EncodeToString([]byte("random_signature_attempt_12345678"))
forgedState1 := "malicious_data.somefakerandom." + randomSig
// Generate a real userAgentKey (attacker might try to get this)
_, realKey, err := GenerateState(cfg, "legitimate_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Try to verify forged state
_, err = VerifyState(cfg, forgedState1, realKey)
if err == nil {
t.Error("CSRF VULNERABILITY: Attacker forged a valid state without private key!")
}
// Attempt 2: Attacker tries to compute signature without private key
// They use: SHA256(data + "." + random + userAgentKey) - missing privateKey
attackerPayload := "malicious_data.fakerandom" + base64.RawURLEncoding.EncodeToString(realKey)
hash := sha256.Sum256([]byte(attackerPayload))
attackerSig := base64.RawURLEncoding.EncodeToString(hash[:])
forgedState2 := "malicious_data.fakerandom." + attackerSig
_, err = VerifyState(cfg, forgedState2, realKey)
if err == nil {
t.Error("CSRF VULNERABILITY: Attacker forged valid state without private key!")
}
t.Log("✓ CSRF protection verified: Cannot forge state without private key")
}
// TestTampering_SignatureDetectsAllModifications verifies tamper detection
func TestTampering_SignatureDetectsAllModifications(t *testing.T) {
cfg := testConfig()
// Generate a valid state
originalState, userAgentKey, err := GenerateState(cfg, "original_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Verify original is valid
data, err := VerifyState(cfg, originalState, userAgentKey)
if err != nil || data != "original_data" {
t.Fatalf("Original state should be valid")
}
parts := strings.Split(originalState, ".")
// Test 1: Attacker modifies data but keeps signature
tamperedState := "modified_data." + parts[1] + "." + parts[2]
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("TAMPER VULNERABILITY: Modified data not detected!")
}
// Test 2: Attacker modifies random but keeps signature
newRandom := base64.RawURLEncoding.EncodeToString([]byte("new_random_value_32bytes_long!!"))
tamperedState = parts[0] + "." + newRandom + "." + parts[2]
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("TAMPER VULNERABILITY: Modified random not detected!")
}
// Test 3: Attacker tries to recompute signature but doesn't have privateKey
// They compute: SHA256(modified_data + "." + random + userAgentKey)
attackerPayload := "modified_data." + parts[1] + base64.RawURLEncoding.EncodeToString(userAgentKey)
hash := sha256.Sum256([]byte(attackerPayload))
attackerSig := base64.RawURLEncoding.EncodeToString(hash[:])
tamperedState = "modified_data." + parts[1] + "." + attackerSig
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("TAMPER VULNERABILITY: Attacker recomputed signature without private key!")
}
// Test 4: Single bit flip in signature
sigBytes, _ := base64.RawURLEncoding.DecodeString(parts[2])
sigBytes[0] ^= 0x01 // Flip one bit
flippedSig := base64.RawURLEncoding.EncodeToString(sigBytes)
tamperedState = parts[0] + "." + parts[1] + "." + flippedSig
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("TAMPER VULNERABILITY: Single bit flip in signature not detected!")
}
t.Log("✓ Tamper detection verified: All modifications to state are detected")
}
// TestReplay_DifferentSessionsCannotReuseState verifies replay protection
func TestReplay_DifferentSessionsCannotReuseState(t *testing.T) {
cfg := testConfig()
// Session 1: User initiates OAuth flow
state1, key1, err := GenerateState(cfg, "session1_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// State is valid for session 1
_, err = VerifyState(cfg, state1, key1)
if err != nil {
t.Fatalf("State should be valid for session 1")
}
// Session 2: Same user (or attacker) initiates a new OAuth flow
state2, key2, err := GenerateState(cfg, "session1_data") // same data
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Replay Attack: Try to use state1 with key2
_, err = VerifyState(cfg, state1, key2)
if err == nil {
t.Error("REPLAY VULNERABILITY: State from session 1 was accepted in session 2!")
}
// Even with same data, each session should have unique state+key binding
if state1 == state2 {
t.Error("REPLAY VULNERABILITY: Same data produces identical states!")
}
t.Log("✓ Replay protection verified: States are bound to specific session cookies")
}
// TestConstantTimeComparison verifies that signature comparison is timing-safe
// This is a behavioral test - we can't easily test timing, but we can verify the function is used
func TestConstantTimeComparison_IsUsed(t *testing.T) {
cfg := testConfig()
// Generate valid state
state, userAgentKey, err := GenerateState(cfg, "test")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Create states with signatures that differ at different positions
parts := strings.Split(state, ".")
originalSig, _ := base64.RawURLEncoding.DecodeString(parts[2])
testCases := []struct {
name string
position int
}{
{"first byte differs", 0},
{"middle byte differs", 16},
{"last byte differs", 31},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create signature that differs at specific position
tamperedSig := make([]byte, len(originalSig))
copy(tamperedSig, originalSig)
tamperedSig[tc.position] ^= 0xFF // Flip all bits
tamperedSigStr := base64.RawURLEncoding.EncodeToString(tamperedSig)
tamperedState := parts[0] + "." + parts[1] + "." + tamperedSigStr
// All should fail verification
_, err := VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Errorf("Tampered signature at position %d should be invalid", tc.position)
}
// If constant-time comparison is NOT used, early differences would return faster
// While we can't easily test timing here, we verify all positions fail equally
})
}
t.Log("✓ Constant-time comparison: All signature positions validated equally")
t.Log(" Note: crypto/subtle.ConstantTimeCompare is used in implementation")
}
// TestPrivateKey_IsCriticalToSecurity verifies private key is essential
func TestPrivateKey_IsCriticalToSecurity(t *testing.T) {
cfg1 := &Config{PrivateKey: "secret_key_1"}
cfg2 := &Config{PrivateKey: "secret_key_2"}
// Generate state with key1
state, userAgentKey, err := GenerateState(cfg1, "data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Should verify with key1
_, err = VerifyState(cfg1, state, userAgentKey)
if err != nil {
t.Fatalf("State should be valid with correct private key")
}
// Should NOT verify with key2 (different private key)
_, err = VerifyState(cfg2, state, userAgentKey)
if err == nil {
t.Error("SECURITY VULNERABILITY: State verified with different private key!")
}
// This proves that the private key is cryptographically involved in the signature
t.Log("✓ Private key security verified: Different keys produce incompatible signatures")
}
// TestUserAgentKey_ProperlyIntegratedInSignature verifies userAgentKey is in payload
func TestUserAgentKey_ProperlyIntegratedInSignature(t *testing.T) {
cfg := testConfig()
// Generate two states with same data but different userAgentKeys (implicit)
state1, key1, err := GenerateState(cfg, "same_data")
if err != nil {
t.Fatalf("Failed to generate state1: %v", err)
}
state2, key2, err := GenerateState(cfg, "same_data")
if err != nil {
t.Fatalf("Failed to generate state2: %v", err)
}
// The states should be different even with same data (different random and keys)
if state1 == state2 {
t.Error("States should differ due to different random values")
}
// Each state should only verify with its own key
_, err1 := VerifyState(cfg, state1, key1)
_, err2 := VerifyState(cfg, state2, key2)
if err1 != nil || err2 != nil {
t.Fatal("States should be valid with their own keys")
}
// Cross-verification should fail
_, err1 = VerifyState(cfg, state1, key2)
_, err2 = VerifyState(cfg, state2, key1)
if err1 == nil || err2 == nil {
t.Error("SECURITY VULNERABILITY: userAgentKey not properly integrated in signature!")
}
t.Log("✓ UserAgentKey integration verified: Each state bound to its specific key")
}

View File

@@ -1,32 +0,0 @@
# Scripts
## generate-css-sources.sh
Automatically generates the `pkg/embedfs/files/css/input.css` file with `@source` directives for all `.templ` files in the project.
### Why is this needed?
Tailwind CSS v4 requires explicit `@source` directives to know which files to scan for utility classes. Glob patterns like `**/*.templ` don't work in `@source` directives, so each file must be listed individually.
This script:
1. Finds all `.templ` files in the `internal/` directory
2. Generates `@source` directives with relative paths from the CSS file location
3. Adds your custom theme and utility classes
### When does it run?
The script runs automatically as part of:
- `make build` - Before building the CSS
- `make dev` - Before starting watch mode
### Manual usage
If you need to regenerate the sources manually:
```bash
./scripts/generate-css-sources.sh
```
### Adding new template files
When you add a new `.templ` file, you don't need to do anything special - just run `make build` or `make dev` and the script will automatically pick up the new file.

View File

@@ -1,140 +0,0 @@
#!/bin/bash
# Generate @source directives for all .templ files
# Paths are relative to pkg/embedfs/files/css/input.css
INPUT_CSS="pkg/embedfs/files/css/input.css"
# Start with the base imports
cat > "$INPUT_CSS" <<'CSSHEAD'
@import "tailwindcss";
CSSHEAD
# Find all .templ files and add @source directives
find internal -name "*.templ" -type f | sort | while read -r file; do
# Convert to relative path from pkg/embedfs/files/css/
rel_path="../../../../$file"
echo "@source \"$rel_path\";" >> "$INPUT_CSS"
done
# Add the custom theme and utility classes
cat >> "$INPUT_CSS" <<'CSSBODY'
[x-cloak] {
display: none !important;
}
@theme inline {
--color-rosewater: var(--rosewater);
--color-flamingo: var(--flamingo);
--color-pink: var(--pink);
--color-mauve: var(--mauve);
--color-red: var(--red);
--color-dark-red: var(--dark-red);
--color-maroon: var(--maroon);
--color-peach: var(--peach);
--color-yellow: var(--yellow);
--color-green: var(--green);
--color-teal: var(--teal);
--color-sky: var(--sky);
--color-sapphire: var(--sapphire);
--color-blue: var(--blue);
--color-lavender: var(--lavender);
--color-text: var(--text);
--color-subtext1: var(--subtext1);
--color-subtext0: var(--subtext0);
--color-overlay2: var(--overlay2);
--color-overlay1: var(--overlay1);
--color-overlay0: var(--overlay0);
--color-surface2: var(--surface2);
--color-surface1: var(--surface1);
--color-surface0: var(--surface0);
--color-base: var(--base);
--color-mantle: var(--mantle);
--color-crust: var(--crust);
}
:root {
--rosewater: hsl(11, 59%, 67%);
--flamingo: hsl(0, 60%, 67%);
--pink: hsl(316, 73%, 69%);
--mauve: hsl(266, 85%, 58%);
--red: hsl(347, 87%, 44%);
--dark-red: hsl(343, 50%, 82%);
--maroon: hsl(355, 76%, 59%);
--peach: hsl(22, 99%, 52%);
--yellow: hsl(35, 77%, 49%);
--green: hsl(109, 58%, 40%);
--teal: hsl(183, 74%, 35%);
--sky: hsl(197, 97%, 46%);
--sapphire: hsl(189, 70%, 42%);
--blue: hsl(220, 91%, 54%);
--lavender: hsl(231, 97%, 72%);
--text: hsl(234, 16%, 35%);
--subtext1: hsl(233, 13%, 41%);
--subtext0: hsl(233, 10%, 47%);
--overlay2: hsl(232, 10%, 53%);
--overlay1: hsl(231, 10%, 59%);
--overlay0: hsl(228, 11%, 65%);
--surface2: hsl(227, 12%, 71%);
--surface1: hsl(225, 14%, 77%);
--surface0: hsl(223, 16%, 83%);
--base: hsl(220, 23%, 95%);
--mantle: hsl(220, 22%, 92%);
--crust: hsl(220, 21%, 89%);
}
.dark {
--rosewater: hsl(10, 56%, 91%);
--flamingo: hsl(0, 59%, 88%);
--pink: hsl(316, 72%, 86%);
--mauve: hsl(267, 84%, 81%);
--red: hsl(343, 81%, 75%);
--dark-red: hsl(316, 19%, 27%);
--maroon: hsl(350, 65%, 77%);
--peach: hsl(23, 92%, 75%);
--yellow: hsl(41, 86%, 83%);
--green: hsl(115, 54%, 76%);
--teal: hsl(170, 57%, 73%);
--sky: hsl(189, 71%, 73%);
--sapphire: hsl(199, 76%, 69%);
--blue: hsl(217, 92%, 76%);
--lavender: hsl(232, 97%, 85%);
--text: hsl(226, 64%, 88%);
--subtext1: hsl(227, 35%, 80%);
--subtext0: hsl(228, 24%, 72%);
--overlay2: hsl(228, 17%, 64%);
--overlay1: hsl(230, 13%, 55%);
--overlay0: hsl(231, 11%, 47%);
--surface2: hsl(233, 12%, 39%);
--surface1: hsl(234, 13%, 31%);
--surface0: hsl(237, 16%, 23%);
--base: hsl(240, 21%, 15%);
--mantle: hsl(240, 21%, 12%);
--crust: hsl(240, 23%, 9%);
}
.ubuntu-mono-regular {
font-family: "Ubuntu Mono", serif;
font-weight: 400;
font-style: normal;
}
.ubuntu-mono-bold {
font-family: "Ubuntu Mono", serif;
font-weight: 700;
font-style: normal;
}
.ubuntu-mono-regular-italic {
font-family: "Ubuntu Mono", serif;
font-weight: 400;
font-style: italic;
}
.ubuntu-mono-bold-italic {
font-family: "Ubuntu Mono", serif;
font-weight: 700;
font-style: italic;
}
CSSBODY
echo "Generated $INPUT_CSS with $(grep -c '@source' "$INPUT_CSS") source files"