added logout

This commit is contained in:
2026-01-24 15:23:28 +11:00
parent 73a5c9726b
commit 4dec97def8
14 changed files with 327 additions and 191 deletions

View File

@@ -1,6 +1,11 @@
package discord
import (
"net/http"
"sync"
"time"
"git.haelnorr.com/h/golib/hlog"
"github.com/bwmarrin/discordgo"
"github.com/pkg/errors"
)
@@ -24,3 +29,33 @@ func (s *OAuthSession) GetUser() (*discordgo.User, error) {
}
return user, nil
}
// APIClient is an HTTP client wrapper that handles Discord API rate limits
type APIClient struct {
cfg *Config
client *http.Client
logger *hlog.Logger
mu sync.RWMutex
buckets map[string]*RateLimitState
trustedHost string
}
// NewAPIClient creates a new Discord API client with rate limit handling
func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APIClient, error) {
if cfg == nil {
return nil, errors.New("config cannot be nil")
}
if logger == nil {
return nil, errors.New("logger cannot be nil")
}
if trustedhost == "" {
return nil, errors.New("trustedhost cannot be empty")
}
return &APIClient{
client: &http.Client{Timeout: 30 * time.Second},
logger: logger,
buckets: make(map[string]*RateLimitState),
cfg: cfg,
trustedHost: trustedhost,
}, nil
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/pkg/errors"
)
// Token represents a response from the Discord OAuth API after a successful authorization request
type Token struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
@@ -22,46 +23,32 @@ type Token struct {
const oauthurl string = "https://discord.com/oauth2/authorize"
const apiurl string = "https://discord.com/api/v10"
func GetOAuthLink(cfg *Config, state, trustedHost string) (string, error) {
if cfg == nil {
return "", errors.New("cfg cannot be nil")
}
// GetOAuthLink generates a new Discord OAuth2 link for user authentication
func (api *APIClient) GetOAuthLink(state string) (string, error) {
if state == "" {
return "", errors.New("state cannot be empty")
}
if trustedHost == "" {
return "", errors.New("trustedHost cannot be empty")
}
values := url.Values{}
values.Add("response_type", "code")
values.Add("client_id", cfg.ClientID)
values.Add("scope", cfg.OAuthScopes)
values.Add("client_id", api.cfg.ClientID)
values.Add("scope", api.cfg.OAuthScopes)
values.Add("state", state)
values.Add("redirect_uri", fmt.Sprintf("%s/%s", trustedHost, cfg.RedirectPath))
values.Add("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath))
values.Add("prompt", "none")
return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil
}
func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClient) (*Token, error) {
// AuthorizeWithCode uses a users authorization token generated by OAuth2 to get a token for
// making requests to the API on behalf of the user
func (api *APIClient) AuthorizeWithCode(code string) (*Token, error) {
if code == "" {
return nil, errors.New("code cannot be empty")
}
if cfg == nil {
return nil, errors.New("config cannot be nil")
}
if trustedHost == "" {
return nil, errors.New("trustedHost cannot be empty")
}
if apiClient == nil {
return nil, errors.New("apiClient cannot be nil")
}
// Prepare form data
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", fmt.Sprintf("%s/%s", trustedHost, cfg.RedirectPath))
// Create request
data.Set("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath))
req, err := http.NewRequest(
"POST",
apiurl+"/oauth2/token",
@@ -70,27 +57,21 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClie
if err != nil {
return nil, errors.Wrap(err, "failed to create request")
}
// Set headers
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Set basic auth (client_id and client_secret)
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
// Execute request with rate limit handling
resp, err := apiClient.Do(req)
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
resp, err := api.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to execute request")
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read response body")
}
// Check status code
if resp.StatusCode != http.StatusOK {
return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body))
}
// Parse JSON response
var tokenResp Token
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.Wrap(err, "failed to parse token response")
@@ -98,21 +79,14 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClie
return &tokenResp, nil
}
func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, error) {
// RefreshToken uses the refresh token to generate a new token pair
func (api *APIClient) RefreshToken(token *Token) (*Token, error) {
if token == nil {
return nil, errors.New("token cannot be nil")
}
if cfg == nil {
return nil, errors.New("config cannot be nil")
}
if apiClient == nil {
return nil, errors.New("apiClient cannot be nil")
}
// Prepare form data
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", token.RefreshToken)
// Create request
req, err := http.NewRequest(
"POST",
apiurl+"/oauth2/token",
@@ -121,27 +95,21 @@ func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, erro
if err != nil {
return nil, errors.Wrap(err, "failed to create request")
}
// Set headers
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Set basic auth (client_id and client_secret)
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
// Execute request with rate limit handling
resp, err := apiClient.Do(req)
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
resp, err := api.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to execute request")
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read response body")
}
// Check status code
if resp.StatusCode != http.StatusOK {
return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body))
}
// Parse JSON response
var tokenResp Token
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.Wrap(err, "failed to parse token response")
@@ -149,21 +117,14 @@ func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, erro
return &tokenResp, nil
}
func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error {
// RevokeToken sends a request to the Discord API to revoke the token pair
func (api *APIClient) RevokeToken(token *Token) error {
if token == nil {
return errors.New("token cannot be nil")
}
if cfg == nil {
return errors.New("config cannot be nil")
}
if apiClient == nil {
return errors.New("apiClient cannot be nil")
}
// Prepare form data
data := url.Values{}
data.Set("token", token.AccessToken)
data.Set("token_type_hint", "access_token")
// Create request
req, err := http.NewRequest(
"POST",
apiurl+"/oauth2/token/revoke",
@@ -172,18 +133,14 @@ func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error {
if err != nil {
return errors.Wrap(err, "failed to create request")
}
// Set headers
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Set basic auth (client_id and client_secret)
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
// Execute request with rate limit handling
resp, err := apiClient.Do(req)
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
resp, err := api.Do(req)
if err != nil {
return errors.Wrap(err, "failed to execute request")
}
defer resp.Body.Close()
// Check status code
if resp.StatusCode != http.StatusOK {
return errors.Errorf("discord API returned status %d", resp.StatusCode)
}

View File

@@ -4,10 +4,8 @@ import (
"net"
"net/http"
"strconv"
"sync"
"time"
"git.haelnorr.com/h/golib/hlog"
"github.com/pkg/errors"
)
@@ -19,23 +17,6 @@ type RateLimitState struct {
Bucket string // Discord's bucket identifier
}
// APIClient is an HTTP client wrapper that handles Discord API rate limits
type APIClient struct {
client *http.Client
logger *hlog.Logger
mu sync.RWMutex
buckets map[string]*RateLimitState
}
// NewRateLimitedClient creates a new Discord API client with rate limit handling
func NewRateLimitedClient(logger *hlog.Logger) *APIClient {
return &APIClient{
client: &http.Client{Timeout: 30 * time.Second},
logger: logger,
buckets: make(map[string]*RateLimitState),
}
}
// Do executes an HTTP request with automatic rate limit handling
// It will wait if rate limits are about to be exceeded and retry once if a 429 is received
func (c *APIClient) Do(req *http.Request) (*http.Response, error) {

View File

@@ -27,12 +27,26 @@ func testLogger(t *testing.T) *hlog.Logger {
return logger
}
// testConfig creates a test config for testing
func testConfig() *Config {
return &Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
OAuthScopes: "identify+email",
RedirectPath: "/oauth/callback",
}
}
func TestNewRateLimitedClient(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
if client == nil {
t.Fatal("NewRateLimitedClient returned nil")
t.Fatal("NewAPIClient returned nil")
}
if client.client == nil {
t.Error("client.client is nil")
@@ -43,11 +57,21 @@ func TestNewRateLimitedClient(t *testing.T) {
if client.buckets == nil {
t.Error("client.buckets map is nil")
}
if client.cfg == nil {
t.Error("client.cfg is nil")
}
if client.trustedHost != "trusted-host.example.com" {
t.Errorf("expected trustedHost='trusted-host.example.com', got '%s'", client.trustedHost)
}
}
func TestAPIClient_Do_Success(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
// Mock server that returns success with rate limit headers
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -93,7 +117,11 @@ func TestAPIClient_Do_Success(t *testing.T) {
func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
attemptCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -149,7 +177,11 @@ func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
attemptCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -187,7 +219,11 @@ func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap
@@ -224,7 +260,11 @@ func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return 429 but NO Retry-After header
@@ -254,7 +294,11 @@ func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
func TestAPIClient_UpdateRateLimit(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
headers := http.Header{}
headers.Set("X-RateLimit-Bucket", "global")
@@ -291,7 +335,11 @@ func TestAPIClient_UpdateRateLimit(t *testing.T) {
func TestAPIClient_WaitIfNeeded(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
// Set up a bucket with 0 remaining and reset in future
bucket := "test-bucket"
@@ -305,7 +353,7 @@ func TestAPIClient_WaitIfNeeded(t *testing.T) {
client.mu.Unlock()
start := time.Now()
err := client.waitIfNeeded(bucket)
err = client.waitIfNeeded(bucket)
elapsed := time.Since(start)
if err != nil {
@@ -323,7 +371,11 @@ func TestAPIClient_WaitIfNeeded(t *testing.T) {
func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
// Set up a bucket with remaining requests
bucket := "test-bucket"
@@ -337,7 +389,7 @@ func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
client.mu.Unlock()
start := time.Now()
err := client.waitIfNeeded(bucket)
err = client.waitIfNeeded(bucket)
elapsed := time.Since(start)
if err != nil {
@@ -352,7 +404,11 @@ func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
func TestAPIClient_Do_Concurrent(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
requestCount := 0
var mu sync.Mutex
@@ -376,24 +432,22 @@ func TestAPIClient_Do_Concurrent(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, 10)
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 10 {
wg.Go(
func() {
req, err := http.NewRequest("GET", server.URL+"/test", nil)
if err != nil {
errors <- err
return
}
req, err := http.NewRequest("GET", server.URL+"/test", nil)
if err != nil {
errors <- err
return
}
resp, err := client.Do(req)
if err != nil {
errors <- err
return
}
resp.Body.Close()
}()
resp, err := client.Do(req)
if err != nil {
errors <- err
return
}
resp.Body.Close()
})
}
wg.Wait()
@@ -430,7 +484,11 @@ func TestAPIClient_Do_Concurrent(t *testing.T) {
func TestAPIClient_ParseRetryAfter(t *testing.T) {
logger := testLogger(t)
client := NewRateLimitedClient(logger)
cfg := testConfig()
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
if err != nil {
t.Fatalf("NewAPIClient returned error: %v", err)
}
tests := []struct {
name string