added discord api limiting

This commit is contained in:
2026-01-24 00:58:31 +11:00
parent af6bec983b
commit ff0f61f534
15 changed files with 1363 additions and 141 deletions

View File

@@ -43,7 +43,7 @@ func GetOAuthLink(cfg *Config, state, trustedHost string) (string, error) {
return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil
}
func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClient) (*Token, error) {
if code == "" {
return nil, errors.New("code cannot be empty")
}
@@ -53,6 +53,9 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
if trustedHost == "" {
return nil, errors.New("trustedHost cannot be empty")
}
if apiClient == nil {
return nil, errors.New("apiClient cannot be nil")
}
// Prepare form data
data := url.Values{}
data.Set("grant_type", "authorization_code")
@@ -72,9 +75,8 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
// Set basic auth (client_id and client_secret)
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
// Execute request
client := &http.Client{}
resp, err := client.Do(req)
// Execute request with rate limit handling
resp, err := apiClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to execute request")
}
@@ -96,13 +98,16 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
return &tokenResp, nil
}
func RefreshToken(cfg *Config, token *Token) (*Token, error) {
func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, error) {
if token == nil {
return nil, errors.New("token cannot be nil")
}
if cfg == nil {
return nil, errors.New("config cannot be nil")
}
if apiClient == nil {
return nil, errors.New("apiClient cannot be nil")
}
// Prepare form data
data := url.Values{}
data.Set("grant_type", "refresh_token")
@@ -121,9 +126,8 @@ func RefreshToken(cfg *Config, token *Token) (*Token, error) {
// Set basic auth (client_id and client_secret)
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
// Execute request
client := &http.Client{}
resp, err := client.Do(req)
// Execute request with rate limit handling
resp, err := apiClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to execute request")
}
@@ -145,13 +149,16 @@ func RefreshToken(cfg *Config, token *Token) (*Token, error) {
return &tokenResp, nil
}
func RevokeToken(cfg *Config, token *Token) error {
func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error {
if token == nil {
return errors.New("token cannot be nil")
}
if cfg == nil {
return errors.New("config cannot be nil")
}
if apiClient == nil {
return errors.New("apiClient cannot be nil")
}
// Prepare form data
data := url.Values{}
data.Set("token", token.AccessToken)
@@ -170,9 +177,8 @@ func RevokeToken(cfg *Config, token *Token) error {
// Set basic auth (client_id and client_secret)
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
// Execute request
client := &http.Client{}
resp, err := client.Do(req)
// Execute request with rate limit handling
resp, err := apiClient.Do(req)
if err != nil {
return errors.Wrap(err, "failed to execute request")
}

View File

@@ -0,0 +1,235 @@
package discord
import (
"net"
"net/http"
"strconv"
"sync"
"time"
"git.haelnorr.com/h/golib/hlog"
"github.com/pkg/errors"
)
// RateLimitState tracks rate limit information for a specific bucket
type RateLimitState struct {
Remaining int // Requests remaining in current window
Limit int // Total requests allowed in window
Reset time.Time // When the rate limit resets
Bucket string // Discord's bucket identifier
}
// APIClient is an HTTP client wrapper that handles Discord API rate limits
type APIClient struct {
client *http.Client
logger *hlog.Logger
mu sync.RWMutex
buckets map[string]*RateLimitState
}
// NewRateLimitedClient creates a new Discord API client with rate limit handling
func NewRateLimitedClient(logger *hlog.Logger) *APIClient {
return &APIClient{
client: &http.Client{Timeout: 30 * time.Second},
logger: logger,
buckets: make(map[string]*RateLimitState),
}
}
// Do executes an HTTP request with automatic rate limit handling
// It will wait if rate limits are about to be exceeded and retry once if a 429 is received
func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
if req == nil {
return nil, errors.New("request cannot be nil")
}
// Step 1: Check if we need to wait before making request
bucket := c.getBucketFromRequest(req)
if err := c.waitIfNeeded(bucket); err != nil {
return nil, err
}
// Step 2: Execute request
resp, err := c.client.Do(req)
if err != nil {
// Check if it's a network timeout
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil, errors.Wrap(err, "request timed out")
}
return nil, errors.Wrap(err, "http request failed")
}
// Step 3: Update rate limit state from response headers
c.updateRateLimit(resp.Header)
// Step 4: Handle 429 (rate limited)
if resp.StatusCode == http.StatusTooManyRequests {
resp.Body.Close() // Close original response
retryAfter := c.parseRetryAfter(resp.Header)
// No Retry-After header, can't retry safely
if retryAfter == 0 {
c.logger.Warn().
Str("bucket", bucket).
Str("method", req.Method).
Str("path", req.URL.Path).
Msg("Rate limited but no Retry-After header provided")
return nil, errors.New("discord API rate limited but no Retry-After header provided")
}
// Retry-After exceeds 30 second cap
if retryAfter > 30*time.Second {
c.logger.Warn().
Str("bucket", bucket).
Str("method", req.Method).
Str("path", req.URL.Path).
Dur("retry_after", retryAfter).
Msg("Rate limited with Retry-After exceeding 30s cap, not retrying")
return nil, errors.Errorf(
"discord API rate limited (retry after %s exceeds 30s cap)",
retryAfter,
)
}
// Wait and retry
c.logger.Warn().
Str("bucket", bucket).
Str("method", req.Method).
Str("path", req.URL.Path).
Dur("retry_after", retryAfter).
Msg("Rate limited, waiting before retry")
time.Sleep(retryAfter)
// Retry the request
resp, err = c.client.Do(req)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil, errors.Wrap(err, "retry request timed out")
}
return nil, errors.Wrap(err, "retry request failed")
}
// Update rate limit again after retry
c.updateRateLimit(resp.Header)
// If STILL rate limited after retry, return error
if resp.StatusCode == http.StatusTooManyRequests {
resp.Body.Close()
c.logger.Error().
Str("bucket", bucket).
Str("method", req.Method).
Str("path", req.URL.Path).
Msg("Still rate limited after retry, Discord may be experiencing issues")
return nil, errors.Errorf(
"discord API still rate limited after retry (waited %s), Discord may be experiencing issues",
retryAfter,
)
}
}
return resp, nil
}
// getBucketFromRequest extracts or generates bucket ID from request
// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers
func (c *APIClient) getBucketFromRequest(req *http.Request) string {
return req.Method + ":" + req.URL.Path
}
// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits
func (c *APIClient) waitIfNeeded(bucket string) error {
c.mu.RLock()
state, exists := c.buckets[bucket]
c.mu.RUnlock()
if !exists {
return nil // No state yet, proceed
}
now := time.Now()
// If we have no remaining requests and reset hasn't occurred, wait
if state.Remaining == 0 && now.Before(state.Reset) {
waitDuration := time.Until(state.Reset)
// Add small buffer (100ms) to ensure reset has occurred
waitDuration += 100 * time.Millisecond
if waitDuration > 0 {
c.logger.Debug().
Str("bucket", bucket).
Dur("wait_duration", waitDuration).
Msg("Proactively waiting for rate limit reset")
time.Sleep(waitDuration)
}
}
return nil
}
// updateRateLimit parses response headers and updates bucket state
func (c *APIClient) updateRateLimit(headers http.Header) {
bucket := headers.Get("X-RateLimit-Bucket")
if bucket == "" {
return // No bucket info, can't track
}
// Parse headers
limit := c.parseInt(headers.Get("X-RateLimit-Limit"))
remaining := c.parseInt(headers.Get("X-RateLimit-Remaining"))
resetAfter := c.parseFloat(headers.Get("X-RateLimit-Reset-After"))
state := &RateLimitState{
Bucket: bucket,
Limit: limit,
Remaining: remaining,
Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))),
}
c.mu.Lock()
c.buckets[bucket] = state
c.mu.Unlock()
// Log rate limit state for debugging
c.logger.Debug().
Str("bucket", bucket).
Int("remaining", remaining).
Int("limit", limit).
Dur("reset_in", time.Until(state.Reset)).
Msg("Rate limit state updated")
}
// parseRetryAfter extracts retry delay from Retry-After header
func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration {
retryAfter := headers.Get("Retry-After")
if retryAfter == "" {
return 0
}
// Discord returns seconds as float
seconds := c.parseFloat(retryAfter)
if seconds <= 0 {
return 0
}
return time.Duration(seconds * float64(time.Second))
}
// parseInt parses an integer from a header value, returns 0 on error
func (c *APIClient) parseInt(s string) int {
if s == "" {
return 0
}
i, _ := strconv.Atoi(s)
return i
}
// parseFloat parses a float from a header value, returns 0 on error
func (c *APIClient) parseFloat(s string) float64 {
if s == "" {
return 0
}
f, _ := strconv.ParseFloat(s, 64)
return f
}

View File

@@ -0,0 +1,459 @@
package discord
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"git.haelnorr.com/h/golib/hlog"
)
// testLogger creates a test logger for testing
func testLogger(t *testing.T) *hlog.Logger {
level, _ := hlog.LogLevel("debug")
cfg := &hlog.Config{
LogLevel: level,
LogOutput: "console",
}
logger, err := hlog.NewLogger(cfg, io.Discard)
if err != nil {
t.Fatalf("failed to create test logger: %v", err)
}
return logger
}
func TestNewRateLimitedClient(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
if client == nil {
t.Fatal("NewRateLimitedClient returned nil")
}
if client.client == nil {
t.Error("client.client is nil")
}
if client.logger == nil {
t.Error("client.logger is nil")
}
if client.buckets == nil {
t.Error("client.buckets map is nil")
}
}
func TestAPIClient_Do_Success(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
// Mock server that returns success with rate limit headers
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
w.Header().Set("X-RateLimit-Limit", "5")
w.Header().Set("X-RateLimit-Remaining", "3")
w.Header().Set("X-RateLimit-Reset-After", "2.5")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}))
defer server.Close()
req, err := http.NewRequest("GET", server.URL+"/test", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Do() returned error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
// Check that rate limit state was updated
client.mu.RLock()
state, exists := client.buckets["test-bucket"]
client.mu.RUnlock()
if !exists {
t.Fatal("rate limit state not stored")
}
if state.Remaining != 3 {
t.Errorf("expected remaining=3, got %d", state.Remaining)
}
if state.Limit != 5 {
t.Errorf("expected limit=5, got %d", state.Limit)
}
}
func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
attemptCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount++
if attemptCount == 1 {
// First request: return 429
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
w.Header().Set("Retry-After", "0.1") // 100ms
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limited",
"error_description": "You are being rate limited",
})
return
}
// Second request: success
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
w.Header().Set("X-RateLimit-Limit", "5")
w.Header().Set("X-RateLimit-Remaining", "4")
w.Header().Set("X-RateLimit-Reset-After", "2.5")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}))
defer server.Close()
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
start := time.Now()
resp, err := client.Do(req)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Do() returned error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200 after retry, got %d", resp.StatusCode)
}
if attemptCount != 2 {
t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount)
}
// Should have waited approximately 100ms
if elapsed < 100*time.Millisecond {
t.Errorf("expected delay of ~100ms, but took %v", elapsed)
}
}
func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
attemptCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount++
// Always return 429
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
w.Header().Set("Retry-After", "0.05") // 50ms
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limited",
})
}))
defer server.Close()
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
t.Fatal("Do() should have returned error after failed retry")
}
if !strings.Contains(err.Error(), "still rate limited after retry") {
t.Errorf("expected 'still rate limited after retry' error, got: %v", err)
}
if attemptCount != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount)
}
}
func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limited",
})
}))
defer server.Close()
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
start := time.Now()
resp, err := client.Do(req)
elapsed := time.Since(start)
if err == nil {
resp.Body.Close()
t.Fatal("Do() should have returned error for Retry-After > 30s")
}
if !strings.Contains(err.Error(), "exceeds 30s cap") {
t.Errorf("expected 'exceeds 30s cap' error, got: %v", err)
}
// Should NOT have waited (immediate error)
if elapsed > 1*time.Second {
t.Errorf("should return immediately, but took %v", elapsed)
}
}
func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return 429 but NO Retry-After header
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limited",
})
}))
defer server.Close()
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
t.Fatal("Do() should have returned error when no Retry-After header")
}
if !strings.Contains(err.Error(), "no Retry-After header") {
t.Errorf("expected 'no Retry-After header' error, got: %v", err)
}
}
func TestAPIClient_UpdateRateLimit(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
headers := http.Header{}
headers.Set("X-RateLimit-Bucket", "global")
headers.Set("X-RateLimit-Limit", "10")
headers.Set("X-RateLimit-Remaining", "7")
headers.Set("X-RateLimit-Reset-After", "5.5")
client.updateRateLimit(headers)
client.mu.RLock()
state, exists := client.buckets["global"]
client.mu.RUnlock()
if !exists {
t.Fatal("bucket state not created")
}
if state.Bucket != "global" {
t.Errorf("expected bucket='global', got '%s'", state.Bucket)
}
if state.Limit != 10 {
t.Errorf("expected limit=10, got %d", state.Limit)
}
if state.Remaining != 7 {
t.Errorf("expected remaining=7, got %d", state.Remaining)
}
// Check reset time is approximately 5.5 seconds from now
resetIn := time.Until(state.Reset)
if resetIn < 5*time.Second || resetIn > 6*time.Second {
t.Errorf("expected reset in ~5.5s, got %v", resetIn)
}
}
func TestAPIClient_WaitIfNeeded(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
// Set up a bucket with 0 remaining and reset in future
bucket := "test-bucket"
client.mu.Lock()
client.buckets[bucket] = &RateLimitState{
Bucket: bucket,
Limit: 5,
Remaining: 0,
Reset: time.Now().Add(200 * time.Millisecond),
}
client.mu.Unlock()
start := time.Now()
err := client.waitIfNeeded(bucket)
elapsed := time.Since(start)
if err != nil {
t.Errorf("waitIfNeeded returned error: %v", err)
}
// Should have waited ~200ms + 100ms buffer
if elapsed < 200*time.Millisecond {
t.Errorf("expected wait of ~300ms, but took %v", elapsed)
}
if elapsed > 500*time.Millisecond {
t.Errorf("waited too long: %v", elapsed)
}
}
func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
// Set up a bucket with remaining requests
bucket := "test-bucket"
client.mu.Lock()
client.buckets[bucket] = &RateLimitState{
Bucket: bucket,
Limit: 5,
Remaining: 3,
Reset: time.Now().Add(5 * time.Second),
}
client.mu.Unlock()
start := time.Now()
err := client.waitIfNeeded(bucket)
elapsed := time.Since(start)
if err != nil {
t.Errorf("waitIfNeeded returned error: %v", err)
}
// Should NOT wait (has remaining requests)
if elapsed > 10*time.Millisecond {
t.Errorf("should not wait when remaining > 0, but took %v", elapsed)
}
}
func TestAPIClient_Do_Concurrent(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
requestCount := 0
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
requestCount++
count := requestCount
mu.Unlock()
w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket")
w.Header().Set("X-RateLimit-Limit", "10")
w.Header().Set("X-RateLimit-Remaining", "5")
w.Header().Set("X-RateLimit-Reset-After", "1.0")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))})
}))
defer server.Close()
// Launch 10 concurrent requests
var wg sync.WaitGroup
errors := make(chan error, 10)
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req, err := http.NewRequest("GET", server.URL+"/test", nil)
if err != nil {
errors <- err
return
}
resp, err := client.Do(req)
if err != nil {
errors <- err
return
}
resp.Body.Close()
}()
}
wg.Wait()
close(errors)
// Check for any errors
for err := range errors {
t.Errorf("concurrent request failed: %v", err)
}
// All requests should have completed
mu.Lock()
finalCount := requestCount
mu.Unlock()
if finalCount != 10 {
t.Errorf("expected 10 requests, got %d", finalCount)
}
// Check rate limit state is consistent (no data races)
client.mu.RLock()
state, exists := client.buckets["concurrent-bucket"]
client.mu.RUnlock()
if !exists {
t.Fatal("bucket state not found after concurrent requests")
}
// State should exist and be valid
if state.Limit != 10 {
t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit)
}
}
func TestAPIClient_ParseRetryAfter(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
tests := []struct {
name string
header string
expected time.Duration
}{
{"integer seconds", "2", 2 * time.Second},
{"float seconds", "2.5", 2500 * time.Millisecond},
{"zero", "0", 0},
{"empty", "", 0},
{"invalid", "abc", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
headers := http.Header{}
headers.Set("Retry-After", tt.header)
result := client.parseRetryAfter(headers)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

View File

@@ -6,18 +6,49 @@ import (
"time"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/session"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/pkg/oauth"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *session.Store) http.Handler {
func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *store.Store, discordAPI *discord.APIClient) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// Track callback redirect attempts
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
if exceeded {
// Build detailed error for logging
err := errors.Errorf(
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
)
// Clear the tracking entry
store.ClearRedirectTrack(r, "/callback")
// Show error page
throwError(
server,
w,
r,
http.StatusBadRequest,
"OAuth callback failed: Too many redirect attempts. Please try logging in again.",
err,
"warn",
)
return
}
state := r.URL.Query().Get("state")
code := r.URL.Query().Get("code")
if state == "" && code == "" {
@@ -41,6 +72,10 @@ func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *sessi
}
return
}
// SUCCESS POINT: State verified successfully
// Clear redirect tracking - OAuth callback completed successfully
store.ClearRedirectTrack(r, "/callback")
switch data {
case "login":
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
@@ -51,7 +86,7 @@ func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *sessi
return
}
defer tx.Rollback()
redirect, err := login(ctx, tx, cfg, w, r, code, store)
redirect, err := login(ctx, tx, cfg, w, r, code, store, discordAPI)
if err != nil {
throwInternalServiceError(server, w, r, "OAuth login failed", err)
return
@@ -122,9 +157,10 @@ func login(
w http.ResponseWriter,
r *http.Request,
code string,
store *session.Store,
store *store.Store,
discordAPI *discord.APIClient,
) (func(), error) {
token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost)
token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost, discordAPI)
if err != nil {
return nil, errors.Wrap(err, "discord.AuthorizeWithCode")
}

View File

@@ -4,14 +4,47 @@ import (
"net/http"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/pkg/oauth"
)
func Login(server *hws.Server, cfg *config.Config) http.Handler {
func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// Track login redirect attempts
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
if exceeded {
// Build detailed error for logging
err := errors.Errorf(
"login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
)
// Clear the tracking entry
st.ClearRedirectTrack(r, "/login")
// Show error page
throwError(
server,
w,
r,
http.StatusBadRequest,
"Login failed: Too many redirect attempts. Please clear your browser cookies and try again.",
err,
"warn",
)
return
}
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
if err != nil {
throwInternalServiceError(server, w, r, "Failed to generate state token", err)
@@ -24,6 +57,11 @@ func Login(server *hws.Server, cfg *config.Config) http.Handler {
throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err)
return
}
// SUCCESS POINT: OAuth link generated, redirecting to Discord
// Clear redirect tracking - user successfully initiated OAuth
st.ClearRedirectTrack(r, "/login")
http.Redirect(w, r, link, http.StatusSeeOther)
},
)

View File

@@ -6,31 +6,63 @@ import (
"time"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/session"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/view/page"
"github.com/uptrace/bun"
)
func Register(
server *hws.Server,
conn *bun.DB,
cfg *config.Config,
store *session.Store,
store *store.Store,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
attempts, exceeded, track := store.TrackRedirect(r, "/register", 3)
if exceeded {
err := errors.Errorf(
"registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
cfg.HWSAuth.SSL,
)
store.ClearRedirectTrack(r, "/register")
throwError(
server,
w,
r,
http.StatusBadRequest,
"Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.",
err,
"warn",
)
return
}
sessionCookie, err := r.Cookie("registration_session")
if err != nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
details, ok := store.GetRegistrationSession(sessionCookie.Value)
ok = false
if !ok {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
store.ClearRedirectTrack(r, "/register")
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
@@ -65,12 +97,11 @@ func IsUsernameUnique(
server *hws.Server,
conn *bun.DB,
cfg *config.Config,
store *session.Store,
store *store.Store,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("username")
// check if its unique
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)

View File

@@ -1,46 +0,0 @@
package session
import (
"crypto/rand"
"encoding/base64"
"sync"
"time"
)
type Store struct {
sessions sync.Map
cleanup *time.Ticker
}
func NewStore() *Store {
s := &Store{
cleanup: time.NewTicker(1 * time.Minute),
}
// Background cleanup of expired sessions
go func() {
for range s.cleanup.C {
s.cleanupExpired()
}
}()
return s
}
func (s *Store) Delete(id string) {
s.sessions.Delete(id)
}
func (s *Store) cleanupExpired() {
s.sessions.Range(func(key, value any) bool {
session := value.(*RegistrationSession)
if time.Now().After(session.ExpiresAt) {
s.sessions.Delete(key)
}
return true
})
}
func generateID() string {
b := make([]byte, 32)
rand.Read(b)
return base64.RawURLEncoding.EncodeToString(b)
}

View File

@@ -1,4 +1,4 @@
package session
package store
import (
"errors"

View File

@@ -0,0 +1,95 @@
package store
import (
"net"
"net/http"
"strings"
"time"
)
// getClientIP extracts the client IP address, checking X-Forwarded-For first
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (comma-separated list, first is client)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
ips := strings.Split(xff, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// Fall back to RemoteAddr (format: "IP:port" or "[IPv6]:port")
// Use net.SplitHostPort to properly handle both IPv4 and IPv6
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// If SplitHostPort fails, return as-is (shouldn't happen with valid RemoteAddr)
return r.RemoteAddr
}
return host
}
// TrackRedirect increments the redirect counter for this IP+UA+Path combination
// Returns the current attempt count, whether limit was exceeded, and the track details
func (s *Store) TrackRedirect(r *http.Request, path string, maxAttempts int) (attempts int, exceeded bool, track *RedirectTrack) {
if r == nil {
return 0, false, nil
}
ip := getClientIP(r)
userAgent := r.UserAgent()
key := redirectKey(ip, userAgent, path)
now := time.Now()
expiresAt := now.Add(5 * time.Minute)
// Try to load existing track
val, exists := s.redirectTracks.Load(key)
if exists {
track = val.(*RedirectTrack)
// Check if expired
if now.After(track.ExpiresAt) {
// Expired, start fresh
track = &RedirectTrack{
IP: ip,
UserAgent: userAgent,
Path: path,
Attempts: 1,
FirstSeen: now,
ExpiresAt: expiresAt,
}
s.redirectTracks.Store(key, track)
return 1, false, track
}
// Increment existing
track.Attempts++
track.ExpiresAt = expiresAt // Extend expiry
exceeded = track.Attempts >= maxAttempts
return track.Attempts, exceeded, track
}
// Create new track
track = &RedirectTrack{
IP: ip,
UserAgent: userAgent,
Path: path,
Attempts: 1,
FirstSeen: now,
ExpiresAt: expiresAt,
}
s.redirectTracks.Store(key, track)
return 1, false, track
}
// ClearRedirectTrack removes a redirect tracking entry (called after successful completion)
func (s *Store) ClearRedirectTrack(r *http.Request, path string) {
if r == nil {
return
}
ip := getClientIP(r)
userAgent := r.UserAgent()
key := redirectKey(ip, userAgent, path)
s.redirectTracks.Delete(key)
}

80
internal/store/store.go Normal file
View File

@@ -0,0 +1,80 @@
package store
import (
"crypto/rand"
"encoding/base64"
"fmt"
"sync"
"time"
)
// RedirectTrack represents a single redirect attempt tracking entry
type RedirectTrack struct {
IP string // Client IP (X-Forwarded-For aware)
UserAgent string // Full User-Agent string for debugging
Path string // Request path (without query params)
Attempts int // Number of redirect attempts
FirstSeen time.Time // When first redirect was tracked
ExpiresAt time.Time // When to clean up this entry
}
type Store struct {
sessions sync.Map // key: string, value: *RegistrationSession
redirectTracks sync.Map // key: string, value: *RedirectTrack
cleanup *time.Ticker
}
func NewStore() *Store {
s := &Store{
cleanup: time.NewTicker(1 * time.Minute),
}
// Background cleanup of expired sessions
go func() {
for range s.cleanup.C {
s.cleanupExpired()
}
}()
return s
}
func (s *Store) Delete(id string) {
s.sessions.Delete(id)
}
func (s *Store) cleanupExpired() {
now := time.Now()
// Clean up expired registration sessions
s.sessions.Range(func(key, value any) bool {
session := value.(*RegistrationSession)
if now.After(session.ExpiresAt) {
s.sessions.Delete(key)
}
return true
})
// Clean up expired redirect tracks
s.redirectTracks.Range(func(key, value any) bool {
track := value.(*RedirectTrack)
if now.After(track.ExpiresAt) {
s.redirectTracks.Delete(key)
}
return true
})
}
func generateID() string {
b := make([]byte, 32)
rand.Read(b)
return base64.RawURLEncoding.EncodeToString(b)
}
// redirectKey generates a unique key for tracking redirects
// Uses IP + first 100 chars of UA + path as key (not hashed for debugging)
func redirectKey(ip, userAgent, path string) string {
ua := userAgent
if len(ua) > 100 {
ua = ua[:100]
}
return fmt.Sprintf("%s:%s:%s", ip, ua, path)
}