460 lines
11 KiB
Go
460 lines
11 KiB
Go
package discord
|
|
|
|
import (
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.haelnorr.com/h/golib/hlog"
|
|
)
|
|
|
|
// testLogger creates a test logger for testing
|
|
func testLogger(t *testing.T) *hlog.Logger {
|
|
level, _ := hlog.LogLevel("debug")
|
|
cfg := &hlog.Config{
|
|
LogLevel: level,
|
|
LogOutput: "console",
|
|
}
|
|
logger, err := hlog.NewLogger(cfg, io.Discard)
|
|
if err != nil {
|
|
t.Fatalf("failed to create test logger: %v", err)
|
|
}
|
|
return logger
|
|
}
|
|
|
|
func TestNewRateLimitedClient(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
if client == nil {
|
|
t.Fatal("NewRateLimitedClient returned nil")
|
|
}
|
|
if client.client == nil {
|
|
t.Error("client.client is nil")
|
|
}
|
|
if client.logger == nil {
|
|
t.Error("client.logger is nil")
|
|
}
|
|
if client.buckets == nil {
|
|
t.Error("client.buckets map is nil")
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_Do_Success(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
// Mock server that returns success with rate limit headers
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
|
w.Header().Set("X-RateLimit-Limit", "5")
|
|
w.Header().Set("X-RateLimit-Remaining", "3")
|
|
w.Header().Set("X-RateLimit-Reset-After", "2.5")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
|
}))
|
|
defer server.Close()
|
|
|
|
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create request: %v", err)
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Do() returned error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Check that rate limit state was updated
|
|
client.mu.RLock()
|
|
state, exists := client.buckets["test-bucket"]
|
|
client.mu.RUnlock()
|
|
|
|
if !exists {
|
|
t.Fatal("rate limit state not stored")
|
|
}
|
|
if state.Remaining != 3 {
|
|
t.Errorf("expected remaining=3, got %d", state.Remaining)
|
|
}
|
|
if state.Limit != 5 {
|
|
t.Errorf("expected limit=5, got %d", state.Limit)
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
attemptCount := 0
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
attemptCount++
|
|
if attemptCount == 1 {
|
|
// First request: return 429
|
|
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
|
w.Header().Set("Retry-After", "0.1") // 100ms
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "rate_limited",
|
|
"error_description": "You are being rate limited",
|
|
})
|
|
return
|
|
}
|
|
// Second request: success
|
|
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
|
w.Header().Set("X-RateLimit-Limit", "5")
|
|
w.Header().Set("X-RateLimit-Remaining", "4")
|
|
w.Header().Set("X-RateLimit-Reset-After", "2.5")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
|
}))
|
|
defer server.Close()
|
|
|
|
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create request: %v", err)
|
|
}
|
|
|
|
start := time.Now()
|
|
resp, err := client.Do(req)
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Do() returned error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("expected status 200 after retry, got %d", resp.StatusCode)
|
|
}
|
|
|
|
if attemptCount != 2 {
|
|
t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount)
|
|
}
|
|
|
|
// Should have waited approximately 100ms
|
|
if elapsed < 100*time.Millisecond {
|
|
t.Errorf("expected delay of ~100ms, but took %v", elapsed)
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
attemptCount := 0
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
attemptCount++
|
|
// Always return 429
|
|
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
|
w.Header().Set("Retry-After", "0.05") // 50ms
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "rate_limited",
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create request: %v", err)
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
t.Fatal("Do() should have returned error after failed retry")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), "still rate limited after retry") {
|
|
t.Errorf("expected 'still rate limited after retry' error, got: %v", err)
|
|
}
|
|
|
|
if attemptCount != 2 {
|
|
t.Errorf("expected 2 attempts, got %d", attemptCount)
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "rate_limited",
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create request: %v", err)
|
|
}
|
|
|
|
start := time.Now()
|
|
resp, err := client.Do(req)
|
|
elapsed := time.Since(start)
|
|
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
t.Fatal("Do() should have returned error for Retry-After > 30s")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), "exceeds 30s cap") {
|
|
t.Errorf("expected 'exceeds 30s cap' error, got: %v", err)
|
|
}
|
|
|
|
// Should NOT have waited (immediate error)
|
|
if elapsed > 1*time.Second {
|
|
t.Errorf("should return immediately, but took %v", elapsed)
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Return 429 but NO Retry-After header
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "rate_limited",
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create request: %v", err)
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
t.Fatal("Do() should have returned error when no Retry-After header")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), "no Retry-After header") {
|
|
t.Errorf("expected 'no Retry-After header' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_UpdateRateLimit(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
headers := http.Header{}
|
|
headers.Set("X-RateLimit-Bucket", "global")
|
|
headers.Set("X-RateLimit-Limit", "10")
|
|
headers.Set("X-RateLimit-Remaining", "7")
|
|
headers.Set("X-RateLimit-Reset-After", "5.5")
|
|
|
|
client.updateRateLimit(headers)
|
|
|
|
client.mu.RLock()
|
|
state, exists := client.buckets["global"]
|
|
client.mu.RUnlock()
|
|
|
|
if !exists {
|
|
t.Fatal("bucket state not created")
|
|
}
|
|
|
|
if state.Bucket != "global" {
|
|
t.Errorf("expected bucket='global', got '%s'", state.Bucket)
|
|
}
|
|
if state.Limit != 10 {
|
|
t.Errorf("expected limit=10, got %d", state.Limit)
|
|
}
|
|
if state.Remaining != 7 {
|
|
t.Errorf("expected remaining=7, got %d", state.Remaining)
|
|
}
|
|
|
|
// Check reset time is approximately 5.5 seconds from now
|
|
resetIn := time.Until(state.Reset)
|
|
if resetIn < 5*time.Second || resetIn > 6*time.Second {
|
|
t.Errorf("expected reset in ~5.5s, got %v", resetIn)
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_WaitIfNeeded(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
// Set up a bucket with 0 remaining and reset in future
|
|
bucket := "test-bucket"
|
|
client.mu.Lock()
|
|
client.buckets[bucket] = &RateLimitState{
|
|
Bucket: bucket,
|
|
Limit: 5,
|
|
Remaining: 0,
|
|
Reset: time.Now().Add(200 * time.Millisecond),
|
|
}
|
|
client.mu.Unlock()
|
|
|
|
start := time.Now()
|
|
err := client.waitIfNeeded(bucket)
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Errorf("waitIfNeeded returned error: %v", err)
|
|
}
|
|
|
|
// Should have waited ~200ms + 100ms buffer
|
|
if elapsed < 200*time.Millisecond {
|
|
t.Errorf("expected wait of ~300ms, but took %v", elapsed)
|
|
}
|
|
if elapsed > 500*time.Millisecond {
|
|
t.Errorf("waited too long: %v", elapsed)
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
// Set up a bucket with remaining requests
|
|
bucket := "test-bucket"
|
|
client.mu.Lock()
|
|
client.buckets[bucket] = &RateLimitState{
|
|
Bucket: bucket,
|
|
Limit: 5,
|
|
Remaining: 3,
|
|
Reset: time.Now().Add(5 * time.Second),
|
|
}
|
|
client.mu.Unlock()
|
|
|
|
start := time.Now()
|
|
err := client.waitIfNeeded(bucket)
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
t.Errorf("waitIfNeeded returned error: %v", err)
|
|
}
|
|
|
|
// Should NOT wait (has remaining requests)
|
|
if elapsed > 10*time.Millisecond {
|
|
t.Errorf("should not wait when remaining > 0, but took %v", elapsed)
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_Do_Concurrent(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
requestCount := 0
|
|
var mu sync.Mutex
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
requestCount++
|
|
count := requestCount
|
|
mu.Unlock()
|
|
|
|
w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket")
|
|
w.Header().Set("X-RateLimit-Limit", "10")
|
|
w.Header().Set("X-RateLimit-Remaining", "5")
|
|
w.Header().Set("X-RateLimit-Reset-After", "1.0")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))})
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Launch 10 concurrent requests
|
|
var wg sync.WaitGroup
|
|
errors := make(chan error, 10)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
|
if err != nil {
|
|
errors <- err
|
|
return
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
errors <- err
|
|
return
|
|
}
|
|
resp.Body.Close()
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errors)
|
|
|
|
// Check for any errors
|
|
for err := range errors {
|
|
t.Errorf("concurrent request failed: %v", err)
|
|
}
|
|
|
|
// All requests should have completed
|
|
mu.Lock()
|
|
finalCount := requestCount
|
|
mu.Unlock()
|
|
|
|
if finalCount != 10 {
|
|
t.Errorf("expected 10 requests, got %d", finalCount)
|
|
}
|
|
|
|
// Check rate limit state is consistent (no data races)
|
|
client.mu.RLock()
|
|
state, exists := client.buckets["concurrent-bucket"]
|
|
client.mu.RUnlock()
|
|
|
|
if !exists {
|
|
t.Fatal("bucket state not found after concurrent requests")
|
|
}
|
|
|
|
// State should exist and be valid
|
|
if state.Limit != 10 {
|
|
t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit)
|
|
}
|
|
}
|
|
|
|
func TestAPIClient_ParseRetryAfter(t *testing.T) {
|
|
logger := testLogger(t)
|
|
client := NewRateLimitedClient(logger)
|
|
|
|
tests := []struct {
|
|
name string
|
|
header string
|
|
expected time.Duration
|
|
}{
|
|
{"integer seconds", "2", 2 * time.Second},
|
|
{"float seconds", "2.5", 2500 * time.Millisecond},
|
|
{"zero", "0", 0},
|
|
{"empty", "", 0},
|
|
{"invalid", "abc", 0},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
headers := http.Header{}
|
|
headers.Set("Retry-After", tt.header)
|
|
|
|
result := client.parseRetryAfter(headers)
|
|
|
|
if result != tt.expected {
|
|
t.Errorf("expected %v, got %v", tt.expected, result)
|
|
}
|
|
})
|
|
}
|
|
}
|