added discord api limiting
This commit is contained in:
@@ -43,7 +43,7 @@ func GetOAuthLink(cfg *Config, state, trustedHost string) (string, error) {
|
||||
return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil
|
||||
}
|
||||
|
||||
func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
|
||||
func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClient) (*Token, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("code cannot be empty")
|
||||
}
|
||||
@@ -53,6 +53,9 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
|
||||
if trustedHost == "" {
|
||||
return nil, errors.New("trustedHost cannot be empty")
|
||||
}
|
||||
if apiClient == nil {
|
||||
return nil, errors.New("apiClient cannot be nil")
|
||||
}
|
||||
// Prepare form data
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
@@ -72,9 +75,8 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
|
||||
|
||||
// Set basic auth (client_id and client_secret)
|
||||
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||
// Execute request
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
// Execute request with rate limit handling
|
||||
resp, err := apiClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute request")
|
||||
}
|
||||
@@ -96,13 +98,16 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) {
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func RefreshToken(cfg *Config, token *Token) (*Token, error) {
|
||||
func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, error) {
|
||||
if token == nil {
|
||||
return nil, errors.New("token cannot be nil")
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config cannot be nil")
|
||||
}
|
||||
if apiClient == nil {
|
||||
return nil, errors.New("apiClient cannot be nil")
|
||||
}
|
||||
// Prepare form data
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "refresh_token")
|
||||
@@ -121,9 +126,8 @@ func RefreshToken(cfg *Config, token *Token) (*Token, error) {
|
||||
|
||||
// Set basic auth (client_id and client_secret)
|
||||
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||
// Execute request
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
// Execute request with rate limit handling
|
||||
resp, err := apiClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute request")
|
||||
}
|
||||
@@ -145,13 +149,16 @@ func RefreshToken(cfg *Config, token *Token) (*Token, error) {
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func RevokeToken(cfg *Config, token *Token) error {
|
||||
func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error {
|
||||
if token == nil {
|
||||
return errors.New("token cannot be nil")
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New("config cannot be nil")
|
||||
}
|
||||
if apiClient == nil {
|
||||
return errors.New("apiClient cannot be nil")
|
||||
}
|
||||
// Prepare form data
|
||||
data := url.Values{}
|
||||
data.Set("token", token.AccessToken)
|
||||
@@ -170,9 +177,8 @@ func RevokeToken(cfg *Config, token *Token) error {
|
||||
|
||||
// Set basic auth (client_id and client_secret)
|
||||
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||
// Execute request
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
// Execute request with rate limit handling
|
||||
resp, err := apiClient.Do(req)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to execute request")
|
||||
}
|
||||
|
||||
235
internal/discord/ratelimit.go
Normal file
235
internal/discord/ratelimit.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// RateLimitState tracks rate limit information for a specific bucket
|
||||
type RateLimitState struct {
|
||||
Remaining int // Requests remaining in current window
|
||||
Limit int // Total requests allowed in window
|
||||
Reset time.Time // When the rate limit resets
|
||||
Bucket string // Discord's bucket identifier
|
||||
}
|
||||
|
||||
// APIClient is an HTTP client wrapper that handles Discord API rate limits
|
||||
type APIClient struct {
|
||||
client *http.Client
|
||||
logger *hlog.Logger
|
||||
mu sync.RWMutex
|
||||
buckets map[string]*RateLimitState
|
||||
}
|
||||
|
||||
// NewRateLimitedClient creates a new Discord API client with rate limit handling
|
||||
func NewRateLimitedClient(logger *hlog.Logger) *APIClient {
|
||||
return &APIClient{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
logger: logger,
|
||||
buckets: make(map[string]*RateLimitState),
|
||||
}
|
||||
}
|
||||
|
||||
// Do executes an HTTP request with automatic rate limit handling
|
||||
// It will wait if rate limits are about to be exceeded and retry once if a 429 is received
|
||||
func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("request cannot be nil")
|
||||
}
|
||||
|
||||
// Step 1: Check if we need to wait before making request
|
||||
bucket := c.getBucketFromRequest(req)
|
||||
if err := c.waitIfNeeded(bucket); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Step 2: Execute request
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
// Check if it's a network timeout
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return nil, errors.Wrap(err, "request timed out")
|
||||
}
|
||||
return nil, errors.Wrap(err, "http request failed")
|
||||
}
|
||||
|
||||
// Step 3: Update rate limit state from response headers
|
||||
c.updateRateLimit(resp.Header)
|
||||
|
||||
// Step 4: Handle 429 (rate limited)
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
resp.Body.Close() // Close original response
|
||||
|
||||
retryAfter := c.parseRetryAfter(resp.Header)
|
||||
|
||||
// No Retry-After header, can't retry safely
|
||||
if retryAfter == 0 {
|
||||
c.logger.Warn().
|
||||
Str("bucket", bucket).
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Msg("Rate limited but no Retry-After header provided")
|
||||
return nil, errors.New("discord API rate limited but no Retry-After header provided")
|
||||
}
|
||||
|
||||
// Retry-After exceeds 30 second cap
|
||||
if retryAfter > 30*time.Second {
|
||||
c.logger.Warn().
|
||||
Str("bucket", bucket).
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Dur("retry_after", retryAfter).
|
||||
Msg("Rate limited with Retry-After exceeding 30s cap, not retrying")
|
||||
return nil, errors.Errorf(
|
||||
"discord API rate limited (retry after %s exceeds 30s cap)",
|
||||
retryAfter,
|
||||
)
|
||||
}
|
||||
|
||||
// Wait and retry
|
||||
c.logger.Warn().
|
||||
Str("bucket", bucket).
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Dur("retry_after", retryAfter).
|
||||
Msg("Rate limited, waiting before retry")
|
||||
|
||||
time.Sleep(retryAfter)
|
||||
|
||||
// Retry the request
|
||||
resp, err = c.client.Do(req)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return nil, errors.Wrap(err, "retry request timed out")
|
||||
}
|
||||
return nil, errors.Wrap(err, "retry request failed")
|
||||
}
|
||||
|
||||
// Update rate limit again after retry
|
||||
c.updateRateLimit(resp.Header)
|
||||
|
||||
// If STILL rate limited after retry, return error
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
resp.Body.Close()
|
||||
c.logger.Error().
|
||||
Str("bucket", bucket).
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Msg("Still rate limited after retry, Discord may be experiencing issues")
|
||||
return nil, errors.Errorf(
|
||||
"discord API still rate limited after retry (waited %s), Discord may be experiencing issues",
|
||||
retryAfter,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// getBucketFromRequest extracts or generates bucket ID from request
|
||||
// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers
|
||||
func (c *APIClient) getBucketFromRequest(req *http.Request) string {
|
||||
return req.Method + ":" + req.URL.Path
|
||||
}
|
||||
|
||||
// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits
|
||||
func (c *APIClient) waitIfNeeded(bucket string) error {
|
||||
c.mu.RLock()
|
||||
state, exists := c.buckets[bucket]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil // No state yet, proceed
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// If we have no remaining requests and reset hasn't occurred, wait
|
||||
if state.Remaining == 0 && now.Before(state.Reset) {
|
||||
waitDuration := time.Until(state.Reset)
|
||||
// Add small buffer (100ms) to ensure reset has occurred
|
||||
waitDuration += 100 * time.Millisecond
|
||||
|
||||
if waitDuration > 0 {
|
||||
c.logger.Debug().
|
||||
Str("bucket", bucket).
|
||||
Dur("wait_duration", waitDuration).
|
||||
Msg("Proactively waiting for rate limit reset")
|
||||
time.Sleep(waitDuration)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateRateLimit parses response headers and updates bucket state
|
||||
func (c *APIClient) updateRateLimit(headers http.Header) {
|
||||
bucket := headers.Get("X-RateLimit-Bucket")
|
||||
if bucket == "" {
|
||||
return // No bucket info, can't track
|
||||
}
|
||||
|
||||
// Parse headers
|
||||
limit := c.parseInt(headers.Get("X-RateLimit-Limit"))
|
||||
remaining := c.parseInt(headers.Get("X-RateLimit-Remaining"))
|
||||
resetAfter := c.parseFloat(headers.Get("X-RateLimit-Reset-After"))
|
||||
|
||||
state := &RateLimitState{
|
||||
Bucket: bucket,
|
||||
Limit: limit,
|
||||
Remaining: remaining,
|
||||
Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))),
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.buckets[bucket] = state
|
||||
c.mu.Unlock()
|
||||
|
||||
// Log rate limit state for debugging
|
||||
c.logger.Debug().
|
||||
Str("bucket", bucket).
|
||||
Int("remaining", remaining).
|
||||
Int("limit", limit).
|
||||
Dur("reset_in", time.Until(state.Reset)).
|
||||
Msg("Rate limit state updated")
|
||||
}
|
||||
|
||||
// parseRetryAfter extracts retry delay from Retry-After header
|
||||
func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration {
|
||||
retryAfter := headers.Get("Retry-After")
|
||||
if retryAfter == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Discord returns seconds as float
|
||||
seconds := c.parseFloat(retryAfter)
|
||||
if seconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return time.Duration(seconds * float64(time.Second))
|
||||
}
|
||||
|
||||
// parseInt parses an integer from a header value, returns 0 on error
|
||||
func (c *APIClient) parseInt(s string) int {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
i, _ := strconv.Atoi(s)
|
||||
return i
|
||||
}
|
||||
|
||||
// parseFloat parses a float from a header value, returns 0 on error
|
||||
func (c *APIClient) parseFloat(s string) float64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
f, _ := strconv.ParseFloat(s, 64)
|
||||
return f
|
||||
}
|
||||
459
internal/discord/ratelimit_test.go
Normal file
459
internal/discord/ratelimit_test.go
Normal file
@@ -0,0 +1,459 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
)
|
||||
|
||||
// testLogger creates a test logger for testing
|
||||
func testLogger(t *testing.T) *hlog.Logger {
|
||||
level, _ := hlog.LogLevel("debug")
|
||||
cfg := &hlog.Config{
|
||||
LogLevel: level,
|
||||
LogOutput: "console",
|
||||
}
|
||||
logger, err := hlog.NewLogger(cfg, io.Discard)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test logger: %v", err)
|
||||
}
|
||||
return logger
|
||||
}
|
||||
|
||||
func TestNewRateLimitedClient(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
if client == nil {
|
||||
t.Fatal("NewRateLimitedClient returned nil")
|
||||
}
|
||||
if client.client == nil {
|
||||
t.Error("client.client is nil")
|
||||
}
|
||||
if client.logger == nil {
|
||||
t.Error("client.logger is nil")
|
||||
}
|
||||
if client.buckets == nil {
|
||||
t.Error("client.buckets map is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_Success(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
// Mock server that returns success with rate limit headers
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||
w.Header().Set("X-RateLimit-Limit", "5")
|
||||
w.Header().Set("X-RateLimit-Remaining", "3")
|
||||
w.Header().Set("X-RateLimit-Reset-After", "2.5")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Do() returned error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check that rate limit state was updated
|
||||
client.mu.RLock()
|
||||
state, exists := client.buckets["test-bucket"]
|
||||
client.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Fatal("rate limit state not stored")
|
||||
}
|
||||
if state.Remaining != 3 {
|
||||
t.Errorf("expected remaining=3, got %d", state.Remaining)
|
||||
}
|
||||
if state.Limit != 5 {
|
||||
t.Errorf("expected limit=5, got %d", state.Limit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
attemptCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attemptCount++
|
||||
if attemptCount == 1 {
|
||||
// First request: return 429
|
||||
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||
w.Header().Set("Retry-After", "0.1") // 100ms
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "rate_limited",
|
||||
"error_description": "You are being rate limited",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Second request: success
|
||||
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||
w.Header().Set("X-RateLimit-Limit", "5")
|
||||
w.Header().Set("X-RateLimit-Remaining", "4")
|
||||
w.Header().Set("X-RateLimit-Reset-After", "2.5")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Do() returned error: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200 after retry, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if attemptCount != 2 {
|
||||
t.Errorf("expected 2 attempts (initial + retry), got %d", attemptCount)
|
||||
}
|
||||
|
||||
// Should have waited approximately 100ms
|
||||
if elapsed < 100*time.Millisecond {
|
||||
t.Errorf("expected delay of ~100ms, but took %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
attemptCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attemptCount++
|
||||
// Always return 429
|
||||
w.Header().Set("X-RateLimit-Bucket", "test-bucket")
|
||||
w.Header().Set("Retry-After", "0.05") // 50ms
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "rate_limited",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
t.Fatal("Do() should have returned error after failed retry")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "still rate limited after retry") {
|
||||
t.Errorf("expected 'still rate limited after retry' error, got: %v", err)
|
||||
}
|
||||
|
||||
if attemptCount != 2 {
|
||||
t.Errorf("expected 2 attempts, got %d", attemptCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "rate_limited",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
t.Fatal("Do() should have returned error for Retry-After > 30s")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "exceeds 30s cap") {
|
||||
t.Errorf("expected 'exceeds 30s cap' error, got: %v", err)
|
||||
}
|
||||
|
||||
// Should NOT have waited (immediate error)
|
||||
if elapsed > 1*time.Second {
|
||||
t.Errorf("should return immediately, but took %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return 429 but NO Retry-After header
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "rate_limited",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/oauth2/token", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
t.Fatal("Do() should have returned error when no Retry-After header")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no Retry-After header") {
|
||||
t.Errorf("expected 'no Retry-After header' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_UpdateRateLimit(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("X-RateLimit-Bucket", "global")
|
||||
headers.Set("X-RateLimit-Limit", "10")
|
||||
headers.Set("X-RateLimit-Remaining", "7")
|
||||
headers.Set("X-RateLimit-Reset-After", "5.5")
|
||||
|
||||
client.updateRateLimit(headers)
|
||||
|
||||
client.mu.RLock()
|
||||
state, exists := client.buckets["global"]
|
||||
client.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Fatal("bucket state not created")
|
||||
}
|
||||
|
||||
if state.Bucket != "global" {
|
||||
t.Errorf("expected bucket='global', got '%s'", state.Bucket)
|
||||
}
|
||||
if state.Limit != 10 {
|
||||
t.Errorf("expected limit=10, got %d", state.Limit)
|
||||
}
|
||||
if state.Remaining != 7 {
|
||||
t.Errorf("expected remaining=7, got %d", state.Remaining)
|
||||
}
|
||||
|
||||
// Check reset time is approximately 5.5 seconds from now
|
||||
resetIn := time.Until(state.Reset)
|
||||
if resetIn < 5*time.Second || resetIn > 6*time.Second {
|
||||
t.Errorf("expected reset in ~5.5s, got %v", resetIn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_WaitIfNeeded(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
// Set up a bucket with 0 remaining and reset in future
|
||||
bucket := "test-bucket"
|
||||
client.mu.Lock()
|
||||
client.buckets[bucket] = &RateLimitState{
|
||||
Bucket: bucket,
|
||||
Limit: 5,
|
||||
Remaining: 0,
|
||||
Reset: time.Now().Add(200 * time.Millisecond),
|
||||
}
|
||||
client.mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
err := client.waitIfNeeded(bucket)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("waitIfNeeded returned error: %v", err)
|
||||
}
|
||||
|
||||
// Should have waited ~200ms + 100ms buffer
|
||||
if elapsed < 200*time.Millisecond {
|
||||
t.Errorf("expected wait of ~300ms, but took %v", elapsed)
|
||||
}
|
||||
if elapsed > 500*time.Millisecond {
|
||||
t.Errorf("waited too long: %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
// Set up a bucket with remaining requests
|
||||
bucket := "test-bucket"
|
||||
client.mu.Lock()
|
||||
client.buckets[bucket] = &RateLimitState{
|
||||
Bucket: bucket,
|
||||
Limit: 5,
|
||||
Remaining: 3,
|
||||
Reset: time.Now().Add(5 * time.Second),
|
||||
}
|
||||
client.mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
err := client.waitIfNeeded(bucket)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("waitIfNeeded returned error: %v", err)
|
||||
}
|
||||
|
||||
// Should NOT wait (has remaining requests)
|
||||
if elapsed > 10*time.Millisecond {
|
||||
t.Errorf("should not wait when remaining > 0, but took %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_Concurrent(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
requestCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
requestCount++
|
||||
count := requestCount
|
||||
mu.Unlock()
|
||||
|
||||
w.Header().Set("X-RateLimit-Bucket", "concurrent-bucket")
|
||||
w.Header().Set("X-RateLimit-Limit", "10")
|
||||
w.Header().Set("X-RateLimit-Remaining", "5")
|
||||
w.Header().Set("X-RateLimit-Reset-After", "1.0")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"count": string(rune(count))})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Launch 10 concurrent requests
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for any errors
|
||||
for err := range errors {
|
||||
t.Errorf("concurrent request failed: %v", err)
|
||||
}
|
||||
|
||||
// All requests should have completed
|
||||
mu.Lock()
|
||||
finalCount := requestCount
|
||||
mu.Unlock()
|
||||
|
||||
if finalCount != 10 {
|
||||
t.Errorf("expected 10 requests, got %d", finalCount)
|
||||
}
|
||||
|
||||
// Check rate limit state is consistent (no data races)
|
||||
client.mu.RLock()
|
||||
state, exists := client.buckets["concurrent-bucket"]
|
||||
client.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Fatal("bucket state not found after concurrent requests")
|
||||
}
|
||||
|
||||
// State should exist and be valid
|
||||
if state.Limit != 10 {
|
||||
t.Errorf("expected limit=10, got %d (possible race condition)", state.Limit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_ParseRetryAfter(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
expected time.Duration
|
||||
}{
|
||||
{"integer seconds", "2", 2 * time.Second},
|
||||
{"float seconds", "2.5", 2500 * time.Millisecond},
|
||||
{"zero", "0", 0},
|
||||
{"empty", "", 0},
|
||||
{"invalid", "abc", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("Retry-After", tt.header)
|
||||
|
||||
result := client.parseRetryAfter(headers)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user