added oauth flow to get authorization code

This commit is contained in:
2026-01-22 19:52:43 +11:00
parent 7be15160d5
commit c14c5d43ee
15 changed files with 1313 additions and 32 deletions

View File

@@ -38,33 +38,6 @@
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
--default-font-family: var(--font-sans);
--default-mono-font-family: var(--font-mono);
--color-rosewater: var(--rosewater);
--color-flamingo: var(--flamingo);
--color-pink: var(--pink);
--color-mauve: var(--mauve);
--color-red: var(--red);
--color-dark-red: var(--dark-red);
--color-maroon: var(--maroon);
--color-peach: var(--peach);
--color-yellow: var(--yellow);
--color-green: var(--green);
--color-teal: var(--teal);
--color-sky: var(--sky);
--color-sapphire: var(--sapphire);
--color-blue: var(--blue);
--color-lavender: var(--lavender);
--color-text: var(--text);
--color-subtext1: var(--subtext1);
--color-subtext0: var(--subtext0);
--color-overlay2: var(--overlay2);
--color-overlay1: var(--overlay1);
--color-overlay0: var(--overlay0);
--color-surface2: var(--surface2);
--color-surface1: var(--surface1);
--color-surface0: var(--surface0);
--color-base: var(--base);
--color-mantle: var(--mantle);
--color-crust: var(--crust);
}
}
@layer base {

23
pkg/oauth/config.go Normal file
View File

@@ -0,0 +1,23 @@
package oauth
import (
"git.haelnorr.com/h/golib/env"
"github.com/pkg/errors"
)
type Config struct {
PrivateKey string // ENV OAUTH_PRIVATE_KEY: Private key for signing OAuth state tokens (required)
}
func ConfigFromEnv() (any, error) {
cfg := &Config{
PrivateKey: env.String("OAUTH_PRIVATE_KEY", ""),
}
// Check required fields
if cfg.PrivateKey == "" {
return nil, errors.New("Envar not set: OAUTH_PRIVATE_KEY")
}
return cfg, nil
}

45
pkg/oauth/cookies.go Normal file
View File

@@ -0,0 +1,45 @@
package oauth
import (
"encoding/base64"
"net/http"
"github.com/pkg/errors"
)
func SetStateCookie(w http.ResponseWriter, uak []byte, ssl bool) {
encodedUak := base64.RawURLEncoding.EncodeToString(uak)
http.SetCookie(w, &http.Cookie{
Name: "oauth_uak",
Value: encodedUak,
Path: "/",
MaxAge: 300,
HttpOnly: true,
Secure: ssl,
SameSite: http.SameSiteLaxMode,
})
}
func GetStateCookie(r *http.Request) ([]byte, error) {
if r == nil {
return nil, errors.New("Request cannot be nil")
}
cookie, err := r.Cookie("oauth_uak")
if err != nil {
return nil, err
}
uak, err := base64.RawURLEncoding.DecodeString(cookie.Value)
if err != nil {
return nil, errors.Wrap(err, "failed to decode userAgentKey from cookie")
}
return uak, nil
}
func DeleteStateCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: "oauth_uak",
Value: "",
Path: "/",
MaxAge: -1,
})
}

41
pkg/oauth/ezconf.go Normal file
View File

@@ -0,0 +1,41 @@
package oauth
import (
"runtime"
"strings"
)
// EZConfIntegration provides integration with ezconf for automatic configuration
type EZConfIntegration struct {
configFunc func() (any, error)
name string
}
// PackagePath returns the path to the config package for source parsing
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
return func() (any, error) {
return e.configFunc()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return strings.ToLower(e.name)
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return e.name
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{name: "OAuth", configFunc: ConfigFromEnv}
}

117
pkg/oauth/state.go Normal file
View File

@@ -0,0 +1,117 @@
package oauth
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"slices"
"strings"
"github.com/pkg/errors"
)
// STATE FLOW:
// data provided at call time to be retrieved later
// random value generated on the spot
// userAgentKey - nonce used to prevent MITM, stored as lax cookie on client
// privateKey - from config
func GenerateState(cfg *Config, data string) (state string, userAgentKey []byte, err error) {
// signature = BASE64_SHA256(data + "." + random + userAgentKey + privateKey)
// state = data + "." + random + "." + signature
if cfg == nil {
return "", nil, errors.New("cfg cannot be nil")
}
if cfg.PrivateKey == "" {
return "", nil, errors.New("private key cannot be empty")
}
if data == "" {
return "", nil, errors.New("data cannot be empty")
}
// Generate 32 random bytes for random component
randomBytes := make([]byte, 32)
_, err = rand.Read(randomBytes)
if err != nil {
return "", nil, errors.Wrap(err, "failed to generate random bytes")
}
// Generate 32 random bytes for userAgentKey
userAgentKey = make([]byte, 32)
_, err = rand.Read(userAgentKey)
if err != nil {
return "", nil, errors.Wrap(err, "failed to generate userAgentKey bytes")
}
// Encode random and userAgentKey to base64
randomEncoded := base64.RawURLEncoding.EncodeToString(randomBytes)
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
// Create payload for signing: data + "." + random + userAgentKey + privateKey
// Note: userAgentKey is concatenated directly with privateKey (no separator)
payload := data + "." + randomEncoded + userAgentKeyEncoded + cfg.PrivateKey
// Generate signature
hash := sha256.Sum256([]byte(payload))
signature := base64.RawURLEncoding.EncodeToString(hash[:])
// Construct state: data + "." + random + "." + signature
state = data + "." + randomEncoded + "." + signature
return state, userAgentKey, nil
}
func VerifyState(cfg *Config, state string, userAgentKey []byte) (data string, err error) {
// Validate inputs
if cfg == nil {
return "", errors.New("cfg cannot be nil")
}
if cfg.PrivateKey == "" {
return "", errors.New("private key cannot be empty")
}
if state == "" {
return "", errors.New("state cannot be empty")
}
if len(userAgentKey) == 0 {
return "", errors.New("userAgentKey cannot be empty")
}
// Split state into parts
parts := strings.Split(state, ".")
if len(parts) != 3 {
return "", errors.Errorf("state must have exactly 3 parts (data.random.signature), got %d parts", len(parts))
}
// Check for empty parts
if slices.Contains(parts, "") {
return "", errors.New("state parts cannot be empty")
}
data = parts[0]
random := parts[1]
receivedSignature := parts[2]
// Encode userAgentKey to base64 for payload reconstruction
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
// Reconstruct payload (same as generation): data + "." + random + userAgentKeyEncoded + privateKey
payload := data + "." + random + userAgentKeyEncoded + cfg.PrivateKey
// Generate expected hash
hash := sha256.Sum256([]byte(payload))
// Decode received signature to bytes
receivedBytes, err := base64.RawURLEncoding.DecodeString(receivedSignature)
if err != nil {
return "", errors.Wrap(err, "failed to decode received signature")
}
// Compare hash bytes directly with decoded signature using constant-time comparison
// This is more efficient than encoding hash and then decoding both for comparison
if subtle.ConstantTimeCompare(hash[:], receivedBytes) == 1 {
return data, nil
}
return "", errors.New("invalid state signature")
}

817
pkg/oauth/state_test.go Normal file
View File

@@ -0,0 +1,817 @@
package oauth
import (
"crypto/sha256"
"encoding/base64"
"strings"
"testing"
)
// Helper function to create a test config
func testConfig() *Config {
return &Config{
PrivateKey: "test_private_key_for_testing_12345",
}
}
// TestGenerateState_Success tests the happy path of state generation
func TestGenerateState_Success(t *testing.T) {
cfg := testConfig()
data := "test_data_payload"
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if state == "" {
t.Error("Expected non-empty state")
}
if len(userAgentKey) != 32 {
t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey))
}
// Verify state format: data.random.signature
parts := strings.Split(state, ".")
if len(parts) != 3 {
t.Errorf("Expected state to have 3 parts, got %d", len(parts))
}
// Verify data is preserved
if parts[0] != data {
t.Errorf("Expected data to be '%s', got '%s'", data, parts[0])
}
// Verify random part is base64 encoded
randomBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Errorf("Expected random part to be valid base64: %v", err)
}
if len(randomBytes) != 32 {
t.Errorf("Expected random to be 32 bytes when decoded, got %d", len(randomBytes))
}
// Verify signature part is base64 encoded
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
t.Errorf("Expected signature part to be valid base64: %v", err)
}
if len(sigBytes) != 32 {
t.Errorf("Expected signature to be 32 bytes (SHA256), got %d", len(sigBytes))
}
}
// TestGenerateState_NilConfig tests that nil config returns error
func TestGenerateState_NilConfig(t *testing.T) {
_, _, err := GenerateState(nil, "test_data")
if err == nil {
t.Fatal("Expected error for nil config, got nil")
}
if !strings.Contains(err.Error(), "cfg cannot be nil") {
t.Errorf("Expected error message about nil config, got: %v", err)
}
}
// TestGenerateState_EmptyPrivateKey tests that empty private key returns error
func TestGenerateState_EmptyPrivateKey(t *testing.T) {
cfg := &Config{PrivateKey: ""}
_, _, err := GenerateState(cfg, "test_data")
if err == nil {
t.Fatal("Expected error for empty private key, got nil")
}
if !strings.Contains(err.Error(), "private key cannot be empty") {
t.Errorf("Expected error message about empty private key, got: %v", err)
}
}
// TestGenerateState_EmptyData tests that empty data returns error
func TestGenerateState_EmptyData(t *testing.T) {
cfg := testConfig()
_, _, err := GenerateState(cfg, "")
if err == nil {
t.Fatal("Expected error for empty data, got nil")
}
if !strings.Contains(err.Error(), "data cannot be empty") {
t.Errorf("Expected error message about empty data, got: %v", err)
}
}
// TestGenerateState_Randomness tests that multiple calls generate different states
func TestGenerateState_Randomness(t *testing.T) {
cfg := testConfig()
data := "same_data"
state1, _, err1 := GenerateState(cfg, data)
state2, _, err2 := GenerateState(cfg, data)
if err1 != nil || err2 != nil {
t.Fatalf("Unexpected errors: %v, %v", err1, err2)
}
if state1 == state2 {
t.Error("Expected different states for multiple calls, got identical states")
}
}
// TestGenerateState_DifferentData tests states with different data payloads
func TestGenerateState_DifferentData(t *testing.T) {
cfg := testConfig()
testCases := []string{
"simple",
"with-dashes",
"with_underscores",
"123456789",
"MixedCase123",
}
for _, data := range testCases {
t.Run(data, func(t *testing.T) {
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Unexpected error for data '%s': %v", data, err)
}
if !strings.HasPrefix(state, data+".") {
t.Errorf("Expected state to start with '%s.', got: %s", data, state)
}
if len(userAgentKey) != 32 {
t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey))
}
})
}
}
// TestVerifyState_Success tests the happy path of state verification
func TestVerifyState_Success(t *testing.T) {
cfg := testConfig()
data := "test_data"
// Generate state
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Verify state
extractedData, err := VerifyState(cfg, state, userAgentKey)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if extractedData != data {
t.Errorf("Expected extracted data to be '%s', got '%s'", data, extractedData)
}
}
// TestVerifyState_NilConfig tests that nil config returns error
func TestVerifyState_NilConfig(t *testing.T) {
_, err := VerifyState(nil, "state", []byte("key"))
if err == nil {
t.Fatal("Expected error for nil config, got nil")
}
if !strings.Contains(err.Error(), "cfg cannot be nil") {
t.Errorf("Expected error message about nil config, got: %v", err)
}
}
// TestVerifyState_EmptyPrivateKey tests that empty private key returns error
func TestVerifyState_EmptyPrivateKey(t *testing.T) {
cfg := &Config{PrivateKey: ""}
_, err := VerifyState(cfg, "state", []byte("key"))
if err == nil {
t.Fatal("Expected error for empty private key, got nil")
}
if !strings.Contains(err.Error(), "private key cannot be empty") {
t.Errorf("Expected error message about empty private key, got: %v", err)
}
}
// TestVerifyState_EmptyState tests that empty state returns error
func TestVerifyState_EmptyState(t *testing.T) {
cfg := testConfig()
_, err := VerifyState(cfg, "", []byte("key"))
if err == nil {
t.Fatal("Expected error for empty state, got nil")
}
if !strings.Contains(err.Error(), "state cannot be empty") {
t.Errorf("Expected error message about empty state, got: %v", err)
}
}
// TestVerifyState_EmptyUserAgentKey tests that empty userAgentKey returns error
func TestVerifyState_EmptyUserAgentKey(t *testing.T) {
cfg := testConfig()
_, err := VerifyState(cfg, "data.random.signature", []byte{})
if err == nil {
t.Fatal("Expected error for empty userAgentKey, got nil")
}
if !strings.Contains(err.Error(), "userAgentKey cannot be empty") {
t.Errorf("Expected error message about empty userAgentKey, got: %v", err)
}
}
// TestVerifyState_WrongUserAgentKey tests MITM protection
func TestVerifyState_WrongUserAgentKey(t *testing.T) {
cfg := testConfig()
// Generate first state
state, _, err := GenerateState(cfg, "test_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Generate a different userAgentKey
_, wrongKey, err := GenerateState(cfg, "other_data")
if err != nil {
t.Fatalf("Failed to generate second state: %v", err)
}
// Try to verify with wrong key
_, err = VerifyState(cfg, state, wrongKey)
if err == nil {
t.Error("Expected error for invalid signature")
}
if !strings.Contains(err.Error(), "invalid state signature") {
t.Errorf("Expected error about invalid signature, got: %v", err)
}
}
// TestVerifyState_TamperedData tests tampering detection
func TestVerifyState_TamperedData(t *testing.T) {
cfg := testConfig()
// Generate state
state, userAgentKey, err := GenerateState(cfg, "original_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Tamper with the data portion
parts := strings.Split(state, ".")
parts[0] = "tampered_data"
tamperedState := strings.Join(parts, ".")
// Try to verify tampered state
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("Expected error for tampered state")
}
}
// TestVerifyState_TamperedRandom tests tampering with random portion
func TestVerifyState_TamperedRandom(t *testing.T) {
cfg := testConfig()
// Generate state
state, userAgentKey, err := GenerateState(cfg, "test_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Tamper with the random portion
parts := strings.Split(state, ".")
parts[1] = base64.RawURLEncoding.EncodeToString([]byte("tampered_random_value_here12"))
tamperedState := strings.Join(parts, ".")
// Try to verify tampered state
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("Expected error for tampered state")
}
}
// TestVerifyState_TamperedSignature tests tampering with signature
func TestVerifyState_TamperedSignature(t *testing.T) {
cfg := testConfig()
// Generate state
state, userAgentKey, err := GenerateState(cfg, "test_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Tamper with the signature portion
parts := strings.Split(state, ".")
// Create a different valid base64 string
parts[2] = base64.RawURLEncoding.EncodeToString(sha256.New().Sum([]byte("tampered")))
tamperedState := strings.Join(parts, ".")
// Try to verify tampered state
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("Expected error for tampered signature")
}
}
// TestVerifyState_MalformedState_TwoParts tests state with only 2 parts
func TestVerifyState_MalformedState_TwoParts(t *testing.T) {
cfg := testConfig()
malformedState := "data.random"
_, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890"))
if err == nil {
t.Fatal("Expected error for malformed state")
}
if !strings.Contains(err.Error(), "must have exactly 3 parts") {
t.Errorf("Expected error about incorrect number of parts, got: %v", err)
}
}
// TestVerifyState_MalformedState_FourParts tests state with 4 parts
func TestVerifyState_MalformedState_FourParts(t *testing.T) {
cfg := testConfig()
malformedState := "data.random.signature.extra"
_, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890"))
if err == nil {
t.Fatal("Expected error for malformed state")
}
if !strings.Contains(err.Error(), "must have exactly 3 parts") {
t.Errorf("Expected error about incorrect number of parts, got: %v", err)
}
}
// TestVerifyState_EmptyStateParts tests state with empty parts
func TestVerifyState_EmptyStateParts(t *testing.T) {
cfg := testConfig()
testCases := []struct {
name string
state string
}{
{"empty data", ".random.signature"},
{"empty random", "data..signature"},
{"empty signature", "data.random."},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := VerifyState(cfg, tc.state, []byte("key123456789012345678901234567890"))
if err == nil {
t.Fatal("Expected error for state with empty parts")
}
if !strings.Contains(err.Error(), "state parts cannot be empty") {
t.Errorf("Expected error about empty parts, got: %v", err)
}
})
}
}
// TestVerifyState_InvalidBase64Signature tests state with invalid base64 in signature
func TestVerifyState_InvalidBase64Signature(t *testing.T) {
cfg := testConfig()
invalidState := "data.random.invalid@base64!"
_, err := VerifyState(cfg, invalidState, []byte("key123456789012345678901234567890"))
if err == nil {
t.Fatal("Expected error for invalid base64 signature")
}
if !strings.Contains(err.Error(), "failed to decode") {
t.Errorf("Expected error about decoding signature, got: %v", err)
}
}
// TestVerifyState_DifferentPrivateKey tests that different private keys fail verification
func TestVerifyState_DifferentPrivateKey(t *testing.T) {
cfg1 := &Config{PrivateKey: "private_key_1"}
cfg2 := &Config{PrivateKey: "private_key_2"}
// Generate with first config
state, userAgentKey, err := GenerateState(cfg1, "test_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Try to verify with second config
_, err = VerifyState(cfg2, state, userAgentKey)
if err == nil {
t.Error("Expected error for mismatched private key")
}
}
// TestRoundTrip tests complete round trip with various data payloads
func TestRoundTrip(t *testing.T) {
cfg := testConfig()
testCases := []string{
"simple",
"with-dashes-and-numbers-123",
"MixedCaseData",
"user_token_abc123",
"link_resource_xyz789",
}
for _, data := range testCases {
t.Run(data, func(t *testing.T) {
// Generate
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Verify
extractedData, err := VerifyState(cfg, state, userAgentKey)
if err != nil {
t.Fatalf("Failed to verify state: %v", err)
}
if extractedData != data {
t.Errorf("Expected extracted data '%s', got '%s'", data, extractedData)
}
})
}
}
// TestConcurrentGeneration tests that concurrent state generation works correctly
func TestConcurrentGeneration(t *testing.T) {
cfg := testConfig()
data := "concurrent_test"
const numGoroutines = 10
results := make(chan string, numGoroutines)
errors := make(chan error, numGoroutines)
// Generate states concurrently
for i := 0; i < numGoroutines; i++ {
go func() {
state, userAgentKey, err := GenerateState(cfg, data)
if err != nil {
errors <- err
return
}
// Verify immediately
_, verifyErr := VerifyState(cfg, state, userAgentKey)
if verifyErr != nil {
errors <- verifyErr
return
}
results <- state
}()
}
// Collect results
states := make(map[string]bool)
for i := 0; i < numGoroutines; i++ {
select {
case state := <-results:
if states[state] {
t.Errorf("Duplicate state generated: %s", state)
}
states[state] = true
case err := <-errors:
t.Errorf("Concurrent generation error: %v", err)
}
}
if len(states) != numGoroutines {
t.Errorf("Expected %d unique states, got %d", numGoroutines, len(states))
}
}
// TestStateFormatCompatibility ensures state is URL-safe
func TestStateFormatCompatibility(t *testing.T) {
cfg := testConfig()
data := "url_safe_test"
state, _, err := GenerateState(cfg, data)
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Check that state doesn't contain characters that need URL encoding
unsafeChars := []string{"+", "/", "=", " ", "&", "?", "#"}
for _, char := range unsafeChars {
if strings.Contains(state, char) {
t.Errorf("State contains URL-unsafe character '%s': %s", char, state)
}
}
}
// TestMITM_AttackerCannotSubstituteState verifies MITM protection actually works
// An attacker obtains their own valid state but tries to use it with victim's session
func TestMITM_AttackerCannotSubstituteState(t *testing.T) {
cfg := testConfig()
// Victim generates a state for their login
victimState, victimKey, err := GenerateState(cfg, "victim_data")
if err != nil {
t.Fatalf("Failed to generate victim state: %v", err)
}
// Attacker generates their own valid state (they can request this from the server)
attackerState, attackerKey, err := GenerateState(cfg, "attacker_data")
if err != nil {
t.Fatalf("Failed to generate attacker state: %v", err)
}
// Both states should be valid on their own
_, err = VerifyState(cfg, victimState, victimKey)
if err != nil {
t.Fatalf("Victim state should be valid: err=%v", err)
}
_, err = VerifyState(cfg, attackerState, attackerKey)
if err != nil {
t.Fatalf("Attacker state should be valid: err=%v", err)
}
// MITM Attack Scenario 1: Attacker substitutes their state but victim has their cookie
// This should FAIL because attackerState was signed with attackerKey, not victimKey
_, err = VerifyState(cfg, attackerState, victimKey)
if err == nil {
t.Error("Expected error when attacker substitutes state")
}
// MITM Attack Scenario 2: Attacker uses victim's state but has their own cookie
// This should also FAIL
_, err = VerifyState(cfg, victimState, attackerKey)
if err == nil {
t.Error("Expected error when attacker uses victim's state")
}
// The key insight: even though both states are "valid", they are bound to their respective cookies
// An attacker cannot mix and match states and cookies
t.Log("✓ MITM protection verified: States are cryptographically bound to their userAgentKey cookies")
}
// TestCSRF_AttackerCannotForgeState verifies CSRF protection
// An attacker tries to forge a state parameter without knowing the private key
func TestCSRF_AttackerCannotForgeState(t *testing.T) {
cfg := testConfig()
// Attacker doesn't know the private key, but tries to forge a state
// They might try to construct: "malicious_data.random.signature"
// Attempt 1: Use a random signature
randomSig := base64.RawURLEncoding.EncodeToString([]byte("random_signature_attempt_12345678"))
forgedState1 := "malicious_data.somefakerandom." + randomSig
// Generate a real userAgentKey (attacker might try to get this)
_, realKey, err := GenerateState(cfg, "legitimate_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Try to verify forged state
_, err = VerifyState(cfg, forgedState1, realKey)
if err == nil {
t.Error("CSRF VULNERABILITY: Attacker forged a valid state without private key!")
}
// Attempt 2: Attacker tries to compute signature without private key
// They use: SHA256(data + "." + random + userAgentKey) - missing privateKey
attackerPayload := "malicious_data.fakerandom" + base64.RawURLEncoding.EncodeToString(realKey)
hash := sha256.Sum256([]byte(attackerPayload))
attackerSig := base64.RawURLEncoding.EncodeToString(hash[:])
forgedState2 := "malicious_data.fakerandom." + attackerSig
_, err = VerifyState(cfg, forgedState2, realKey)
if err == nil {
t.Error("CSRF VULNERABILITY: Attacker forged valid state without private key!")
}
t.Log("✓ CSRF protection verified: Cannot forge state without private key")
}
// TestTampering_SignatureDetectsAllModifications verifies tamper detection
func TestTampering_SignatureDetectsAllModifications(t *testing.T) {
cfg := testConfig()
// Generate a valid state
originalState, userAgentKey, err := GenerateState(cfg, "original_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Verify original is valid
data, err := VerifyState(cfg, originalState, userAgentKey)
if err != nil || data != "original_data" {
t.Fatalf("Original state should be valid")
}
parts := strings.Split(originalState, ".")
// Test 1: Attacker modifies data but keeps signature
tamperedState := "modified_data." + parts[1] + "." + parts[2]
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("TAMPER VULNERABILITY: Modified data not detected!")
}
// Test 2: Attacker modifies random but keeps signature
newRandom := base64.RawURLEncoding.EncodeToString([]byte("new_random_value_32bytes_long!!"))
tamperedState = parts[0] + "." + newRandom + "." + parts[2]
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("TAMPER VULNERABILITY: Modified random not detected!")
}
// Test 3: Attacker tries to recompute signature but doesn't have privateKey
// They compute: SHA256(modified_data + "." + random + userAgentKey)
attackerPayload := "modified_data." + parts[1] + base64.RawURLEncoding.EncodeToString(userAgentKey)
hash := sha256.Sum256([]byte(attackerPayload))
attackerSig := base64.RawURLEncoding.EncodeToString(hash[:])
tamperedState = "modified_data." + parts[1] + "." + attackerSig
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("TAMPER VULNERABILITY: Attacker recomputed signature without private key!")
}
// Test 4: Single bit flip in signature
sigBytes, _ := base64.RawURLEncoding.DecodeString(parts[2])
sigBytes[0] ^= 0x01 // Flip one bit
flippedSig := base64.RawURLEncoding.EncodeToString(sigBytes)
tamperedState = parts[0] + "." + parts[1] + "." + flippedSig
_, err = VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Error("TAMPER VULNERABILITY: Single bit flip in signature not detected!")
}
t.Log("✓ Tamper detection verified: All modifications to state are detected")
}
// TestReplay_DifferentSessionsCannotReuseState verifies replay protection
func TestReplay_DifferentSessionsCannotReuseState(t *testing.T) {
cfg := testConfig()
// Session 1: User initiates OAuth flow
state1, key1, err := GenerateState(cfg, "session1_data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// State is valid for session 1
_, err = VerifyState(cfg, state1, key1)
if err != nil {
t.Fatalf("State should be valid for session 1")
}
// Session 2: Same user (or attacker) initiates a new OAuth flow
state2, key2, err := GenerateState(cfg, "session1_data") // same data
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Replay Attack: Try to use state1 with key2
_, err = VerifyState(cfg, state1, key2)
if err == nil {
t.Error("REPLAY VULNERABILITY: State from session 1 was accepted in session 2!")
}
// Even with same data, each session should have unique state+key binding
if state1 == state2 {
t.Error("REPLAY VULNERABILITY: Same data produces identical states!")
}
t.Log("✓ Replay protection verified: States are bound to specific session cookies")
}
// TestConstantTimeComparison verifies that signature comparison is timing-safe
// This is a behavioral test - we can't easily test timing, but we can verify the function is used
func TestConstantTimeComparison_IsUsed(t *testing.T) {
cfg := testConfig()
// Generate valid state
state, userAgentKey, err := GenerateState(cfg, "test")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Create states with signatures that differ at different positions
parts := strings.Split(state, ".")
originalSig, _ := base64.RawURLEncoding.DecodeString(parts[2])
testCases := []struct {
name string
position int
}{
{"first byte differs", 0},
{"middle byte differs", 16},
{"last byte differs", 31},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create signature that differs at specific position
tamperedSig := make([]byte, len(originalSig))
copy(tamperedSig, originalSig)
tamperedSig[tc.position] ^= 0xFF // Flip all bits
tamperedSigStr := base64.RawURLEncoding.EncodeToString(tamperedSig)
tamperedState := parts[0] + "." + parts[1] + "." + tamperedSigStr
// All should fail verification
_, err := VerifyState(cfg, tamperedState, userAgentKey)
if err == nil {
t.Errorf("Tampered signature at position %d should be invalid", tc.position)
}
// If constant-time comparison is NOT used, early differences would return faster
// While we can't easily test timing here, we verify all positions fail equally
})
}
t.Log("✓ Constant-time comparison: All signature positions validated equally")
t.Log(" Note: crypto/subtle.ConstantTimeCompare is used in implementation")
}
// TestPrivateKey_IsCriticalToSecurity verifies private key is essential
func TestPrivateKey_IsCriticalToSecurity(t *testing.T) {
cfg1 := &Config{PrivateKey: "secret_key_1"}
cfg2 := &Config{PrivateKey: "secret_key_2"}
// Generate state with key1
state, userAgentKey, err := GenerateState(cfg1, "data")
if err != nil {
t.Fatalf("Failed to generate state: %v", err)
}
// Should verify with key1
_, err = VerifyState(cfg1, state, userAgentKey)
if err != nil {
t.Fatalf("State should be valid with correct private key")
}
// Should NOT verify with key2 (different private key)
_, err = VerifyState(cfg2, state, userAgentKey)
if err == nil {
t.Error("SECURITY VULNERABILITY: State verified with different private key!")
}
// This proves that the private key is cryptographically involved in the signature
t.Log("✓ Private key security verified: Different keys produce incompatible signatures")
}
// TestUserAgentKey_ProperlyIntegratedInSignature verifies userAgentKey is in payload
func TestUserAgentKey_ProperlyIntegratedInSignature(t *testing.T) {
cfg := testConfig()
// Generate two states with same data but different userAgentKeys (implicit)
state1, key1, err := GenerateState(cfg, "same_data")
if err != nil {
t.Fatalf("Failed to generate state1: %v", err)
}
state2, key2, err := GenerateState(cfg, "same_data")
if err != nil {
t.Fatalf("Failed to generate state2: %v", err)
}
// The states should be different even with same data (different random and keys)
if state1 == state2 {
t.Error("States should differ due to different random values")
}
// Each state should only verify with its own key
_, err1 := VerifyState(cfg, state1, key1)
_, err2 := VerifyState(cfg, state2, key2)
if err1 != nil || err2 != nil {
t.Fatal("States should be valid with their own keys")
}
// Cross-verification should fail
_, err1 = VerifyState(cfg, state1, key2)
_, err2 = VerifyState(cfg, state2, key1)
if err1 == nil || err2 == nil {
t.Error("SECURITY VULNERABILITY: userAgentKey not properly integrated in signature!")
}
t.Log("✓ UserAgentKey integration verified: Each state bound to its specific key")
}