217 lines
5.9 KiB
Go
217 lines
5.9 KiB
Go
package discord
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// RateLimitState tracks rate limit information for a specific bucket
|
|
type RateLimitState struct {
|
|
Remaining int // Requests remaining in current window
|
|
Limit int // Total requests allowed in window
|
|
Reset time.Time // When the rate limit resets
|
|
Bucket string // Discord's bucket identifier
|
|
}
|
|
|
|
// Do executes an HTTP request with automatic rate limit handling
|
|
// It will wait if rate limits are about to be exceeded and retry once if a 429 is received
|
|
func (api *APIClient) Do(req *http.Request) (*http.Response, error) {
|
|
if req == nil {
|
|
return nil, errors.New("request cannot be nil")
|
|
}
|
|
|
|
// Step 1: Check if we need to wait before making request
|
|
bucket := api.getBucketFromRequest(req)
|
|
if err := api.waitIfNeeded(bucket); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Step 2: Execute request
|
|
resp, err := api.client.Do(req)
|
|
if err != nil {
|
|
// Check if it's a network timeout
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
return nil, errors.Wrap(err, "request timed out")
|
|
}
|
|
return nil, errors.Wrap(err, "http request failed")
|
|
}
|
|
|
|
// Step 3: Update rate limit state from response headers
|
|
api.updateRateLimit(resp.Header)
|
|
|
|
// Step 4: Handle 429 (rate limited)
|
|
if resp.StatusCode == http.StatusTooManyRequests {
|
|
resp.Body.Close() // Close original response
|
|
|
|
retryAfter := api.parseRetryAfter(resp.Header)
|
|
|
|
// No Retry-After header, can't retry safely
|
|
if retryAfter == 0 {
|
|
api.logger.Warn().
|
|
Str("bucket", bucket).
|
|
Str("method", req.Method).
|
|
Str("path", req.URL.Path).
|
|
Msg("Rate limited but no Retry-After header provided")
|
|
return nil, errors.New("discord API rate limited but no Retry-After header provided")
|
|
}
|
|
|
|
// Retry-After exceeds 30 second cap
|
|
if retryAfter > 30*time.Second {
|
|
api.logger.Warn().
|
|
Str("bucket", bucket).
|
|
Str("method", req.Method).
|
|
Str("path", req.URL.Path).
|
|
Dur("retry_after", retryAfter).
|
|
Msg("Rate limited with Retry-After exceeding 30s cap, not retrying")
|
|
return nil, errors.Errorf(
|
|
"discord API rate limited (retry after %s exceeds 30s cap)",
|
|
retryAfter,
|
|
)
|
|
}
|
|
|
|
// Wait and retry
|
|
api.logger.Warn().
|
|
Str("bucket", bucket).
|
|
Str("method", req.Method).
|
|
Str("path", req.URL.Path).
|
|
Dur("retry_after", retryAfter).
|
|
Msg("Rate limited, waiting before retry")
|
|
|
|
time.Sleep(retryAfter)
|
|
|
|
// Retry the request
|
|
resp, err = api.client.Do(req)
|
|
if err != nil {
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
return nil, errors.Wrap(err, "retry request timed out")
|
|
}
|
|
return nil, errors.Wrap(err, "retry request failed")
|
|
}
|
|
|
|
// Update rate limit again after retry
|
|
api.updateRateLimit(resp.Header)
|
|
|
|
// If STILL rate limited after retry, return error
|
|
if resp.StatusCode == http.StatusTooManyRequests {
|
|
resp.Body.Close()
|
|
api.logger.Error().
|
|
Str("bucket", bucket).
|
|
Str("method", req.Method).
|
|
Str("path", req.URL.Path).
|
|
Msg("Still rate limited after retry, Discord may be experiencing issues")
|
|
return nil, errors.Errorf(
|
|
"discord API still rate limited after retry (waited %s), Discord may be experiencing issues",
|
|
retryAfter,
|
|
)
|
|
}
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// getBucketFromRequest extracts or generates bucket ID from request
|
|
// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers
|
|
func (api *APIClient) getBucketFromRequest(req *http.Request) string {
|
|
return req.Method + ":" + req.URL.Path
|
|
}
|
|
|
|
// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits
|
|
func (api *APIClient) waitIfNeeded(bucket string) error {
|
|
api.mu.RLock()
|
|
state, exists := api.buckets[bucket]
|
|
api.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return nil // No state yet, proceed
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
// If we have no remaining requests and reset hasn't occurred, wait
|
|
if state.Remaining == 0 && now.Before(state.Reset) {
|
|
waitDuration := time.Until(state.Reset)
|
|
// Add small buffer (100ms) to ensure reset has occurred
|
|
waitDuration += 100 * time.Millisecond
|
|
|
|
if waitDuration > 0 {
|
|
api.logger.Debug().
|
|
Str("bucket", bucket).
|
|
Dur("wait_duration", waitDuration).
|
|
Msg("Proactively waiting for rate limit reset")
|
|
time.Sleep(waitDuration)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// updateRateLimit parses response headers and updates bucket state
|
|
func (api *APIClient) updateRateLimit(headers http.Header) {
|
|
bucket := headers.Get("X-RateLimit-Bucket")
|
|
if bucket == "" {
|
|
return // No bucket info, can't track
|
|
}
|
|
|
|
// Parse headers
|
|
limit := api.parseInt(headers.Get("X-RateLimit-Limit"))
|
|
remaining := api.parseInt(headers.Get("X-RateLimit-Remaining"))
|
|
resetAfter := api.parseFloat(headers.Get("X-RateLimit-Reset-After"))
|
|
|
|
state := &RateLimitState{
|
|
Bucket: bucket,
|
|
Limit: limit,
|
|
Remaining: remaining,
|
|
Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))),
|
|
}
|
|
|
|
api.mu.Lock()
|
|
api.buckets[bucket] = state
|
|
api.mu.Unlock()
|
|
|
|
// Log rate limit state for debugging
|
|
api.logger.Debug().
|
|
Str("bucket", bucket).
|
|
Int("remaining", remaining).
|
|
Int("limit", limit).
|
|
Dur("reset_in", time.Until(state.Reset)).
|
|
Msg("Rate limit state updated")
|
|
}
|
|
|
|
// parseRetryAfter extracts retry delay from Retry-After header
|
|
func (api *APIClient) parseRetryAfter(headers http.Header) time.Duration {
|
|
retryAfter := headers.Get("Retry-After")
|
|
if retryAfter == "" {
|
|
return 0
|
|
}
|
|
|
|
// Discord returns seconds as float
|
|
seconds := api.parseFloat(retryAfter)
|
|
if seconds <= 0 {
|
|
return 0
|
|
}
|
|
|
|
return time.Duration(seconds * float64(time.Second))
|
|
}
|
|
|
|
// parseInt parses an integer from a header value, returns 0 on error
|
|
func (api *APIClient) parseInt(s string) int {
|
|
if s == "" {
|
|
return 0
|
|
}
|
|
i, _ := strconv.Atoi(s)
|
|
return i
|
|
}
|
|
|
|
// parseFloat parses a float from a header value, returns 0 on error
|
|
func (api *APIClient) parseFloat(s string) float64 {
|
|
if s == "" {
|
|
return 0
|
|
}
|
|
f, _ := strconv.ParseFloat(s, 64)
|
|
return f
|
|
}
|