Files
oslstats/pkg/oauth/state.go

118 lines
3.5 KiB
Go

package oauth
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"slices"
"strings"
"github.com/pkg/errors"
)
// STATE FLOW:
// data provided at call time to be retrieved later
// random value generated on the spot
// userAgentKey - nonce used to prevent MITM, stored as lax cookie on client
// privateKey - from config
func GenerateState(cfg *Config, data string) (state string, userAgentKey []byte, err error) {
// signature = BASE64_SHA256(data + "." + random + userAgentKey + privateKey)
// state = data + "." + random + "." + signature
if cfg == nil {
return "", nil, errors.New("cfg cannot be nil")
}
if cfg.PrivateKey == "" {
return "", nil, errors.New("private key cannot be empty")
}
if data == "" {
return "", nil, errors.New("data cannot be empty")
}
// Generate 32 random bytes for random component
randomBytes := make([]byte, 32)
_, err = rand.Read(randomBytes)
if err != nil {
return "", nil, errors.Wrap(err, "failed to generate random bytes")
}
// Generate 32 random bytes for userAgentKey
userAgentKey = make([]byte, 32)
_, err = rand.Read(userAgentKey)
if err != nil {
return "", nil, errors.Wrap(err, "failed to generate userAgentKey bytes")
}
// Encode random and userAgentKey to base64
randomEncoded := base64.RawURLEncoding.EncodeToString(randomBytes)
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
// Create payload for signing: data + "." + random + userAgentKey + privateKey
// Note: userAgentKey is concatenated directly with privateKey (no separator)
payload := data + "." + randomEncoded + userAgentKeyEncoded + cfg.PrivateKey
// Generate signature
hash := sha256.Sum256([]byte(payload))
signature := base64.RawURLEncoding.EncodeToString(hash[:])
// Construct state: data + "." + random + "." + signature
state = data + "." + randomEncoded + "." + signature
return state, userAgentKey, nil
}
func VerifyState(cfg *Config, state string, userAgentKey []byte) (data string, err error) {
// Validate inputs
if cfg == nil {
return "", errors.New("cfg cannot be nil")
}
if cfg.PrivateKey == "" {
return "", errors.New("private key cannot be empty")
}
if state == "" {
return "", errors.New("state cannot be empty")
}
if len(userAgentKey) == 0 {
return "", errors.New("userAgentKey cannot be empty")
}
// Split state into parts
parts := strings.Split(state, ".")
if len(parts) != 3 {
return "", errors.Errorf("state must have exactly 3 parts (data.random.signature), got %d parts", len(parts))
}
// Check for empty parts
if slices.Contains(parts, "") {
return "", errors.New("state parts cannot be empty")
}
data = parts[0]
random := parts[1]
receivedSignature := parts[2]
// Encode userAgentKey to base64 for payload reconstruction
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
// Reconstruct payload (same as generation): data + "." + random + userAgentKeyEncoded + privateKey
payload := data + "." + random + userAgentKeyEncoded + cfg.PrivateKey
// Generate expected hash
hash := sha256.Sum256([]byte(payload))
// Decode received signature to bytes
receivedBytes, err := base64.RawURLEncoding.DecodeString(receivedSignature)
if err != nil {
return "", errors.Wrap(err, "failed to decode received signature")
}
// Compare hash bytes directly with decoded signature using constant-time comparison
// This is more efficient than encoding hash and then decoding both for comparison
if subtle.ConstantTimeCompare(hash[:], receivedBytes) == 1 {
return data, nil
}
return "", errors.New("invalid state signature")
}