From c14c5d43eeab954ca71014c488f494dbebaa165b Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Thu, 22 Jan 2026 19:52:43 +1100 Subject: [PATCH] added oauth flow to get authorization code --- cmd/oslstats/httpserver.go | 2 +- cmd/oslstats/routes.go | 14 +- internal/config/config.go | 18 + internal/db/ezconf.go | 2 +- internal/discord/config.go | 50 ++ internal/discord/ezconf.go | 41 ++ internal/discord/oauth.go | 39 ++ internal/handlers/callback.go | 61 +++ internal/handlers/login.go | 48 ++ pkg/embedfs/files/css/output.css | 27 - pkg/oauth/config.go | 23 + pkg/oauth/cookies.go | 45 ++ pkg/oauth/ezconf.go | 41 ++ pkg/oauth/state.go | 117 +++++ pkg/oauth/state_test.go | 817 +++++++++++++++++++++++++++++++ 15 files changed, 1313 insertions(+), 32 deletions(-) create mode 100644 internal/discord/config.go create mode 100644 internal/discord/ezconf.go create mode 100644 internal/discord/oauth.go create mode 100644 internal/handlers/callback.go create mode 100644 internal/handlers/login.go create mode 100644 pkg/oauth/config.go create mode 100644 pkg/oauth/cookies.go create mode 100644 pkg/oauth/ezconf.go create mode 100644 pkg/oauth/state.go create mode 100644 pkg/oauth/state_test.go diff --git a/cmd/oslstats/httpserver.go b/cmd/oslstats/httpserver.go index fa0461d..9a50c78 100644 --- a/cmd/oslstats/httpserver.go +++ b/cmd/oslstats/httpserver.go @@ -53,7 +53,7 @@ func setupHttpServer( return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths") } - err = addRoutes(httpServer, &fs, config, logger, bun, auth) + err = addRoutes(httpServer, &fs, config, bun, auth) if err != nil { return nil, errors.Wrap(err, "addRoutes") } diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index efea343..993e135 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -9,7 +9,6 @@ import ( "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/handlers" - "git.haelnorr.com/h/golib/hlog" "github.com/pkg/errors" "github.com/uptrace/bun" ) @@ -17,8 +16,7 @@ import ( func addRoutes( server *hws.Server, staticFS *http.FileSystem, - config *config.Config, - logger *hlog.Logger, + cfg *config.Config, conn *bun.DB, auth *hwsauth.Authenticator[*db.User, bun.Tx], ) error { @@ -34,6 +32,16 @@ func addRoutes( Method: hws.MethodGET, Handler: handlers.Index(server), }, + { + Path: "/login", + Method: hws.MethodGET, + Handler: auth.LogoutReq(handlers.Login(server, cfg)), + }, + { + Path: "/auth/callback", + Method: hws.MethodGET, + Handler: auth.LogoutReq(handlers.Callback(server, cfg)), + }, } // Register the routes with the server diff --git a/internal/config/config.go b/internal/config/config.go index 87b9f15..112752a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,8 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/pkg/oauth" "github.com/joho/godotenv" "github.com/pkg/errors" ) @@ -15,6 +17,8 @@ type Config struct { HWS *hws.Config HWSAuth *hwsauth.Config HLOG *hlog.Config + Discord *discord.Config + OAuth *oauth.Config Flags *Flags } @@ -32,6 +36,8 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { hws.NewEZConfIntegration(), hwsauth.NewEZConfIntegration(), db.NewEZConfIntegration(), + discord.NewEZConfIntegration(), + oauth.NewEZConfIntegration(), ) if err := loader.ParseEnvVars(); err != nil { return nil, nil, errors.Wrap(err, "loader.ParseEnvVars") @@ -65,11 +71,23 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) { return nil, nil, errors.New("DB Config not loaded") } + discordcfg, ok := loader.GetConfig("discord") + if !ok { + return nil, nil, errors.New("Dicord Config not loaded") + } + + oauthcfg, ok := loader.GetConfig("oauth") + if !ok { + return nil, nil, errors.New("OAuth Config not loaded") + } + config := &Config{ DB: dbcfg.(*db.Config), HWS: hwscfg.(*hws.Config), HWSAuth: hwsauthcfg.(*hwsauth.Config), HLOG: hlogcfg.(*hlog.Config), + Discord: discordcfg.(*discord.Config), + OAuth: oauthcfg.(*oauth.Config), Flags: flags, } diff --git a/internal/db/ezconf.go b/internal/db/ezconf.go index db4eac9..1cb08c5 100644 --- a/internal/db/ezconf.go +++ b/internal/db/ezconf.go @@ -37,5 +37,5 @@ func (e EZConfIntegration) GroupName() string { // NewEZConfIntegration creates a new EZConf integration helper func NewEZConfIntegration() EZConfIntegration { - return EZConfIntegration{name: "db", configFunc: ConfigFromEnv} + return EZConfIntegration{name: "DB", configFunc: ConfigFromEnv} } diff --git a/internal/discord/config.go b/internal/discord/config.go new file mode 100644 index 0000000..cc086ff --- /dev/null +++ b/internal/discord/config.go @@ -0,0 +1,50 @@ +package discord + +import ( + "strings" + + "git.haelnorr.com/h/golib/env" + "github.com/pkg/errors" +) + +type Config struct { + ClientID string // ENV DISCORD_CLIENT_ID: Discord application client ID (required) + ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required) + OAuthScopes string // Authorisation scopes for OAuth + RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required) +} + +func ConfigFromEnv() (any, error) { + cfg := &Config{ + ClientID: env.String("DISCORD_CLIENT_ID", ""), + ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""), + OAuthScopes: getOAuthScopes(), + RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""), + } + + // Check required fields + if cfg.ClientID == "" { + return nil, errors.New("Envar not set: DISCORD_CLIENT_ID") + } + if cfg.ClientSecret == "" { + return nil, errors.New("Envar not set: DISCORD_CLIENT_SECRET") + } + if cfg.RedirectPath == "" { + return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH") + } + + return cfg, nil +} + +func getOAuthScopes() string { + list := []string{ + "connections", + "email", + "guilds", + "gdm.join", + "guilds.members.read", + "identify", + } + scopes := strings.Join(list, "+") + return scopes +} diff --git a/internal/discord/ezconf.go b/internal/discord/ezconf.go new file mode 100644 index 0000000..8442714 --- /dev/null +++ b/internal/discord/ezconf.go @@ -0,0 +1,41 @@ +package discord + +import ( + "runtime" + "strings" +) + +// EZConfIntegration provides integration with ezconf for automatic configuration +type EZConfIntegration struct { + configFunc func() (any, error) + name string +} + +// PackagePath returns the path to the config package for source parsing +func (e EZConfIntegration) PackagePath() string { + _, filename, _, _ := runtime.Caller(0) + // Return directory of this file + return filename[:len(filename)-len("/ezconf.go")] +} + +// ConfigFunc returns the ConfigFromEnv function for ezconf +func (e EZConfIntegration) ConfigFunc() func() (any, error) { + return func() (any, error) { + return e.configFunc() + } +} + +// Name returns the name to use when registering with ezconf +func (e EZConfIntegration) Name() string { + return strings.ToLower(e.name) +} + +// GroupName returns the display name for grouping environment variables +func (e EZConfIntegration) GroupName() string { + return e.name +} + +// NewEZConfIntegration creates a new EZConf integration helper +func NewEZConfIntegration() EZConfIntegration { + return EZConfIntegration{name: "Discord", configFunc: ConfigFromEnv} +} diff --git a/internal/discord/oauth.go b/internal/discord/oauth.go new file mode 100644 index 0000000..591d0f9 --- /dev/null +++ b/internal/discord/oauth.go @@ -0,0 +1,39 @@ +package discord + +import ( + "fmt" + "net/url" + + "github.com/pkg/errors" +) + +type Token struct { + AccessToken string + TokenType string + ExpiresIn int + RefreshToken string + Scope string +} + +const oauthurl string = "https://discord.com/oauth2/authorize" + +func GetOAuthLink(cfg *Config, state string, trustedHost string) (string, error) { + if cfg == nil { + return "", errors.New("cfg cannot be nil") + } + if state == "" { + return "", errors.New("state cannot be empty") + } + if trustedHost == "" { + return "", errors.New("trustedHost cannot be empty") + } + values := url.Values{} + values.Add("response_type", "code") + values.Add("client_id", cfg.ClientID) + values.Add("scope", cfg.OAuthScopes) + values.Add("state", state) + values.Add("redirect_uri", fmt.Sprintf("%s/%s", trustedHost, cfg.RedirectPath)) + values.Add("prompt", "none") + + return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil +} diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go new file mode 100644 index 0000000..9d985c7 --- /dev/null +++ b/internal/handlers/callback.go @@ -0,0 +1,61 @@ +package handlers + +import ( + "net/http" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/pkg/oauth" + "github.com/pkg/errors" +) + +func Callback(server *hws.Server, cfg *config.Config) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + state := r.URL.Query().Get("state") + code := r.URL.Query().Get("code") + if state == "" && code == "" { + http.Redirect(w, r, "/", http.StatusBadRequest) + return + } + data, err := verifyState(cfg.OAuth, w, r, state) + if err != nil { + err = server.ThrowError(w, r, hws.HWSError{ + StatusCode: http.StatusForbidden, + Message: "OAuth state verification failed", + Error: err, + Level: hws.ErrorLevel("debug"), + RenderErrorPage: true, + }) + if err != nil { + server.ThrowFatal(w, err) + } + return + } + switch data { + case "login": + w.Write([]byte(code)) + return + } + }, + ) +} + +func verifyState(cfg *oauth.Config, w http.ResponseWriter, r *http.Request, state string) (string, error) { + if r == nil { + return "", errors.New("request cannot be nil") + } + if state == "" { + return "", errors.New("state param field is empty") + } + uak, err := oauth.GetStateCookie(r) + if err != nil { + return "", errors.Wrap(err, "oauth.GetStateCookie") + } + data, err := oauth.VerifyState(cfg, state, uak) + if err != nil { + return "", errors.Wrap(err, "oauth.VerifyState") + } + oauth.DeleteStateCookie(w) + return data, nil +} diff --git a/internal/handlers/login.go b/internal/handlers/login.go new file mode 100644 index 0000000..a64e9d3 --- /dev/null +++ b/internal/handlers/login.go @@ -0,0 +1,48 @@ +package handlers + +import ( + "net/http" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/pkg/oauth" +) + +func Login(server *hws.Server, cfg *config.Config) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + state, uak, err := oauth.GenerateState(cfg.OAuth, "login") + if err != nil { + err = server.ThrowError(w, r, hws.HWSError{ + StatusCode: http.StatusInternalServerError, + Message: "Failed to generate state token", + Error: err, + Level: hws.ErrorLevel("error"), + RenderErrorPage: true, + }) + if err != nil { + server.ThrowFatal(w, err) + } + return + } + oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL) + + link, err := discord.GetOAuthLink(cfg.Discord, state, cfg.HWSAuth.TrustedHost) + if err != nil { + err = server.ThrowError(w, r, hws.HWSError{ + StatusCode: http.StatusInternalServerError, + Message: "An error occured trying to generate the login link", + Error: err, + Level: hws.ErrorLevel("error"), + RenderErrorPage: true, + }) + if err != nil { + server.ThrowFatal(w, err) + } + return + } + http.Redirect(w, r, link, http.StatusSeeOther) + }, + ) +} diff --git a/pkg/embedfs/files/css/output.css b/pkg/embedfs/files/css/output.css index db10ff8..8876f5c 100644 --- a/pkg/embedfs/files/css/output.css +++ b/pkg/embedfs/files/css/output.css @@ -38,33 +38,6 @@ --default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); --default-font-family: var(--font-sans); --default-mono-font-family: var(--font-mono); - --color-rosewater: var(--rosewater); - --color-flamingo: var(--flamingo); - --color-pink: var(--pink); - --color-mauve: var(--mauve); - --color-red: var(--red); - --color-dark-red: var(--dark-red); - --color-maroon: var(--maroon); - --color-peach: var(--peach); - --color-yellow: var(--yellow); - --color-green: var(--green); - --color-teal: var(--teal); - --color-sky: var(--sky); - --color-sapphire: var(--sapphire); - --color-blue: var(--blue); - --color-lavender: var(--lavender); - --color-text: var(--text); - --color-subtext1: var(--subtext1); - --color-subtext0: var(--subtext0); - --color-overlay2: var(--overlay2); - --color-overlay1: var(--overlay1); - --color-overlay0: var(--overlay0); - --color-surface2: var(--surface2); - --color-surface1: var(--surface1); - --color-surface0: var(--surface0); - --color-base: var(--base); - --color-mantle: var(--mantle); - --color-crust: var(--crust); } } @layer base { diff --git a/pkg/oauth/config.go b/pkg/oauth/config.go new file mode 100644 index 0000000..c37ef21 --- /dev/null +++ b/pkg/oauth/config.go @@ -0,0 +1,23 @@ +package oauth + +import ( + "git.haelnorr.com/h/golib/env" + "github.com/pkg/errors" +) + +type Config struct { + PrivateKey string // ENV OAUTH_PRIVATE_KEY: Private key for signing OAuth state tokens (required) +} + +func ConfigFromEnv() (any, error) { + cfg := &Config{ + PrivateKey: env.String("OAUTH_PRIVATE_KEY", ""), + } + + // Check required fields + if cfg.PrivateKey == "" { + return nil, errors.New("Envar not set: OAUTH_PRIVATE_KEY") + } + + return cfg, nil +} diff --git a/pkg/oauth/cookies.go b/pkg/oauth/cookies.go new file mode 100644 index 0000000..adf0f54 --- /dev/null +++ b/pkg/oauth/cookies.go @@ -0,0 +1,45 @@ +package oauth + +import ( + "encoding/base64" + "net/http" + + "github.com/pkg/errors" +) + +func SetStateCookie(w http.ResponseWriter, uak []byte, ssl bool) { + encodedUak := base64.RawURLEncoding.EncodeToString(uak) + http.SetCookie(w, &http.Cookie{ + Name: "oauth_uak", + Value: encodedUak, + Path: "/", + MaxAge: 300, + HttpOnly: true, + Secure: ssl, + SameSite: http.SameSiteLaxMode, + }) +} + +func GetStateCookie(r *http.Request) ([]byte, error) { + if r == nil { + return nil, errors.New("Request cannot be nil") + } + cookie, err := r.Cookie("oauth_uak") + if err != nil { + return nil, err + } + uak, err := base64.RawURLEncoding.DecodeString(cookie.Value) + if err != nil { + return nil, errors.Wrap(err, "failed to decode userAgentKey from cookie") + } + return uak, nil +} + +func DeleteStateCookie(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: "oauth_uak", + Value: "", + Path: "/", + MaxAge: -1, + }) +} diff --git a/pkg/oauth/ezconf.go b/pkg/oauth/ezconf.go new file mode 100644 index 0000000..e8a87ca --- /dev/null +++ b/pkg/oauth/ezconf.go @@ -0,0 +1,41 @@ +package oauth + +import ( + "runtime" + "strings" +) + +// EZConfIntegration provides integration with ezconf for automatic configuration +type EZConfIntegration struct { + configFunc func() (any, error) + name string +} + +// PackagePath returns the path to the config package for source parsing +func (e EZConfIntegration) PackagePath() string { + _, filename, _, _ := runtime.Caller(0) + // Return directory of this file + return filename[:len(filename)-len("/ezconf.go")] +} + +// ConfigFunc returns the ConfigFromEnv function for ezconf +func (e EZConfIntegration) ConfigFunc() func() (any, error) { + return func() (any, error) { + return e.configFunc() + } +} + +// Name returns the name to use when registering with ezconf +func (e EZConfIntegration) Name() string { + return strings.ToLower(e.name) +} + +// GroupName returns the display name for grouping environment variables +func (e EZConfIntegration) GroupName() string { + return e.name +} + +// NewEZConfIntegration creates a new EZConf integration helper +func NewEZConfIntegration() EZConfIntegration { + return EZConfIntegration{name: "OAuth", configFunc: ConfigFromEnv} +} diff --git a/pkg/oauth/state.go b/pkg/oauth/state.go new file mode 100644 index 0000000..5c97ab5 --- /dev/null +++ b/pkg/oauth/state.go @@ -0,0 +1,117 @@ +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "slices" + "strings" + + "github.com/pkg/errors" +) + +// STATE FLOW: +// data provided at call time to be retrieved later +// random value generated on the spot +// userAgentKey - nonce used to prevent MITM, stored as lax cookie on client +// privateKey - from config + +func GenerateState(cfg *Config, data string) (state string, userAgentKey []byte, err error) { + // signature = BASE64_SHA256(data + "." + random + userAgentKey + privateKey) + // state = data + "." + random + "." + signature + if cfg == nil { + return "", nil, errors.New("cfg cannot be nil") + } + if cfg.PrivateKey == "" { + return "", nil, errors.New("private key cannot be empty") + } + if data == "" { + return "", nil, errors.New("data cannot be empty") + } + + // Generate 32 random bytes for random component + randomBytes := make([]byte, 32) + _, err = rand.Read(randomBytes) + if err != nil { + return "", nil, errors.Wrap(err, "failed to generate random bytes") + } + + // Generate 32 random bytes for userAgentKey + userAgentKey = make([]byte, 32) + _, err = rand.Read(userAgentKey) + if err != nil { + return "", nil, errors.Wrap(err, "failed to generate userAgentKey bytes") + } + + // Encode random and userAgentKey to base64 + randomEncoded := base64.RawURLEncoding.EncodeToString(randomBytes) + userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey) + + // Create payload for signing: data + "." + random + userAgentKey + privateKey + // Note: userAgentKey is concatenated directly with privateKey (no separator) + payload := data + "." + randomEncoded + userAgentKeyEncoded + cfg.PrivateKey + + // Generate signature + hash := sha256.Sum256([]byte(payload)) + signature := base64.RawURLEncoding.EncodeToString(hash[:]) + + // Construct state: data + "." + random + "." + signature + state = data + "." + randomEncoded + "." + signature + + return state, userAgentKey, nil +} + +func VerifyState(cfg *Config, state string, userAgentKey []byte) (data string, err error) { + // Validate inputs + if cfg == nil { + return "", errors.New("cfg cannot be nil") + } + if cfg.PrivateKey == "" { + return "", errors.New("private key cannot be empty") + } + if state == "" { + return "", errors.New("state cannot be empty") + } + if len(userAgentKey) == 0 { + return "", errors.New("userAgentKey cannot be empty") + } + + // Split state into parts + parts := strings.Split(state, ".") + if len(parts) != 3 { + return "", errors.Errorf("state must have exactly 3 parts (data.random.signature), got %d parts", len(parts)) + } + + // Check for empty parts + if slices.Contains(parts, "") { + return "", errors.New("state parts cannot be empty") + } + + data = parts[0] + random := parts[1] + receivedSignature := parts[2] + + // Encode userAgentKey to base64 for payload reconstruction + userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey) + + // Reconstruct payload (same as generation): data + "." + random + userAgentKeyEncoded + privateKey + payload := data + "." + random + userAgentKeyEncoded + cfg.PrivateKey + + // Generate expected hash + hash := sha256.Sum256([]byte(payload)) + + // Decode received signature to bytes + receivedBytes, err := base64.RawURLEncoding.DecodeString(receivedSignature) + if err != nil { + return "", errors.Wrap(err, "failed to decode received signature") + } + + // Compare hash bytes directly with decoded signature using constant-time comparison + // This is more efficient than encoding hash and then decoding both for comparison + if subtle.ConstantTimeCompare(hash[:], receivedBytes) == 1 { + return data, nil + } + + return "", errors.New("invalid state signature") +} diff --git a/pkg/oauth/state_test.go b/pkg/oauth/state_test.go new file mode 100644 index 0000000..fb0672d --- /dev/null +++ b/pkg/oauth/state_test.go @@ -0,0 +1,817 @@ +package oauth + +import ( + "crypto/sha256" + "encoding/base64" + "strings" + "testing" +) + +// Helper function to create a test config +func testConfig() *Config { + return &Config{ + PrivateKey: "test_private_key_for_testing_12345", + } +} + +// TestGenerateState_Success tests the happy path of state generation +func TestGenerateState_Success(t *testing.T) { + cfg := testConfig() + data := "test_data_payload" + + state, userAgentKey, err := GenerateState(cfg, data) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if state == "" { + t.Error("Expected non-empty state") + } + + if len(userAgentKey) != 32 { + t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey)) + } + + // Verify state format: data.random.signature + parts := strings.Split(state, ".") + if len(parts) != 3 { + t.Errorf("Expected state to have 3 parts, got %d", len(parts)) + } + + // Verify data is preserved + if parts[0] != data { + t.Errorf("Expected data to be '%s', got '%s'", data, parts[0]) + } + + // Verify random part is base64 encoded + randomBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Errorf("Expected random part to be valid base64: %v", err) + } + if len(randomBytes) != 32 { + t.Errorf("Expected random to be 32 bytes when decoded, got %d", len(randomBytes)) + } + + // Verify signature part is base64 encoded + sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + t.Errorf("Expected signature part to be valid base64: %v", err) + } + if len(sigBytes) != 32 { + t.Errorf("Expected signature to be 32 bytes (SHA256), got %d", len(sigBytes)) + } +} + +// TestGenerateState_NilConfig tests that nil config returns error +func TestGenerateState_NilConfig(t *testing.T) { + _, _, err := GenerateState(nil, "test_data") + + if err == nil { + t.Fatal("Expected error for nil config, got nil") + } + + if !strings.Contains(err.Error(), "cfg cannot be nil") { + t.Errorf("Expected error message about nil config, got: %v", err) + } +} + +// TestGenerateState_EmptyPrivateKey tests that empty private key returns error +func TestGenerateState_EmptyPrivateKey(t *testing.T) { + cfg := &Config{PrivateKey: ""} + _, _, err := GenerateState(cfg, "test_data") + + if err == nil { + t.Fatal("Expected error for empty private key, got nil") + } + + if !strings.Contains(err.Error(), "private key cannot be empty") { + t.Errorf("Expected error message about empty private key, got: %v", err) + } +} + +// TestGenerateState_EmptyData tests that empty data returns error +func TestGenerateState_EmptyData(t *testing.T) { + cfg := testConfig() + _, _, err := GenerateState(cfg, "") + + if err == nil { + t.Fatal("Expected error for empty data, got nil") + } + + if !strings.Contains(err.Error(), "data cannot be empty") { + t.Errorf("Expected error message about empty data, got: %v", err) + } +} + +// TestGenerateState_Randomness tests that multiple calls generate different states +func TestGenerateState_Randomness(t *testing.T) { + cfg := testConfig() + data := "same_data" + + state1, _, err1 := GenerateState(cfg, data) + state2, _, err2 := GenerateState(cfg, data) + + if err1 != nil || err2 != nil { + t.Fatalf("Unexpected errors: %v, %v", err1, err2) + } + + if state1 == state2 { + t.Error("Expected different states for multiple calls, got identical states") + } +} + +// TestGenerateState_DifferentData tests states with different data payloads +func TestGenerateState_DifferentData(t *testing.T) { + cfg := testConfig() + + testCases := []string{ + "simple", + "with-dashes", + "with_underscores", + "123456789", + "MixedCase123", + } + + for _, data := range testCases { + t.Run(data, func(t *testing.T) { + state, userAgentKey, err := GenerateState(cfg, data) + + if err != nil { + t.Fatalf("Unexpected error for data '%s': %v", data, err) + } + + if !strings.HasPrefix(state, data+".") { + t.Errorf("Expected state to start with '%s.', got: %s", data, state) + } + + if len(userAgentKey) != 32 { + t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey)) + } + }) + } +} + +// TestVerifyState_Success tests the happy path of state verification +func TestVerifyState_Success(t *testing.T) { + cfg := testConfig() + data := "test_data" + + // Generate state + state, userAgentKey, err := GenerateState(cfg, data) + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Verify state + extractedData, err := VerifyState(cfg, state, userAgentKey) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if extractedData != data { + t.Errorf("Expected extracted data to be '%s', got '%s'", data, extractedData) + } +} + +// TestVerifyState_NilConfig tests that nil config returns error +func TestVerifyState_NilConfig(t *testing.T) { + _, err := VerifyState(nil, "state", []byte("key")) + + if err == nil { + t.Fatal("Expected error for nil config, got nil") + } + + if !strings.Contains(err.Error(), "cfg cannot be nil") { + t.Errorf("Expected error message about nil config, got: %v", err) + } +} + +// TestVerifyState_EmptyPrivateKey tests that empty private key returns error +func TestVerifyState_EmptyPrivateKey(t *testing.T) { + cfg := &Config{PrivateKey: ""} + _, err := VerifyState(cfg, "state", []byte("key")) + + if err == nil { + t.Fatal("Expected error for empty private key, got nil") + } + + if !strings.Contains(err.Error(), "private key cannot be empty") { + t.Errorf("Expected error message about empty private key, got: %v", err) + } +} + +// TestVerifyState_EmptyState tests that empty state returns error +func TestVerifyState_EmptyState(t *testing.T) { + cfg := testConfig() + _, err := VerifyState(cfg, "", []byte("key")) + + if err == nil { + t.Fatal("Expected error for empty state, got nil") + } + + if !strings.Contains(err.Error(), "state cannot be empty") { + t.Errorf("Expected error message about empty state, got: %v", err) + } +} + +// TestVerifyState_EmptyUserAgentKey tests that empty userAgentKey returns error +func TestVerifyState_EmptyUserAgentKey(t *testing.T) { + cfg := testConfig() + _, err := VerifyState(cfg, "data.random.signature", []byte{}) + + if err == nil { + t.Fatal("Expected error for empty userAgentKey, got nil") + } + + if !strings.Contains(err.Error(), "userAgentKey cannot be empty") { + t.Errorf("Expected error message about empty userAgentKey, got: %v", err) + } +} + +// TestVerifyState_WrongUserAgentKey tests MITM protection +func TestVerifyState_WrongUserAgentKey(t *testing.T) { + cfg := testConfig() + + // Generate first state + state, _, err := GenerateState(cfg, "test_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Generate a different userAgentKey + _, wrongKey, err := GenerateState(cfg, "other_data") + if err != nil { + t.Fatalf("Failed to generate second state: %v", err) + } + + // Try to verify with wrong key + _, err = VerifyState(cfg, state, wrongKey) + + if err == nil { + t.Error("Expected error for invalid signature") + } + + if !strings.Contains(err.Error(), "invalid state signature") { + t.Errorf("Expected error about invalid signature, got: %v", err) + } +} + +// TestVerifyState_TamperedData tests tampering detection +func TestVerifyState_TamperedData(t *testing.T) { + cfg := testConfig() + + // Generate state + state, userAgentKey, err := GenerateState(cfg, "original_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Tamper with the data portion + parts := strings.Split(state, ".") + parts[0] = "tampered_data" + tamperedState := strings.Join(parts, ".") + + // Try to verify tampered state + _, err = VerifyState(cfg, tamperedState, userAgentKey) + + if err == nil { + t.Error("Expected error for tampered state") + } +} + +// TestVerifyState_TamperedRandom tests tampering with random portion +func TestVerifyState_TamperedRandom(t *testing.T) { + cfg := testConfig() + + // Generate state + state, userAgentKey, err := GenerateState(cfg, "test_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Tamper with the random portion + parts := strings.Split(state, ".") + parts[1] = base64.RawURLEncoding.EncodeToString([]byte("tampered_random_value_here12")) + tamperedState := strings.Join(parts, ".") + + // Try to verify tampered state + _, err = VerifyState(cfg, tamperedState, userAgentKey) + + if err == nil { + t.Error("Expected error for tampered state") + } +} + +// TestVerifyState_TamperedSignature tests tampering with signature +func TestVerifyState_TamperedSignature(t *testing.T) { + cfg := testConfig() + + // Generate state + state, userAgentKey, err := GenerateState(cfg, "test_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Tamper with the signature portion + parts := strings.Split(state, ".") + // Create a different valid base64 string + parts[2] = base64.RawURLEncoding.EncodeToString(sha256.New().Sum([]byte("tampered"))) + tamperedState := strings.Join(parts, ".") + + // Try to verify tampered state + _, err = VerifyState(cfg, tamperedState, userAgentKey) + + if err == nil { + t.Error("Expected error for tampered signature") + } +} + +// TestVerifyState_MalformedState_TwoParts tests state with only 2 parts +func TestVerifyState_MalformedState_TwoParts(t *testing.T) { + cfg := testConfig() + malformedState := "data.random" + + _, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890")) + + if err == nil { + t.Fatal("Expected error for malformed state") + } + + if !strings.Contains(err.Error(), "must have exactly 3 parts") { + t.Errorf("Expected error about incorrect number of parts, got: %v", err) + } +} + +// TestVerifyState_MalformedState_FourParts tests state with 4 parts +func TestVerifyState_MalformedState_FourParts(t *testing.T) { + cfg := testConfig() + malformedState := "data.random.signature.extra" + + _, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890")) + + if err == nil { + t.Fatal("Expected error for malformed state") + } + + if !strings.Contains(err.Error(), "must have exactly 3 parts") { + t.Errorf("Expected error about incorrect number of parts, got: %v", err) + } +} + +// TestVerifyState_EmptyStateParts tests state with empty parts +func TestVerifyState_EmptyStateParts(t *testing.T) { + cfg := testConfig() + testCases := []struct { + name string + state string + }{ + {"empty data", ".random.signature"}, + {"empty random", "data..signature"}, + {"empty signature", "data.random."}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := VerifyState(cfg, tc.state, []byte("key123456789012345678901234567890")) + + if err == nil { + t.Fatal("Expected error for state with empty parts") + } + + if !strings.Contains(err.Error(), "state parts cannot be empty") { + t.Errorf("Expected error about empty parts, got: %v", err) + } + }) + } +} + +// TestVerifyState_InvalidBase64Signature tests state with invalid base64 in signature +func TestVerifyState_InvalidBase64Signature(t *testing.T) { + cfg := testConfig() + invalidState := "data.random.invalid@base64!" + + _, err := VerifyState(cfg, invalidState, []byte("key123456789012345678901234567890")) + + if err == nil { + t.Fatal("Expected error for invalid base64 signature") + } + + if !strings.Contains(err.Error(), "failed to decode") { + t.Errorf("Expected error about decoding signature, got: %v", err) + } +} + +// TestVerifyState_DifferentPrivateKey tests that different private keys fail verification +func TestVerifyState_DifferentPrivateKey(t *testing.T) { + cfg1 := &Config{PrivateKey: "private_key_1"} + cfg2 := &Config{PrivateKey: "private_key_2"} + + // Generate with first config + state, userAgentKey, err := GenerateState(cfg1, "test_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Try to verify with second config + _, err = VerifyState(cfg2, state, userAgentKey) + + if err == nil { + t.Error("Expected error for mismatched private key") + } +} + +// TestRoundTrip tests complete round trip with various data payloads +func TestRoundTrip(t *testing.T) { + cfg := testConfig() + + testCases := []string{ + "simple", + "with-dashes-and-numbers-123", + "MixedCaseData", + "user_token_abc123", + "link_resource_xyz789", + } + + for _, data := range testCases { + t.Run(data, func(t *testing.T) { + // Generate + state, userAgentKey, err := GenerateState(cfg, data) + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Verify + extractedData, err := VerifyState(cfg, state, userAgentKey) + if err != nil { + t.Fatalf("Failed to verify state: %v", err) + } + + if extractedData != data { + t.Errorf("Expected extracted data '%s', got '%s'", data, extractedData) + } + }) + } +} + +// TestConcurrentGeneration tests that concurrent state generation works correctly +func TestConcurrentGeneration(t *testing.T) { + cfg := testConfig() + data := "concurrent_test" + + const numGoroutines = 10 + results := make(chan string, numGoroutines) + errors := make(chan error, numGoroutines) + + // Generate states concurrently + for i := 0; i < numGoroutines; i++ { + go func() { + state, userAgentKey, err := GenerateState(cfg, data) + if err != nil { + errors <- err + return + } + + // Verify immediately + _, verifyErr := VerifyState(cfg, state, userAgentKey) + if verifyErr != nil { + errors <- verifyErr + return + } + + results <- state + }() + } + + // Collect results + states := make(map[string]bool) + for i := 0; i < numGoroutines; i++ { + select { + case state := <-results: + if states[state] { + t.Errorf("Duplicate state generated: %s", state) + } + states[state] = true + case err := <-errors: + t.Errorf("Concurrent generation error: %v", err) + } + } + + if len(states) != numGoroutines { + t.Errorf("Expected %d unique states, got %d", numGoroutines, len(states)) + } +} + +// TestStateFormatCompatibility ensures state is URL-safe +func TestStateFormatCompatibility(t *testing.T) { + cfg := testConfig() + data := "url_safe_test" + + state, _, err := GenerateState(cfg, data) + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Check that state doesn't contain characters that need URL encoding + unsafeChars := []string{"+", "/", "=", " ", "&", "?", "#"} + for _, char := range unsafeChars { + if strings.Contains(state, char) { + t.Errorf("State contains URL-unsafe character '%s': %s", char, state) + } + } +} + +// TestMITM_AttackerCannotSubstituteState verifies MITM protection actually works +// An attacker obtains their own valid state but tries to use it with victim's session +func TestMITM_AttackerCannotSubstituteState(t *testing.T) { + cfg := testConfig() + + // Victim generates a state for their login + victimState, victimKey, err := GenerateState(cfg, "victim_data") + if err != nil { + t.Fatalf("Failed to generate victim state: %v", err) + } + + // Attacker generates their own valid state (they can request this from the server) + attackerState, attackerKey, err := GenerateState(cfg, "attacker_data") + if err != nil { + t.Fatalf("Failed to generate attacker state: %v", err) + } + + // Both states should be valid on their own + _, err = VerifyState(cfg, victimState, victimKey) + if err != nil { + t.Fatalf("Victim state should be valid: err=%v", err) + } + + _, err = VerifyState(cfg, attackerState, attackerKey) + if err != nil { + t.Fatalf("Attacker state should be valid: err=%v", err) + } + + // MITM Attack Scenario 1: Attacker substitutes their state but victim has their cookie + // This should FAIL because attackerState was signed with attackerKey, not victimKey + _, err = VerifyState(cfg, attackerState, victimKey) + if err == nil { + t.Error("Expected error when attacker substitutes state") + } + + // MITM Attack Scenario 2: Attacker uses victim's state but has their own cookie + // This should also FAIL + _, err = VerifyState(cfg, victimState, attackerKey) + if err == nil { + t.Error("Expected error when attacker uses victim's state") + } + + // The key insight: even though both states are "valid", they are bound to their respective cookies + // An attacker cannot mix and match states and cookies + t.Log("✓ MITM protection verified: States are cryptographically bound to their userAgentKey cookies") +} + +// TestCSRF_AttackerCannotForgeState verifies CSRF protection +// An attacker tries to forge a state parameter without knowing the private key +func TestCSRF_AttackerCannotForgeState(t *testing.T) { + cfg := testConfig() + + // Attacker doesn't know the private key, but tries to forge a state + // They might try to construct: "malicious_data.random.signature" + + // Attempt 1: Use a random signature + randomSig := base64.RawURLEncoding.EncodeToString([]byte("random_signature_attempt_12345678")) + forgedState1 := "malicious_data.somefakerandom." + randomSig + + // Generate a real userAgentKey (attacker might try to get this) + _, realKey, err := GenerateState(cfg, "legitimate_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Try to verify forged state + _, err = VerifyState(cfg, forgedState1, realKey) + if err == nil { + t.Error("CSRF VULNERABILITY: Attacker forged a valid state without private key!") + } + + // Attempt 2: Attacker tries to compute signature without private key + // They use: SHA256(data + "." + random + userAgentKey) - missing privateKey + attackerPayload := "malicious_data.fakerandom" + base64.RawURLEncoding.EncodeToString(realKey) + hash := sha256.Sum256([]byte(attackerPayload)) + attackerSig := base64.RawURLEncoding.EncodeToString(hash[:]) + forgedState2 := "malicious_data.fakerandom." + attackerSig + + _, err = VerifyState(cfg, forgedState2, realKey) + if err == nil { + t.Error("CSRF VULNERABILITY: Attacker forged valid state without private key!") + } + + t.Log("✓ CSRF protection verified: Cannot forge state without private key") +} + +// TestTampering_SignatureDetectsAllModifications verifies tamper detection +func TestTampering_SignatureDetectsAllModifications(t *testing.T) { + cfg := testConfig() + + // Generate a valid state + originalState, userAgentKey, err := GenerateState(cfg, "original_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Verify original is valid + data, err := VerifyState(cfg, originalState, userAgentKey) + if err != nil || data != "original_data" { + t.Fatalf("Original state should be valid") + } + + parts := strings.Split(originalState, ".") + + // Test 1: Attacker modifies data but keeps signature + tamperedState := "modified_data." + parts[1] + "." + parts[2] + _, err = VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Error("TAMPER VULNERABILITY: Modified data not detected!") + } + + // Test 2: Attacker modifies random but keeps signature + newRandom := base64.RawURLEncoding.EncodeToString([]byte("new_random_value_32bytes_long!!")) + tamperedState = parts[0] + "." + newRandom + "." + parts[2] + _, err = VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Error("TAMPER VULNERABILITY: Modified random not detected!") + } + + // Test 3: Attacker tries to recompute signature but doesn't have privateKey + // They compute: SHA256(modified_data + "." + random + userAgentKey) + attackerPayload := "modified_data." + parts[1] + base64.RawURLEncoding.EncodeToString(userAgentKey) + hash := sha256.Sum256([]byte(attackerPayload)) + attackerSig := base64.RawURLEncoding.EncodeToString(hash[:]) + tamperedState = "modified_data." + parts[1] + "." + attackerSig + _, err = VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Error("TAMPER VULNERABILITY: Attacker recomputed signature without private key!") + } + + // Test 4: Single bit flip in signature + sigBytes, _ := base64.RawURLEncoding.DecodeString(parts[2]) + sigBytes[0] ^= 0x01 // Flip one bit + flippedSig := base64.RawURLEncoding.EncodeToString(sigBytes) + tamperedState = parts[0] + "." + parts[1] + "." + flippedSig + _, err = VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Error("TAMPER VULNERABILITY: Single bit flip in signature not detected!") + } + + t.Log("✓ Tamper detection verified: All modifications to state are detected") +} + +// TestReplay_DifferentSessionsCannotReuseState verifies replay protection +func TestReplay_DifferentSessionsCannotReuseState(t *testing.T) { + cfg := testConfig() + + // Session 1: User initiates OAuth flow + state1, key1, err := GenerateState(cfg, "session1_data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // State is valid for session 1 + _, err = VerifyState(cfg, state1, key1) + if err != nil { + t.Fatalf("State should be valid for session 1") + } + + // Session 2: Same user (or attacker) initiates a new OAuth flow + state2, key2, err := GenerateState(cfg, "session1_data") // same data + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Replay Attack: Try to use state1 with key2 + _, err = VerifyState(cfg, state1, key2) + if err == nil { + t.Error("REPLAY VULNERABILITY: State from session 1 was accepted in session 2!") + } + + // Even with same data, each session should have unique state+key binding + if state1 == state2 { + t.Error("REPLAY VULNERABILITY: Same data produces identical states!") + } + + t.Log("✓ Replay protection verified: States are bound to specific session cookies") +} + +// TestConstantTimeComparison verifies that signature comparison is timing-safe +// This is a behavioral test - we can't easily test timing, but we can verify the function is used +func TestConstantTimeComparison_IsUsed(t *testing.T) { + cfg := testConfig() + + // Generate valid state + state, userAgentKey, err := GenerateState(cfg, "test") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Create states with signatures that differ at different positions + parts := strings.Split(state, ".") + originalSig, _ := base64.RawURLEncoding.DecodeString(parts[2]) + + testCases := []struct { + name string + position int + }{ + {"first byte differs", 0}, + {"middle byte differs", 16}, + {"last byte differs", 31}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create signature that differs at specific position + tamperedSig := make([]byte, len(originalSig)) + copy(tamperedSig, originalSig) + tamperedSig[tc.position] ^= 0xFF // Flip all bits + + tamperedSigStr := base64.RawURLEncoding.EncodeToString(tamperedSig) + tamperedState := parts[0] + "." + parts[1] + "." + tamperedSigStr + + // All should fail verification + _, err := VerifyState(cfg, tamperedState, userAgentKey) + if err == nil { + t.Errorf("Tampered signature at position %d should be invalid", tc.position) + } + + // If constant-time comparison is NOT used, early differences would return faster + // While we can't easily test timing here, we verify all positions fail equally + }) + } + + t.Log("✓ Constant-time comparison: All signature positions validated equally") + t.Log(" Note: crypto/subtle.ConstantTimeCompare is used in implementation") +} + +// TestPrivateKey_IsCriticalToSecurity verifies private key is essential +func TestPrivateKey_IsCriticalToSecurity(t *testing.T) { + cfg1 := &Config{PrivateKey: "secret_key_1"} + cfg2 := &Config{PrivateKey: "secret_key_2"} + + // Generate state with key1 + state, userAgentKey, err := GenerateState(cfg1, "data") + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Should verify with key1 + _, err = VerifyState(cfg1, state, userAgentKey) + if err != nil { + t.Fatalf("State should be valid with correct private key") + } + + // Should NOT verify with key2 (different private key) + _, err = VerifyState(cfg2, state, userAgentKey) + if err == nil { + t.Error("SECURITY VULNERABILITY: State verified with different private key!") + } + + // This proves that the private key is cryptographically involved in the signature + t.Log("✓ Private key security verified: Different keys produce incompatible signatures") +} + +// TestUserAgentKey_ProperlyIntegratedInSignature verifies userAgentKey is in payload +func TestUserAgentKey_ProperlyIntegratedInSignature(t *testing.T) { + cfg := testConfig() + + // Generate two states with same data but different userAgentKeys (implicit) + state1, key1, err := GenerateState(cfg, "same_data") + if err != nil { + t.Fatalf("Failed to generate state1: %v", err) + } + + state2, key2, err := GenerateState(cfg, "same_data") + if err != nil { + t.Fatalf("Failed to generate state2: %v", err) + } + + // The states should be different even with same data (different random and keys) + if state1 == state2 { + t.Error("States should differ due to different random values") + } + + // Each state should only verify with its own key + _, err1 := VerifyState(cfg, state1, key1) + _, err2 := VerifyState(cfg, state2, key2) + + if err1 != nil || err2 != nil { + t.Fatal("States should be valid with their own keys") + } + + // Cross-verification should fail + _, err1 = VerifyState(cfg, state1, key2) + _, err2 = VerifyState(cfg, state2, key1) + + if err1 == nil || err2 == nil { + t.Error("SECURITY VULNERABILITY: userAgentKey not properly integrated in signature!") + } + + t.Log("✓ UserAgentKey integration verified: Each state bound to its specific key") +}