From 9497c17c30fa0b08b96556eb9ca9dce9e23bcc8c Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Sat, 24 Jan 2026 15:23:28 +1100 Subject: [PATCH] added logout --- cmd/oslstats/routes.go | 8 +- cmd/oslstats/run.go | 13 +-- go.mod | 4 +- go.sum | 10 ++- internal/db/discord_tokens.go | 54 +++++++++++- internal/discord/api.go | 35 ++++++++ internal/discord/oauth.go | 83 +++++------------- internal/discord/ratelimit.go | 19 ---- internal/discord/ratelimit_test.go | 120 +++++++++++++++++++------- internal/handlers/callback.go | 25 ++---- internal/handlers/isusernameunique.go | 45 ++++++++++ internal/handlers/login.go | 9 +- internal/handlers/logout.go | 59 +++++++++++++ internal/handlers/register.go | 34 +------- 14 files changed, 327 insertions(+), 191 deletions(-) create mode 100644 internal/handlers/isusernameunique.go create mode 100644 internal/handlers/logout.go diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index e2942c5..c19f6d0 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -48,13 +48,13 @@ func addRoutes( }, { Path: "/register", - Method: hws.MethodGET, + Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, Handler: auth.LogoutReq(handlers.Register(server, auth, conn, cfg, store)), }, { - Path: "/register", - Method: hws.MethodPOST, - Handler: auth.LogoutReq(handlers.Register(server, auth, conn, cfg, store)), + Path: "/logout", + Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, + Handler: auth.LoginReq(handlers.Logout(server, auth, conn, discordAPI)), }, } diff --git a/cmd/oslstats/run.go b/cmd/oslstats/run.go index 32395e8..2d6f34f 100644 --- a/cmd/oslstats/run.go +++ b/cmd/oslstats/run.go @@ -18,12 +18,12 @@ import ( ) // Initializes and runs the server -func run(ctx context.Context, w io.Writer, config *config.Config) error { +func run(ctx context.Context, w io.Writer, cfg *config.Config) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() // Setup the logger - logger, err := hlog.NewLogger(config.HLOG, w) + logger, err := hlog.NewLogger(cfg.HLOG, w) if err != nil { return errors.Wrap(err, "hlog.NewLogger") } @@ -31,7 +31,7 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error { // Setup the database connection logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Connecting to database") - bun, closedb, err := setupBun(ctx, config) + bun, closedb, err := setupBun(ctx, cfg) if err != nil { return errors.Wrap(err, "setupDBConn") } @@ -50,10 +50,13 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error { // Setup Discord API client logger.Debug().Msg("Setting up Discord API client") - discordAPI := discord.NewRateLimitedClient(logger) + discordAPI, err := discord.NewAPIClient(cfg.Discord, logger, cfg.HWSAuth.TrustedHost) + if err != nil { + return errors.Wrap(err, "discord.NewAPIClient") + } logger.Debug().Msg("Setting up HTTP server") - httpServer, err := setupHttpServer(&staticFS, config, logger, bun, store, discordAPI) + httpServer, err := setupHttpServer(&staticFS, cfg, logger, bun, store, discordAPI) if err != nil { return errors.Wrap(err, "setupHttpServer") } diff --git a/go.mod b/go.mod index 385f76a..e909fb8 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,8 @@ require ( git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/hlog v0.10.4 - git.haelnorr.com/h/golib/hws v0.3.0 - git.haelnorr.com/h/golib/hwsauth v0.5.0 + git.haelnorr.com/h/golib/hws v0.3.1 + git.haelnorr.com/h/golib/hwsauth v0.5.2 github.com/a-h/templ v0.3.977 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 795badc..58d5675 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,12 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= -git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo= -git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= -git.haelnorr.com/h/golib/hwsauth v0.5.0 h1:RAr7cdMe2aden50n7d9m5R4josZZ8ikNfWGMAEGnJbo= -git.haelnorr.com/h/golib/hwsauth v0.5.0/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ= +git.haelnorr.com/h/golib/hws v0.3.1 h1:uFXAT8SuKs4VACBdrkmZ+dJjeBlSPgCKUPt8zGCcwrI= +git.haelnorr.com/h/golib/hws v0.3.1/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= +git.haelnorr.com/h/golib/hwsauth v0.5.1 h1:U7rLPWLjPvggcY0Ez8VVIgcueLLHLLUV69OjYv/QepQ= +git.haelnorr.com/h/golib/hwsauth v0.5.1/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ= +git.haelnorr.com/h/golib/hwsauth v0.5.2 h1:K4McXMEHtI5o4fAL3AZrmaMkwORNqSTV3MM6BExNKag= +git.haelnorr.com/h/golib/hwsauth v0.5.2/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= diff --git a/internal/db/discord_tokens.go b/internal/db/discord_tokens.go index ae3f6fc..7cc09d9 100644 --- a/internal/db/discord_tokens.go +++ b/internal/db/discord_tokens.go @@ -16,12 +16,13 @@ type DiscordToken struct { AccessToken string `bun:"access_token,notnull"` RefreshToken string `bun:"refresh_token,notnull"` ExpiresAt int64 `bun:"expires_at,notnull"` + Scope string `bun:"scope,notnull"` + TokenType string `bun:"token_type,notnull"` } -func UpdateDiscordToken(ctx context.Context, tx bun.Tx, user *User, token *discord.Token) error { - if user == nil { - return errors.New("user cannot be nil") - } +// UpdateDiscordToken adds the provided discord token to the database. +// If the user already has a token stored, it will replace that token instead. +func (user *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error { if token == nil { return errors.New("token cannot be nil") } @@ -32,6 +33,8 @@ func UpdateDiscordToken(ctx context.Context, tx bun.Tx, user *User, token *disco AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, ExpiresAt: expiresAt, + Scope: token.Scope, + TokenType: token.TokenType, } _, err := tx.NewInsert(). @@ -47,3 +50,46 @@ func UpdateDiscordToken(ctx context.Context, tx bun.Tx, user *User, token *disco } return nil } + +// DeleteDiscordTokens deletes a users discord OAuth tokens from the database. +// It returns the DiscordToken so that it can be revoked via the discord API +func (user *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { + token, err := user.GetDiscordToken(ctx, tx) + if err != nil { + return nil, errors.Wrap(err, "user.GetDiscordToken") + } + _, err = tx.NewDelete(). + Model((*DiscordToken)(nil)). + Where("discord_id = ?", user.DiscordID). + Exec(ctx) + if err != nil { + return nil, errors.Wrap(err, "tx.NewDelete") + } + return token, nil +} + +// GetDiscordToken retrieves the users discord token from the database +func (user *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { + token := new(DiscordToken) + err := tx.NewSelect(). + Model(token). + Where("discord_id = ?", user.DiscordID). + Limit(1). + Scan(ctx) + if err != nil { + return nil, errors.Wrap(err, "tx.NewSelect") + } + return token, nil +} + +// Convert reverts the token back into a *discord.Token +func (t *DiscordToken) Convert() *discord.Token { + token := &discord.Token{ + AccessToken: t.AccessToken, + RefreshToken: t.RefreshToken, + ExpiresIn: int(t.ExpiresAt - time.Now().Unix()), + Scope: t.Scope, + TokenType: t.TokenType, + } + return token +} diff --git a/internal/discord/api.go b/internal/discord/api.go index 7c41257..aa490b5 100644 --- a/internal/discord/api.go +++ b/internal/discord/api.go @@ -1,6 +1,11 @@ package discord import ( + "net/http" + "sync" + "time" + + "git.haelnorr.com/h/golib/hlog" "github.com/bwmarrin/discordgo" "github.com/pkg/errors" ) @@ -24,3 +29,33 @@ func (s *OAuthSession) GetUser() (*discordgo.User, error) { } return user, nil } + +// APIClient is an HTTP client wrapper that handles Discord API rate limits +type APIClient struct { + cfg *Config + client *http.Client + logger *hlog.Logger + mu sync.RWMutex + buckets map[string]*RateLimitState + trustedHost string +} + +// NewAPIClient creates a new Discord API client with rate limit handling +func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APIClient, error) { + if cfg == nil { + return nil, errors.New("config cannot be nil") + } + if logger == nil { + return nil, errors.New("logger cannot be nil") + } + if trustedhost == "" { + return nil, errors.New("trustedhost cannot be empty") + } + return &APIClient{ + client: &http.Client{Timeout: 30 * time.Second}, + logger: logger, + buckets: make(map[string]*RateLimitState), + cfg: cfg, + trustedHost: trustedhost, + }, nil +} diff --git a/internal/discord/oauth.go b/internal/discord/oauth.go index 780a71b..faff919 100644 --- a/internal/discord/oauth.go +++ b/internal/discord/oauth.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" ) +// Token represents a response from the Discord OAuth API after a successful authorization request type Token struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` @@ -22,46 +23,32 @@ type Token struct { const oauthurl string = "https://discord.com/oauth2/authorize" const apiurl string = "https://discord.com/api/v10" -func GetOAuthLink(cfg *Config, state, trustedHost string) (string, error) { - if cfg == nil { - return "", errors.New("cfg cannot be nil") - } +// GetOAuthLink generates a new Discord OAuth2 link for user authentication +func (api *APIClient) GetOAuthLink(state string) (string, error) { if state == "" { return "", errors.New("state cannot be empty") } - if trustedHost == "" { - return "", errors.New("trustedHost cannot be empty") - } values := url.Values{} values.Add("response_type", "code") - values.Add("client_id", cfg.ClientID) - values.Add("scope", cfg.OAuthScopes) + values.Add("client_id", api.cfg.ClientID) + values.Add("scope", api.cfg.OAuthScopes) values.Add("state", state) - values.Add("redirect_uri", fmt.Sprintf("%s/%s", trustedHost, cfg.RedirectPath)) + values.Add("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath)) values.Add("prompt", "none") return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil } -func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClient) (*Token, error) { +// AuthorizeWithCode uses a users authorization token generated by OAuth2 to get a token for +// making requests to the API on behalf of the user +func (api *APIClient) AuthorizeWithCode(code string) (*Token, error) { if code == "" { return nil, errors.New("code cannot be empty") } - if cfg == nil { - return nil, errors.New("config cannot be nil") - } - if trustedHost == "" { - return nil, errors.New("trustedHost cannot be empty") - } - if apiClient == nil { - return nil, errors.New("apiClient cannot be nil") - } - // Prepare form data data := url.Values{} data.Set("grant_type", "authorization_code") data.Set("code", code) - data.Set("redirect_uri", fmt.Sprintf("%s/%s", trustedHost, cfg.RedirectPath)) - // Create request + data.Set("redirect_uri", fmt.Sprintf("%s/%s", api.trustedHost, api.cfg.RedirectPath)) req, err := http.NewRequest( "POST", apiurl+"/oauth2/token", @@ -70,27 +57,21 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClie if err != nil { return nil, errors.Wrap(err, "failed to create request") } - // Set headers req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - // Set basic auth (client_id and client_secret) - req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) - // Execute request with rate limit handling - resp, err := apiClient.Do(req) + req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret) + resp, err := api.Do(req) if err != nil { return nil, errors.Wrap(err, "failed to execute request") } defer resp.Body.Close() - // Read response body body, err := io.ReadAll(resp.Body) if err != nil { return nil, errors.Wrap(err, "failed to read response body") } - // Check status code if resp.StatusCode != http.StatusOK { return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body)) } - // Parse JSON response var tokenResp Token if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, errors.Wrap(err, "failed to parse token response") @@ -98,21 +79,14 @@ func AuthorizeWithCode(cfg *Config, code, trustedHost string, apiClient *APIClie return &tokenResp, nil } -func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, error) { +// RefreshToken uses the refresh token to generate a new token pair +func (api *APIClient) RefreshToken(token *Token) (*Token, error) { if token == nil { return nil, errors.New("token cannot be nil") } - if cfg == nil { - return nil, errors.New("config cannot be nil") - } - if apiClient == nil { - return nil, errors.New("apiClient cannot be nil") - } - // Prepare form data data := url.Values{} data.Set("grant_type", "refresh_token") data.Set("refresh_token", token.RefreshToken) - // Create request req, err := http.NewRequest( "POST", apiurl+"/oauth2/token", @@ -121,27 +95,21 @@ func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, erro if err != nil { return nil, errors.Wrap(err, "failed to create request") } - // Set headers req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - // Set basic auth (client_id and client_secret) - req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) - // Execute request with rate limit handling - resp, err := apiClient.Do(req) + req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret) + resp, err := api.Do(req) if err != nil { return nil, errors.Wrap(err, "failed to execute request") } defer resp.Body.Close() - // Read response body body, err := io.ReadAll(resp.Body) if err != nil { return nil, errors.Wrap(err, "failed to read response body") } - // Check status code if resp.StatusCode != http.StatusOK { return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body)) } - // Parse JSON response var tokenResp Token if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, errors.Wrap(err, "failed to parse token response") @@ -149,21 +117,14 @@ func RefreshToken(cfg *Config, token *Token, apiClient *APIClient) (*Token, erro return &tokenResp, nil } -func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error { +// RevokeToken sends a request to the Discord API to revoke the token pair +func (api *APIClient) RevokeToken(token *Token) error { if token == nil { return errors.New("token cannot be nil") } - if cfg == nil { - return errors.New("config cannot be nil") - } - if apiClient == nil { - return errors.New("apiClient cannot be nil") - } - // Prepare form data data := url.Values{} data.Set("token", token.AccessToken) data.Set("token_type_hint", "access_token") - // Create request req, err := http.NewRequest( "POST", apiurl+"/oauth2/token/revoke", @@ -172,18 +133,14 @@ func RevokeToken(cfg *Config, token *Token, apiClient *APIClient) error { if err != nil { return errors.Wrap(err, "failed to create request") } - // Set headers req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - // Set basic auth (client_id and client_secret) - req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) - // Execute request with rate limit handling - resp, err := apiClient.Do(req) + req.SetBasicAuth(api.cfg.ClientID, api.cfg.ClientSecret) + resp, err := api.Do(req) if err != nil { return errors.Wrap(err, "failed to execute request") } defer resp.Body.Close() - // Check status code if resp.StatusCode != http.StatusOK { return errors.Errorf("discord API returned status %d", resp.StatusCode) } diff --git a/internal/discord/ratelimit.go b/internal/discord/ratelimit.go index 4cde26d..21f3888 100644 --- a/internal/discord/ratelimit.go +++ b/internal/discord/ratelimit.go @@ -4,10 +4,8 @@ import ( "net" "net/http" "strconv" - "sync" "time" - "git.haelnorr.com/h/golib/hlog" "github.com/pkg/errors" ) @@ -19,23 +17,6 @@ type RateLimitState struct { Bucket string // Discord's bucket identifier } -// APIClient is an HTTP client wrapper that handles Discord API rate limits -type APIClient struct { - client *http.Client - logger *hlog.Logger - mu sync.RWMutex - buckets map[string]*RateLimitState -} - -// NewRateLimitedClient creates a new Discord API client with rate limit handling -func NewRateLimitedClient(logger *hlog.Logger) *APIClient { - return &APIClient{ - client: &http.Client{Timeout: 30 * time.Second}, - logger: logger, - buckets: make(map[string]*RateLimitState), - } -} - // Do executes an HTTP request with automatic rate limit handling // It will wait if rate limits are about to be exceeded and retry once if a 429 is received func (c *APIClient) Do(req *http.Request) (*http.Response, error) { diff --git a/internal/discord/ratelimit_test.go b/internal/discord/ratelimit_test.go index 33f6b57..38dfb04 100644 --- a/internal/discord/ratelimit_test.go +++ b/internal/discord/ratelimit_test.go @@ -27,12 +27,26 @@ func testLogger(t *testing.T) *hlog.Logger { return logger } +// testConfig creates a test config for testing +func testConfig() *Config { + return &Config{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + OAuthScopes: "identify+email", + RedirectPath: "/oauth/callback", + } +} + func TestNewRateLimitedClient(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } if client == nil { - t.Fatal("NewRateLimitedClient returned nil") + t.Fatal("NewAPIClient returned nil") } if client.client == nil { t.Error("client.client is nil") @@ -43,11 +57,21 @@ func TestNewRateLimitedClient(t *testing.T) { if client.buckets == nil { t.Error("client.buckets map is nil") } + if client.cfg == nil { + t.Error("client.cfg is nil") + } + if client.trustedHost != "trusted-host.example.com" { + t.Errorf("expected trustedHost='trusted-host.example.com', got '%s'", client.trustedHost) + } } func TestAPIClient_Do_Success(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } // Mock server that returns success with rate limit headers server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -93,7 +117,11 @@ func TestAPIClient_Do_Success(t *testing.T) { func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } attemptCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -149,7 +177,11 @@ func TestAPIClient_Do_RateLimitRetrySuccess(t *testing.T) { func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } attemptCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -187,7 +219,11 @@ func TestAPIClient_Do_RateLimitRetryFails(t *testing.T) { func TestAPIClient_Do_RateLimitTooLong(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Retry-After", "60") // 60 seconds > 30s cap @@ -224,7 +260,11 @@ func TestAPIClient_Do_RateLimitTooLong(t *testing.T) { func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Return 429 but NO Retry-After header @@ -254,7 +294,11 @@ func TestAPIClient_Do_NoRetryAfterHeader(t *testing.T) { func TestAPIClient_UpdateRateLimit(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } headers := http.Header{} headers.Set("X-RateLimit-Bucket", "global") @@ -291,7 +335,11 @@ func TestAPIClient_UpdateRateLimit(t *testing.T) { func TestAPIClient_WaitIfNeeded(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } // Set up a bucket with 0 remaining and reset in future bucket := "test-bucket" @@ -305,7 +353,7 @@ func TestAPIClient_WaitIfNeeded(t *testing.T) { client.mu.Unlock() start := time.Now() - err := client.waitIfNeeded(bucket) + err = client.waitIfNeeded(bucket) elapsed := time.Since(start) if err != nil { @@ -323,7 +371,11 @@ func TestAPIClient_WaitIfNeeded(t *testing.T) { func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } // Set up a bucket with remaining requests bucket := "test-bucket" @@ -337,7 +389,7 @@ func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) { client.mu.Unlock() start := time.Now() - err := client.waitIfNeeded(bucket) + err = client.waitIfNeeded(bucket) elapsed := time.Since(start) if err != nil { @@ -352,7 +404,11 @@ func TestAPIClient_WaitIfNeeded_NoWait(t *testing.T) { func TestAPIClient_Do_Concurrent(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } requestCount := 0 var mu sync.Mutex @@ -376,24 +432,22 @@ func TestAPIClient_Do_Concurrent(t *testing.T) { var wg sync.WaitGroup errors := make(chan error, 10) - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() + for range 10 { + wg.Go( + func() { + req, err := http.NewRequest("GET", server.URL+"/test", nil) + if err != nil { + errors <- err + return + } - req, err := http.NewRequest("GET", server.URL+"/test", nil) - if err != nil { - errors <- err - return - } - - resp, err := client.Do(req) - if err != nil { - errors <- err - return - } - resp.Body.Close() - }() + resp, err := client.Do(req) + if err != nil { + errors <- err + return + } + resp.Body.Close() + }) } wg.Wait() @@ -430,7 +484,11 @@ func TestAPIClient_Do_Concurrent(t *testing.T) { func TestAPIClient_ParseRetryAfter(t *testing.T) { logger := testLogger(t) - client := NewRateLimitedClient(logger) + cfg := testConfig() + client, err := NewAPIClient(cfg, logger, "trusted-host.example.com") + if err != nil { + t.Fatalf("NewAPIClient returned error: %v", err) + } tests := []struct { name string diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go index 9093352..12082a6 100644 --- a/internal/handlers/callback.go +++ b/internal/handlers/callback.go @@ -28,11 +28,9 @@ func Callback( ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - // Track callback redirect attempts attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5) if exceeded { - // Build detailed error for logging err := errors.Errorf( "callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", attempts, @@ -42,10 +40,8 @@ func Callback( track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), ) - // Clear the tracking entry store.ClearRedirectTrack(r, "/callback") - // Show error page throwError( server, w, @@ -66,23 +62,17 @@ func Callback( } data, err := verifyState(cfg.OAuth, w, r, state) if err != nil { - // Check if this is a cookie error (401) or signature error (403) if vsErr, ok := err.(*verifyStateError); ok { if vsErr.IsCookieError() { - // Cookie missing/expired - normal failed/expired session (DEBUG) throwUnauthorized(server, w, r, "OAuth session not found or expired", err) } else { - // Signature verification failed - security violation (WARN) throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) } } else { - // Unknown error type - treat as security issue throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) } return } - // SUCCESS POINT: State verified successfully - // Clear redirect tracking - OAuth callback completed successfully store.ClearRedirectTrack(r, "/callback") switch data { @@ -108,10 +98,9 @@ func Callback( ) } -// verifyStateError wraps an error with context about what went wrong type verifyStateError struct { err error - cookieError bool // true if cookie missing/invalid, false if signature invalid + cookieError bool } func (e *verifyStateError) Error() string { @@ -135,20 +124,16 @@ func verifyState( return "", errors.New("state param field is empty") } - // Try to get the cookie uak, err := oauth.GetStateCookie(r) if err != nil { - // Cookie missing or invalid - this is a 401 (not authenticated) return "", &verifyStateError{ err: errors.Wrap(err, "oauth.GetStateCookie"), cookieError: true, } } - // Verify the state signature data, err := oauth.VerifyState(cfg, state, uak) if err != nil { - // Signature verification failed - this is a 403 (security violation) return "", &verifyStateError{ err: errors.Wrap(err, "oauth.VerifyState"), cookieError: false, @@ -170,9 +155,9 @@ func login( store *store.Store, discordAPI *discord.APIClient, ) (func(), error) { - token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost, discordAPI) + token, err := discordAPI.AuthorizeWithCode(code) if err != nil { - return nil, errors.Wrap(err, "discord.AuthorizeWithCode") + return nil, errors.Wrap(err, "discordAPI.AuthorizeWithCode") } session, err := discord.NewOAuthSession(token) if err != nil { @@ -204,6 +189,10 @@ func login( }) redirect = "/register" } else { + err = user.UpdateDiscordToken(ctx, tx, token) + if err != nil { + return nil, errors.Wrap(err, "user.UpdateDiscordToken") + } err := auth.Login(w, r, user, true) if err != nil { return nil, errors.Wrap(err, "auth.Login") diff --git a/internal/handlers/isusernameunique.go b/internal/handlers/isusernameunique.go new file mode 100644 index 0000000..f441784 --- /dev/null +++ b/internal/handlers/isusernameunique.go @@ -0,0 +1,45 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/store" + "github.com/uptrace/bun" +) + +func IsUsernameUnique( + server *hws.Server, + conn *bun.DB, + cfg *config.Config, + store *store.Store, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + username := r.FormValue("username") + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database transaction failed", err) + return + } + defer tx.Rollback() + unique, err := db.IsUsernameUnique(ctx, tx, username) + if err != nil { + throwInternalServiceError(server, w, r, "Database query failed", err) + return + } + tx.Commit() + if !unique { + w.WriteHeader(http.StatusConflict) + } else { + w.WriteHeader(http.StatusOK) + } + }, + ) +} diff --git a/internal/handlers/login.go b/internal/handlers/login.go index db44584..93bd6c8 100644 --- a/internal/handlers/login.go +++ b/internal/handlers/login.go @@ -17,11 +17,9 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI * return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost) - // Track login redirect attempts attempts, exceeded, track := st.TrackRedirect(r, "/login", 5) if exceeded { - // Build detailed error for logging err := errors.Errorf( "login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", attempts, @@ -31,10 +29,8 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI * track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), ) - // Clear the tracking entry st.ClearRedirectTrack(r, "/login") - // Show error page throwError( server, w, @@ -54,14 +50,11 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI * } oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL) - link, err := discord.GetOAuthLink(cfg.Discord, state, cfg.HWSAuth.TrustedHost) + link, err := discordAPI.GetOAuthLink(state) if err != nil { throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err) return } - - // SUCCESS POINT: OAuth link generated, redirecting to Discord - // Clear redirect tracking - user successfully initiated OAuth st.ClearRedirectTrack(r, "/login") http.Redirect(w, r, link, http.StatusSeeOther) diff --git a/internal/handlers/logout.go b/internal/handlers/logout.go new file mode 100644 index 0000000..02b9a9d --- /dev/null +++ b/internal/handlers/logout.go @@ -0,0 +1,59 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/hwsauth" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +func Logout( + server *hws.Server, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + conn *bun.DB, + discordAPI *discord.APIClient, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) + return + } + defer tx.Rollback() + + user := db.CurrentUser(r.Context()) + if user == nil { + // JIC - should be impossible to get here if route is protected by LoginReq + w.Header().Set("HX-Redirect", "/") + return + } + token, err := user.DeleteDiscordTokens(ctx, tx) + if err != nil { + throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens")) + return + } + err = discordAPI.RevokeToken(token.Convert()) + if err != nil { + throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken")) + return + } + err = auth.Logout(tx, w, r) + if err != nil { + throwInternalServiceError(server, w, r, "Logout failed", err) + return + } + tx.Commit() + w.Header().Set("HX-Redirect", "/") + }, + ) +} diff --git a/internal/handlers/register.go b/internal/handlers/register.go index bfeeccd..c8aa694 100644 --- a/internal/handlers/register.go +++ b/internal/handlers/register.go @@ -102,38 +102,6 @@ func Register( ) } -func IsUsernameUnique( - server *hws.Server, - conn *bun.DB, - cfg *config.Config, - store *store.Store, -) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - username := r.FormValue("username") - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) - defer cancel() - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(server, w, r, "Database transaction failed", err) - return - } - defer tx.Rollback() - unique, err := db.IsUsernameUnique(ctx, tx, username) - if err != nil { - throwInternalServiceError(server, w, r, "Database query failed", err) - return - } - tx.Commit() - if !unique { - w.WriteHeader(http.StatusConflict) - } else { - w.WriteHeader(http.StatusOK) - } - }, - ) -} - func registerUser( ctx context.Context, tx bun.Tx, @@ -151,7 +119,7 @@ func registerUser( if err != nil { return nil, errors.Wrap(err, "db.CreateUser") } - err = db.UpdateDiscordToken(ctx, tx, user, details.Token) + err = user.UpdateDiscordToken(ctx, tx, details.Token) if err != nil { return nil, errors.Wrap(err, "db.UpdateDiscordToken") }