package oauth import ( "crypto/sha256" "encoding/base64" "strings" "testing" ) // Helper function to create a test config func testConfig() *Config { return &Config{ PrivateKey: "test_private_key_for_testing_12345", } } // TestGenerateState_Success tests the happy path of state generation func TestGenerateState_Success(t *testing.T) { cfg := testConfig() data := "test_data_payload" state, userAgentKey, err := GenerateState(cfg, data) if err != nil { t.Fatalf("Expected no error, got: %v", err) } if state == "" { t.Error("Expected non-empty state") } if len(userAgentKey) != 32 { t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey)) } // Verify state format: data.random.signature parts := strings.Split(state, ".") if len(parts) != 3 { t.Errorf("Expected state to have 3 parts, got %d", len(parts)) } // Verify data is preserved if parts[0] != data { t.Errorf("Expected data to be '%s', got '%s'", data, parts[0]) } // Verify random part is base64 encoded randomBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { t.Errorf("Expected random part to be valid base64: %v", err) } if len(randomBytes) != 32 { t.Errorf("Expected random to be 32 bytes when decoded, got %d", len(randomBytes)) } // Verify signature part is base64 encoded sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2]) if err != nil { t.Errorf("Expected signature part to be valid base64: %v", err) } if len(sigBytes) != 32 { t.Errorf("Expected signature to be 32 bytes (SHA256), got %d", len(sigBytes)) } } // TestGenerateState_NilConfig tests that nil config returns error func TestGenerateState_NilConfig(t *testing.T) { _, _, err := GenerateState(nil, "test_data") if err == nil { t.Fatal("Expected error for nil config, got nil") } if !strings.Contains(err.Error(), "cfg cannot be nil") { t.Errorf("Expected error message about nil config, got: %v", err) } } // TestGenerateState_EmptyPrivateKey tests that empty private key returns error func TestGenerateState_EmptyPrivateKey(t *testing.T) { cfg := &Config{PrivateKey: ""} _, _, err := GenerateState(cfg, "test_data") if err == nil { t.Fatal("Expected error for empty private key, got nil") } if !strings.Contains(err.Error(), "private key cannot be empty") { t.Errorf("Expected error message about empty private key, got: %v", err) } } // TestGenerateState_EmptyData tests that empty data returns error func TestGenerateState_EmptyData(t *testing.T) { cfg := testConfig() _, _, err := GenerateState(cfg, "") if err == nil { t.Fatal("Expected error for empty data, got nil") } if !strings.Contains(err.Error(), "data cannot be empty") { t.Errorf("Expected error message about empty data, got: %v", err) } } // TestGenerateState_Randomness tests that multiple calls generate different states func TestGenerateState_Randomness(t *testing.T) { cfg := testConfig() data := "same_data" state1, _, err1 := GenerateState(cfg, data) state2, _, err2 := GenerateState(cfg, data) if err1 != nil || err2 != nil { t.Fatalf("Unexpected errors: %v, %v", err1, err2) } if state1 == state2 { t.Error("Expected different states for multiple calls, got identical states") } } // TestGenerateState_DifferentData tests states with different data payloads func TestGenerateState_DifferentData(t *testing.T) { cfg := testConfig() testCases := []string{ "simple", "with-dashes", "with_underscores", "123456789", "MixedCase123", } for _, data := range testCases { t.Run(data, func(t *testing.T) { state, userAgentKey, err := GenerateState(cfg, data) if err != nil { t.Fatalf("Unexpected error for data '%s': %v", data, err) } if !strings.HasPrefix(state, data+".") { t.Errorf("Expected state to start with '%s.', got: %s", data, state) } if len(userAgentKey) != 32 { t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey)) } }) } } // TestVerifyState_Success tests the happy path of state verification func TestVerifyState_Success(t *testing.T) { cfg := testConfig() data := "test_data" // Generate state state, userAgentKey, err := GenerateState(cfg, data) if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Verify state extractedData, err := VerifyState(cfg, state, userAgentKey) if err != nil { t.Fatalf("Expected no error, got: %v", err) } if extractedData != data { t.Errorf("Expected extracted data to be '%s', got '%s'", data, extractedData) } } // TestVerifyState_NilConfig tests that nil config returns error func TestVerifyState_NilConfig(t *testing.T) { _, err := VerifyState(nil, "state", []byte("key")) if err == nil { t.Fatal("Expected error for nil config, got nil") } if !strings.Contains(err.Error(), "cfg cannot be nil") { t.Errorf("Expected error message about nil config, got: %v", err) } } // TestVerifyState_EmptyPrivateKey tests that empty private key returns error func TestVerifyState_EmptyPrivateKey(t *testing.T) { cfg := &Config{PrivateKey: ""} _, err := VerifyState(cfg, "state", []byte("key")) if err == nil { t.Fatal("Expected error for empty private key, got nil") } if !strings.Contains(err.Error(), "private key cannot be empty") { t.Errorf("Expected error message about empty private key, got: %v", err) } } // TestVerifyState_EmptyState tests that empty state returns error func TestVerifyState_EmptyState(t *testing.T) { cfg := testConfig() _, err := VerifyState(cfg, "", []byte("key")) if err == nil { t.Fatal("Expected error for empty state, got nil") } if !strings.Contains(err.Error(), "state cannot be empty") { t.Errorf("Expected error message about empty state, got: %v", err) } } // TestVerifyState_EmptyUserAgentKey tests that empty userAgentKey returns error func TestVerifyState_EmptyUserAgentKey(t *testing.T) { cfg := testConfig() _, err := VerifyState(cfg, "data.random.signature", []byte{}) if err == nil { t.Fatal("Expected error for empty userAgentKey, got nil") } if !strings.Contains(err.Error(), "userAgentKey cannot be empty") { t.Errorf("Expected error message about empty userAgentKey, got: %v", err) } } // TestVerifyState_WrongUserAgentKey tests MITM protection func TestVerifyState_WrongUserAgentKey(t *testing.T) { cfg := testConfig() // Generate first state state, _, err := GenerateState(cfg, "test_data") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Generate a different userAgentKey _, wrongKey, err := GenerateState(cfg, "other_data") if err != nil { t.Fatalf("Failed to generate second state: %v", err) } // Try to verify with wrong key _, err = VerifyState(cfg, state, wrongKey) if err == nil { t.Error("Expected error for invalid signature") } if !strings.Contains(err.Error(), "invalid state signature") { t.Errorf("Expected error about invalid signature, got: %v", err) } } // TestVerifyState_TamperedData tests tampering detection func TestVerifyState_TamperedData(t *testing.T) { cfg := testConfig() // Generate state state, userAgentKey, err := GenerateState(cfg, "original_data") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Tamper with the data portion parts := strings.Split(state, ".") parts[0] = "tampered_data" tamperedState := strings.Join(parts, ".") // Try to verify tampered state _, err = VerifyState(cfg, tamperedState, userAgentKey) if err == nil { t.Error("Expected error for tampered state") } } // TestVerifyState_TamperedRandom tests tampering with random portion func TestVerifyState_TamperedRandom(t *testing.T) { cfg := testConfig() // Generate state state, userAgentKey, err := GenerateState(cfg, "test_data") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Tamper with the random portion parts := strings.Split(state, ".") parts[1] = base64.RawURLEncoding.EncodeToString([]byte("tampered_random_value_here12")) tamperedState := strings.Join(parts, ".") // Try to verify tampered state _, err = VerifyState(cfg, tamperedState, userAgentKey) if err == nil { t.Error("Expected error for tampered state") } } // TestVerifyState_TamperedSignature tests tampering with signature func TestVerifyState_TamperedSignature(t *testing.T) { cfg := testConfig() // Generate state state, userAgentKey, err := GenerateState(cfg, "test_data") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Tamper with the signature portion parts := strings.Split(state, ".") // Create a different valid base64 string parts[2] = base64.RawURLEncoding.EncodeToString(sha256.New().Sum([]byte("tampered"))) tamperedState := strings.Join(parts, ".") // Try to verify tampered state _, err = VerifyState(cfg, tamperedState, userAgentKey) if err == nil { t.Error("Expected error for tampered signature") } } // TestVerifyState_MalformedState_TwoParts tests state with only 2 parts func TestVerifyState_MalformedState_TwoParts(t *testing.T) { cfg := testConfig() malformedState := "data.random" _, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890")) if err == nil { t.Fatal("Expected error for malformed state") } if !strings.Contains(err.Error(), "must have exactly 3 parts") { t.Errorf("Expected error about incorrect number of parts, got: %v", err) } } // TestVerifyState_MalformedState_FourParts tests state with 4 parts func TestVerifyState_MalformedState_FourParts(t *testing.T) { cfg := testConfig() malformedState := "data.random.signature.extra" _, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890")) if err == nil { t.Fatal("Expected error for malformed state") } if !strings.Contains(err.Error(), "must have exactly 3 parts") { t.Errorf("Expected error about incorrect number of parts, got: %v", err) } } // TestVerifyState_EmptyStateParts tests state with empty parts func TestVerifyState_EmptyStateParts(t *testing.T) { cfg := testConfig() testCases := []struct { name string state string }{ {"empty data", ".random.signature"}, {"empty random", "data..signature"}, {"empty signature", "data.random."}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { _, err := VerifyState(cfg, tc.state, []byte("key123456789012345678901234567890")) if err == nil { t.Fatal("Expected error for state with empty parts") } if !strings.Contains(err.Error(), "state parts cannot be empty") { t.Errorf("Expected error about empty parts, got: %v", err) } }) } } // TestVerifyState_InvalidBase64Signature tests state with invalid base64 in signature func TestVerifyState_InvalidBase64Signature(t *testing.T) { cfg := testConfig() invalidState := "data.random.invalid@base64!" _, err := VerifyState(cfg, invalidState, []byte("key123456789012345678901234567890")) if err == nil { t.Fatal("Expected error for invalid base64 signature") } if !strings.Contains(err.Error(), "failed to decode") { t.Errorf("Expected error about decoding signature, got: %v", err) } } // TestVerifyState_DifferentPrivateKey tests that different private keys fail verification func TestVerifyState_DifferentPrivateKey(t *testing.T) { cfg1 := &Config{PrivateKey: "private_key_1"} cfg2 := &Config{PrivateKey: "private_key_2"} // Generate with first config state, userAgentKey, err := GenerateState(cfg1, "test_data") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Try to verify with second config _, err = VerifyState(cfg2, state, userAgentKey) if err == nil { t.Error("Expected error for mismatched private key") } } // TestRoundTrip tests complete round trip with various data payloads func TestRoundTrip(t *testing.T) { cfg := testConfig() testCases := []string{ "simple", "with-dashes-and-numbers-123", "MixedCaseData", "user_token_abc123", "link_resource_xyz789", } for _, data := range testCases { t.Run(data, func(t *testing.T) { // Generate state, userAgentKey, err := GenerateState(cfg, data) if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Verify extractedData, err := VerifyState(cfg, state, userAgentKey) if err != nil { t.Fatalf("Failed to verify state: %v", err) } if extractedData != data { t.Errorf("Expected extracted data '%s', got '%s'", data, extractedData) } }) } } // TestConcurrentGeneration tests that concurrent state generation works correctly func TestConcurrentGeneration(t *testing.T) { cfg := testConfig() data := "concurrent_test" const numGoroutines = 10 results := make(chan string, numGoroutines) errors := make(chan error, numGoroutines) // Generate states concurrently for range numGoroutines { go func() { state, userAgentKey, err := GenerateState(cfg, data) if err != nil { errors <- err return } // Verify immediately _, verifyErr := VerifyState(cfg, state, userAgentKey) if verifyErr != nil { errors <- verifyErr return } results <- state }() } // Collect results states := make(map[string]bool) for range numGoroutines { select { case state := <-results: if states[state] { t.Errorf("Duplicate state generated: %s", state) } states[state] = true case err := <-errors: t.Errorf("Concurrent generation error: %v", err) } } if len(states) != numGoroutines { t.Errorf("Expected %d unique states, got %d", numGoroutines, len(states)) } } // TestStateFormatCompatibility ensures state is URL-safe func TestStateFormatCompatibility(t *testing.T) { cfg := testConfig() data := "url_safe_test" state, _, err := GenerateState(cfg, data) if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Check that state doesn't contain characters that need URL encoding unsafeChars := []string{"+", "/", "=", " ", "&", "?", "#"} for _, char := range unsafeChars { if strings.Contains(state, char) { t.Errorf("State contains URL-unsafe character '%s': %s", char, state) } } } // TestMITM_AttackerCannotSubstituteState verifies MITM protection actually works // An attacker obtains their own valid state but tries to use it with victim's session func TestMITM_AttackerCannotSubstituteState(t *testing.T) { cfg := testConfig() // Victim generates a state for their login victimState, victimKey, err := GenerateState(cfg, "victim_data") if err != nil { t.Fatalf("Failed to generate victim state: %v", err) } // Attacker generates their own valid state (they can request this from the server) attackerState, attackerKey, err := GenerateState(cfg, "attacker_data") if err != nil { t.Fatalf("Failed to generate attacker state: %v", err) } // Both states should be valid on their own _, err = VerifyState(cfg, victimState, victimKey) if err != nil { t.Fatalf("Victim state should be valid: err=%v", err) } _, err = VerifyState(cfg, attackerState, attackerKey) if err != nil { t.Fatalf("Attacker state should be valid: err=%v", err) } // MITM Attack Scenario 1: Attacker substitutes their state but victim has their cookie // This should FAIL because attackerState was signed with attackerKey, not victimKey _, err = VerifyState(cfg, attackerState, victimKey) if err == nil { t.Error("Expected error when attacker substitutes state") } // MITM Attack Scenario 2: Attacker uses victim's state but has their own cookie // This should also FAIL _, err = VerifyState(cfg, victimState, attackerKey) if err == nil { t.Error("Expected error when attacker uses victim's state") } // The key insight: even though both states are "valid", they are bound to their respective cookies // An attacker cannot mix and match states and cookies t.Log("✓ MITM protection verified: States are cryptographically bound to their userAgentKey cookies") } // TestCSRF_AttackerCannotForgeState verifies CSRF protection // An attacker tries to forge a state parameter without knowing the private key func TestCSRF_AttackerCannotForgeState(t *testing.T) { cfg := testConfig() // Attacker doesn't know the private key, but tries to forge a state // They might try to construct: "malicious_data.random.signature" // Attempt 1: Use a random signature randomSig := base64.RawURLEncoding.EncodeToString([]byte("random_signature_attempt_12345678")) forgedState1 := "malicious_data.somefakerandom." + randomSig // Generate a real userAgentKey (attacker might try to get this) _, realKey, err := GenerateState(cfg, "legitimate_data") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Try to verify forged state _, err = VerifyState(cfg, forgedState1, realKey) if err == nil { t.Error("CSRF VULNERABILITY: Attacker forged a valid state without private key!") } // Attempt 2: Attacker tries to compute signature without private key // They use: SHA256(data + "." + random + userAgentKey) - missing privateKey attackerPayload := "malicious_data.fakerandom" + base64.RawURLEncoding.EncodeToString(realKey) hash := sha256.Sum256([]byte(attackerPayload)) attackerSig := base64.RawURLEncoding.EncodeToString(hash[:]) forgedState2 := "malicious_data.fakerandom." + attackerSig _, err = VerifyState(cfg, forgedState2, realKey) if err == nil { t.Error("CSRF VULNERABILITY: Attacker forged valid state without private key!") } t.Log("✓ CSRF protection verified: Cannot forge state without private key") } // TestTampering_SignatureDetectsAllModifications verifies tamper detection func TestTampering_SignatureDetectsAllModifications(t *testing.T) { cfg := testConfig() // Generate a valid state originalState, userAgentKey, err := GenerateState(cfg, "original_data") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Verify original is valid data, err := VerifyState(cfg, originalState, userAgentKey) if err != nil || data != "original_data" { t.Fatalf("Original state should be valid") } parts := strings.Split(originalState, ".") // Test 1: Attacker modifies data but keeps signature tamperedState := "modified_data." + parts[1] + "." + parts[2] _, err = VerifyState(cfg, tamperedState, userAgentKey) if err == nil { t.Error("TAMPER VULNERABILITY: Modified data not detected!") } // Test 2: Attacker modifies random but keeps signature newRandom := base64.RawURLEncoding.EncodeToString([]byte("new_random_value_32bytes_long!!")) tamperedState = parts[0] + "." + newRandom + "." + parts[2] _, err = VerifyState(cfg, tamperedState, userAgentKey) if err == nil { t.Error("TAMPER VULNERABILITY: Modified random not detected!") } // Test 3: Attacker tries to recompute signature but doesn't have privateKey // They compute: SHA256(modified_data + "." + random + userAgentKey) attackerPayload := "modified_data." + parts[1] + base64.RawURLEncoding.EncodeToString(userAgentKey) hash := sha256.Sum256([]byte(attackerPayload)) attackerSig := base64.RawURLEncoding.EncodeToString(hash[:]) tamperedState = "modified_data." + parts[1] + "." + attackerSig _, err = VerifyState(cfg, tamperedState, userAgentKey) if err == nil { t.Error("TAMPER VULNERABILITY: Attacker recomputed signature without private key!") } // Test 4: Single bit flip in signature sigBytes, _ := base64.RawURLEncoding.DecodeString(parts[2]) sigBytes[0] ^= 0x01 // Flip one bit flippedSig := base64.RawURLEncoding.EncodeToString(sigBytes) tamperedState = parts[0] + "." + parts[1] + "." + flippedSig _, err = VerifyState(cfg, tamperedState, userAgentKey) if err == nil { t.Error("TAMPER VULNERABILITY: Single bit flip in signature not detected!") } t.Log("✓ Tamper detection verified: All modifications to state are detected") } // TestReplay_DifferentSessionsCannotReuseState verifies replay protection func TestReplay_DifferentSessionsCannotReuseState(t *testing.T) { cfg := testConfig() // Session 1: User initiates OAuth flow state1, key1, err := GenerateState(cfg, "session1_data") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // State is valid for session 1 _, err = VerifyState(cfg, state1, key1) if err != nil { t.Fatalf("State should be valid for session 1") } // Session 2: Same user (or attacker) initiates a new OAuth flow state2, key2, err := GenerateState(cfg, "session1_data") // same data if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Replay Attack: Try to use state1 with key2 _, err = VerifyState(cfg, state1, key2) if err == nil { t.Error("REPLAY VULNERABILITY: State from session 1 was accepted in session 2!") } // Even with same data, each session should have unique state+key binding if state1 == state2 { t.Error("REPLAY VULNERABILITY: Same data produces identical states!") } t.Log("✓ Replay protection verified: States are bound to specific session cookies") } // TestConstantTimeComparison verifies that signature comparison is timing-safe // This is a behavioral test - we can't easily test timing, but we can verify the function is used func TestConstantTimeComparison_IsUsed(t *testing.T) { cfg := testConfig() // Generate valid state state, userAgentKey, err := GenerateState(cfg, "test") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Create states with signatures that differ at different positions parts := strings.Split(state, ".") originalSig, _ := base64.RawURLEncoding.DecodeString(parts[2]) testCases := []struct { name string position int }{ {"first byte differs", 0}, {"middle byte differs", 16}, {"last byte differs", 31}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create signature that differs at specific position tamperedSig := make([]byte, len(originalSig)) copy(tamperedSig, originalSig) tamperedSig[tc.position] ^= 0xFF // Flip all bits tamperedSigStr := base64.RawURLEncoding.EncodeToString(tamperedSig) tamperedState := parts[0] + "." + parts[1] + "." + tamperedSigStr // All should fail verification _, err := VerifyState(cfg, tamperedState, userAgentKey) if err == nil { t.Errorf("Tampered signature at position %d should be invalid", tc.position) } // If constant-time comparison is NOT used, early differences would return faster // While we can't easily test timing here, we verify all positions fail equally }) } t.Log("✓ Constant-time comparison: All signature positions validated equally") t.Log(" Note: crypto/subtle.ConstantTimeCompare is used in implementation") } // TestPrivateKey_IsCriticalToSecurity verifies private key is essential func TestPrivateKey_IsCriticalToSecurity(t *testing.T) { cfg1 := &Config{PrivateKey: "secret_key_1"} cfg2 := &Config{PrivateKey: "secret_key_2"} // Generate state with key1 state, userAgentKey, err := GenerateState(cfg1, "data") if err != nil { t.Fatalf("Failed to generate state: %v", err) } // Should verify with key1 _, err = VerifyState(cfg1, state, userAgentKey) if err != nil { t.Fatalf("State should be valid with correct private key") } // Should NOT verify with key2 (different private key) _, err = VerifyState(cfg2, state, userAgentKey) if err == nil { t.Error("SECURITY VULNERABILITY: State verified with different private key!") } // This proves that the private key is cryptographically involved in the signature t.Log("✓ Private key security verified: Different keys produce incompatible signatures") } // TestUserAgentKey_ProperlyIntegratedInSignature verifies userAgentKey is in payload func TestUserAgentKey_ProperlyIntegratedInSignature(t *testing.T) { cfg := testConfig() // Generate two states with same data but different userAgentKeys (implicit) state1, key1, err := GenerateState(cfg, "same_data") if err != nil { t.Fatalf("Failed to generate state1: %v", err) } state2, key2, err := GenerateState(cfg, "same_data") if err != nil { t.Fatalf("Failed to generate state2: %v", err) } // The states should be different even with same data (different random and keys) if state1 == state2 { t.Error("States should differ due to different random values") } // Each state should only verify with its own key _, err1 := VerifyState(cfg, state1, key1) _, err2 := VerifyState(cfg, state2, key2) if err1 != nil || err2 != nil { t.Fatal("States should be valid with their own keys") } // Cross-verification should fail _, err1 = VerifyState(cfg, state1, key2) _, err2 = VerifyState(cfg, state2, key1) if err1 == nil || err2 == nil { t.Error("SECURITY VULNERABILITY: userAgentKey not properly integrated in signature!") } t.Log("✓ UserAgentKey integration verified: Each state bound to its specific key") }