Merge branch 'discord-oauth'
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
*.db*
|
||||
.logs/
|
||||
server.log
|
||||
keys/
|
||||
bin/
|
||||
tmp/
|
||||
static/css/output.css
|
||||
|
||||
327
AGENTS.md
Normal file
327
AGENTS.md
Normal file
@@ -0,0 +1,327 @@
|
||||
# AGENTS.md - Developer Guide for oslstats
|
||||
|
||||
This document provides guidelines for AI coding agents and developers working on the oslstats codebase.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Module**: `git.haelnorr.com/h/oslstats`
|
||||
**Language**: Go 1.25.5
|
||||
**Architecture**: Web application with Discord OAuth, PostgreSQL database, templ templates
|
||||
**Key Technologies**: Bun ORM, templ, TailwindCSS, custom golib libraries
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
|
||||
### Building
|
||||
```bash
|
||||
# Full production build (tailwind → templ → go generate → go build)
|
||||
make build
|
||||
|
||||
# Build and run
|
||||
make run
|
||||
|
||||
# Clean build artifacts
|
||||
make clean
|
||||
```
|
||||
|
||||
### Development Mode
|
||||
```bash
|
||||
# Watch mode with hot reload (templ, air, tailwindcss in parallel)
|
||||
make dev
|
||||
|
||||
# Development server runs on:
|
||||
# - Proxy: http://localhost:3000 (use this)
|
||||
# - App: http://localhost:3333 (internal)
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Run all tests
|
||||
go test ./...
|
||||
|
||||
# Run tests for a specific package
|
||||
go test ./pkg/oauth
|
||||
|
||||
# Run a single test function
|
||||
go test ./pkg/oauth -run TestGenerateState_Success
|
||||
|
||||
# Run tests with verbose output
|
||||
go test -v ./pkg/oauth
|
||||
|
||||
# Run tests with coverage
|
||||
go test -cover ./...
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Database
|
||||
```bash
|
||||
# Run migrations
|
||||
make migrate
|
||||
# OR
|
||||
./bin/oslstats --migrate
|
||||
```
|
||||
|
||||
### Configuration Management
|
||||
```bash
|
||||
# Generate .env template file
|
||||
make genenv
|
||||
# OR with custom output: make genenv OUT=.env.example
|
||||
|
||||
# Show environment variable documentation
|
||||
make envdoc
|
||||
|
||||
# Show current environment values
|
||||
make showenv
|
||||
```
|
||||
|
||||
## Code Style Guidelines
|
||||
|
||||
### Import Organization
|
||||
Organize imports in **3 groups** separated by blank lines:
|
||||
|
||||
```go
|
||||
import (
|
||||
// 1. Standard library
|
||||
"context"
|
||||
"net/http"
|
||||
"fmt"
|
||||
|
||||
// 2. External dependencies
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
// 3. Internal packages
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||
)
|
||||
```
|
||||
|
||||
### Naming Conventions
|
||||
|
||||
**Variables**:
|
||||
- Local: `camelCase` (userAgentKey, httpServer, dbConn)
|
||||
- Exported: `PascalCase` (Config, User, Token)
|
||||
- Common abbreviations: `cfg`, `ctx`, `tx`, `db`, `err`, `w`, `r`
|
||||
|
||||
**Functions**:
|
||||
- Exported: `PascalCase` (GetConfig, NewStore, GenerateState)
|
||||
- Private: `camelCase` (throwError, shouldShowDetails, loadModels)
|
||||
- HTTP handlers: Return `http.Handler`, use dependency injection pattern
|
||||
- Database functions: Use `bun.Tx` as parameter for transactions
|
||||
|
||||
**Types**:
|
||||
- Structs/Interfaces: `PascalCase` (Config, User, OAuthSession)
|
||||
- Use `-er` suffix for interfaces (implied from usage)
|
||||
|
||||
**Files**:
|
||||
- Prefer single word: `config.go`, `oauth.go`, `errors.go`
|
||||
- Use snake_case if needed: `discord_tokens.go`, `state_test.go`
|
||||
- Test files: `*_test.go` alongside source files
|
||||
|
||||
### Error Handling
|
||||
|
||||
**Always wrap errors** with context using `github.com/pkg/errors`:
|
||||
|
||||
```go
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "operation_name")
|
||||
}
|
||||
```
|
||||
|
||||
**Validate inputs at function start**:
|
||||
```go
|
||||
func DoSomething(cfg *Config, data string) error {
|
||||
if cfg == nil {
|
||||
return errors.New("cfg cannot be nil")
|
||||
}
|
||||
if data == "" {
|
||||
return errors.New("data cannot be empty")
|
||||
}
|
||||
// ... rest of function
|
||||
}
|
||||
```
|
||||
|
||||
**HTTP error helpers** (in handlers package):
|
||||
- `throwInternalServiceError(s, w, r, msg, err)` - 500 errors
|
||||
- `throwBadRequest(s, w, r, msg, err)` - 400 errors
|
||||
- `throwForbidden(s, w, r, msg, err)` - 403 errors (normal)
|
||||
- `throwForbiddenSecurity(s, w, r, msg, err)` - 403 security violations (WARN level)
|
||||
- `throwUnauthorized(s, w, r, msg, err)` - 401 errors (normal)
|
||||
- `throwUnauthorizedSecurity(s, w, r, msg, err)` - 401 security violations (WARN level)
|
||||
- `throwNotFound(s, w, r, path)` - 404 errors
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**HTTP Handler Pattern**:
|
||||
```go
|
||||
func HandlerName(server *hws.Server, deps ...) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handler logic here
|
||||
},
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Database Operation Pattern**:
|
||||
```go
|
||||
func GetSomething(ctx context.Context, tx bun.Tx, id int) (*Result, error) {
|
||||
result := new(Result)
|
||||
err := tx.NewSelect().
|
||||
Model(result).
|
||||
Where("id = ?", id).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err.Error() == "sql: no rows in result set" {
|
||||
return nil, nil // Return nil, nil for not found
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.Select")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
```
|
||||
|
||||
**Setup Function Pattern** (returns instance, cleanup func, error):
|
||||
```go
|
||||
func setupSomething(ctx context.Context, cfg *Config) (*Type, func() error, error) {
|
||||
instance := newInstance()
|
||||
|
||||
err := configure(instance)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "configure")
|
||||
}
|
||||
|
||||
return instance, instance.Close, nil
|
||||
}
|
||||
```
|
||||
|
||||
**Configuration Pattern** (using ezconf):
|
||||
```go
|
||||
type Config struct {
|
||||
Field string // ENV FIELD_NAME: Description (required/default: value)
|
||||
}
|
||||
|
||||
func ConfigFromEnv() (any, error) {
|
||||
cfg := &Config{
|
||||
Field: env.String("FIELD_NAME", "default"),
|
||||
}
|
||||
// Validation here
|
||||
return cfg, nil
|
||||
}
|
||||
```
|
||||
|
||||
### Formatting & Types
|
||||
|
||||
**Formatting**:
|
||||
- Use `gofmt` (standard Go formatting)
|
||||
- No tabs vs spaces debate - Go uses tabs
|
||||
|
||||
**Types**:
|
||||
- Prefer explicit types over inference when it improves clarity
|
||||
- Use struct tags for ORM and JSON marshaling:
|
||||
```go
|
||||
type User struct {
|
||||
bun.BaseModel `bun:"table:users,alias:u"`
|
||||
ID int `bun:"id,pk,autoincrement"`
|
||||
Username string `bun:"username,unique"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
```
|
||||
|
||||
**Comments**:
|
||||
- Document exported functions and types
|
||||
- Use inline comments for ENV var documentation in Config structs
|
||||
- Explain security-critical code flows
|
||||
|
||||
### Testing
|
||||
|
||||
**Test File Location**: Place `*_test.go` files alongside source files
|
||||
|
||||
**Test Naming**:
|
||||
```go
|
||||
func TestFunctionName_Scenario(t *testing.T)
|
||||
func TestGenerateState_Success(t *testing.T)
|
||||
func TestVerifyState_WrongUserAgentKey(t *testing.T)
|
||||
```
|
||||
|
||||
**Test Structure**:
|
||||
- Use subtests with `t.Run()` for related scenarios
|
||||
- Use table-driven tests for multiple similar cases
|
||||
- Create helper functions for common setup (e.g., `testConfig()`)
|
||||
- Test happy paths, error cases, edge cases, and security properties
|
||||
|
||||
**Test Categories** (from pkg/oauth/state_test.go example):
|
||||
1. Happy path tests
|
||||
2. Error handling (nil params, empty fields, malformed input)
|
||||
3. Security tests (MITM, CSRF, replay attacks, tampering)
|
||||
4. Edge cases (concurrency, constant-time comparison)
|
||||
5. Integration tests (round-trip verification)
|
||||
|
||||
### Security
|
||||
|
||||
**Critical Practices**:
|
||||
- Use `crypto/subtle.ConstantTimeCompare` for cryptographic comparisons
|
||||
- Implement CSRF protection via state tokens
|
||||
- Store sensitive cookies as HttpOnly
|
||||
- Use separate logging levels for security violations (WARN)
|
||||
- Validate all inputs at function boundaries
|
||||
- Use parameterized queries (Bun ORM handles this)
|
||||
- Never commit secrets (.env, keys/ are gitignored)
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
oslstats/
|
||||
├── cmd/oslstats/ # Application entry point
|
||||
│ ├── main.go # Entry point with flag parsing
|
||||
│ ├── run.go # Server initialization & graceful shutdown
|
||||
│ ├── httpserver.go # HTTP server setup
|
||||
│ ├── routes.go # Route registration
|
||||
│ ├── middleware.go # Middleware registration
|
||||
│ ├── auth.go # Authentication setup
|
||||
│ └── db.go # Database connection & migrations
|
||||
├── internal/ # Private application code
|
||||
│ ├── config/ # Configuration aggregation
|
||||
│ ├── db/ # Database models & queries (Bun ORM)
|
||||
│ ├── discord/ # Discord OAuth integration
|
||||
│ ├── handlers/ # HTTP request handlers
|
||||
│ ├── session/ # Session store (in-memory)
|
||||
│ └── view/ # Templ templates
|
||||
│ ├── component/ # Reusable UI components
|
||||
│ ├── layout/ # Page layouts
|
||||
│ └── page/ # Full pages
|
||||
├── pkg/ # Reusable packages
|
||||
│ ├── contexts/ # Context key definitions
|
||||
│ ├── embedfs/ # Embedded static files
|
||||
│ └── oauth/ # OAuth state management
|
||||
├── bin/ # Compiled binaries (gitignored)
|
||||
├── keys/ # Private keys (gitignored)
|
||||
├── tmp/ # Air hot reload temp files (gitignored)
|
||||
├── Makefile # Build automation
|
||||
├── .air.toml # Hot reload configuration
|
||||
└── go.mod # Go module definition
|
||||
```
|
||||
|
||||
## Key Dependencies
|
||||
|
||||
- **git.haelnorr.com/h/golib/*** - Custom libraries (env, ezconf, hlog, hws, hwsauth, cookies, jwt)
|
||||
- **github.com/a-h/templ** - Type-safe HTML templating
|
||||
- **github.com/uptrace/bun** - PostgreSQL ORM
|
||||
- **github.com/bwmarrin/discordgo** - Discord API client
|
||||
- **github.com/pkg/errors** - Error wrapping (use this, not fmt.Errorf)
|
||||
- **github.com/joho/godotenv** - .env file loading
|
||||
|
||||
## Notes for AI Agents
|
||||
|
||||
1. **Never commit** .env files, keys/, or generated files (*_templ.go, output.css)
|
||||
2. **Database operations** should use `bun.Tx` for transaction safety
|
||||
3. **Templates** are written in templ, not Go html/template - run `templ generate` after changes
|
||||
4. **Static files** are embedded via `//go:embed` - check pkg/embedfs/
|
||||
5. **Error messages** should be descriptive and use errors.Wrap for context
|
||||
6. **Security is critical** - especially in OAuth flows (see pkg/oauth/state_test.go for examples)
|
||||
7. **Air proxy** runs on port 3000 during development; app runs on 3333
|
||||
8. **Test coverage** is currently limited - prioritize testing security-critical code
|
||||
9. **Configuration** uses ezconf pattern - see internal/*/ezconf.go files for examples
|
||||
10. **Graceful shutdown** is implemented in cmd/oslstats/run.go - follow this pattern
|
||||
11. When in plan mode, always use the interactive question tool if available
|
||||
5
Makefile
5
Makefile
@@ -4,7 +4,6 @@
|
||||
BINARY_NAME=oslstats
|
||||
|
||||
build:
|
||||
./scripts/generate-css-sources.sh && \
|
||||
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css && \
|
||||
go mod tidy && \
|
||||
templ generate && \
|
||||
@@ -16,7 +15,6 @@ run:
|
||||
./bin/${BINARY_NAME}${SUFFIX}
|
||||
|
||||
dev:
|
||||
./scripts/generate-css-sources.sh && \
|
||||
templ generate --watch &\
|
||||
air &\
|
||||
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css --watch
|
||||
@@ -36,3 +34,6 @@ showenv:
|
||||
make build
|
||||
./bin/${BINARY_NAME} --showenv
|
||||
|
||||
migrate:
|
||||
make build
|
||||
./bin/${BINARY_NAME}${SUFFIX} --migrate
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/handlers"
|
||||
"git.haelnorr.com/h/oslstats/pkg/contexts"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
@@ -31,6 +30,7 @@ func setupAuth(
|
||||
beginTx,
|
||||
logger,
|
||||
handlers.ErrorPage,
|
||||
conn.DB,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
|
||||
@@ -38,7 +38,9 @@ func setupAuth(
|
||||
|
||||
auth.IgnorePaths(ignoredPaths...)
|
||||
|
||||
contexts.CurrentUser = auth.CurrentModel
|
||||
db.CurrentUser = auth.CurrentModel
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// TODO: make a new getuser function that wraps db.GetUserByID and does OAuth refresh
|
||||
|
||||
@@ -20,7 +20,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func
|
||||
conn = bun.NewDB(sqldb, pgdialect.New())
|
||||
close = sqldb.Close
|
||||
|
||||
err = loadModels(ctx, conn, cfg.Flags.ResetDB)
|
||||
err = loadModels(ctx, conn, cfg.Flags.MigrateDB)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "loadModels")
|
||||
}
|
||||
@@ -31,6 +31,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func
|
||||
func loadModels(ctx context.Context, conn *bun.DB, resetDB bool) error {
|
||||
models := []any{
|
||||
(*db.User)(nil),
|
||||
(*db.DiscordToken)(nil),
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
|
||||
@@ -4,13 +4,15 @@ import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/handlers"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/handlers"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
)
|
||||
|
||||
func setupHttpServer(
|
||||
@@ -18,6 +20,8 @@ func setupHttpServer(
|
||||
config *config.Config,
|
||||
logger *hlog.Logger,
|
||||
bun *bun.DB,
|
||||
store *store.Store,
|
||||
discordAPI *discord.APIClient,
|
||||
) (server *hws.Server, err error) {
|
||||
if staticFS == nil {
|
||||
return nil, errors.New("No filesystem provided")
|
||||
@@ -53,7 +57,7 @@ func setupHttpServer(
|
||||
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
|
||||
}
|
||||
|
||||
err = addRoutes(httpServer, &fs, config, logger, bun, auth)
|
||||
err = addRoutes(httpServer, &fs, config, bun, auth, store, discordAPI)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "addRoutes")
|
||||
}
|
||||
|
||||
@@ -29,6 +29,16 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
if flags.MigrateDB {
|
||||
_, closedb, err := setupBun(ctx, cfg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
closedb()
|
||||
return
|
||||
}
|
||||
|
||||
if err := run(ctx, os.Stdout, cfg); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
|
||||
@@ -5,22 +5,24 @@ import (
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/handlers"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/handlers"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
)
|
||||
|
||||
func addRoutes(
|
||||
server *hws.Server,
|
||||
staticFS *http.FileSystem,
|
||||
config *config.Config,
|
||||
logger *hlog.Logger,
|
||||
cfg *config.Config,
|
||||
conn *bun.DB,
|
||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||
store *store.Store,
|
||||
discordAPI *discord.APIClient,
|
||||
) error {
|
||||
// Create the routes
|
||||
routes := []hws.Route{
|
||||
@@ -34,10 +36,38 @@ func addRoutes(
|
||||
Method: hws.MethodGET,
|
||||
Handler: handlers.Index(server),
|
||||
},
|
||||
{
|
||||
Path: "/login",
|
||||
Method: hws.MethodGET,
|
||||
Handler: auth.LogoutReq(handlers.Login(server, cfg, store, discordAPI)),
|
||||
},
|
||||
{
|
||||
Path: "/auth/callback",
|
||||
Method: hws.MethodGET,
|
||||
Handler: auth.LogoutReq(handlers.Callback(server, auth, conn, cfg, store, discordAPI)),
|
||||
},
|
||||
{
|
||||
Path: "/register",
|
||||
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
|
||||
Handler: auth.LogoutReq(handlers.Register(server, auth, conn, cfg, store)),
|
||||
},
|
||||
{
|
||||
Path: "/logout",
|
||||
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
|
||||
Handler: auth.LoginReq(handlers.Logout(server, auth, conn, discordAPI)),
|
||||
},
|
||||
}
|
||||
|
||||
htmxRoutes := []hws.Route{
|
||||
{
|
||||
Path: "/htmx/isusernameunique",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: handlers.IsUsernameUnique(server, conn, cfg, store),
|
||||
},
|
||||
}
|
||||
|
||||
// Register the routes with the server
|
||||
err := server.AddRoutes(routes...)
|
||||
err := server.AddRoutes(append(routes, htmxRoutes...)...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "server.AddRoutes")
|
||||
}
|
||||
|
||||
@@ -9,18 +9,21 @@ import (
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/pkg/embedfs"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"git.haelnorr.com/h/oslstats/pkg/embedfs"
|
||||
)
|
||||
|
||||
// Initializes and runs the server
|
||||
func run(ctx context.Context, w io.Writer, config *config.Config) error {
|
||||
func run(ctx context.Context, w io.Writer, cfg *config.Config) error {
|
||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
// Setup the logger
|
||||
logger, err := hlog.NewLogger(config.HLOG, w)
|
||||
logger, err := hlog.NewLogger(cfg.HLOG, w)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "hlog.NewLogger")
|
||||
}
|
||||
@@ -28,7 +31,7 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error {
|
||||
// Setup the database connection
|
||||
logger.Debug().Msg("Config loaded and logger started")
|
||||
logger.Debug().Msg("Connecting to database")
|
||||
bun, closedb, err := setupBun(ctx, config)
|
||||
bun, closedb, err := setupBun(ctx, cfg)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "setupDBConn")
|
||||
}
|
||||
@@ -41,8 +44,19 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error {
|
||||
return errors.Wrap(err, "getStaticFiles")
|
||||
}
|
||||
|
||||
// Setup session store
|
||||
logger.Debug().Msg("Setting up session store")
|
||||
store := store.NewStore()
|
||||
|
||||
// Setup Discord API client
|
||||
logger.Debug().Msg("Setting up Discord API client")
|
||||
discordAPI, err := discord.NewAPIClient(cfg.Discord, logger, cfg.HWSAuth.TrustedHost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "discord.NewAPIClient")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Setting up HTTP server")
|
||||
httpServer, err := setupHttpServer(&staticFS, config, logger, bun)
|
||||
httpServer, err := setupHttpServer(&staticFS, cfg, logger, bun, store, discordAPI)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "setupHttpServer")
|
||||
}
|
||||
|
||||
15
go.mod
15
go.mod
@@ -6,20 +6,25 @@ require (
|
||||
git.haelnorr.com/h/golib/env v0.9.1
|
||||
git.haelnorr.com/h/golib/ezconf v0.1.1
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4
|
||||
git.haelnorr.com/h/golib/hws v0.2.3
|
||||
git.haelnorr.com/h/golib/hwsauth v0.3.4
|
||||
git.haelnorr.com/h/golib/hws v0.3.1
|
||||
git.haelnorr.com/h/golib/hwsauth v0.5.2
|
||||
github.com/a-h/templ v0.3.977
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/uptrace/bun v1.2.16
|
||||
github.com/uptrace/bun/dialect/pgdialect v1.2.16
|
||||
github.com/uptrace/bun/driver/pgdriver v1.2.16
|
||||
golang.org/x/crypto v0.45.0
|
||||
)
|
||||
|
||||
require (
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0 // indirect
|
||||
git.haelnorr.com/h/golib/jwt v0.10.0 // indirect
|
||||
github.com/gorilla/websocket v1.4.2 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1 // indirect
|
||||
github.com/bwmarrin/discordgo v0.29.0
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
|
||||
22
go.sum
22
go.sum
@@ -6,16 +6,18 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI
|
||||
git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
|
||||
git.haelnorr.com/h/golib/hws v0.2.3 h1:gZQkBciXKh3jYw05vZncSR2lvIqi0H2MVfIWySySsmw=
|
||||
git.haelnorr.com/h/golib/hws v0.2.3/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.3.4 h1:wwYBb6cQQ+x9hxmYuZBF4mVmCv/n4PjJV//e1+SgPOo=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.3.4/go.mod h1:LI7Qz68GPNIW8732Zwptb//ybjiFJOoXf4tgUuUEqHI=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.0 h1:8cI8mSnb8X+EmJtrBO/5UZwuBMtib0IE9dv85gkm94E=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.0/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||
git.haelnorr.com/h/golib/hws v0.3.1 h1:uFXAT8SuKs4VACBdrkmZ+dJjeBlSPgCKUPt8zGCcwrI=
|
||||
git.haelnorr.com/h/golib/hws v0.3.1/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.5.2 h1:K4McXMEHtI5o4fAL3AZrmaMkwORNqSTV3MM6BExNKag=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.5.2/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=
|
||||
github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo=
|
||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -28,6 +30,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
@@ -66,13 +70,19 @@ go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8=
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
@@ -15,6 +17,8 @@ type Config struct {
|
||||
HWS *hws.Config
|
||||
HWSAuth *hwsauth.Config
|
||||
HLOG *hlog.Config
|
||||
Discord *discord.Config
|
||||
OAuth *oauth.Config
|
||||
Flags *Flags
|
||||
}
|
||||
|
||||
@@ -32,6 +36,8 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
|
||||
hws.NewEZConfIntegration(),
|
||||
hwsauth.NewEZConfIntegration(),
|
||||
db.NewEZConfIntegration(),
|
||||
discord.NewEZConfIntegration(),
|
||||
oauth.NewEZConfIntegration(),
|
||||
)
|
||||
if err := loader.ParseEnvVars(); err != nil {
|
||||
return nil, nil, errors.Wrap(err, "loader.ParseEnvVars")
|
||||
@@ -65,11 +71,23 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
|
||||
return nil, nil, errors.New("DB Config not loaded")
|
||||
}
|
||||
|
||||
discordcfg, ok := loader.GetConfig("discord")
|
||||
if !ok {
|
||||
return nil, nil, errors.New("Dicord Config not loaded")
|
||||
}
|
||||
|
||||
oauthcfg, ok := loader.GetConfig("oauth")
|
||||
if !ok {
|
||||
return nil, nil, errors.New("OAuth Config not loaded")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
DB: dbcfg.(*db.Config),
|
||||
HWS: hwscfg.(*hws.Config),
|
||||
HWSAuth: hwsauthcfg.(*hwsauth.Config),
|
||||
HLOG: hlogcfg.(*hlog.Config),
|
||||
Discord: discordcfg.(*discord.Config),
|
||||
OAuth: oauthcfg.(*oauth.Config),
|
||||
Flags: flags,
|
||||
}
|
||||
|
||||
|
||||
@@ -5,16 +5,16 @@ import (
|
||||
)
|
||||
|
||||
type Flags struct {
|
||||
ResetDB bool
|
||||
EnvDoc bool
|
||||
ShowEnv bool
|
||||
GenEnv string
|
||||
EnvFile string
|
||||
MigrateDB bool
|
||||
EnvDoc bool
|
||||
ShowEnv bool
|
||||
GenEnv string
|
||||
EnvFile string
|
||||
}
|
||||
|
||||
func SetupFlags() *Flags {
|
||||
// Parse commandline args
|
||||
resetDB := flag.Bool("resetdb", false, "Reset all the database tables with the updated models")
|
||||
migrateDB := flag.Bool("migrate", false, "Reset all the database tables with the updated models")
|
||||
envDoc := flag.Bool("envdoc", false, "Print all environment variables and their documentation")
|
||||
showEnv := flag.Bool("showenv", false, "Print all environment variable values and their documentation")
|
||||
genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)")
|
||||
@@ -22,11 +22,11 @@ func SetupFlags() *Flags {
|
||||
flag.Parse()
|
||||
|
||||
flags := &Flags{
|
||||
ResetDB: *resetDB,
|
||||
EnvDoc: *envDoc,
|
||||
ShowEnv: *showEnv,
|
||||
GenEnv: *genEnv,
|
||||
EnvFile: *envfile,
|
||||
MigrateDB: *migrateDB,
|
||||
EnvDoc: *envDoc,
|
||||
ShowEnv: *showEnv,
|
||||
GenEnv: *genEnv,
|
||||
EnvFile: *envfile,
|
||||
}
|
||||
return flags
|
||||
}
|
||||
|
||||
95
internal/db/discord_tokens.go
Normal file
95
internal/db/discord_tokens.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type DiscordToken struct {
|
||||
bun.BaseModel `bun:"table:discord_tokens,alias:dt"`
|
||||
|
||||
DiscordID string `bun:"discord_id,pk,notnull"`
|
||||
AccessToken string `bun:"access_token,notnull"`
|
||||
RefreshToken string `bun:"refresh_token,notnull"`
|
||||
ExpiresAt int64 `bun:"expires_at,notnull"`
|
||||
Scope string `bun:"scope,notnull"`
|
||||
TokenType string `bun:"token_type,notnull"`
|
||||
}
|
||||
|
||||
// UpdateDiscordToken adds the provided discord token to the database.
|
||||
// If the user already has a token stored, it will replace that token instead.
|
||||
func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token cannot be nil")
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix()
|
||||
|
||||
discordToken := &DiscordToken{
|
||||
DiscordID: user.DiscordID,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: token.Scope,
|
||||
TokenType: token.TokenType,
|
||||
}
|
||||
|
||||
_, err := tx.NewInsert().
|
||||
Model(discordToken).
|
||||
On("CONFLICT (discord_id) DO UPDATE").
|
||||
Set("access_token = EXCLUDED.access_token").
|
||||
Set("refresh_token = EXCLUDED.refresh_token").
|
||||
Set("expires_at = EXCLUDED.expires_at").
|
||||
Exec(ctx)
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.NewInsert")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteDiscordTokens deletes a users discord OAuth tokens from the database.
|
||||
// It returns the DiscordToken so that it can be revoked via the discord API
|
||||
func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
|
||||
token, err := user.GetDiscordToken(ctx, tx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "user.GetDiscordToken")
|
||||
}
|
||||
_, err = tx.NewDelete().
|
||||
Model((*DiscordToken)(nil)).
|
||||
Where("discord_id = ?", user.DiscordID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.NewDelete")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetDiscordToken retrieves the users discord token from the database
|
||||
func (user *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
|
||||
token := new(DiscordToken)
|
||||
err := tx.NewSelect().
|
||||
Model(token).
|
||||
Where("discord_id = ?", user.DiscordID).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Convert reverts the token back into a *discord.Token
|
||||
func (t *DiscordToken) Convert() *discord.Token {
|
||||
token := &discord.Token{
|
||||
AccessToken: t.AccessToken,
|
||||
RefreshToken: t.RefreshToken,
|
||||
ExpiresIn: int(t.ExpiresAt - time.Now().Unix()),
|
||||
Scope: t.Scope,
|
||||
TokenType: t.TokenType,
|
||||
}
|
||||
return token
|
||||
}
|
||||
@@ -37,5 +37,5 @@ func (e EZConfIntegration) GroupName() string {
|
||||
|
||||
// NewEZConfIntegration creates a new EZConf integration helper
|
||||
func NewEZConfIntegration() EZConfIntegration {
|
||||
return EZConfIntegration{name: "db", configFunc: ConfigFromEnv}
|
||||
return EZConfIntegration{name: "DB", configFunc: ConfigFromEnv}
|
||||
}
|
||||
|
||||
@@ -2,65 +2,30 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var CurrentUser hwsauth.ContextLoader[*User]
|
||||
|
||||
type User struct {
|
||||
bun.BaseModel `bun:"table:users,alias:u"`
|
||||
|
||||
ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key)
|
||||
Username string `bun:"username,unique"` // Username (unique)
|
||||
PasswordHash string `bun:"password_hash,nullzero"` // Bcrypt hashed password (not exported in JSON)
|
||||
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
|
||||
Bio string `bun:"bio"` // Short byline set by the user
|
||||
ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key)
|
||||
Username string `bun:"username,unique"` // Username (unique)
|
||||
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
|
||||
DiscordID string `bun:"discord_id,unique"`
|
||||
}
|
||||
|
||||
func (user *User) GetID() int {
|
||||
return user.ID
|
||||
}
|
||||
|
||||
// Uses bcrypt to set the users password_hash from the given password
|
||||
func (user *User) SetPassword(ctx context.Context, tx bun.Tx, password string) error {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||
}
|
||||
newPassword := string(hashedPassword)
|
||||
|
||||
_, err = tx.NewUpdate().
|
||||
Model(user).
|
||||
Set("password_hash = ?", newPassword).
|
||||
Where("id = ?", user.ID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Update")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Uses bcrypt to check if the given password matches the users password_hash
|
||||
func (user *User) CheckPassword(ctx context.Context, tx bun.Tx, password string) error {
|
||||
var hashedPassword string
|
||||
err := tx.NewSelect().
|
||||
Table("users").
|
||||
Column("password_hash").
|
||||
Where("id = ?", user.ID).
|
||||
Limit(1).
|
||||
Scan(ctx, &hashedPassword)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Select")
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Username or password incorrect")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the user's username
|
||||
func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername string) error {
|
||||
_, err := tx.NewUpdate().
|
||||
@@ -75,35 +40,18 @@ func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername str
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the user's bio
|
||||
func (user *User) ChangeBio(ctx context.Context, tx bun.Tx, newBio string) error {
|
||||
_, err := tx.NewUpdate().
|
||||
Model(user).
|
||||
Set("bio = ?", newBio).
|
||||
Where("id = ?", user.ID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Update")
|
||||
}
|
||||
user.Bio = newBio
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user with the given username and password
|
||||
func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*User, error) {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||
func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) {
|
||||
if discorduser == nil {
|
||||
return nil, errors.New("user cannot be nil")
|
||||
}
|
||||
|
||||
user := &User{
|
||||
Username: username,
|
||||
PasswordHash: string(hashedPassword),
|
||||
CreatedAt: 0, // You may want to set this to time.Now().Unix()
|
||||
Bio: "",
|
||||
Username: username,
|
||||
CreatedAt: time.Now().Unix(),
|
||||
DiscordID: discorduser.ID,
|
||||
}
|
||||
|
||||
_, err = tx.NewInsert().
|
||||
_, err := tx.NewInsert().
|
||||
Model(user).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
@@ -116,6 +64,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*Use
|
||||
// GetUserByID queries the database for a user matching the given ID
|
||||
// Returns nil, nil if no user is found
|
||||
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
|
||||
fmt.Printf("user id requested: %v", id)
|
||||
user := new(User)
|
||||
err := tx.NewSelect().
|
||||
Model(user).
|
||||
@@ -149,6 +98,24 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByDiscordID queries the database for a user matching the given discord id
|
||||
// Returns nil, nil if no user is found
|
||||
func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) {
|
||||
user := new(User)
|
||||
err := tx.NewSelect().
|
||||
Model(user).
|
||||
Where("discord_id = ?", discordID).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err.Error() == "sql: no rows in result set" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.Select")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// IsUsernameUnique checks if the given username is unique (not already taken)
|
||||
// Returns true if the username is available, false if it's taken
|
||||
func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) {
|
||||
|
||||
61
internal/discord/api.go
Normal file
61
internal/discord/api.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type OAuthSession struct {
|
||||
*discordgo.Session
|
||||
}
|
||||
|
||||
func NewOAuthSession(token *Token) (*OAuthSession, error) {
|
||||
session, err := discordgo.New("Bearer " + token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "discordgo.New")
|
||||
}
|
||||
return &OAuthSession{Session: session}, nil
|
||||
}
|
||||
|
||||
func (s *OAuthSession) GetUser() (*discordgo.User, error) {
|
||||
user, err := s.User("@me")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "s.User")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// APIClient is an HTTP client wrapper that handles Discord API rate limits
|
||||
type APIClient struct {
|
||||
cfg *Config
|
||||
client *http.Client
|
||||
logger *hlog.Logger
|
||||
mu sync.RWMutex
|
||||
buckets map[string]*RateLimitState
|
||||
trustedHost string
|
||||
}
|
||||
|
||||
// NewAPIClient creates a new Discord API client with rate limit handling
|
||||
func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APIClient, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config cannot be nil")
|
||||
}
|
||||
if logger == nil {
|
||||
return nil, errors.New("logger cannot be nil")
|
||||
}
|
||||
if trustedhost == "" {
|
||||
return nil, errors.New("trustedhost cannot be empty")
|
||||
}
|
||||
return &APIClient{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
logger: logger,
|
||||
buckets: make(map[string]*RateLimitState),
|
||||
cfg: cfg,
|
||||
trustedHost: trustedhost,
|
||||
}, nil
|
||||
}
|
||||
50
internal/discord/config.go
Normal file
50
internal/discord/config.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"git.haelnorr.com/h/golib/env"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ClientID string // ENV DISCORD_CLIENT_ID: Discord application client ID (required)
|
||||
ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required)
|
||||
OAuthScopes string // Authorisation scopes for OAuth
|
||||
RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required)
|
||||
}
|
||||
|
||||
func ConfigFromEnv() (any, error) {
|
||||
cfg := &Config{
|
||||
ClientID: env.String("DISCORD_CLIENT_ID", ""),
|
||||
ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""),
|
||||
OAuthScopes: getOAuthScopes(),
|
||||
RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""),
|
||||
}
|
||||
|
||||
// Check required fields
|
||||
if cfg.ClientID == "" {
|
||||
return nil, errors.New("Envar not set: DISCORD_CLIENT_ID")
|
||||
}
|
||||
if cfg.ClientSecret == "" {
|
||||
return nil, errors.New("Envar not set: DISCORD_CLIENT_SECRET")
|
||||
}
|
||||
if cfg.RedirectPath == "" {
|
||||
return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getOAuthScopes() string {
|
||||
list := []string{
|
||||
"connections",
|
||||
"email",
|
||||
"guilds",
|
||||
"gdm.join",
|
||||
"guilds.members.read",
|
||||
"identify",
|
||||
}
|
||||
scopes := strings.Join(list, "+")
|
||||
return scopes
|
||||
}
|
||||
41
internal/discord/ezconf.go
Normal file
41
internal/discord/ezconf.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EZConfIntegration provides integration with ezconf for automatic configuration
|
||||
type EZConfIntegration struct {
|
||||
configFunc func() (any, error)
|
||||
name string
|
||||
}
|
||||
|
||||
// PackagePath returns the path to the config package for source parsing
|
||||
func (e EZConfIntegration) PackagePath() string {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
// Return directory of this file
|
||||
return filename[:len(filename)-len("/ezconf.go")]
|
||||
}
|
||||
|
||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||
return func() (any, error) {
|
||||
return e.configFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name to use when registering with ezconf
|
||||
func (e EZConfIntegration) Name() string {
|
||||
return strings.ToLower(e.name)
|
||||
}
|
||||
|
||||
// GroupName returns the display name for grouping environment variables
|
||||
func (e EZConfIntegration) GroupName() string {
|
||||
return e.name
|
||||
}
|
||||
|
||||
// NewEZConfIntegration creates a new EZConf integration helper
|
||||
func NewEZConfIntegration() EZConfIntegration {
|
||||
return EZConfIntegration{name: "Discord", configFunc: ConfigFromEnv}
|
||||
}
|
||||
148
internal/discord/oauth.go
Normal file
148
internal/discord/oauth.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Token represents a response from the Discord OAuth API after a successful authorization request
|
||||
type Token struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
const oauthurl string = "https://discord.com/oauth2/authorize"
|
||||
const apiurl string = "https://discord.com/api/v10"
|
||||
|
||||
// GetOAuthLink generates a new Discord OAuth2 link for user authentication
|
||||
func (api *APIClient) GetOAuthLink(state string) (string, error) {
|
||||
if state == "" {
|
||||
return "", errors.New("state cannot be empty")
|
||||
}
|
||||
values := url.Values{}
|
||||
values.Add("response_type", "code")
|
||||
values.Add("client_id", api.cfg.ClientID)
|
||||
values.Add("scope", api.cfg.OAuthScopes)
|
||||
values.Add("state", state)
|
||||
values.Add("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath))
|
||||
values.Add("prompt", "none")
|
||||
|
||||
return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil
|
||||
}
|
||||
|
||||
// AuthorizeWithCode uses a users authorization token generated by OAuth2 to get a token for
|
||||
// making requests to the API on behalf of the user
|
||||
func (api *APIClient) AuthorizeWithCode(code string) (*Token, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("code cannot be empty")
|
||||
}
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath))
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
apiurl+"/oauth2/token",
|
||||
strings.NewReader(data.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
|
||||
resp, err := api.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
var tokenResp Token
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse token response")
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RefreshToken uses the refresh token to generate a new token pair
|
||||
func (api *APIClient) RefreshToken(token *Token) (*Token, error) {
|
||||
if token == nil {
|
||||
return nil, errors.New("token cannot be nil")
|
||||
}
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", token.RefreshToken)
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
apiurl+"/oauth2/token",
|
||||
strings.NewReader(data.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
|
||||
resp, err := api.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
var tokenResp Token
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse token response")
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RevokeToken sends a request to the Discord API to revoke the token pair
|
||||
func (api *APIClient) RevokeToken(token *Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token cannot be nil")
|
||||
}
|
||||
data := url.Values{}
|
||||
data.Set("token", token.AccessToken)
|
||||
data.Set("token_type_hint", "access_token")
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
apiurl+"/oauth2/token/revoke",
|
||||
strings.NewReader(data.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
|
||||
resp, err := api.Do(req)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to execute request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.Errorf("discord API returned status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
216
internal/discord/ratelimit.go
Normal file
216
internal/discord/ratelimit.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// RateLimitState tracks rate limit information for a specific bucket
|
||||
type RateLimitState struct {
|
||||
Remaining int // Requests remaining in current window
|
||||
Limit int // Total requests allowed in window
|
||||
Reset time.Time // When the rate limit resets
|
||||
Bucket string // Discord's bucket identifier
|
||||
}
|
||||
|
||||
// Do executes an HTTP request with automatic rate limit handling
|
||||
// It will wait if rate limits are about to be exceeded and retry once if a 429 is received
|
||||
func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("request cannot be nil")
|
||||
}
|
||||
|
||||
// Step 1: Check if we need to wait before making request
|
||||
bucket := c.getBucketFromRequest(req)
|
||||
if err := c.waitIfNeeded(bucket); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Step 2: Execute request
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
// Check if it's a network timeout
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return nil, errors.Wrap(err, "request timed out")
|
||||
}
|
||||
return nil, errors.Wrap(err, "http request failed")
|
||||
}
|
||||
|
||||
// Step 3: Update rate limit state from response headers
|
||||
c.updateRateLimit(resp.Header)
|
||||
|
||||
// Step 4: Handle 429 (rate limited)
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
resp.Body.Close() // Close original response
|
||||
|
||||
retryAfter := c.parseRetryAfter(resp.Header)
|
||||
|
||||
// No Retry-After header, can't retry safely
|
||||
if retryAfter == 0 {
|
||||
c.logger.Warn().
|
||||
Str("bucket", bucket).
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Msg("Rate limited but no Retry-After header provided")
|
||||
return nil, errors.New("discord API rate limited but no Retry-After header provided")
|
||||
}
|
||||
|
||||
// Retry-After exceeds 30 second cap
|
||||
if retryAfter > 30*time.Second {
|
||||
c.logger.Warn().
|
||||
Str("bucket", bucket).
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Dur("retry_after", retryAfter).
|
||||
Msg("Rate limited with Retry-After exceeding 30s cap, not retrying")
|
||||
return nil, errors.Errorf(
|
||||
"discord API rate limited (retry after %s exceeds 30s cap)",
|
||||
retryAfter,
|
||||
)
|
||||
}
|
||||
|
||||
// Wait and retry
|
||||
c.logger.Warn().
|
||||
Str("bucket", bucket).
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Dur("retry_after", retryAfter).
|
||||
Msg("Rate limited, waiting before retry")
|
||||
|
||||
time.Sleep(retryAfter)
|
||||
|
||||
// Retry the request
|
||||
resp, err = c.client.Do(req)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return nil, errors.Wrap(err, "retry request timed out")
|
||||
}
|
||||
return nil, errors.Wrap(err, "retry request failed")
|
||||
}
|
||||
|
||||
// Update rate limit again after retry
|
||||
c.updateRateLimit(resp.Header)
|
||||
|
||||
// If STILL rate limited after retry, return error
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
resp.Body.Close()
|
||||
c.logger.Error().
|
||||
Str("bucket", bucket).
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Msg("Still rate limited after retry, Discord may be experiencing issues")
|
||||
return nil, errors.Errorf(
|
||||
"discord API still rate limited after retry (waited %s), Discord may be experiencing issues",
|
||||
retryAfter,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// getBucketFromRequest extracts or generates bucket ID from request
|
||||
// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers
|
||||
func (c *APIClient) getBucketFromRequest(req *http.Request) string {
|
||||
return req.Method + ":" + req.URL.Path
|
||||
}
|
||||
|
||||
// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits
|
||||
func (c *APIClient) waitIfNeeded(bucket string) error {
|
||||
c.mu.RLock()
|
||||
state, exists := c.buckets[bucket]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil // No state yet, proceed
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// If we have no remaining requests and reset hasn't occurred, wait
|
||||
if state.Remaining == 0 && now.Before(state.Reset) {
|
||||
waitDuration := time.Until(state.Reset)
|
||||
// Add small buffer (100ms) to ensure reset has occurred
|
||||
waitDuration += 100 * time.Millisecond
|
||||
|
||||
if waitDuration > 0 {
|
||||
c.logger.Debug().
|
||||
Str("bucket", bucket).
|
||||
Dur("wait_duration", waitDuration).
|
||||
Msg("Proactively waiting for rate limit reset")
|
||||
time.Sleep(waitDuration)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateRateLimit parses response headers and updates bucket state
|
||||
func (c *APIClient) updateRateLimit(headers http.Header) {
|
||||
bucket := headers.Get("X-RateLimit-Bucket")
|
||||
if bucket == "" {
|
||||
return // No bucket info, can't track
|
||||
}
|
||||
|
||||
// Parse headers
|
||||
limit := c.parseInt(headers.Get("X-RateLimit-Limit"))
|
||||
remaining := c.parseInt(headers.Get("X-RateLimit-Remaining"))
|
||||
resetAfter := c.parseFloat(headers.Get("X-RateLimit-Reset-After"))
|
||||
|
||||
state := &RateLimitState{
|
||||
Bucket: bucket,
|
||||
Limit: limit,
|
||||
Remaining: remaining,
|
||||
Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))),
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.buckets[bucket] = state
|
||||
c.mu.Unlock()
|
||||
|
||||
// Log rate limit state for debugging
|
||||
c.logger.Debug().
|
||||
Str("bucket", bucket).
|
||||
Int("remaining", remaining).
|
||||
Int("limit", limit).
|
||||
Dur("reset_in", time.Until(state.Reset)).
|
||||
Msg("Rate limit state updated")
|
||||
}
|
||||
|
||||
// parseRetryAfter extracts retry delay from Retry-After header
|
||||
func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration {
|
||||
retryAfter := headers.Get("Retry-After")
|
||||
if retryAfter == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Discord returns seconds as float
|
||||
seconds := c.parseFloat(retryAfter)
|
||||
if seconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return time.Duration(seconds * float64(time.Second))
|
||||
}
|
||||
|
||||
// parseInt parses an integer from a header value, returns 0 on error
|
||||
func (c *APIClient) parseInt(s string) int {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
i, _ := strconv.Atoi(s)
|
||||
return i
|
||||
}
|
||||
|
||||
// parseFloat parses a float from a header value, returns 0 on error
|
||||
func (c *APIClient) parseFloat(s string) float64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
f, _ := strconv.ParseFloat(s, 64)
|
||||
return f
|
||||
}
|
||||
517
internal/discord/ratelimit_test.go
Normal file
517
internal/discord/ratelimit_test.go
Normal file
@@ -0,0 +1,517 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
)
|
||||
|
||||
// testLogger creates a test logger for testing
|
||||
func testLogger(t *testing.T) *hlog.Logger {
|
||||
level, _ := hlog.LogLevel("debug")
|
||||
cfg := &hlog.Config{
|
||||
LogLevel: level,
|
||||
LogOutput: "console",
|
||||
}
|
||||
logger, err := hlog.NewLogger(cfg, io.Discard)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test logger: %v", err)
|
||||
}
|
||||
return logger
|
||||
}
|
||||
|
||||
// testConfig creates a test config for testing
|
||||
func testConfig() *Config {
|
||||
return &Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
OAuthScopes: "identify+email",
|
||||
RedirectPath: "/oauth/callback",
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimitedClient(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("NewAPIClient returned nil")
|
||||
}
|
||||
if client.client == nil {
|
||||
t.Error("client.client is nil")
|
||||
}
|
||||
if client.logger == nil {
|
||||
t.Error("client.logger is nil")
|
||||
}
|
||||
if client.buckets == nil {
|
||||
t.Error("client.buckets map is nil")
|
||||
}
|
||||
if client.cfg == nil {
|
||||
t.Error("client.cfg is nil")
|
||||
}
|
||||
if client.trustedHost != "trusted-host.example.com" {
|
||||
t.Errorf("expected trustedHost='trusted-host.example.com', got '%s'", client.trustedHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_Success(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
// Mock server that returns success with rate limit headers
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||
w.Header().Set("X-RateLimit-Limit", "5")
|
||||
w.Header().Set("X-RateLimit-Remaining", "3")
|
||||
w.Header().Set("X-RateLimit-Reset-After", "2.5")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Do() returned error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check that rate limit state was updated
|
||||
client.mu.RLock()
|
||||
state, exists := client.buckets["test-bucket"]
|
||||
client.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Fatal("rate limit state not stored")
|
||||
}
|
||||
if state.Remaining != 3 {
|
||||
t.Errorf("expected remaining=3, got %d", state.Remaining)
|
||||
}
|
||||
if state.Limit != 5 {
|
||||
t.Errorf("expected limit=5, got %d", state.Limit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
attemptCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attemptCount++
|
||||
if attemptCount == 1 {
|
||||
// First request: return 429
|
||||
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||
w.Header().Set("Retry-After", "0.1") // 100ms
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "rate_limited",
|
||||
"error_description": "You are being rate limited",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Second request: success
|
||||
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||
w.Header().Set("X-RateLimit-Limit", "5")
|
||||
w.Header().Set("X-RateLimit-Remaining", "4")
|
||||
w.Header().Set("X-RateLimit-Reset-After", "2.5")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Do() returned error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200 after retry, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if attemptCount != 2 {
|
||||
t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount)
|
||||
}
|
||||
|
||||
// Should have waited approximately 100ms
|
||||
if elapsed < 100*time.Millisecond {
|
||||
t.Errorf("expected delay of ~100ms, but took %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
attemptCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attemptCount++
|
||||
// Always return 429
|
||||
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||
w.Header().Set("Retry-After", "0.05") // 50ms
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "rate_limited",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
t.Fatal("Do() should have returned error after failed retry")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "still rate limited after retry") {
|
||||
t.Errorf("expected 'still rate limited after retry' error, got: %v", err)
|
||||
}
|
||||
|
||||
if attemptCount != 2 {
|
||||
t.Errorf("expected 2 attempts, got %d", attemptCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "rate_limited",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
t.Fatal("Do() should have returned error for Retry-After > 30s")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "exceeds 30s cap") {
|
||||
t.Errorf("expected 'exceeds 30s cap' error, got: %v", err)
|
||||
}
|
||||
|
||||
// Should NOT have waited (immediate error)
|
||||
if elapsed > 1*time.Second {
|
||||
t.Errorf("should return immediately, but took %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return 429 but NO Retry-After header
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "rate_limited",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
t.Fatal("Do() should have returned error when no Retry-After header")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no Retry-After header") {
|
||||
t.Errorf("expected 'no Retry-After header' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_UpdateRateLimit(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("X-RateLimit-Bucket", "global")
|
||||
headers.Set("X-RateLimit-Limit", "10")
|
||||
headers.Set("X-RateLimit-Remaining", "7")
|
||||
headers.Set("X-RateLimit-Reset-After", "5.5")
|
||||
|
||||
client.updateRateLimit(headers)
|
||||
|
||||
client.mu.RLock()
|
||||
state, exists := client.buckets["global"]
|
||||
client.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Fatal("bucket state not created")
|
||||
}
|
||||
|
||||
if state.Bucket != "global" {
|
||||
t.Errorf("expected bucket='global', got '%s'", state.Bucket)
|
||||
}
|
||||
if state.Limit != 10 {
|
||||
t.Errorf("expected limit=10, got %d", state.Limit)
|
||||
}
|
||||
if state.Remaining != 7 {
|
||||
t.Errorf("expected remaining=7, got %d", state.Remaining)
|
||||
}
|
||||
|
||||
// Check reset time is approximately 5.5 seconds from now
|
||||
resetIn := time.Until(state.Reset)
|
||||
if resetIn < 5*time.Second || resetIn > 6*time.Second {
|
||||
t.Errorf("expected reset in ~5.5s, got %v", resetIn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_WaitIfNeeded(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
// Set up a bucket with 0 remaining and reset in future
|
||||
bucket := "test-bucket"
|
||||
client.mu.Lock()
|
||||
client.buckets[bucket] = &RateLimitState{
|
||||
Bucket: bucket,
|
||||
Limit: 5,
|
||||
Remaining: 0,
|
||||
Reset: time.Now().Add(200 * time.Millisecond),
|
||||
}
|
||||
client.mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
err = client.waitIfNeeded(bucket)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("waitIfNeeded returned error: %v", err)
|
||||
}
|
||||
|
||||
// Should have waited ~200ms + 100ms buffer
|
||||
if elapsed < 200*time.Millisecond {
|
||||
t.Errorf("expected wait of ~300ms, but took %v", elapsed)
|
||||
}
|
||||
if elapsed > 500*time.Millisecond {
|
||||
t.Errorf("waited too long: %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
// Set up a bucket with remaining requests
|
||||
bucket := "test-bucket"
|
||||
client.mu.Lock()
|
||||
client.buckets[bucket] = &RateLimitState{
|
||||
Bucket: bucket,
|
||||
Limit: 5,
|
||||
Remaining: 3,
|
||||
Reset: time.Now().Add(5 * time.Second),
|
||||
}
|
||||
client.mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
err = client.waitIfNeeded(bucket)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("waitIfNeeded returned error: %v", err)
|
||||
}
|
||||
|
||||
// Should NOT wait (has remaining requests)
|
||||
if elapsed > 10*time.Millisecond {
|
||||
t.Errorf("should not wait when remaining > 0, but took %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_Concurrent(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
requestCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
requestCount++
|
||||
count := requestCount
|
||||
mu.Unlock()
|
||||
|
||||
w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket")
|
||||
w.Header().Set("X-RateLimit-Limit", "10")
|
||||
w.Header().Set("X-RateLimit-Remaining", "5")
|
||||
w.Header().Set("X-RateLimit-Reset-After", "1.0")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Launch 10 concurrent requests
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, 10)
|
||||
|
||||
for range 10 {
|
||||
wg.Go(
|
||||
func() {
|
||||
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for any errors
|
||||
for err := range errors {
|
||||
t.Errorf("concurrent request failed: %v", err)
|
||||
}
|
||||
|
||||
// All requests should have completed
|
||||
mu.Lock()
|
||||
finalCount := requestCount
|
||||
mu.Unlock()
|
||||
|
||||
if finalCount != 10 {
|
||||
t.Errorf("expected 10 requests, got %d", finalCount)
|
||||
}
|
||||
|
||||
// Check rate limit state is consistent (no data races)
|
||||
client.mu.RLock()
|
||||
state, exists := client.buckets["concurrent-bucket"]
|
||||
client.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Fatal("bucket state not found after concurrent requests")
|
||||
}
|
||||
|
||||
// State should exist and be valid
|
||||
if state.Limit != 10 {
|
||||
t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_ParseRetryAfter(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
expected time.Duration
|
||||
}{
|
||||
{"integer seconds", "2", 2 * time.Second},
|
||||
{"float seconds", "2.5", 2500 * time.Millisecond},
|
||||
{"zero", "0", 0},
|
||||
{"empty", "", 0},
|
||||
{"invalid", "abc", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("Retry-After", tt.header)
|
||||
|
||||
result := client.parseRetryAfter(headers)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
205
internal/handlers/callback.go
Normal file
205
internal/handlers/callback.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||
)
|
||||
|
||||
func Callback(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
store *store.Store,
|
||||
discordAPI *discord.APIClient,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
|
||||
|
||||
if exceeded {
|
||||
err := errors.Errorf(
|
||||
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||
attempts,
|
||||
track.IP,
|
||||
track.UserAgent,
|
||||
track.Path,
|
||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||
)
|
||||
|
||||
store.ClearRedirectTrack(r, "/callback")
|
||||
|
||||
throwError(
|
||||
server,
|
||||
w,
|
||||
r,
|
||||
http.StatusBadRequest,
|
||||
"OAuth callback failed: Too many redirect attempts. Please try logging in again.",
|
||||
err,
|
||||
"warn",
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
state := r.URL.Query().Get("state")
|
||||
code := r.URL.Query().Get("code")
|
||||
if state == "" && code == "" {
|
||||
http.Redirect(w, r, "/", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
data, err := verifyState(cfg.OAuth, w, r, state)
|
||||
if err != nil {
|
||||
if vsErr, ok := err.(*verifyStateError); ok {
|
||||
if vsErr.IsCookieError() {
|
||||
throwUnauthorized(server, w, r, "OAuth session not found or expired", err)
|
||||
} else {
|
||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
||||
}
|
||||
} else {
|
||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
store.ClearRedirectTrack(r, "/callback")
|
||||
|
||||
switch data {
|
||||
case "login":
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "DB Transaction failed to start", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "OAuth login failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
redirect()
|
||||
return
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
type verifyStateError struct {
|
||||
err error
|
||||
cookieError bool
|
||||
}
|
||||
|
||||
func (e *verifyStateError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *verifyStateError) IsCookieError() bool {
|
||||
return e.cookieError
|
||||
}
|
||||
|
||||
func verifyState(
|
||||
cfg *oauth.Config,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
state string,
|
||||
) (string, error) {
|
||||
if r == nil {
|
||||
return "", errors.New("request cannot be nil")
|
||||
}
|
||||
if state == "" {
|
||||
return "", errors.New("state param field is empty")
|
||||
}
|
||||
|
||||
uak, err := oauth.GetStateCookie(r)
|
||||
if err != nil {
|
||||
return "", &verifyStateError{
|
||||
err: errors.Wrap(err, "oauth.GetStateCookie"),
|
||||
cookieError: true,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := oauth.VerifyState(cfg, state, uak)
|
||||
if err != nil {
|
||||
return "", &verifyStateError{
|
||||
err: errors.Wrap(err, "oauth.VerifyState"),
|
||||
cookieError: false,
|
||||
}
|
||||
}
|
||||
|
||||
oauth.DeleteStateCookie(w)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func login(
|
||||
ctx context.Context,
|
||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||
tx bun.Tx,
|
||||
cfg *config.Config,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
code string,
|
||||
store *store.Store,
|
||||
discordAPI *discord.APIClient,
|
||||
) (func(), error) {
|
||||
token, err := discordAPI.AuthorizeWithCode(code)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "discordAPI.AuthorizeWithCode")
|
||||
}
|
||||
session, err := discord.NewOAuthSession(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "discord.NewOAuthSession")
|
||||
}
|
||||
discorduser, err := session.GetUser()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "session.GetUser")
|
||||
}
|
||||
|
||||
user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserByDiscordID")
|
||||
}
|
||||
var redirect string
|
||||
if user == nil {
|
||||
sessionID, err := store.CreateRegistrationSession(discorduser, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "store.CreateRegistrationSession")
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "registration_session",
|
||||
Path: "/",
|
||||
Value: sessionID,
|
||||
MaxAge: 300, // 5 minutes
|
||||
HttpOnly: true,
|
||||
Secure: cfg.HWSAuth.SSL,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
redirect = "/register"
|
||||
} else {
|
||||
err = user.UpdateDiscordToken(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "user.UpdateDiscordToken")
|
||||
}
|
||||
err := auth.Login(w, r, user, true)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "auth.Login")
|
||||
}
|
||||
redirect = cookies.CheckPageFrom(w, r)
|
||||
}
|
||||
return func() {
|
||||
http.Redirect(w, r, redirect, http.StatusSeeOther)
|
||||
}, nil
|
||||
}
|
||||
@@ -5,24 +5,93 @@ import (
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func ErrorPage(
|
||||
errorCode int,
|
||||
) (hws.ErrorPage, error) {
|
||||
// func ErrorPage(
|
||||
// error hws.HWSError,
|
||||
// ) (hws.ErrorPage, error) {
|
||||
// messages := map[int]string{
|
||||
// 400: "The request you made was malformed or unexpected.",
|
||||
// 401: "You need to login to view this page.",
|
||||
// 403: "You do not have permission to view this page.",
|
||||
// 404: "The page or resource you have requested does not exist.",
|
||||
// 500: `An error occured on the server. Please try again, and if this
|
||||
// continues to happen contact an administrator.`,
|
||||
// 503: "The server is currently down for maintenance and should be back soon. =)",
|
||||
// }
|
||||
// msg, exists := messages[error.StatusCode]
|
||||
// if !exists {
|
||||
// return nil, errors.New("No valid message for the given code")
|
||||
// }
|
||||
// return page.Error(error.StatusCode, http.StatusText(error.StatusCode), msg), nil
|
||||
// }
|
||||
|
||||
func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
|
||||
// Determine if this status code should show technical details
|
||||
showDetails := shouldShowDetails(hwsError.StatusCode)
|
||||
|
||||
// Get the user-friendly message
|
||||
message := hwsError.Message
|
||||
if message == "" {
|
||||
// Fallback to default messages if no custom message provided
|
||||
message = getDefaultMessage(hwsError.StatusCode)
|
||||
}
|
||||
|
||||
// Get technical details if applicable
|
||||
var details string
|
||||
if showDetails && hwsError.Error != nil {
|
||||
details = hwsError.Error.Error()
|
||||
}
|
||||
|
||||
// Render appropriate template
|
||||
if details != "" {
|
||||
return page.ErrorWithDetails(
|
||||
hwsError.StatusCode,
|
||||
http.StatusText(hwsError.StatusCode),
|
||||
message,
|
||||
details,
|
||||
), nil
|
||||
}
|
||||
|
||||
return page.Error(
|
||||
hwsError.StatusCode,
|
||||
http.StatusText(hwsError.StatusCode),
|
||||
message,
|
||||
), nil
|
||||
}
|
||||
|
||||
// shouldShowDetails determines if a status code should display technical details
|
||||
func shouldShowDetails(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 400, 500, 503: // Bad Request, Internal Server Error, Service Unavailable
|
||||
return true
|
||||
case 401, 403, 404: // Unauthorized, Forbidden, Not Found
|
||||
return false
|
||||
default:
|
||||
// For unknown codes, show details for 5xx errors
|
||||
return statusCode >= 500
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultMessage provides fallback messages for status codes
|
||||
func getDefaultMessage(statusCode int) string {
|
||||
messages := map[int]string{
|
||||
400: "The request you made was malformed or unexpected.",
|
||||
401: "You need to login to view this page.",
|
||||
403: "You do not have permission to view this page.",
|
||||
404: "The page or resource you have requested does not exist.",
|
||||
500: `An error occured on the server. Please try again, and if this
|
||||
continues to happen contact an administrator.`,
|
||||
500: `An error occurred on the server. Please try again, and if this
|
||||
continues to happen contact an administrator.`,
|
||||
503: "The server is currently down for maintenance and should be back soon. =)",
|
||||
}
|
||||
msg, exists := messages[errorCode]
|
||||
|
||||
msg, exists := messages[statusCode]
|
||||
if !exists {
|
||||
return nil, errors.New("No valid message for the given code")
|
||||
if statusCode >= 500 {
|
||||
return "A server error occurred. Please try again later."
|
||||
}
|
||||
return "An error occurred while processing your request."
|
||||
}
|
||||
return page.Error(errorCode, http.StatusText(errorCode), msg), nil
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
109
internal/handlers/errors.go
Normal file
109
internal/handlers/errors.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// throwError is a generic helper that all throw* functions use internally
|
||||
func throwError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
statusCode int,
|
||||
msg string,
|
||||
err error,
|
||||
level string,
|
||||
) {
|
||||
err = s.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: statusCode,
|
||||
Message: msg,
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel(level),
|
||||
RenderErrorPage: true, // throw* family always renders error pages
|
||||
})
|
||||
if err != nil {
|
||||
s.ThrowFatal(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
// throwInternalServiceError handles 500 errors (server failures)
|
||||
func throwInternalServiceError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusInternalServerError, msg, err, "error")
|
||||
}
|
||||
|
||||
// throwBadRequest handles 400 errors (malformed requests)
|
||||
func throwBadRequest(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusBadRequest, msg, err, "debug")
|
||||
}
|
||||
|
||||
// throwForbidden handles 403 errors (normal permission denials)
|
||||
func throwForbidden(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusForbidden, msg, err, "debug")
|
||||
}
|
||||
|
||||
// throwForbiddenSecurity handles 403 errors for security events (uses WARN level)
|
||||
func throwForbiddenSecurity(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusForbidden, msg, err, "warn")
|
||||
}
|
||||
|
||||
// throwUnauthorized handles 401 errors (not authenticated)
|
||||
func throwUnauthorized(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, "debug")
|
||||
}
|
||||
|
||||
// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level)
|
||||
func throwUnauthorizedSecurity(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, "warn")
|
||||
}
|
||||
|
||||
// throwNotFound handles 404 errors
|
||||
func throwNotFound(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
path string,
|
||||
) {
|
||||
msg := fmt.Sprintf("The requested resource was not found: %s", path)
|
||||
err := errors.New("Resource not found")
|
||||
throwError(s, w, r, http.StatusNotFound, msg, err, "debug")
|
||||
}
|
||||
@@ -14,34 +14,7 @@ func Index(server *hws.Server) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
page, err := ErrorPage(http.StatusNotFound)
|
||||
if err != nil {
|
||||
err = server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured trying to generate the error page",
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel("error"),
|
||||
RenderErrorPage: false,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
err = page.Render(r.Context(), w)
|
||||
if err != nil {
|
||||
err = server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured trying to render the error page",
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel("error"),
|
||||
RenderErrorPage: false,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
throwNotFound(server, w, r, r.URL.Path)
|
||||
}
|
||||
page.Index().Render(r.Context(), w)
|
||||
},
|
||||
|
||||
45
internal/handlers/isusernameunique.go
Normal file
45
internal/handlers/isusernameunique.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func IsUsernameUnique(
|
||||
server *hws.Server,
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
store *store.Store,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
username := r.FormValue("username")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database transaction failed", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
unique, err := db.IsUsernameUnique(ctx, tx, username)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database query failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
if !unique {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
63
internal/handlers/login.go
Normal file
63
internal/handlers/login.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||
)
|
||||
|
||||
func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
|
||||
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
|
||||
|
||||
if exceeded {
|
||||
err := errors.Errorf(
|
||||
"login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||
attempts,
|
||||
track.IP,
|
||||
track.UserAgent,
|
||||
track.Path,
|
||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||
)
|
||||
|
||||
st.ClearRedirectTrack(r, "/login")
|
||||
|
||||
throwError(
|
||||
server,
|
||||
w,
|
||||
r,
|
||||
http.StatusBadRequest,
|
||||
"Login failed: Too many redirect attempts. Please clear your browser cookies and try again.",
|
||||
err,
|
||||
"warn",
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Failed to generate state token", err)
|
||||
return
|
||||
}
|
||||
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
|
||||
|
||||
link, err := discordAPI.GetOAuthLink(state)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err)
|
||||
return
|
||||
}
|
||||
st.ClearRedirectTrack(r, "/login")
|
||||
|
||||
http.Redirect(w, r, link, http.StatusSeeOther)
|
||||
},
|
||||
)
|
||||
}
|
||||
59
internal/handlers/logout.go
Normal file
59
internal/handlers/logout.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func Logout(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||
conn *bun.DB,
|
||||
discordAPI *discord.APIClient,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
user := db.CurrentUser(r.Context())
|
||||
if user == nil {
|
||||
// JIC - should be impossible to get here if route is protected by LoginReq
|
||||
w.Header().Set("HX-Redirect", "/")
|
||||
return
|
||||
}
|
||||
token, err := user.DeleteDiscordTokens(ctx, tx)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
|
||||
return
|
||||
}
|
||||
err = discordAPI.RevokeToken(token.Convert())
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
|
||||
return
|
||||
}
|
||||
err = auth.Logout(tx, w, r)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Logout failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.Header().Set("HX-Redirect", "/")
|
||||
},
|
||||
)
|
||||
}
|
||||
129
internal/handlers/register.go
Normal file
129
internal/handlers/register.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
)
|
||||
|
||||
func Register(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
store *store.Store,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts, exceeded, track := store.TrackRedirect(r, "/register", 3)
|
||||
|
||||
if exceeded {
|
||||
err := errors.Errorf(
|
||||
"registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t",
|
||||
attempts,
|
||||
track.IP,
|
||||
track.UserAgent,
|
||||
track.Path,
|
||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||
cfg.HWSAuth.SSL,
|
||||
)
|
||||
|
||||
store.ClearRedirectTrack(r, "/register")
|
||||
|
||||
throwError(
|
||||
server,
|
||||
w,
|
||||
r,
|
||||
http.StatusBadRequest,
|
||||
"Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.",
|
||||
err,
|
||||
"warn",
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
sessionCookie, err := r.Cookie("registration_session")
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
details, ok := store.GetRegistrationSession(sessionCookie.Value)
|
||||
if !ok {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
store.ClearRedirectTrack(r, "/register")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database transaction failed", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
method := r.Method
|
||||
if method == "GET" {
|
||||
tx.Commit()
|
||||
page.Register(details.DiscordUser.Username).Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
if method == "POST" {
|
||||
username := r.FormValue("username")
|
||||
user, err := registerUser(ctx, tx, username, details)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Registration failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
if user == nil {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
} else {
|
||||
err = auth.Login(w, r, user, true)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Login failed", err)
|
||||
return
|
||||
}
|
||||
pageFrom := cookies.CheckPageFrom(w, r)
|
||||
w.Header().Set("HX-Redirect", pageFrom)
|
||||
}
|
||||
return
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func registerUser(
|
||||
ctx context.Context,
|
||||
tx bun.Tx,
|
||||
username string,
|
||||
details *store.RegistrationSession,
|
||||
) (*db.User, error) {
|
||||
unique, err := db.IsUsernameUnique(ctx, tx, username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.IsUsernameUnique")
|
||||
}
|
||||
if !unique {
|
||||
return nil, nil
|
||||
}
|
||||
user, err := db.CreateUser(ctx, tx, username, details.DiscordUser)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.CreateUser")
|
||||
}
|
||||
err = user.UpdateDiscordToken(ctx, tx, details.Token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.UpdateDiscordToken")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
@@ -15,16 +15,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler {
|
||||
if err != nil {
|
||||
// If we can't create the file server, return a handler that always errors
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err = server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured trying to load the file system",
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel("error"),
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
46
internal/store/newlogin.go
Normal file
46
internal/store/newlogin.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
type RegistrationSession struct {
|
||||
DiscordUser *discordgo.User
|
||||
Token *discord.Token
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func (s *Store) CreateRegistrationSession(user *discordgo.User, token *discord.Token) (string, error) {
|
||||
if user == nil {
|
||||
return "", errors.New("user cannot be nil")
|
||||
}
|
||||
if token == nil {
|
||||
return "", errors.New("token cannot be nil")
|
||||
}
|
||||
id := generateID()
|
||||
s.sessions.Store(id, &RegistrationSession{
|
||||
DiscordUser: user,
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetRegistrationSession(id string) (*RegistrationSession, bool) {
|
||||
val, ok := s.sessions.Load(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
session := val.(*RegistrationSession)
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
s.sessions.Delete(id)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return session, true
|
||||
}
|
||||
95
internal/store/redirects.go
Normal file
95
internal/store/redirects.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// getClientIP extracts the client IP address, checking X-Forwarded-For first
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (comma-separated list, first is client)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the list
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr (format: "IP:port" or "[IPv6]:port")
|
||||
// Use net.SplitHostPort to properly handle both IPv4 and IPv6
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
// If SplitHostPort fails, return as-is (shouldn't happen with valid RemoteAddr)
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// TrackRedirect increments the redirect counter for this IP+UA+Path combination
|
||||
// Returns the current attempt count, whether limit was exceeded, and the track details
|
||||
func (s *Store) TrackRedirect(r *http.Request, path string, maxAttempts int) (attempts int, exceeded bool, track *RedirectTrack) {
|
||||
if r == nil {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
ip := getClientIP(r)
|
||||
userAgent := r.UserAgent()
|
||||
key := redirectKey(ip, userAgent, path)
|
||||
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(5 * time.Minute)
|
||||
|
||||
// Try to load existing track
|
||||
val, exists := s.redirectTracks.Load(key)
|
||||
if exists {
|
||||
track = val.(*RedirectTrack)
|
||||
|
||||
// Check if expired
|
||||
if now.After(track.ExpiresAt) {
|
||||
// Expired, start fresh
|
||||
track = &RedirectTrack{
|
||||
IP: ip,
|
||||
UserAgent: userAgent,
|
||||
Path: path,
|
||||
Attempts: 1,
|
||||
FirstSeen: now,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
s.redirectTracks.Store(key, track)
|
||||
return 1, false, track
|
||||
}
|
||||
|
||||
// Increment existing
|
||||
track.Attempts++
|
||||
track.ExpiresAt = expiresAt // Extend expiry
|
||||
exceeded = track.Attempts >= maxAttempts
|
||||
return track.Attempts, exceeded, track
|
||||
}
|
||||
|
||||
// Create new track
|
||||
track = &RedirectTrack{
|
||||
IP: ip,
|
||||
UserAgent: userAgent,
|
||||
Path: path,
|
||||
Attempts: 1,
|
||||
FirstSeen: now,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
s.redirectTracks.Store(key, track)
|
||||
return 1, false, track
|
||||
}
|
||||
|
||||
// ClearRedirectTrack removes a redirect tracking entry (called after successful completion)
|
||||
func (s *Store) ClearRedirectTrack(r *http.Request, path string) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ip := getClientIP(r)
|
||||
userAgent := r.UserAgent()
|
||||
key := redirectKey(ip, userAgent, path)
|
||||
s.redirectTracks.Delete(key)
|
||||
}
|
||||
80
internal/store/store.go
Normal file
80
internal/store/store.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedirectTrack represents a single redirect attempt tracking entry
|
||||
type RedirectTrack struct {
|
||||
IP string // Client IP (X-Forwarded-For aware)
|
||||
UserAgent string // Full User-Agent string for debugging
|
||||
Path string // Request path (without query params)
|
||||
Attempts int // Number of redirect attempts
|
||||
FirstSeen time.Time // When first redirect was tracked
|
||||
ExpiresAt time.Time // When to clean up this entry
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
sessions sync.Map // key: string, value: *RegistrationSession
|
||||
redirectTracks sync.Map // key: string, value: *RedirectTrack
|
||||
cleanup *time.Ticker
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
s := &Store{
|
||||
cleanup: time.NewTicker(1 * time.Minute),
|
||||
}
|
||||
|
||||
// Background cleanup of expired sessions
|
||||
go func() {
|
||||
for range s.cleanup.C {
|
||||
s.cleanupExpired()
|
||||
}
|
||||
}()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Store) Delete(id string) {
|
||||
s.sessions.Delete(id)
|
||||
}
|
||||
func (s *Store) cleanupExpired() {
|
||||
now := time.Now()
|
||||
|
||||
// Clean up expired registration sessions
|
||||
s.sessions.Range(func(key, value any) bool {
|
||||
session := value.(*RegistrationSession)
|
||||
if now.After(session.ExpiresAt) {
|
||||
s.sessions.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Clean up expired redirect tracks
|
||||
s.redirectTracks.Range(func(key, value any) bool {
|
||||
track := value.(*RedirectTrack)
|
||||
if now.After(track.ExpiresAt) {
|
||||
s.redirectTracks.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
func generateID() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// redirectKey generates a unique key for tracking redirects
|
||||
// Uses IP + first 100 chars of UA + path as key (not hashed for debugging)
|
||||
func redirectKey(ip, userAgent, path string) string {
|
||||
ua := userAgent
|
||||
if len(ua) > 100 {
|
||||
ua = ua[:100]
|
||||
}
|
||||
return fmt.Sprintf("%s:%s:%s", ip, ua, path)
|
||||
}
|
||||
89
internal/view/component/form/register.templ
Normal file
89
internal/view/component/form/register.templ
Normal file
@@ -0,0 +1,89 @@
|
||||
package form
|
||||
|
||||
templ RegisterForm(username string) {
|
||||
<form
|
||||
hx-post="/register"
|
||||
hx-swap="none"
|
||||
x-data={ templ.JSFuncCall("registerFormData").CallInline }
|
||||
@submit="handleSubmit()"
|
||||
@htmx:after-request="if(submitTimeout) clearTimeout(submitTimeout); const redirect = $event.detail.xhr.getResponseHeader('HX-Redirect'); if(redirect) return; if(!$event.detail.successful) { isSubmitting=false; buttontext='Register'; if($event.detail.xhr.status === 409) { errorMessage='Username is already taken'; isUnique=false; } else { errorMessage='An error occurred. Please try again.'; } }"
|
||||
>
|
||||
<script>
|
||||
function registerFormData() {
|
||||
return {
|
||||
canSubmit: false,
|
||||
buttontext: "Register",
|
||||
errorMessage: "",
|
||||
isChecking: false,
|
||||
isUnique: false,
|
||||
isEmpty: true,
|
||||
isSubmitting: false,
|
||||
submitTimeout: null,
|
||||
resetErr() {
|
||||
this.errorMessage = "";
|
||||
this.isChecking = false;
|
||||
this.isUnique = false;
|
||||
},
|
||||
enableSubmit() {
|
||||
this.canSubmit = true;
|
||||
},
|
||||
handleSubmit() {
|
||||
this.isSubmitting = true;
|
||||
this.buttontext = 'Loading...';
|
||||
// Set timeout for 10 seconds
|
||||
this.submitTimeout = setTimeout(() => {
|
||||
this.isSubmitting = false;
|
||||
this.buttontext = 'Register';
|
||||
this.errorMessage = 'Request timed out. Please try again.';
|
||||
}, 10000);
|
||||
}
|
||||
};
|
||||
}
|
||||
</script>
|
||||
<div
|
||||
class="grid gap-y-4"
|
||||
>
|
||||
<div>
|
||||
<div class="relative">
|
||||
<input
|
||||
type="text"
|
||||
id="username"
|
||||
name="username"
|
||||
x-bind:class="{
|
||||
'py-3 px-4 block w-full rounded-lg text-sm bg-base disabled:opacity-50 disabled:pointer-events-none border-2 outline-none': true,
|
||||
'border-overlay0 focus:border-blue': !isUnique && !errorMessage,
|
||||
'border-green focus:border-green': isUnique && !isChecking && !errorMessage,
|
||||
'border-red focus:border-red': errorMessage && !isChecking && !isSubmitting
|
||||
}"
|
||||
required
|
||||
aria-describedby="username-error"
|
||||
value={ username }
|
||||
@input="resetErr(); isEmpty = $el.value.trim() === ''; if(isEmpty) { errorMessage='Username is required'; isUnique=false; }"
|
||||
hx-post="/htmx/isusernameunique"
|
||||
hx-trigger="load delay:100ms, input changed delay:500ms"
|
||||
hx-swap="none"
|
||||
@htmx:before-request="if($el.value.trim() === '') { isEmpty=true; return; } isEmpty=false; isChecking=true; isUnique=false; errorMessage=''"
|
||||
@htmx:after-request="isChecking=false; if($event.detail.successful) { isUnique=true; canSubmit=true; } else if($event.detail.xhr.status === 409) { errorMessage='Username is already taken'; isUnique=false; canSubmit=false; }"
|
||||
/>
|
||||
|
||||
<p
|
||||
class="text-center text-xs text-red mt-2"
|
||||
id="username-error"
|
||||
x-show="errorMessage && !isSubmitting"
|
||||
x-cloak
|
||||
x-text="errorMessage"
|
||||
></p>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
x-bind:disabled="isEmpty || !isUnique || isChecking || isSubmitting"
|
||||
x-text="buttontext"
|
||||
type="submit"
|
||||
class="w-full py-3 px-4 inline-flex justify-center items-center
|
||||
gap-x-2 rounded-lg border border-transparent transition
|
||||
bg-green hover:bg-green/75 text-mantle hover:cursor-pointer
|
||||
disabled:bg-green/60 disabled:cursor-default"
|
||||
></button>
|
||||
</div>
|
||||
</form>
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package nav
|
||||
|
||||
import "git.haelnorr.com/h/oslstats/pkg/contexts"
|
||||
import "git.haelnorr.com/h/oslstats/internal/db"
|
||||
|
||||
type ProfileItem struct {
|
||||
name string // Label to display
|
||||
@@ -23,7 +23,7 @@ func getProfileItems() []ProfileItem {
|
||||
|
||||
// Returns the right portion of the navbar
|
||||
templ navRight() {
|
||||
{{ user := contexts.CurrentUser(ctx) }}
|
||||
{{ user := db.CurrentUser(ctx) }}
|
||||
{{ items := getProfileItems() }}
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="sm:flex sm:gap-2">
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package nav
|
||||
|
||||
import "git.haelnorr.com/h/oslstats/pkg/contexts"
|
||||
import "git.haelnorr.com/h/oslstats/internal/db"
|
||||
|
||||
// Returns the mobile version of the navbar thats only visible when activated
|
||||
templ sideNav(navItems []NavItem) {
|
||||
{{ user := contexts.CurrentUser(ctx) }}
|
||||
{{ user := db.CurrentUser(ctx) }}
|
||||
<div
|
||||
x-show="open"
|
||||
x-transition
|
||||
|
||||
@@ -18,14 +18,11 @@ templ Global(title string) {
|
||||
window.matchMedia('(prefers-color-scheme: dark)').matches)}"
|
||||
>
|
||||
<head>
|
||||
<!-- <script src="/static/js/theme.js"></script> -->
|
||||
<script src="/static/js/theme.js"></script>
|
||||
<meta charset="UTF-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<title>{ title }</title>
|
||||
<link rel="icon" type="image/x-icon" href="/static/favicon.ico"/>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com"/>
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin="anonymous"/>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Ubuntu+Mono:ital,wght@0,400;0,700;1,400;1,700&display=swap" rel="stylesheet"/>
|
||||
<link href="/static/css/output.css" rel="stylesheet"/>
|
||||
<script src="https://unpkg.com/htmx.org@2.0.4" integrity="sha384-HGfztofotfshcF7+8n44JQL2oJmowVChPTg48S+jvZoztPfvwD79OC/LTtG6dMp+" crossorigin="anonymous"></script>
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/@alpinejs/persist@3.x.x/dist/cdn.min.js"></script>
|
||||
@@ -38,19 +35,6 @@ templ Global(title string) {
|
||||
const bodyData = {
|
||||
showError500: false,
|
||||
showError503: false,
|
||||
showConfirmPasswordModal: false,
|
||||
handleHtmxBeforeOnLoad(event) {
|
||||
const requestPath = event.detail.pathInfo.requestPath;
|
||||
if (requestPath === "/reauthenticate") {
|
||||
// handle password incorrect on refresh attempt
|
||||
if (event.detail.xhr.status === 445) {
|
||||
event.detail.shouldSwap = true;
|
||||
event.detail.isError = false;
|
||||
} else if (event.detail.xhr.status === 200) {
|
||||
this.showConfirmPasswordModal = false;
|
||||
}
|
||||
}
|
||||
},
|
||||
// handle errors from the server on HTMX requests
|
||||
handleHtmxError(event) {
|
||||
const errorCode = event.detail.errorInfo.error;
|
||||
@@ -65,11 +49,6 @@ templ Global(title string) {
|
||||
this.showError503 = true;
|
||||
setTimeout(() => (this.showError503 = false), 6000);
|
||||
}
|
||||
|
||||
// user is authorized but needs to refresh their login
|
||||
if (errorCode.includes("Code 444")) {
|
||||
this.showConfirmPasswordModal = true;
|
||||
}
|
||||
},
|
||||
};
|
||||
</script>
|
||||
@@ -78,7 +57,6 @@ templ Global(title string) {
|
||||
class="bg-base text-text ubuntu-mono-regular overflow-x-hidden"
|
||||
x-data="bodyData"
|
||||
x-on:htmx:error="handleHtmxError($event)"
|
||||
x-on:htmx:before-on-load="handleHtmxBeforeOnLoad($event)"
|
||||
>
|
||||
@popup.Error500Popup()
|
||||
@popup.Error503Popup()
|
||||
|
||||
@@ -3,32 +3,66 @@ package page
|
||||
import "git.haelnorr.com/h/oslstats/internal/view/layout"
|
||||
import "strconv"
|
||||
|
||||
// Page template for Error pages. Error code should be a HTTP status code as
|
||||
// a string, and err should be the corresponding response title.
|
||||
// Message is a custom error message displayed below the code and error.
|
||||
// Original Error template (keep for backwards compatibility where needed)
|
||||
templ Error(code int, err string, message string) {
|
||||
@ErrorWithDetails(code, err, message, "")
|
||||
}
|
||||
|
||||
// Enhanced Error template with optional details section
|
||||
templ ErrorWithDetails(code int, err string, message string, details string) {
|
||||
@layout.Global(err) {
|
||||
<div
|
||||
class="grid mt-24 left-0 right-0 top-0 bottom-0
|
||||
place-content-center bg-base px-4"
|
||||
>
|
||||
<div class="text-center">
|
||||
<h1
|
||||
class="text-9xl text-text"
|
||||
>{ strconv.Itoa(code) }</h1>
|
||||
<p
|
||||
class="text-2xl font-bold tracking-tight text-subtext1
|
||||
sm:text-4xl"
|
||||
>{ err }</p>
|
||||
<p
|
||||
class="mt-4 text-subtext0"
|
||||
>{ message }</p>
|
||||
<a
|
||||
href="/"
|
||||
class="mt-6 inline-block rounded-lg bg-mauve px-5 py-3
|
||||
text-sm text-crust transition hover:bg-mauve/75"
|
||||
>Go to homepage</a>
|
||||
<div class="grid mt-24 left-0 right-0 top-0 bottom-0 place-content-center bg-base px-4">
|
||||
<div class="text-center max-w-2xl mx-auto">
|
||||
<h1 class="text-9xl text-text">{ strconv.Itoa(code) }</h1>
|
||||
<p class="text-2xl font-bold tracking-tight text-subtext1 sm:text-4xl">{ err }</p>
|
||||
// Always show the message from hws.HWSError.Message
|
||||
<p class="mt-4 text-subtext0">{ message }</p>
|
||||
// Conditionally show technical details in dropdown
|
||||
if details != "" {
|
||||
<div class="mt-8 text-left">
|
||||
<details class="bg-surface0 rounded-lg p-4 text-right">
|
||||
<summary class="text-left cursor-pointer text-subtext1 font-semibold select-none hover:text-text">
|
||||
Details
|
||||
<span class="text-xs text-subtext0 ml-2">(click to expand)</span>
|
||||
</summary>
|
||||
<div class="text-left mt-4 relative">
|
||||
<pre id="details" class="text-xs text-subtext0 font-mono whitespace-pre-wrap break-all bg-mantle p-4 rounded overflow-x-auto">{ details }</pre>
|
||||
</div>
|
||||
<button
|
||||
onclick="copyToClipboard('details')"
|
||||
id="copyButton"
|
||||
class="mt-2 bg-mauve text-crust px-3 py-1 rounded text-xs hover:bg-mauve/75 transition hover:cursor-pointer"
|
||||
title="Copy to clipboard"
|
||||
>
|
||||
Copy
|
||||
</button>
|
||||
</details>
|
||||
</div>
|
||||
}
|
||||
<a href="/" class="mt-6 inline-block rounded-lg bg-mauve px-5 py-3 text-sm text-crust transition hover:bg-mauve/75">
|
||||
Go to homepage
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
if details != "" {
|
||||
<script>
|
||||
function copyToClipboard(id) {
|
||||
var details = document.getElementById(id).innerText;
|
||||
var button = document.getElementById("copyButton");
|
||||
navigator.clipboard
|
||||
.writeText(details)
|
||||
.then(function () {
|
||||
button.innerText = "Copied!";
|
||||
setTimeout(function () {
|
||||
button.innerText = "Copy";
|
||||
}, 2000);
|
||||
})
|
||||
.catch(function (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
button.innerText = "Failed";
|
||||
});
|
||||
}
|
||||
</script>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
29
internal/view/page/register.templ
Normal file
29
internal/view/page/register.templ
Normal file
@@ -0,0 +1,29 @@
|
||||
package page
|
||||
|
||||
import "git.haelnorr.com/h/oslstats/internal/view/layout"
|
||||
import "git.haelnorr.com/h/oslstats/internal/view/component/form"
|
||||
|
||||
// Returns the login page
|
||||
templ Register(username string) {
|
||||
@layout.Global("Register") {
|
||||
<div class="max-w-100 mx-auto px-2">
|
||||
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
|
||||
<div class="p-4 sm:p-7">
|
||||
<div class="text-center">
|
||||
<h1
|
||||
class="block text-2xl font-bold"
|
||||
>Set your display name</h1>
|
||||
<p
|
||||
class="mt-2 text-sm text-subtext0"
|
||||
>
|
||||
Select your display name. This must be unique, and cannot be changed.
|
||||
</p>
|
||||
</div>
|
||||
<div class="mt-5">
|
||||
@form.RegisterForm(username)
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package contexts
|
||||
|
||||
import (
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
)
|
||||
|
||||
var CurrentUser hwsauth.ContextLoader[*db.User]
|
||||
@@ -1,7 +1,7 @@
|
||||
package contexts
|
||||
|
||||
type contextKey string
|
||||
type Key string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
func (c Key) String() string {
|
||||
return "oslstats context key " + string(c)
|
||||
}
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
@import url("https://fonts.googleapis.com/css2?family=Ubuntu+Mono:ital,wght@0,400;0,700;1,400;1,700&display=swap");
|
||||
@import "tailwindcss";
|
||||
|
||||
@source "../../../../internal/view/component/footer/footer.templ";
|
||||
@source "../../../../internal/view/component/nav/navbarleft.templ";
|
||||
@source "../../../../internal/view/component/nav/navbarright.templ";
|
||||
@source "../../../../internal/view/component/nav/navbar.templ";
|
||||
@source "../../../../internal/view/component/nav/sidenav.templ";
|
||||
@source "../../../../internal/view/component/popup/error500Popup.templ";
|
||||
@source "../../../../internal/view/component/popup/error503Popup.templ";
|
||||
@source "../../../../internal/view/layout/global.templ";
|
||||
@source "../../../../internal/view/page/error.templ";
|
||||
@source "../../../../internal/view/page/index.templ";
|
||||
|
||||
[x-cloak] {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
@theme inline {
|
||||
--color-rosewater: var(--rosewater);
|
||||
--color-flamingo: var(--flamingo);
|
||||
@@ -43,6 +34,7 @@
|
||||
--color-mantle: var(--mantle);
|
||||
--color-crust: var(--crust);
|
||||
}
|
||||
|
||||
:root {
|
||||
--rosewater: hsl(11, 59%, 67%);
|
||||
--flamingo: hsl(0, 60%, 67%);
|
||||
@@ -102,6 +94,7 @@
|
||||
--mantle: hsl(240, 21%, 12%);
|
||||
--crust: hsl(240, 23%, 9%);
|
||||
}
|
||||
|
||||
.ubuntu-mono-regular {
|
||||
font-family: "Ubuntu Mono", serif;
|
||||
font-weight: 400;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
/*! tailwindcss v4.1.18 | MIT License | https://tailwindcss.com */
|
||||
@import url("https://fonts.googleapis.com/css2?family=Ubuntu+Mono:ital,wght@0,400;0,700;1,400;1,700&display=swap");
|
||||
@layer properties;
|
||||
@layer theme, base, components, utilities;
|
||||
@layer theme {
|
||||
@@ -10,7 +11,10 @@
|
||||
--spacing: 0.25rem;
|
||||
--breakpoint-xl: 80rem;
|
||||
--container-md: 28rem;
|
||||
--container-2xl: 42rem;
|
||||
--container-7xl: 80rem;
|
||||
--text-xs: 0.75rem;
|
||||
--text-xs--line-height: calc(1 / 0.75);
|
||||
--text-sm: 0.875rem;
|
||||
--text-sm--line-height: calc(1.25 / 0.875);
|
||||
--text-lg: 1.125rem;
|
||||
@@ -28,11 +32,13 @@
|
||||
--text-9xl: 8rem;
|
||||
--text-9xl--line-height: 1;
|
||||
--font-weight-medium: 500;
|
||||
--font-weight-semibold: 600;
|
||||
--font-weight-bold: 700;
|
||||
--tracking-tight: -0.025em;
|
||||
--leading-relaxed: 1.625;
|
||||
--radius-sm: 0.25rem;
|
||||
--radius-lg: 0.5rem;
|
||||
--radius-xl: 0.75rem;
|
||||
--default-transition-duration: 150ms;
|
||||
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
--default-font-family: var(--font-sans);
|
||||
@@ -208,6 +214,9 @@
|
||||
.relative {
|
||||
position: relative;
|
||||
}
|
||||
.static {
|
||||
position: static;
|
||||
}
|
||||
.end-0 {
|
||||
inset-inline-end: calc(var(--spacing) * 0);
|
||||
}
|
||||
@@ -244,9 +253,18 @@
|
||||
.mt-4 {
|
||||
margin-top: calc(var(--spacing) * 4);
|
||||
}
|
||||
.mt-5 {
|
||||
margin-top: calc(var(--spacing) * 5);
|
||||
}
|
||||
.mt-6 {
|
||||
margin-top: calc(var(--spacing) * 6);
|
||||
}
|
||||
.mt-7 {
|
||||
margin-top: calc(var(--spacing) * 7);
|
||||
}
|
||||
.mt-8 {
|
||||
margin-top: calc(var(--spacing) * 8);
|
||||
}
|
||||
.mt-10 {
|
||||
margin-top: calc(var(--spacing) * 10);
|
||||
}
|
||||
@@ -265,6 +283,9 @@
|
||||
.mb-auto {
|
||||
margin-bottom: auto;
|
||||
}
|
||||
.ml-2 {
|
||||
margin-left: calc(var(--spacing) * 2);
|
||||
}
|
||||
.ml-auto {
|
||||
margin-left: auto;
|
||||
}
|
||||
@@ -321,9 +342,15 @@
|
||||
.w-full {
|
||||
width: 100%;
|
||||
}
|
||||
.max-w-2xl {
|
||||
max-width: var(--container-2xl);
|
||||
}
|
||||
.max-w-7xl {
|
||||
max-width: var(--container-7xl);
|
||||
}
|
||||
.max-w-100 {
|
||||
max-width: calc(var(--spacing) * 100);
|
||||
}
|
||||
.max-w-md {
|
||||
max-width: var(--container-md);
|
||||
}
|
||||
@@ -344,6 +371,9 @@
|
||||
.transform {
|
||||
transform: var(--tw-rotate-x,) var(--tw-rotate-y,) var(--tw-rotate-z,) var(--tw-skew-x,) var(--tw-skew-y,);
|
||||
}
|
||||
.cursor-pointer {
|
||||
cursor: pointer;
|
||||
}
|
||||
.flex-col {
|
||||
flex-direction: column;
|
||||
}
|
||||
@@ -381,6 +411,12 @@
|
||||
margin-block-end: calc(calc(var(--spacing) * 1) * calc(1 - var(--tw-space-y-reverse)));
|
||||
}
|
||||
}
|
||||
.gap-x-2 {
|
||||
column-gap: calc(var(--spacing) * 2);
|
||||
}
|
||||
.gap-y-4 {
|
||||
row-gap: calc(var(--spacing) * 4);
|
||||
}
|
||||
.divide-y {
|
||||
:where(& > :not(:last-child)) {
|
||||
--tw-divide-y-reverse: 0;
|
||||
@@ -398,9 +434,15 @@
|
||||
.overflow-hidden {
|
||||
overflow: hidden;
|
||||
}
|
||||
.overflow-x-auto {
|
||||
overflow-x: auto;
|
||||
}
|
||||
.overflow-x-hidden {
|
||||
overflow-x: hidden;
|
||||
}
|
||||
.rounded {
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
.rounded-full {
|
||||
border-radius: calc(infinity * 1px);
|
||||
}
|
||||
@@ -410,19 +452,35 @@
|
||||
.rounded-sm {
|
||||
border-radius: var(--radius-sm);
|
||||
}
|
||||
.rounded-xl {
|
||||
border-radius: var(--radius-xl);
|
||||
}
|
||||
.border {
|
||||
border-style: var(--tw-border-style);
|
||||
border-width: 1px;
|
||||
}
|
||||
.border-2 {
|
||||
border-style: var(--tw-border-style);
|
||||
border-width: 2px;
|
||||
}
|
||||
.border-green {
|
||||
border-color: var(--green);
|
||||
}
|
||||
.border-overlay0 {
|
||||
border-color: var(--overlay0);
|
||||
}
|
||||
.border-red {
|
||||
border-color: var(--red);
|
||||
}
|
||||
.border-surface1 {
|
||||
border-color: var(--surface1);
|
||||
}
|
||||
.border-transparent {
|
||||
border-color: transparent;
|
||||
}
|
||||
.bg-base {
|
||||
background-color: var(--base);
|
||||
}
|
||||
.bg-blue {
|
||||
background-color: var(--blue);
|
||||
}
|
||||
.bg-crust {
|
||||
background-color: var(--crust);
|
||||
}
|
||||
@@ -456,12 +514,21 @@
|
||||
.p-4 {
|
||||
padding: calc(var(--spacing) * 4);
|
||||
}
|
||||
.px-2 {
|
||||
padding-inline: calc(var(--spacing) * 2);
|
||||
}
|
||||
.px-3 {
|
||||
padding-inline: calc(var(--spacing) * 3);
|
||||
}
|
||||
.px-4 {
|
||||
padding-inline: calc(var(--spacing) * 4);
|
||||
}
|
||||
.px-5 {
|
||||
padding-inline: calc(var(--spacing) * 5);
|
||||
}
|
||||
.py-1 {
|
||||
padding-block: calc(var(--spacing) * 1);
|
||||
}
|
||||
.py-2 {
|
||||
padding-block: calc(var(--spacing) * 2);
|
||||
}
|
||||
@@ -480,6 +547,15 @@
|
||||
.text-center {
|
||||
text-align: center;
|
||||
}
|
||||
.text-left {
|
||||
text-align: left;
|
||||
}
|
||||
.text-right {
|
||||
text-align: right;
|
||||
}
|
||||
.font-mono {
|
||||
font-family: var(--font-mono);
|
||||
}
|
||||
.text-2xl {
|
||||
font-size: var(--text-2xl);
|
||||
line-height: var(--tw-leading, var(--text-2xl--line-height));
|
||||
@@ -508,6 +584,10 @@
|
||||
font-size: var(--text-xl);
|
||||
line-height: var(--tw-leading, var(--text-xl--line-height));
|
||||
}
|
||||
.text-xs {
|
||||
font-size: var(--text-xs);
|
||||
line-height: var(--tw-leading, var(--text-xs--line-height));
|
||||
}
|
||||
.leading-relaxed {
|
||||
--tw-leading: var(--leading-relaxed);
|
||||
line-height: var(--leading-relaxed);
|
||||
@@ -520,10 +600,20 @@
|
||||
--tw-font-weight: var(--font-weight-medium);
|
||||
font-weight: var(--font-weight-medium);
|
||||
}
|
||||
.font-semibold {
|
||||
--tw-font-weight: var(--font-weight-semibold);
|
||||
font-weight: var(--font-weight-semibold);
|
||||
}
|
||||
.tracking-tight {
|
||||
--tw-tracking: var(--tracking-tight);
|
||||
letter-spacing: var(--tracking-tight);
|
||||
}
|
||||
.break-all {
|
||||
word-break: break-all;
|
||||
}
|
||||
.whitespace-pre-wrap {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
.text-crust {
|
||||
color: var(--crust);
|
||||
}
|
||||
@@ -568,6 +658,14 @@
|
||||
--tw-duration: 200ms;
|
||||
transition-duration: 200ms;
|
||||
}
|
||||
.outline-none {
|
||||
--tw-outline-style: none;
|
||||
outline-style: none;
|
||||
}
|
||||
.select-none {
|
||||
-webkit-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
.hover\:cursor-pointer {
|
||||
&:hover {
|
||||
@media (hover: hover) {
|
||||
@@ -575,16 +673,6 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
.hover\:bg-blue\/75 {
|
||||
&:hover {
|
||||
@media (hover: hover) {
|
||||
background-color: var(--blue);
|
||||
@supports (color: color-mix(in lab, red, red)) {
|
||||
background-color: color-mix(in oklab, var(--blue) 75%, transparent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.hover\:bg-crust {
|
||||
&:hover {
|
||||
@media (hover: hover) {
|
||||
@@ -673,6 +761,51 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
.hover\:text-text {
|
||||
&:hover {
|
||||
@media (hover: hover) {
|
||||
color: var(--text);
|
||||
}
|
||||
}
|
||||
}
|
||||
.focus\:border-blue {
|
||||
&:focus {
|
||||
border-color: var(--blue);
|
||||
}
|
||||
}
|
||||
.focus\:border-green {
|
||||
&:focus {
|
||||
border-color: var(--green);
|
||||
}
|
||||
}
|
||||
.focus\:border-red {
|
||||
&:focus {
|
||||
border-color: var(--red);
|
||||
}
|
||||
}
|
||||
.disabled\:pointer-events-none {
|
||||
&:disabled {
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
.disabled\:cursor-default {
|
||||
&:disabled {
|
||||
cursor: default;
|
||||
}
|
||||
}
|
||||
.disabled\:bg-green\/60 {
|
||||
&:disabled {
|
||||
background-color: var(--green);
|
||||
@supports (color: color-mix(in lab, red, red)) {
|
||||
background-color: color-mix(in oklab, var(--green) 60%, transparent);
|
||||
}
|
||||
}
|
||||
}
|
||||
.disabled\:opacity-50 {
|
||||
&:disabled {
|
||||
opacity: 50%;
|
||||
}
|
||||
}
|
||||
.sm\:end-6 {
|
||||
@media (width >= 40rem) {
|
||||
inset-inline-end: calc(var(--spacing) * 6);
|
||||
@@ -693,11 +826,6 @@
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
.sm\:inline {
|
||||
@media (width >= 40rem) {
|
||||
display: inline;
|
||||
}
|
||||
}
|
||||
.sm\:justify-between {
|
||||
@media (width >= 40rem) {
|
||||
justify-content: space-between;
|
||||
@@ -708,6 +836,11 @@
|
||||
gap: calc(var(--spacing) * 2);
|
||||
}
|
||||
}
|
||||
.sm\:p-7 {
|
||||
@media (width >= 40rem) {
|
||||
padding: calc(var(--spacing) * 7);
|
||||
}
|
||||
}
|
||||
.sm\:px-6 {
|
||||
@media (width >= 40rem) {
|
||||
padding-inline: calc(var(--spacing) * 6);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// This function prevents the 'flash of unstyled content'
|
||||
// Include it at the top of <head>
|
||||
(function() {
|
||||
let theme = localStorage.getItem("theme") || "system";
|
||||
if (theme === "system") {
|
||||
|
||||
23
pkg/oauth/config.go
Normal file
23
pkg/oauth/config.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"git.haelnorr.com/h/golib/env"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
PrivateKey string // ENV OAUTH_PRIVATE_KEY: Private key for signing OAuth state tokens (required)
|
||||
}
|
||||
|
||||
func ConfigFromEnv() (any, error) {
|
||||
cfg := &Config{
|
||||
PrivateKey: env.String("OAUTH_PRIVATE_KEY", ""),
|
||||
}
|
||||
|
||||
// Check required fields
|
||||
if cfg.PrivateKey == "" {
|
||||
return nil, errors.New("Envar not set: OAUTH_PRIVATE_KEY")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
45
pkg/oauth/cookies.go
Normal file
45
pkg/oauth/cookies.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func SetStateCookie(w http.ResponseWriter, uak []byte, ssl bool) {
|
||||
encodedUak := base64.RawURLEncoding.EncodeToString(uak)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_uak",
|
||||
Value: encodedUak,
|
||||
Path: "/",
|
||||
MaxAge: 300,
|
||||
HttpOnly: true,
|
||||
Secure: ssl,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func GetStateCookie(r *http.Request) ([]byte, error) {
|
||||
if r == nil {
|
||||
return nil, errors.New("Request cannot be nil")
|
||||
}
|
||||
cookie, err := r.Cookie("oauth_uak")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
uak, err := base64.RawURLEncoding.DecodeString(cookie.Value)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to decode userAgentKey from cookie")
|
||||
}
|
||||
return uak, nil
|
||||
}
|
||||
|
||||
func DeleteStateCookie(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_uak",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
})
|
||||
}
|
||||
41
pkg/oauth/ezconf.go
Normal file
41
pkg/oauth/ezconf.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EZConfIntegration provides integration with ezconf for automatic configuration
|
||||
type EZConfIntegration struct {
|
||||
configFunc func() (any, error)
|
||||
name string
|
||||
}
|
||||
|
||||
// PackagePath returns the path to the config package for source parsing
|
||||
func (e EZConfIntegration) PackagePath() string {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
// Return directory of this file
|
||||
return filename[:len(filename)-len("/ezconf.go")]
|
||||
}
|
||||
|
||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||
return func() (any, error) {
|
||||
return e.configFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name to use when registering with ezconf
|
||||
func (e EZConfIntegration) Name() string {
|
||||
return strings.ToLower(e.name)
|
||||
}
|
||||
|
||||
// GroupName returns the display name for grouping environment variables
|
||||
func (e EZConfIntegration) GroupName() string {
|
||||
return e.name
|
||||
}
|
||||
|
||||
// NewEZConfIntegration creates a new EZConf integration helper
|
||||
func NewEZConfIntegration() EZConfIntegration {
|
||||
return EZConfIntegration{name: "OAuth", configFunc: ConfigFromEnv}
|
||||
}
|
||||
117
pkg/oauth/state.go
Normal file
117
pkg/oauth/state.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// STATE FLOW:
|
||||
// data provided at call time to be retrieved later
|
||||
// random value generated on the spot
|
||||
// userAgentKey - nonce used to prevent MITM, stored as lax cookie on client
|
||||
// privateKey - from config
|
||||
|
||||
func GenerateState(cfg *Config, data string) (state string, userAgentKey []byte, err error) {
|
||||
// signature = BASE64_SHA256(data + "." + random + userAgentKey + privateKey)
|
||||
// state = data + "." + random + "." + signature
|
||||
if cfg == nil {
|
||||
return "", nil, errors.New("cfg cannot be nil")
|
||||
}
|
||||
if cfg.PrivateKey == "" {
|
||||
return "", nil, errors.New("private key cannot be empty")
|
||||
}
|
||||
if data == "" {
|
||||
return "", nil, errors.New("data cannot be empty")
|
||||
}
|
||||
|
||||
// Generate 32 random bytes for random component
|
||||
randomBytes := make([]byte, 32)
|
||||
_, err = rand.Read(randomBytes)
|
||||
if err != nil {
|
||||
return "", nil, errors.Wrap(err, "failed to generate random bytes")
|
||||
}
|
||||
|
||||
// Generate 32 random bytes for userAgentKey
|
||||
userAgentKey = make([]byte, 32)
|
||||
_, err = rand.Read(userAgentKey)
|
||||
if err != nil {
|
||||
return "", nil, errors.Wrap(err, "failed to generate userAgentKey bytes")
|
||||
}
|
||||
|
||||
// Encode random and userAgentKey to base64
|
||||
randomEncoded := base64.RawURLEncoding.EncodeToString(randomBytes)
|
||||
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
|
||||
|
||||
// Create payload for signing: data + "." + random + userAgentKey + privateKey
|
||||
// Note: userAgentKey is concatenated directly with privateKey (no separator)
|
||||
payload := data + "." + randomEncoded + userAgentKeyEncoded + cfg.PrivateKey
|
||||
|
||||
// Generate signature
|
||||
hash := sha256.Sum256([]byte(payload))
|
||||
signature := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
// Construct state: data + "." + random + "." + signature
|
||||
state = data + "." + randomEncoded + "." + signature
|
||||
|
||||
return state, userAgentKey, nil
|
||||
}
|
||||
|
||||
func VerifyState(cfg *Config, state string, userAgentKey []byte) (data string, err error) {
|
||||
// Validate inputs
|
||||
if cfg == nil {
|
||||
return "", errors.New("cfg cannot be nil")
|
||||
}
|
||||
if cfg.PrivateKey == "" {
|
||||
return "", errors.New("private key cannot be empty")
|
||||
}
|
||||
if state == "" {
|
||||
return "", errors.New("state cannot be empty")
|
||||
}
|
||||
if len(userAgentKey) == 0 {
|
||||
return "", errors.New("userAgentKey cannot be empty")
|
||||
}
|
||||
|
||||
// Split state into parts
|
||||
parts := strings.Split(state, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", errors.Errorf("state must have exactly 3 parts (data.random.signature), got %d parts", len(parts))
|
||||
}
|
||||
|
||||
// Check for empty parts
|
||||
if slices.Contains(parts, "") {
|
||||
return "", errors.New("state parts cannot be empty")
|
||||
}
|
||||
|
||||
data = parts[0]
|
||||
random := parts[1]
|
||||
receivedSignature := parts[2]
|
||||
|
||||
// Encode userAgentKey to base64 for payload reconstruction
|
||||
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
|
||||
|
||||
// Reconstruct payload (same as generation): data + "." + random + userAgentKeyEncoded + privateKey
|
||||
payload := data + "." + random + userAgentKeyEncoded + cfg.PrivateKey
|
||||
|
||||
// Generate expected hash
|
||||
hash := sha256.Sum256([]byte(payload))
|
||||
|
||||
// Decode received signature to bytes
|
||||
receivedBytes, err := base64.RawURLEncoding.DecodeString(receivedSignature)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to decode received signature")
|
||||
}
|
||||
|
||||
// Compare hash bytes directly with decoded signature using constant-time comparison
|
||||
// This is more efficient than encoding hash and then decoding both for comparison
|
||||
if subtle.ConstantTimeCompare(hash[:], receivedBytes) == 1 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return "", errors.New("invalid state signature")
|
||||
}
|
||||
817
pkg/oauth/state_test.go
Normal file
817
pkg/oauth/state_test.go
Normal file
@@ -0,0 +1,817 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Helper function to create a test config
|
||||
func testConfig() *Config {
|
||||
return &Config{
|
||||
PrivateKey: "test_private_key_for_testing_12345",
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateState_Success tests the happy path of state generation
|
||||
func TestGenerateState_Success(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
data := "test_data_payload"
|
||||
|
||||
state, userAgentKey, err := GenerateState(cfg, data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
t.Error("Expected non-empty state")
|
||||
}
|
||||
|
||||
if len(userAgentKey) != 32 {
|
||||
t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey))
|
||||
}
|
||||
|
||||
// Verify state format: data.random.signature
|
||||
parts := strings.Split(state, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Expected state to have 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Verify data is preserved
|
||||
if parts[0] != data {
|
||||
t.Errorf("Expected data to be '%s', got '%s'", data, parts[0])
|
||||
}
|
||||
|
||||
// Verify random part is base64 encoded
|
||||
randomBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
t.Errorf("Expected random part to be valid base64: %v", err)
|
||||
}
|
||||
if len(randomBytes) != 32 {
|
||||
t.Errorf("Expected random to be 32 bytes when decoded, got %d", len(randomBytes))
|
||||
}
|
||||
|
||||
// Verify signature part is base64 encoded
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
t.Errorf("Expected signature part to be valid base64: %v", err)
|
||||
}
|
||||
if len(sigBytes) != 32 {
|
||||
t.Errorf("Expected signature to be 32 bytes (SHA256), got %d", len(sigBytes))
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateState_NilConfig tests that nil config returns error
|
||||
func TestGenerateState_NilConfig(t *testing.T) {
|
||||
_, _, err := GenerateState(nil, "test_data")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nil config, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "cfg cannot be nil") {
|
||||
t.Errorf("Expected error message about nil config, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateState_EmptyPrivateKey tests that empty private key returns error
|
||||
func TestGenerateState_EmptyPrivateKey(t *testing.T) {
|
||||
cfg := &Config{PrivateKey: ""}
|
||||
_, _, err := GenerateState(cfg, "test_data")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty private key, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "private key cannot be empty") {
|
||||
t.Errorf("Expected error message about empty private key, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateState_EmptyData tests that empty data returns error
|
||||
func TestGenerateState_EmptyData(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
_, _, err := GenerateState(cfg, "")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty data, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "data cannot be empty") {
|
||||
t.Errorf("Expected error message about empty data, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateState_Randomness tests that multiple calls generate different states
|
||||
func TestGenerateState_Randomness(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
data := "same_data"
|
||||
|
||||
state1, _, err1 := GenerateState(cfg, data)
|
||||
state2, _, err2 := GenerateState(cfg, data)
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
t.Fatalf("Unexpected errors: %v, %v", err1, err2)
|
||||
}
|
||||
|
||||
if state1 == state2 {
|
||||
t.Error("Expected different states for multiple calls, got identical states")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateState_DifferentData tests states with different data payloads
|
||||
func TestGenerateState_DifferentData(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
testCases := []string{
|
||||
"simple",
|
||||
"with-dashes",
|
||||
"with_underscores",
|
||||
"123456789",
|
||||
"MixedCase123",
|
||||
}
|
||||
|
||||
for _, data := range testCases {
|
||||
t.Run(data, func(t *testing.T) {
|
||||
state, userAgentKey, err := GenerateState(cfg, data)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for data '%s': %v", data, err)
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(state, data+".") {
|
||||
t.Errorf("Expected state to start with '%s.', got: %s", data, state)
|
||||
}
|
||||
|
||||
if len(userAgentKey) != 32 {
|
||||
t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_Success tests the happy path of state verification
|
||||
func TestVerifyState_Success(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
data := "test_data"
|
||||
|
||||
// Generate state
|
||||
state, userAgentKey, err := GenerateState(cfg, data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Verify state
|
||||
extractedData, err := VerifyState(cfg, state, userAgentKey)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if extractedData != data {
|
||||
t.Errorf("Expected extracted data to be '%s', got '%s'", data, extractedData)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_NilConfig tests that nil config returns error
|
||||
func TestVerifyState_NilConfig(t *testing.T) {
|
||||
_, err := VerifyState(nil, "state", []byte("key"))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nil config, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "cfg cannot be nil") {
|
||||
t.Errorf("Expected error message about nil config, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_EmptyPrivateKey tests that empty private key returns error
|
||||
func TestVerifyState_EmptyPrivateKey(t *testing.T) {
|
||||
cfg := &Config{PrivateKey: ""}
|
||||
_, err := VerifyState(cfg, "state", []byte("key"))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty private key, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "private key cannot be empty") {
|
||||
t.Errorf("Expected error message about empty private key, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_EmptyState tests that empty state returns error
|
||||
func TestVerifyState_EmptyState(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
_, err := VerifyState(cfg, "", []byte("key"))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty state, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "state cannot be empty") {
|
||||
t.Errorf("Expected error message about empty state, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_EmptyUserAgentKey tests that empty userAgentKey returns error
|
||||
func TestVerifyState_EmptyUserAgentKey(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
_, err := VerifyState(cfg, "data.random.signature", []byte{})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty userAgentKey, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "userAgentKey cannot be empty") {
|
||||
t.Errorf("Expected error message about empty userAgentKey, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_WrongUserAgentKey tests MITM protection
|
||||
func TestVerifyState_WrongUserAgentKey(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Generate first state
|
||||
state, _, err := GenerateState(cfg, "test_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Generate a different userAgentKey
|
||||
_, wrongKey, err := GenerateState(cfg, "other_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate second state: %v", err)
|
||||
}
|
||||
|
||||
// Try to verify with wrong key
|
||||
_, err = VerifyState(cfg, state, wrongKey)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid signature")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid state signature") {
|
||||
t.Errorf("Expected error about invalid signature, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_TamperedData tests tampering detection
|
||||
func TestVerifyState_TamperedData(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Generate state
|
||||
state, userAgentKey, err := GenerateState(cfg, "original_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Tamper with the data portion
|
||||
parts := strings.Split(state, ".")
|
||||
parts[0] = "tampered_data"
|
||||
tamperedState := strings.Join(parts, ".")
|
||||
|
||||
// Try to verify tampered state
|
||||
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for tampered state")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_TamperedRandom tests tampering with random portion
|
||||
func TestVerifyState_TamperedRandom(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Generate state
|
||||
state, userAgentKey, err := GenerateState(cfg, "test_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Tamper with the random portion
|
||||
parts := strings.Split(state, ".")
|
||||
parts[1] = base64.RawURLEncoding.EncodeToString([]byte("tampered_random_value_here12"))
|
||||
tamperedState := strings.Join(parts, ".")
|
||||
|
||||
// Try to verify tampered state
|
||||
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for tampered state")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_TamperedSignature tests tampering with signature
|
||||
func TestVerifyState_TamperedSignature(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Generate state
|
||||
state, userAgentKey, err := GenerateState(cfg, "test_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Tamper with the signature portion
|
||||
parts := strings.Split(state, ".")
|
||||
// Create a different valid base64 string
|
||||
parts[2] = base64.RawURLEncoding.EncodeToString(sha256.New().Sum([]byte("tampered")))
|
||||
tamperedState := strings.Join(parts, ".")
|
||||
|
||||
// Try to verify tampered state
|
||||
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for tampered signature")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_MalformedState_TwoParts tests state with only 2 parts
|
||||
func TestVerifyState_MalformedState_TwoParts(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
malformedState := "data.random"
|
||||
|
||||
_, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890"))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for malformed state")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "must have exactly 3 parts") {
|
||||
t.Errorf("Expected error about incorrect number of parts, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_MalformedState_FourParts tests state with 4 parts
|
||||
func TestVerifyState_MalformedState_FourParts(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
malformedState := "data.random.signature.extra"
|
||||
|
||||
_, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890"))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for malformed state")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "must have exactly 3 parts") {
|
||||
t.Errorf("Expected error about incorrect number of parts, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_EmptyStateParts tests state with empty parts
|
||||
func TestVerifyState_EmptyStateParts(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
testCases := []struct {
|
||||
name string
|
||||
state string
|
||||
}{
|
||||
{"empty data", ".random.signature"},
|
||||
{"empty random", "data..signature"},
|
||||
{"empty signature", "data.random."},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := VerifyState(cfg, tc.state, []byte("key123456789012345678901234567890"))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for state with empty parts")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "state parts cannot be empty") {
|
||||
t.Errorf("Expected error about empty parts, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_InvalidBase64Signature tests state with invalid base64 in signature
|
||||
func TestVerifyState_InvalidBase64Signature(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
invalidState := "data.random.invalid@base64!"
|
||||
|
||||
_, err := VerifyState(cfg, invalidState, []byte("key123456789012345678901234567890"))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid base64 signature")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "failed to decode") {
|
||||
t.Errorf("Expected error about decoding signature, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyState_DifferentPrivateKey tests that different private keys fail verification
|
||||
func TestVerifyState_DifferentPrivateKey(t *testing.T) {
|
||||
cfg1 := &Config{PrivateKey: "private_key_1"}
|
||||
cfg2 := &Config{PrivateKey: "private_key_2"}
|
||||
|
||||
// Generate with first config
|
||||
state, userAgentKey, err := GenerateState(cfg1, "test_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Try to verify with second config
|
||||
_, err = VerifyState(cfg2, state, userAgentKey)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for mismatched private key")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoundTrip tests complete round trip with various data payloads
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
testCases := []string{
|
||||
"simple",
|
||||
"with-dashes-and-numbers-123",
|
||||
"MixedCaseData",
|
||||
"user_token_abc123",
|
||||
"link_resource_xyz789",
|
||||
}
|
||||
|
||||
for _, data := range testCases {
|
||||
t.Run(data, func(t *testing.T) {
|
||||
// Generate
|
||||
state, userAgentKey, err := GenerateState(cfg, data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Verify
|
||||
extractedData, err := VerifyState(cfg, state, userAgentKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to verify state: %v", err)
|
||||
}
|
||||
|
||||
if extractedData != data {
|
||||
t.Errorf("Expected extracted data '%s', got '%s'", data, extractedData)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentGeneration tests that concurrent state generation works correctly
|
||||
func TestConcurrentGeneration(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
data := "concurrent_test"
|
||||
|
||||
const numGoroutines = 10
|
||||
results := make(chan string, numGoroutines)
|
||||
errors := make(chan error, numGoroutines)
|
||||
|
||||
// Generate states concurrently
|
||||
for range numGoroutines {
|
||||
go func() {
|
||||
state, userAgentKey, err := GenerateState(cfg, data)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Verify immediately
|
||||
_, verifyErr := VerifyState(cfg, state, userAgentKey)
|
||||
if verifyErr != nil {
|
||||
errors <- verifyErr
|
||||
return
|
||||
}
|
||||
|
||||
results <- state
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect results
|
||||
states := make(map[string]bool)
|
||||
for range numGoroutines {
|
||||
select {
|
||||
case state := <-results:
|
||||
if states[state] {
|
||||
t.Errorf("Duplicate state generated: %s", state)
|
||||
}
|
||||
states[state] = true
|
||||
case err := <-errors:
|
||||
t.Errorf("Concurrent generation error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(states) != numGoroutines {
|
||||
t.Errorf("Expected %d unique states, got %d", numGoroutines, len(states))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStateFormatCompatibility ensures state is URL-safe
|
||||
func TestStateFormatCompatibility(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
data := "url_safe_test"
|
||||
|
||||
state, _, err := GenerateState(cfg, data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Check that state doesn't contain characters that need URL encoding
|
||||
unsafeChars := []string{"+", "/", "=", " ", "&", "?", "#"}
|
||||
for _, char := range unsafeChars {
|
||||
if strings.Contains(state, char) {
|
||||
t.Errorf("State contains URL-unsafe character '%s': %s", char, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMITM_AttackerCannotSubstituteState verifies MITM protection actually works
|
||||
// An attacker obtains their own valid state but tries to use it with victim's session
|
||||
func TestMITM_AttackerCannotSubstituteState(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Victim generates a state for their login
|
||||
victimState, victimKey, err := GenerateState(cfg, "victim_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate victim state: %v", err)
|
||||
}
|
||||
|
||||
// Attacker generates their own valid state (they can request this from the server)
|
||||
attackerState, attackerKey, err := GenerateState(cfg, "attacker_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate attacker state: %v", err)
|
||||
}
|
||||
|
||||
// Both states should be valid on their own
|
||||
_, err = VerifyState(cfg, victimState, victimKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Victim state should be valid: err=%v", err)
|
||||
}
|
||||
|
||||
_, err = VerifyState(cfg, attackerState, attackerKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Attacker state should be valid: err=%v", err)
|
||||
}
|
||||
|
||||
// MITM Attack Scenario 1: Attacker substitutes their state but victim has their cookie
|
||||
// This should FAIL because attackerState was signed with attackerKey, not victimKey
|
||||
_, err = VerifyState(cfg, attackerState, victimKey)
|
||||
if err == nil {
|
||||
t.Error("Expected error when attacker substitutes state")
|
||||
}
|
||||
|
||||
// MITM Attack Scenario 2: Attacker uses victim's state but has their own cookie
|
||||
// This should also FAIL
|
||||
_, err = VerifyState(cfg, victimState, attackerKey)
|
||||
if err == nil {
|
||||
t.Error("Expected error when attacker uses victim's state")
|
||||
}
|
||||
|
||||
// The key insight: even though both states are "valid", they are bound to their respective cookies
|
||||
// An attacker cannot mix and match states and cookies
|
||||
t.Log("✓ MITM protection verified: States are cryptographically bound to their userAgentKey cookies")
|
||||
}
|
||||
|
||||
// TestCSRF_AttackerCannotForgeState verifies CSRF protection
|
||||
// An attacker tries to forge a state parameter without knowing the private key
|
||||
func TestCSRF_AttackerCannotForgeState(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Attacker doesn't know the private key, but tries to forge a state
|
||||
// They might try to construct: "malicious_data.random.signature"
|
||||
|
||||
// Attempt 1: Use a random signature
|
||||
randomSig := base64.RawURLEncoding.EncodeToString([]byte("random_signature_attempt_12345678"))
|
||||
forgedState1 := "malicious_data.somefakerandom." + randomSig
|
||||
|
||||
// Generate a real userAgentKey (attacker might try to get this)
|
||||
_, realKey, err := GenerateState(cfg, "legitimate_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Try to verify forged state
|
||||
_, err = VerifyState(cfg, forgedState1, realKey)
|
||||
if err == nil {
|
||||
t.Error("CSRF VULNERABILITY: Attacker forged a valid state without private key!")
|
||||
}
|
||||
|
||||
// Attempt 2: Attacker tries to compute signature without private key
|
||||
// They use: SHA256(data + "." + random + userAgentKey) - missing privateKey
|
||||
attackerPayload := "malicious_data.fakerandom" + base64.RawURLEncoding.EncodeToString(realKey)
|
||||
hash := sha256.Sum256([]byte(attackerPayload))
|
||||
attackerSig := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
forgedState2 := "malicious_data.fakerandom." + attackerSig
|
||||
|
||||
_, err = VerifyState(cfg, forgedState2, realKey)
|
||||
if err == nil {
|
||||
t.Error("CSRF VULNERABILITY: Attacker forged valid state without private key!")
|
||||
}
|
||||
|
||||
t.Log("✓ CSRF protection verified: Cannot forge state without private key")
|
||||
}
|
||||
|
||||
// TestTampering_SignatureDetectsAllModifications verifies tamper detection
|
||||
func TestTampering_SignatureDetectsAllModifications(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Generate a valid state
|
||||
originalState, userAgentKey, err := GenerateState(cfg, "original_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Verify original is valid
|
||||
data, err := VerifyState(cfg, originalState, userAgentKey)
|
||||
if err != nil || data != "original_data" {
|
||||
t.Fatalf("Original state should be valid")
|
||||
}
|
||||
|
||||
parts := strings.Split(originalState, ".")
|
||||
|
||||
// Test 1: Attacker modifies data but keeps signature
|
||||
tamperedState := "modified_data." + parts[1] + "." + parts[2]
|
||||
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||
if err == nil {
|
||||
t.Error("TAMPER VULNERABILITY: Modified data not detected!")
|
||||
}
|
||||
|
||||
// Test 2: Attacker modifies random but keeps signature
|
||||
newRandom := base64.RawURLEncoding.EncodeToString([]byte("new_random_value_32bytes_long!!"))
|
||||
tamperedState = parts[0] + "." + newRandom + "." + parts[2]
|
||||
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||
if err == nil {
|
||||
t.Error("TAMPER VULNERABILITY: Modified random not detected!")
|
||||
}
|
||||
|
||||
// Test 3: Attacker tries to recompute signature but doesn't have privateKey
|
||||
// They compute: SHA256(modified_data + "." + random + userAgentKey)
|
||||
attackerPayload := "modified_data." + parts[1] + base64.RawURLEncoding.EncodeToString(userAgentKey)
|
||||
hash := sha256.Sum256([]byte(attackerPayload))
|
||||
attackerSig := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
tamperedState = "modified_data." + parts[1] + "." + attackerSig
|
||||
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||
if err == nil {
|
||||
t.Error("TAMPER VULNERABILITY: Attacker recomputed signature without private key!")
|
||||
}
|
||||
|
||||
// Test 4: Single bit flip in signature
|
||||
sigBytes, _ := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
sigBytes[0] ^= 0x01 // Flip one bit
|
||||
flippedSig := base64.RawURLEncoding.EncodeToString(sigBytes)
|
||||
tamperedState = parts[0] + "." + parts[1] + "." + flippedSig
|
||||
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||
if err == nil {
|
||||
t.Error("TAMPER VULNERABILITY: Single bit flip in signature not detected!")
|
||||
}
|
||||
|
||||
t.Log("✓ Tamper detection verified: All modifications to state are detected")
|
||||
}
|
||||
|
||||
// TestReplay_DifferentSessionsCannotReuseState verifies replay protection
|
||||
func TestReplay_DifferentSessionsCannotReuseState(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Session 1: User initiates OAuth flow
|
||||
state1, key1, err := GenerateState(cfg, "session1_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// State is valid for session 1
|
||||
_, err = VerifyState(cfg, state1, key1)
|
||||
if err != nil {
|
||||
t.Fatalf("State should be valid for session 1")
|
||||
}
|
||||
|
||||
// Session 2: Same user (or attacker) initiates a new OAuth flow
|
||||
state2, key2, err := GenerateState(cfg, "session1_data") // same data
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Replay Attack: Try to use state1 with key2
|
||||
_, err = VerifyState(cfg, state1, key2)
|
||||
if err == nil {
|
||||
t.Error("REPLAY VULNERABILITY: State from session 1 was accepted in session 2!")
|
||||
}
|
||||
|
||||
// Even with same data, each session should have unique state+key binding
|
||||
if state1 == state2 {
|
||||
t.Error("REPLAY VULNERABILITY: Same data produces identical states!")
|
||||
}
|
||||
|
||||
t.Log("✓ Replay protection verified: States are bound to specific session cookies")
|
||||
}
|
||||
|
||||
// TestConstantTimeComparison verifies that signature comparison is timing-safe
|
||||
// This is a behavioral test - we can't easily test timing, but we can verify the function is used
|
||||
func TestConstantTimeComparison_IsUsed(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Generate valid state
|
||||
state, userAgentKey, err := GenerateState(cfg, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Create states with signatures that differ at different positions
|
||||
parts := strings.Split(state, ".")
|
||||
originalSig, _ := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
position int
|
||||
}{
|
||||
{"first byte differs", 0},
|
||||
{"middle byte differs", 16},
|
||||
{"last byte differs", 31},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create signature that differs at specific position
|
||||
tamperedSig := make([]byte, len(originalSig))
|
||||
copy(tamperedSig, originalSig)
|
||||
tamperedSig[tc.position] ^= 0xFF // Flip all bits
|
||||
|
||||
tamperedSigStr := base64.RawURLEncoding.EncodeToString(tamperedSig)
|
||||
tamperedState := parts[0] + "." + parts[1] + "." + tamperedSigStr
|
||||
|
||||
// All should fail verification
|
||||
_, err := VerifyState(cfg, tamperedState, userAgentKey)
|
||||
if err == nil {
|
||||
t.Errorf("Tampered signature at position %d should be invalid", tc.position)
|
||||
}
|
||||
|
||||
// If constant-time comparison is NOT used, early differences would return faster
|
||||
// While we can't easily test timing here, we verify all positions fail equally
|
||||
})
|
||||
}
|
||||
|
||||
t.Log("✓ Constant-time comparison: All signature positions validated equally")
|
||||
t.Log(" Note: crypto/subtle.ConstantTimeCompare is used in implementation")
|
||||
}
|
||||
|
||||
// TestPrivateKey_IsCriticalToSecurity verifies private key is essential
|
||||
func TestPrivateKey_IsCriticalToSecurity(t *testing.T) {
|
||||
cfg1 := &Config{PrivateKey: "secret_key_1"}
|
||||
cfg2 := &Config{PrivateKey: "secret_key_2"}
|
||||
|
||||
// Generate state with key1
|
||||
state, userAgentKey, err := GenerateState(cfg1, "data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state: %v", err)
|
||||
}
|
||||
|
||||
// Should verify with key1
|
||||
_, err = VerifyState(cfg1, state, userAgentKey)
|
||||
if err != nil {
|
||||
t.Fatalf("State should be valid with correct private key")
|
||||
}
|
||||
|
||||
// Should NOT verify with key2 (different private key)
|
||||
_, err = VerifyState(cfg2, state, userAgentKey)
|
||||
if err == nil {
|
||||
t.Error("SECURITY VULNERABILITY: State verified with different private key!")
|
||||
}
|
||||
|
||||
// This proves that the private key is cryptographically involved in the signature
|
||||
t.Log("✓ Private key security verified: Different keys produce incompatible signatures")
|
||||
}
|
||||
|
||||
// TestUserAgentKey_ProperlyIntegratedInSignature verifies userAgentKey is in payload
|
||||
func TestUserAgentKey_ProperlyIntegratedInSignature(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
// Generate two states with same data but different userAgentKeys (implicit)
|
||||
state1, key1, err := GenerateState(cfg, "same_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state1: %v", err)
|
||||
}
|
||||
|
||||
state2, key2, err := GenerateState(cfg, "same_data")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate state2: %v", err)
|
||||
}
|
||||
|
||||
// The states should be different even with same data (different random and keys)
|
||||
if state1 == state2 {
|
||||
t.Error("States should differ due to different random values")
|
||||
}
|
||||
|
||||
// Each state should only verify with its own key
|
||||
_, err1 := VerifyState(cfg, state1, key1)
|
||||
_, err2 := VerifyState(cfg, state2, key2)
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
t.Fatal("States should be valid with their own keys")
|
||||
}
|
||||
|
||||
// Cross-verification should fail
|
||||
_, err1 = VerifyState(cfg, state1, key2)
|
||||
_, err2 = VerifyState(cfg, state2, key1)
|
||||
|
||||
if err1 == nil || err2 == nil {
|
||||
t.Error("SECURITY VULNERABILITY: userAgentKey not properly integrated in signature!")
|
||||
}
|
||||
|
||||
t.Log("✓ UserAgentKey integration verified: Each state bound to its specific key")
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
# Scripts
|
||||
|
||||
## generate-css-sources.sh
|
||||
|
||||
Automatically generates the `pkg/embedfs/files/css/input.css` file with `@source` directives for all `.templ` files in the project.
|
||||
|
||||
### Why is this needed?
|
||||
|
||||
Tailwind CSS v4 requires explicit `@source` directives to know which files to scan for utility classes. Glob patterns like `**/*.templ` don't work in `@source` directives, so each file must be listed individually.
|
||||
|
||||
This script:
|
||||
1. Finds all `.templ` files in the `internal/` directory
|
||||
2. Generates `@source` directives with relative paths from the CSS file location
|
||||
3. Adds your custom theme and utility classes
|
||||
|
||||
### When does it run?
|
||||
|
||||
The script runs automatically as part of:
|
||||
- `make build` - Before building the CSS
|
||||
- `make dev` - Before starting watch mode
|
||||
|
||||
### Manual usage
|
||||
|
||||
If you need to regenerate the sources manually:
|
||||
|
||||
```bash
|
||||
./scripts/generate-css-sources.sh
|
||||
```
|
||||
|
||||
### Adding new template files
|
||||
|
||||
When you add a new `.templ` file, you don't need to do anything special - just run `make build` or `make dev` and the script will automatically pick up the new file.
|
||||
@@ -1,140 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Generate @source directives for all .templ files
|
||||
# Paths are relative to pkg/embedfs/files/css/input.css
|
||||
|
||||
INPUT_CSS="pkg/embedfs/files/css/input.css"
|
||||
|
||||
# Start with the base imports
|
||||
cat > "$INPUT_CSS" <<'CSSHEAD'
|
||||
@import "tailwindcss";
|
||||
|
||||
CSSHEAD
|
||||
|
||||
# Find all .templ files and add @source directives
|
||||
find internal -name "*.templ" -type f | sort | while read -r file; do
|
||||
# Convert to relative path from pkg/embedfs/files/css/
|
||||
rel_path="../../../../$file"
|
||||
echo "@source \"$rel_path\";" >> "$INPUT_CSS"
|
||||
done
|
||||
|
||||
# Add the custom theme and utility classes
|
||||
cat >> "$INPUT_CSS" <<'CSSBODY'
|
||||
|
||||
[x-cloak] {
|
||||
display: none !important;
|
||||
}
|
||||
@theme inline {
|
||||
--color-rosewater: var(--rosewater);
|
||||
--color-flamingo: var(--flamingo);
|
||||
--color-pink: var(--pink);
|
||||
--color-mauve: var(--mauve);
|
||||
--color-red: var(--red);
|
||||
--color-dark-red: var(--dark-red);
|
||||
--color-maroon: var(--maroon);
|
||||
--color-peach: var(--peach);
|
||||
--color-yellow: var(--yellow);
|
||||
--color-green: var(--green);
|
||||
--color-teal: var(--teal);
|
||||
--color-sky: var(--sky);
|
||||
--color-sapphire: var(--sapphire);
|
||||
--color-blue: var(--blue);
|
||||
--color-lavender: var(--lavender);
|
||||
--color-text: var(--text);
|
||||
--color-subtext1: var(--subtext1);
|
||||
--color-subtext0: var(--subtext0);
|
||||
--color-overlay2: var(--overlay2);
|
||||
--color-overlay1: var(--overlay1);
|
||||
--color-overlay0: var(--overlay0);
|
||||
--color-surface2: var(--surface2);
|
||||
--color-surface1: var(--surface1);
|
||||
--color-surface0: var(--surface0);
|
||||
--color-base: var(--base);
|
||||
--color-mantle: var(--mantle);
|
||||
--color-crust: var(--crust);
|
||||
}
|
||||
:root {
|
||||
--rosewater: hsl(11, 59%, 67%);
|
||||
--flamingo: hsl(0, 60%, 67%);
|
||||
--pink: hsl(316, 73%, 69%);
|
||||
--mauve: hsl(266, 85%, 58%);
|
||||
--red: hsl(347, 87%, 44%);
|
||||
--dark-red: hsl(343, 50%, 82%);
|
||||
--maroon: hsl(355, 76%, 59%);
|
||||
--peach: hsl(22, 99%, 52%);
|
||||
--yellow: hsl(35, 77%, 49%);
|
||||
--green: hsl(109, 58%, 40%);
|
||||
--teal: hsl(183, 74%, 35%);
|
||||
--sky: hsl(197, 97%, 46%);
|
||||
--sapphire: hsl(189, 70%, 42%);
|
||||
--blue: hsl(220, 91%, 54%);
|
||||
--lavender: hsl(231, 97%, 72%);
|
||||
--text: hsl(234, 16%, 35%);
|
||||
--subtext1: hsl(233, 13%, 41%);
|
||||
--subtext0: hsl(233, 10%, 47%);
|
||||
--overlay2: hsl(232, 10%, 53%);
|
||||
--overlay1: hsl(231, 10%, 59%);
|
||||
--overlay0: hsl(228, 11%, 65%);
|
||||
--surface2: hsl(227, 12%, 71%);
|
||||
--surface1: hsl(225, 14%, 77%);
|
||||
--surface0: hsl(223, 16%, 83%);
|
||||
--base: hsl(220, 23%, 95%);
|
||||
--mantle: hsl(220, 22%, 92%);
|
||||
--crust: hsl(220, 21%, 89%);
|
||||
}
|
||||
|
||||
.dark {
|
||||
--rosewater: hsl(10, 56%, 91%);
|
||||
--flamingo: hsl(0, 59%, 88%);
|
||||
--pink: hsl(316, 72%, 86%);
|
||||
--mauve: hsl(267, 84%, 81%);
|
||||
--red: hsl(343, 81%, 75%);
|
||||
--dark-red: hsl(316, 19%, 27%);
|
||||
--maroon: hsl(350, 65%, 77%);
|
||||
--peach: hsl(23, 92%, 75%);
|
||||
--yellow: hsl(41, 86%, 83%);
|
||||
--green: hsl(115, 54%, 76%);
|
||||
--teal: hsl(170, 57%, 73%);
|
||||
--sky: hsl(189, 71%, 73%);
|
||||
--sapphire: hsl(199, 76%, 69%);
|
||||
--blue: hsl(217, 92%, 76%);
|
||||
--lavender: hsl(232, 97%, 85%);
|
||||
--text: hsl(226, 64%, 88%);
|
||||
--subtext1: hsl(227, 35%, 80%);
|
||||
--subtext0: hsl(228, 24%, 72%);
|
||||
--overlay2: hsl(228, 17%, 64%);
|
||||
--overlay1: hsl(230, 13%, 55%);
|
||||
--overlay0: hsl(231, 11%, 47%);
|
||||
--surface2: hsl(233, 12%, 39%);
|
||||
--surface1: hsl(234, 13%, 31%);
|
||||
--surface0: hsl(237, 16%, 23%);
|
||||
--base: hsl(240, 21%, 15%);
|
||||
--mantle: hsl(240, 21%, 12%);
|
||||
--crust: hsl(240, 23%, 9%);
|
||||
}
|
||||
.ubuntu-mono-regular {
|
||||
font-family: "Ubuntu Mono", serif;
|
||||
font-weight: 400;
|
||||
font-style: normal;
|
||||
}
|
||||
|
||||
.ubuntu-mono-bold {
|
||||
font-family: "Ubuntu Mono", serif;
|
||||
font-weight: 700;
|
||||
font-style: normal;
|
||||
}
|
||||
|
||||
.ubuntu-mono-regular-italic {
|
||||
font-family: "Ubuntu Mono", serif;
|
||||
font-weight: 400;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.ubuntu-mono-bold-italic {
|
||||
font-family: "Ubuntu Mono", serif;
|
||||
font-weight: 700;
|
||||
font-style: italic;
|
||||
}
|
||||
CSSBODY
|
||||
|
||||
echo "Generated $INPUT_CSS with $(grep -c '@source' "$INPUT_CSS") source files"
|
||||
Reference in New Issue
Block a user