added logout
This commit is contained in:
@@ -16,12 +16,13 @@ type DiscordToken struct {
|
||||
AccessToken string `bun:"access_token,notnull"`
|
||||
RefreshToken string `bun:"refresh_token,notnull"`
|
||||
ExpiresAt int64 `bun:"expires_at,notnull"`
|
||||
Scope string `bun:"scope,notnull"`
|
||||
TokenType string `bun:"token_type,notnull"`
|
||||
}
|
||||
|
||||
func UpdateDiscordToken(ctx context.Context, tx bun.Tx, user *User, token *discord.Token) error {
|
||||
if user == nil {
|
||||
return errors.New("user cannot be nil")
|
||||
}
|
||||
// UpdateDiscordToken adds the provided discord token to the database.
|
||||
// If the user already has a token stored, it will replace that token instead.
|
||||
func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token cannot be nil")
|
||||
}
|
||||
@@ -32,6 +33,8 @@ func UpdateDiscordToken(ctx context.Context, tx bun.Tx, user *User, token *disco
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: token.Scope,
|
||||
TokenType: token.TokenType,
|
||||
}
|
||||
|
||||
_, err := tx.NewInsert().
|
||||
@@ -47,3 +50,46 @@ func UpdateDiscordToken(ctx context.Context, tx bun.Tx, user *User, token *disco
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteDiscordTokens deletes a users discord OAuth tokens from the database.
|
||||
// It returns the DiscordToken so that it can be revoked via the discord API
|
||||
func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
|
||||
token, err := user.GetDiscordToken(ctx, tx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "user.GetDiscordToken")
|
||||
}
|
||||
_, err = tx.NewDelete().
|
||||
Model((*DiscordToken)(nil)).
|
||||
Where("discord_id = ?", user.DiscordID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.NewDelete")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetDiscordToken retrieves the users discord token from the database
|
||||
func (user *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
|
||||
token := new(DiscordToken)
|
||||
err := tx.NewSelect().
|
||||
Model(token).
|
||||
Where("discord_id = ?", user.DiscordID).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Convert reverts the token back into a *discord.Token
|
||||
func (t *DiscordToken) Convert() *discord.Token {
|
||||
token := &discord.Token{
|
||||
AccessToken: t.AccessToken,
|
||||
RefreshToken: t.RefreshToken,
|
||||
ExpiresIn: int(t.ExpiresAt - time.Now().Unix()),
|
||||
Scope: t.Scope,
|
||||
TokenType: t.TokenType,
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
@@ -24,3 +29,33 @@ func (s *OAuthSession) GetUser() (*discordgo.User, error) {
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// APIClient is an HTTP client wrapper that handles Discord API rate limits
|
||||
type APIClient struct {
|
||||
cfg *Config
|
||||
client *http.Client
|
||||
logger *hlog.Logger
|
||||
mu sync.RWMutex
|
||||
buckets map[string]*RateLimitState
|
||||
trustedHost string
|
||||
}
|
||||
|
||||
// NewAPIClient creates a new Discord API client with rate limit handling
|
||||
func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APIClient, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config cannot be nil")
|
||||
}
|
||||
if logger == nil {
|
||||
return nil, errors.New("logger cannot be nil")
|
||||
}
|
||||
if trustedhost == "" {
|
||||
return nil, errors.New("trustedhost cannot be empty")
|
||||
}
|
||||
return &APIClient{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
logger: logger,
|
||||
buckets: make(map[string]*RateLimitState),
|
||||
cfg: cfg,
|
||||
trustedHost: trustedhost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Token represents a response from the Discord OAuth API after a successful authorization request
|
||||
type Token struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
@@ -22,46 +23,32 @@ type Token struct {
|
||||
const oauthurl string = "https://discord.com/oauth2/authorize"
|
||||
const apiurl string = "https://discord.com/api/v10"
|
||||
|
||||
func GetOAuthLink(cfg *Config, state, trustedHost string) (string, error) {
|
||||
if cfg == nil {
|
||||
return "", errors.New("cfg cannot be nil")
|
||||
}
|
||||
// GetOAuthLink generates a new Discord OAuth2 link for user authentication
|
||||
func (api *APIClient) GetOAuthLink(state string) (string, error) {
|
||||
if state == "" {
|
||||
return "", errors.New("state cannot be empty")
|
||||
}
|
||||
if trustedHost == "" {
|
||||
return "", errors.New("trustedHost cannot be empty")
|
||||
}
|
||||
values := url.Values{}
|
||||
values.Add("response_type", "code")
|
||||
values.Add("client_id", cfg.ClientID)
|
||||
values.Add("scope", cfg.OAuthScopes)
|
||||
values.Add("client_id", api.cfg.ClientID)
|
||||
values.Add("scope", api.cfg.OAuthScopes)
|
||||
values.Add("state", state)
|
||||
values.Add("redirect_uri", fmt.Sprintf("%s/%s", trustedHost, cfg.RedirectPath))
|
||||
values.Add("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath))
|
||||
values.Add("prompt", "none")
|
||||
|
||||
return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil
|
||||
}
|
||||
|
||||
func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClient) (*Token, error) {
|
||||
// AuthorizeWithCode uses a users authorization token generated by OAuth2 to get a token for
|
||||
// making requests to the API on behalf of the user
|
||||
func (api *APIClient) AuthorizeWithCode(code string) (*Token, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("code cannot be empty")
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config cannot be nil")
|
||||
}
|
||||
if trustedHost == "" {
|
||||
return nil, errors.New("trustedHost cannot be empty")
|
||||
}
|
||||
if apiClient == nil {
|
||||
return nil, errors.New("apiClient cannot be nil")
|
||||
}
|
||||
// Prepare form data
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", fmt.Sprintf("%s/%s", trustedHost, cfg.RedirectPath))
|
||||
// Create request
|
||||
data.Set("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath))
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
apiurl+"/oauth2/token",
|
||||
@@ -70,27 +57,21 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClie
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create request")
|
||||
}
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// Set basic auth (client_id and client_secret)
|
||||
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||
// Execute request with rate limit handling
|
||||
resp, err := apiClient.Do(req)
|
||||
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
|
||||
resp, err := api.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
// Read response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
// Check status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
// Parse JSON response
|
||||
var tokenResp Token
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse token response")
|
||||
@@ -98,21 +79,14 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClie
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, error) {
|
||||
// RefreshToken uses the refresh token to generate a new token pair
|
||||
func (api *APIClient) RefreshToken(token *Token) (*Token, error) {
|
||||
if token == nil {
|
||||
return nil, errors.New("token cannot be nil")
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config cannot be nil")
|
||||
}
|
||||
if apiClient == nil {
|
||||
return nil, errors.New("apiClient cannot be nil")
|
||||
}
|
||||
// Prepare form data
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", token.RefreshToken)
|
||||
// Create request
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
apiurl+"/oauth2/token",
|
||||
@@ -121,27 +95,21 @@ func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, erro
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create request")
|
||||
}
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// Set basic auth (client_id and client_secret)
|
||||
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||
// Execute request with rate limit handling
|
||||
resp, err := apiClient.Do(req)
|
||||
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
|
||||
resp, err := api.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
// Read response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
// Check status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
// Parse JSON response
|
||||
var tokenResp Token
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse token response")
|
||||
@@ -149,21 +117,14 @@ func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, erro
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error {
|
||||
// RevokeToken sends a request to the Discord API to revoke the token pair
|
||||
func (api *APIClient) RevokeToken(token *Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token cannot be nil")
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New("config cannot be nil")
|
||||
}
|
||||
if apiClient == nil {
|
||||
return errors.New("apiClient cannot be nil")
|
||||
}
|
||||
// Prepare form data
|
||||
data := url.Values{}
|
||||
data.Set("token", token.AccessToken)
|
||||
data.Set("token_type_hint", "access_token")
|
||||
// Create request
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
apiurl+"/oauth2/token/revoke",
|
||||
@@ -172,18 +133,14 @@ func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error {
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create request")
|
||||
}
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// Set basic auth (client_id and client_secret)
|
||||
req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
|
||||
// Execute request with rate limit handling
|
||||
resp, err := apiClient.Do(req)
|
||||
req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret)
|
||||
resp, err := api.Do(req)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to execute request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
// Check status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.Errorf("discord API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -4,10 +4,8 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -19,23 +17,6 @@ type RateLimitState struct {
|
||||
Bucket string // Discord's bucket identifier
|
||||
}
|
||||
|
||||
// APIClient is an HTTP client wrapper that handles Discord API rate limits
|
||||
type APIClient struct {
|
||||
client *http.Client
|
||||
logger *hlog.Logger
|
||||
mu sync.RWMutex
|
||||
buckets map[string]*RateLimitState
|
||||
}
|
||||
|
||||
// NewRateLimitedClient creates a new Discord API client with rate limit handling
|
||||
func NewRateLimitedClient(logger *hlog.Logger) *APIClient {
|
||||
return &APIClient{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
logger: logger,
|
||||
buckets: make(map[string]*RateLimitState),
|
||||
}
|
||||
}
|
||||
|
||||
// Do executes an HTTP request with automatic rate limit handling
|
||||
// It will wait if rate limits are about to be exceeded and retry once if a 429 is received
|
||||
func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
|
||||
|
||||
@@ -27,12 +27,26 @@ func testLogger(t *testing.T) *hlog.Logger {
|
||||
return logger
|
||||
}
|
||||
|
||||
// testConfig creates a test config for testing
|
||||
func testConfig() *Config {
|
||||
return &Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
OAuthScopes: "identify+email",
|
||||
RedirectPath: "/oauth/callback",
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimitedClient(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("NewRateLimitedClient returned nil")
|
||||
t.Fatal("NewAPIClient returned nil")
|
||||
}
|
||||
if client.client == nil {
|
||||
t.Error("client.client is nil")
|
||||
@@ -43,11 +57,21 @@ func TestNewRateLimitedClient(t *testing.T) {
|
||||
if client.buckets == nil {
|
||||
t.Error("client.buckets map is nil")
|
||||
}
|
||||
if client.cfg == nil {
|
||||
t.Error("client.cfg is nil")
|
||||
}
|
||||
if client.trustedHost != "trusted-host.example.com" {
|
||||
t.Errorf("expected trustedHost='trusted-host.example.com', got '%s'", client.trustedHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIClient_Do_Success(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
// Mock server that returns success with rate limit headers
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -93,7 +117,11 @@ func TestAPIClient_Do_Success(t *testing.T) {
|
||||
|
||||
func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
attemptCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -149,7 +177,11 @@ func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) {
|
||||
|
||||
func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
attemptCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -187,7 +219,11 @@ func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) {
|
||||
|
||||
func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap
|
||||
@@ -224,7 +260,11 @@ func TestAPIClient_Do_RateLimitTooLong(t *testing.T) {
|
||||
|
||||
func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return 429 but NO Retry-After header
|
||||
@@ -254,7 +294,11 @@ func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) {
|
||||
|
||||
func TestAPIClient_UpdateRateLimit(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("X-RateLimit-Bucket", "global")
|
||||
@@ -291,7 +335,11 @@ func TestAPIClient_UpdateRateLimit(t *testing.T) {
|
||||
|
||||
func TestAPIClient_WaitIfNeeded(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
// Set up a bucket with 0 remaining and reset in future
|
||||
bucket := "test-bucket"
|
||||
@@ -305,7 +353,7 @@ func TestAPIClient_WaitIfNeeded(t *testing.T) {
|
||||
client.mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
err := client.waitIfNeeded(bucket)
|
||||
err = client.waitIfNeeded(bucket)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
@@ -323,7 +371,11 @@ func TestAPIClient_WaitIfNeeded(t *testing.T) {
|
||||
|
||||
func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
// Set up a bucket with remaining requests
|
||||
bucket := "test-bucket"
|
||||
@@ -337,7 +389,7 @@ func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
|
||||
client.mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
err := client.waitIfNeeded(bucket)
|
||||
err = client.waitIfNeeded(bucket)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
@@ -352,7 +404,11 @@ func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) {
|
||||
|
||||
func TestAPIClient_Do_Concurrent(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
requestCount := 0
|
||||
var mu sync.Mutex
|
||||
@@ -376,24 +432,22 @@ func TestAPIClient_Do_Concurrent(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 10 {
|
||||
wg.Go(
|
||||
func() {
|
||||
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL+"/test", nil)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
@@ -430,7 +484,11 @@ func TestAPIClient_Do_Concurrent(t *testing.T) {
|
||||
|
||||
func TestAPIClient_ParseRetryAfter(t *testing.T) {
|
||||
logger := testLogger(t)
|
||||
client := NewRateLimitedClient(logger)
|
||||
cfg := testConfig()
|
||||
client, err := NewAPIClient(cfg, logger, "trusted-host.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAPIClient returned error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -28,11 +28,9 @@ func Callback(
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// Track callback redirect attempts
|
||||
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
|
||||
|
||||
if exceeded {
|
||||
// Build detailed error for logging
|
||||
err := errors.Errorf(
|
||||
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||
attempts,
|
||||
@@ -42,10 +40,8 @@ func Callback(
|
||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||
)
|
||||
|
||||
// Clear the tracking entry
|
||||
store.ClearRedirectTrack(r, "/callback")
|
||||
|
||||
// Show error page
|
||||
throwError(
|
||||
server,
|
||||
w,
|
||||
@@ -66,23 +62,17 @@ func Callback(
|
||||
}
|
||||
data, err := verifyState(cfg.OAuth, w, r, state)
|
||||
if err != nil {
|
||||
// Check if this is a cookie error (401) or signature error (403)
|
||||
if vsErr, ok := err.(*verifyStateError); ok {
|
||||
if vsErr.IsCookieError() {
|
||||
// Cookie missing/expired - normal failed/expired session (DEBUG)
|
||||
throwUnauthorized(server, w, r, "OAuth session not found or expired", err)
|
||||
} else {
|
||||
// Signature verification failed - security violation (WARN)
|
||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
||||
}
|
||||
} else {
|
||||
// Unknown error type - treat as security issue
|
||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
// SUCCESS POINT: State verified successfully
|
||||
// Clear redirect tracking - OAuth callback completed successfully
|
||||
store.ClearRedirectTrack(r, "/callback")
|
||||
|
||||
switch data {
|
||||
@@ -108,10 +98,9 @@ func Callback(
|
||||
)
|
||||
}
|
||||
|
||||
// verifyStateError wraps an error with context about what went wrong
|
||||
type verifyStateError struct {
|
||||
err error
|
||||
cookieError bool // true if cookie missing/invalid, false if signature invalid
|
||||
cookieError bool
|
||||
}
|
||||
|
||||
func (e *verifyStateError) Error() string {
|
||||
@@ -135,20 +124,16 @@ func verifyState(
|
||||
return "", errors.New("state param field is empty")
|
||||
}
|
||||
|
||||
// Try to get the cookie
|
||||
uak, err := oauth.GetStateCookie(r)
|
||||
if err != nil {
|
||||
// Cookie missing or invalid - this is a 401 (not authenticated)
|
||||
return "", &verifyStateError{
|
||||
err: errors.Wrap(err, "oauth.GetStateCookie"),
|
||||
cookieError: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the state signature
|
||||
data, err := oauth.VerifyState(cfg, state, uak)
|
||||
if err != nil {
|
||||
// Signature verification failed - this is a 403 (security violation)
|
||||
return "", &verifyStateError{
|
||||
err: errors.Wrap(err, "oauth.VerifyState"),
|
||||
cookieError: false,
|
||||
@@ -170,9 +155,9 @@ func login(
|
||||
store *store.Store,
|
||||
discordAPI *discord.APIClient,
|
||||
) (func(), error) {
|
||||
token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost, discordAPI)
|
||||
token, err := discordAPI.AuthorizeWithCode(code)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "discord.AuthorizeWithCode")
|
||||
return nil, errors.Wrap(err, "discordAPI.AuthorizeWithCode")
|
||||
}
|
||||
session, err := discord.NewOAuthSession(token)
|
||||
if err != nil {
|
||||
@@ -204,6 +189,10 @@ func login(
|
||||
})
|
||||
redirect = "/register"
|
||||
} else {
|
||||
err = user.UpdateDiscordToken(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "user.UpdateDiscordToken")
|
||||
}
|
||||
err := auth.Login(w, r, user, true)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "auth.Login")
|
||||
|
||||
45
internal/handlers/isusernameunique.go
Normal file
45
internal/handlers/isusernameunique.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func IsUsernameUnique(
|
||||
server *hws.Server,
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
store *store.Store,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
username := r.FormValue("username")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database transaction failed", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
unique, err := db.IsUsernameUnique(ctx, tx, username)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database query failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
if !unique {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -17,11 +17,9 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
|
||||
// Track login redirect attempts
|
||||
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
|
||||
|
||||
if exceeded {
|
||||
// Build detailed error for logging
|
||||
err := errors.Errorf(
|
||||
"login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||
attempts,
|
||||
@@ -31,10 +29,8 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *
|
||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||
)
|
||||
|
||||
// Clear the tracking entry
|
||||
st.ClearRedirectTrack(r, "/login")
|
||||
|
||||
// Show error page
|
||||
throwError(
|
||||
server,
|
||||
w,
|
||||
@@ -54,14 +50,11 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *
|
||||
}
|
||||
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
|
||||
|
||||
link, err := discord.GetOAuthLink(cfg.Discord, state, cfg.HWSAuth.TrustedHost)
|
||||
link, err := discordAPI.GetOAuthLink(state)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err)
|
||||
return
|
||||
}
|
||||
|
||||
// SUCCESS POINT: OAuth link generated, redirecting to Discord
|
||||
// Clear redirect tracking - user successfully initiated OAuth
|
||||
st.ClearRedirectTrack(r, "/login")
|
||||
|
||||
http.Redirect(w, r, link, http.StatusSeeOther)
|
||||
|
||||
59
internal/handlers/logout.go
Normal file
59
internal/handlers/logout.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func Logout(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||
conn *bun.DB,
|
||||
discordAPI *discord.APIClient,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
user := db.CurrentUser(r.Context())
|
||||
if user == nil {
|
||||
// JIC - should be impossible to get here if route is protected by LoginReq
|
||||
w.Header().Set("HX-Redirect", "/")
|
||||
return
|
||||
}
|
||||
token, err := user.DeleteDiscordTokens(ctx, tx)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
|
||||
return
|
||||
}
|
||||
err = discordAPI.RevokeToken(token.Convert())
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
|
||||
return
|
||||
}
|
||||
err = auth.Logout(tx, w, r)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Logout failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.Header().Set("HX-Redirect", "/")
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -102,38 +102,6 @@ func Register(
|
||||
)
|
||||
}
|
||||
|
||||
func IsUsernameUnique(
|
||||
server *hws.Server,
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
store *store.Store,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
username := r.FormValue("username")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database transaction failed", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
unique, err := db.IsUsernameUnique(ctx, tx, username)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database query failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
if !unique {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func registerUser(
|
||||
ctx context.Context,
|
||||
tx bun.Tx,
|
||||
@@ -151,7 +119,7 @@ func registerUser(
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.CreateUser")
|
||||
}
|
||||
err = db.UpdateDiscordToken(ctx, tx, user, details.Token)
|
||||
err = user.UpdateDiscordToken(ctx, tx, details.Token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.UpdateDiscordToken")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user