package discord import ( "encoding/json" "io" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "git.haelnorr.com/h/golib/hlog" ) // testLogger creates a test logger for testing func testLogger(t *testing.T) *hlog.Logger { level, _ := hlog.LogLevel("debug") cfg := &hlog.Config{ LogLevel: level, LogOutput: "console", } logger, err := hlog.NewLogger(cfg, io.Discard) if err != nil { t.Fatalf("failed to create test logger: %v", err) } return logger } // testConfig creates a test config for testing func testConfig() *Config { return &Config{ ClientID: "test-client-id", ClientSecret: "test-client-secret", OAuthScopes: "identify+email", RedirectPath: "/oauth/callback", } } func TestNewRateLimitedClient(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } if client == nil { t.Fatal("NewAPIClient returned nil") } if client.client == nil { t.Error("client.client is nil") } if client.logger == nil { t.Error("client.logger is nil") } if client.buckets == nil { t.Error("client.buckets map is nil") } if client.cfg == nil { t.Error("client.cfg is nil") } if client.trustedHost != "trusted-host.example.com" { t.Errorf("expected trustedHost='trusted-host.example.com', got '%s'", client.trustedHost) } } func TestAPIClient_Do_Success(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } // Mock server that returns success with rate limit headers server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-RateLimit-Bucket", "test-bucket") w.Header().Set("X-RateLimit-Limit", "5") w.Header().Set("X-RateLimit-Remaining", "3") w.Header().Set("X-RateLimit-Reset-After", "2.5") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) })) defer server.Close() req, err := http.NewRequest("GET", server.URL+"/test", nil) if err != nil { t.Fatalf("failed to create request: %v", err) } resp, err := client.Do(req) if err != nil { t.Fatalf("Do() returned error: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } // Check that rate limit state was updated client.mu.RLock() state, exists := client.buckets["test-bucket"] client.mu.RUnlock() if !exists { t.Fatal("rate limit state not stored") } if state.Remaining != 3 { t.Errorf("expected remaining=3, got %d", state.Remaining) } if state.Limit != 5 { t.Errorf("expected limit=5, got %d", state.Limit) } } func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } attemptCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attemptCount++ if attemptCount == 1 { // First request: return 429 w.Header().Set("X-RateLimit-Bucket", "test-bucket") w.Header().Set("Retry-After", "0.1") // 100ms w.WriteHeader(http.StatusTooManyRequests) json.NewEncoder(w).Encode(map[string]string{ "error": "rate_limited", "error_description": "You are being rate limited", }) return } // Second request: success w.Header().Set("X-RateLimit-Bucket", "test-bucket") w.Header().Set("X-RateLimit-Limit", "5") w.Header().Set("X-RateLimit-Remaining", "4") w.Header().Set("X-RateLimit-Reset-After", "2.5") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) })) defer server.Close() req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) if err != nil { t.Fatalf("failed to create request: %v", err) } start := time.Now() resp, err := client.Do(req) elapsed := time.Since(start) if err != nil { t.Fatalf("Do() returned error: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200 after retry, got %d", resp.StatusCode) } if attemptCount != 2 { t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount) } // Should have waited approximately 100ms if elapsed < 100*time.Millisecond { t.Errorf("expected delay of ~100ms, but took %v", elapsed) } } func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } attemptCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attemptCount++ // Always return 429 w.Header().Set("X-RateLimit-Bucket", "test-bucket") w.Header().Set("Retry-After", "0.05") // 50ms w.WriteHeader(http.StatusTooManyRequests) json.NewEncoder(w).Encode(map[string]string{ "error": "rate_limited", }) })) defer server.Close() req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) if err != nil { t.Fatalf("failed to create request: %v", err) } resp, err := client.Do(req) if err == nil { resp.Body.Close() t.Fatal("Do() should have returned error after failed retry") } if !strings.Contains(err.Error(), "still rate limited after retry") { t.Errorf("expected 'still rate limited after retry' error, got: %v", err) } if attemptCount != 2 { t.Errorf("expected 2 attempts, got %d", attemptCount) } } func TestAPIClient_Do_RateLimitTooLong(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap w.WriteHeader(http.StatusTooManyRequests) json.NewEncoder(w).Encode(map[string]string{ "error": "rate_limited", }) })) defer server.Close() req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) if err != nil { t.Fatalf("failed to create request: %v", err) } start := time.Now() resp, err := client.Do(req) elapsed := time.Since(start) if err == nil { resp.Body.Close() t.Fatal("Do() should have returned error for Retry-After > 30s") } if !strings.Contains(err.Error(), "exceeds 30s cap") { t.Errorf("expected 'exceeds 30s cap' error, got: %v", err) } // Should NOT have waited (immediate error) if elapsed > 1*time.Second { t.Errorf("should return immediately, but took %v", elapsed) } } func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Return 429 but NO Retry-After header w.WriteHeader(http.StatusTooManyRequests) json.NewEncoder(w).Encode(map[string]string{ "error": "rate_limited", }) })) defer server.Close() req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil) if err != nil { t.Fatalf("failed to create request: %v", err) } resp, err := client.Do(req) if err == nil { resp.Body.Close() t.Fatal("Do() should have returned error when no Retry-After header") } if !strings.Contains(err.Error(), "no Retry-After header") { t.Errorf("expected 'no Retry-After header' error, got: %v", err) } } func TestAPIClient_UpdateRateLimit(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } headers := http.Header{} headers.Set("X-RateLimit-Bucket", "global") headers.Set("X-RateLimit-Limit", "10") headers.Set("X-RateLimit-Remaining", "7") headers.Set("X-RateLimit-Reset-After", "5.5") client.updateRateLimit(headers) client.mu.RLock() state, exists := client.buckets["global"] client.mu.RUnlock() if !exists { t.Fatal("bucket state not created") } if state.Bucket != "global" { t.Errorf("expected bucket='global', got '%s'", state.Bucket) } if state.Limit != 10 { t.Errorf("expected limit=10, got %d", state.Limit) } if state.Remaining != 7 { t.Errorf("expected remaining=7, got %d", state.Remaining) } // Check reset time is approximately 5.5 seconds from now resetIn := time.Until(state.Reset) if resetIn < 5*time.Second || resetIn > 6*time.Second { t.Errorf("expected reset in ~5.5s, got %v", resetIn) } } func TestAPIClient_WaitIfNeeded(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } // Set up a bucket with 0 remaining and reset in future bucket := "test-bucket" client.mu.Lock() client.buckets[bucket] = &RateLimitState{ Bucket: bucket, Limit: 5, Remaining: 0, Reset: time.Now().Add(200 * time.Millisecond), } client.mu.Unlock() start := time.Now() err = client.waitIfNeeded(bucket) elapsed := time.Since(start) if err != nil { t.Errorf("waitIfNeeded returned error: %v", err) } // Should have waited ~200ms + 100ms buffer if elapsed < 200*time.Millisecond { t.Errorf("expected wait of ~300ms, but took %v", elapsed) } if elapsed > 500*time.Millisecond { t.Errorf("waited too long: %v", elapsed) } } func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } // Set up a bucket with remaining requests bucket := "test-bucket" client.mu.Lock() client.buckets[bucket] = &RateLimitState{ Bucket: bucket, Limit: 5, Remaining: 3, Reset: time.Now().Add(5 * time.Second), } client.mu.Unlock() start := time.Now() err = client.waitIfNeeded(bucket) elapsed := time.Since(start) if err != nil { t.Errorf("waitIfNeeded returned error: %v", err) } // Should NOT wait (has remaining requests) if elapsed > 10*time.Millisecond { t.Errorf("should not wait when remaining > 0, but took %v", elapsed) } } func TestAPIClient_Do_Concurrent(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } requestCount := 0 var mu sync.Mutex server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() requestCount++ count := requestCount mu.Unlock() w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket") w.Header().Set("X-RateLimit-Limit", "10") w.Header().Set("X-RateLimit-Remaining", "5") w.Header().Set("X-RateLimit-Reset-After", "1.0") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))}) })) defer server.Close() // Launch 10 concurrent requests var wg sync.WaitGroup errors := make(chan error, 10) for range 10 { wg.Go( func() { req, err := http.NewRequest("GET", server.URL+"/test", nil) if err != nil { errors <- err return } resp, err := client.Do(req) if err != nil { errors <- err return } resp.Body.Close() }) } wg.Wait() close(errors) // Check for any errors for err := range errors { t.Errorf("concurrent request failed: %v", err) } // All requests should have completed mu.Lock() finalCount := requestCount mu.Unlock() if finalCount != 10 { t.Errorf("expected 10 requests, got %d", finalCount) } // Check rate limit state is consistent (no data races) client.mu.RLock() state, exists := client.buckets["concurrent-bucket"] client.mu.RUnlock() if !exists { t.Fatal("bucket state not found after concurrent requests") } // State should exist and be valid if state.Limit != 10 { t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit) } } func TestAPIClient_ParseRetryAfter(t *testing.T) { logger := testLogger(t) cfg := testConfig() client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") if err != nil { t.Fatalf("NewAPIClient returned error: %v", err) } tests := []struct { name string header string expected time.Duration }{ {"integer seconds", "2", 2 * time.Second}, {"float seconds", "2.5", 2500 * time.Millisecond}, {"zero", "0", 0}, {"empty", "", 0}, {"invalid", "abc", 0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { headers := http.Header{} headers.Set("Retry-After", tt.header) result := client.parseRetryAfter(headers) if result != tt.expected { t.Errorf("expected %v, got %v", tt.expected, result) } }) } }