added oauth flow to get authorization code
This commit is contained in:
@@ -53,7 +53,7 @@ func setupHttpServer(
|
|||||||
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
|
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = addRoutes(httpServer, &fs, config, logger, bun, auth)
|
err = addRoutes(httpServer, &fs, config, bun, auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "addRoutes")
|
return nil, errors.Wrap(err, "addRoutes")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/handlers"
|
"git.haelnorr.com/h/oslstats/internal/handlers"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hlog"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
@@ -17,8 +16,7 @@ import (
|
|||||||
func addRoutes(
|
func addRoutes(
|
||||||
server *hws.Server,
|
server *hws.Server,
|
||||||
staticFS *http.FileSystem,
|
staticFS *http.FileSystem,
|
||||||
config *config.Config,
|
cfg *config.Config,
|
||||||
logger *hlog.Logger,
|
|
||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
) error {
|
) error {
|
||||||
@@ -34,6 +32,16 @@ func addRoutes(
|
|||||||
Method: hws.MethodGET,
|
Method: hws.MethodGET,
|
||||||
Handler: handlers.Index(server),
|
Handler: handlers.Index(server),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Path: "/login",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LogoutReq(handlers.Login(server, cfg)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/auth/callback",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LogoutReq(handlers.Callback(server, cfg)),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the routes with the server
|
// Register the routes with the server
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/golib/hwsauth"
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
|
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
@@ -15,6 +17,8 @@ type Config struct {
|
|||||||
HWS *hws.Config
|
HWS *hws.Config
|
||||||
HWSAuth *hwsauth.Config
|
HWSAuth *hwsauth.Config
|
||||||
HLOG *hlog.Config
|
HLOG *hlog.Config
|
||||||
|
Discord *discord.Config
|
||||||
|
OAuth *oauth.Config
|
||||||
Flags *Flags
|
Flags *Flags
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,6 +36,8 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
|
|||||||
hws.NewEZConfIntegration(),
|
hws.NewEZConfIntegration(),
|
||||||
hwsauth.NewEZConfIntegration(),
|
hwsauth.NewEZConfIntegration(),
|
||||||
db.NewEZConfIntegration(),
|
db.NewEZConfIntegration(),
|
||||||
|
discord.NewEZConfIntegration(),
|
||||||
|
oauth.NewEZConfIntegration(),
|
||||||
)
|
)
|
||||||
if err := loader.ParseEnvVars(); err != nil {
|
if err := loader.ParseEnvVars(); err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "loader.ParseEnvVars")
|
return nil, nil, errors.Wrap(err, "loader.ParseEnvVars")
|
||||||
@@ -65,11 +71,23 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
|
|||||||
return nil, nil, errors.New("DB Config not loaded")
|
return nil, nil, errors.New("DB Config not loaded")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
discordcfg, ok := loader.GetConfig("discord")
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errors.New("Dicord Config not loaded")
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthcfg, ok := loader.GetConfig("oauth")
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errors.New("OAuth Config not loaded")
|
||||||
|
}
|
||||||
|
|
||||||
config := &Config{
|
config := &Config{
|
||||||
DB: dbcfg.(*db.Config),
|
DB: dbcfg.(*db.Config),
|
||||||
HWS: hwscfg.(*hws.Config),
|
HWS: hwscfg.(*hws.Config),
|
||||||
HWSAuth: hwsauthcfg.(*hwsauth.Config),
|
HWSAuth: hwsauthcfg.(*hwsauth.Config),
|
||||||
HLOG: hlogcfg.(*hlog.Config),
|
HLOG: hlogcfg.(*hlog.Config),
|
||||||
|
Discord: discordcfg.(*discord.Config),
|
||||||
|
OAuth: oauthcfg.(*oauth.Config),
|
||||||
Flags: flags,
|
Flags: flags,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,5 +37,5 @@ func (e EZConfIntegration) GroupName() string {
|
|||||||
|
|
||||||
// NewEZConfIntegration creates a new EZConf integration helper
|
// NewEZConfIntegration creates a new EZConf integration helper
|
||||||
func NewEZConfIntegration() EZConfIntegration {
|
func NewEZConfIntegration() EZConfIntegration {
|
||||||
return EZConfIntegration{name: "db", configFunc: ConfigFromEnv}
|
return EZConfIntegration{name: "DB", configFunc: ConfigFromEnv}
|
||||||
}
|
}
|
||||||
|
|||||||
50
internal/discord/config.go
Normal file
50
internal/discord/config.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package discord
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/env"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
ClientID string // ENV DISCORD_CLIENT_ID: Discord application client ID (required)
|
||||||
|
ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required)
|
||||||
|
OAuthScopes string // Authorisation scopes for OAuth
|
||||||
|
RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConfigFromEnv() (any, error) {
|
||||||
|
cfg := &Config{
|
||||||
|
ClientID: env.String("DISCORD_CLIENT_ID", ""),
|
||||||
|
ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""),
|
||||||
|
OAuthScopes: getOAuthScopes(),
|
||||||
|
RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check required fields
|
||||||
|
if cfg.ClientID == "" {
|
||||||
|
return nil, errors.New("Envar not set: DISCORD_CLIENT_ID")
|
||||||
|
}
|
||||||
|
if cfg.ClientSecret == "" {
|
||||||
|
return nil, errors.New("Envar not set: DISCORD_CLIENT_SECRET")
|
||||||
|
}
|
||||||
|
if cfg.RedirectPath == "" {
|
||||||
|
return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOAuthScopes() string {
|
||||||
|
list := []string{
|
||||||
|
"connections",
|
||||||
|
"email",
|
||||||
|
"guilds",
|
||||||
|
"gdm.join",
|
||||||
|
"guilds.members.read",
|
||||||
|
"identify",
|
||||||
|
}
|
||||||
|
scopes := strings.Join(list, "+")
|
||||||
|
return scopes
|
||||||
|
}
|
||||||
41
internal/discord/ezconf.go
Normal file
41
internal/discord/ezconf.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package discord
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EZConfIntegration provides integration with ezconf for automatic configuration
|
||||||
|
type EZConfIntegration struct {
|
||||||
|
configFunc func() (any, error)
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// PackagePath returns the path to the config package for source parsing
|
||||||
|
func (e EZConfIntegration) PackagePath() string {
|
||||||
|
_, filename, _, _ := runtime.Caller(0)
|
||||||
|
// Return directory of this file
|
||||||
|
return filename[:len(filename)-len("/ezconf.go")]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||||
|
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||||
|
return func() (any, error) {
|
||||||
|
return e.configFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the name to use when registering with ezconf
|
||||||
|
func (e EZConfIntegration) Name() string {
|
||||||
|
return strings.ToLower(e.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupName returns the display name for grouping environment variables
|
||||||
|
func (e EZConfIntegration) GroupName() string {
|
||||||
|
return e.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEZConfIntegration creates a new EZConf integration helper
|
||||||
|
func NewEZConfIntegration() EZConfIntegration {
|
||||||
|
return EZConfIntegration{name: "Discord", configFunc: ConfigFromEnv}
|
||||||
|
}
|
||||||
39
internal/discord/oauth.go
Normal file
39
internal/discord/oauth.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package discord
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
AccessToken string
|
||||||
|
TokenType string
|
||||||
|
ExpiresIn int
|
||||||
|
RefreshToken string
|
||||||
|
Scope string
|
||||||
|
}
|
||||||
|
|
||||||
|
const oauthurl string = "https://discord.com/oauth2/authorize"
|
||||||
|
|
||||||
|
func GetOAuthLink(cfg *Config, state string, trustedHost string) (string, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return "", errors.New("cfg cannot be nil")
|
||||||
|
}
|
||||||
|
if state == "" {
|
||||||
|
return "", errors.New("state cannot be empty")
|
||||||
|
}
|
||||||
|
if trustedHost == "" {
|
||||||
|
return "", errors.New("trustedHost cannot be empty")
|
||||||
|
}
|
||||||
|
values := url.Values{}
|
||||||
|
values.Add("response_type", "code")
|
||||||
|
values.Add("client_id", cfg.ClientID)
|
||||||
|
values.Add("scope", cfg.OAuthScopes)
|
||||||
|
values.Add("state", state)
|
||||||
|
values.Add("redirect_uri", fmt.Sprintf("%s/%s", trustedHost, cfg.RedirectPath))
|
||||||
|
values.Add("prompt", "none")
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil
|
||||||
|
}
|
||||||
61
internal/handlers/callback.go
Normal file
61
internal/handlers/callback.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Callback(server *hws.Server, cfg *config.Config) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
if state == "" && code == "" {
|
||||||
|
http.Redirect(w, r, "/", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := verifyState(cfg.OAuth, w, r, state)
|
||||||
|
if err != nil {
|
||||||
|
err = server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
Message: "OAuth state verification failed",
|
||||||
|
Error: err,
|
||||||
|
Level: hws.ErrorLevel("debug"),
|
||||||
|
RenderErrorPage: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch data {
|
||||||
|
case "login":
|
||||||
|
w.Write([]byte(code))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyState(cfg *oauth.Config, w http.ResponseWriter, r *http.Request, state string) (string, error) {
|
||||||
|
if r == nil {
|
||||||
|
return "", errors.New("request cannot be nil")
|
||||||
|
}
|
||||||
|
if state == "" {
|
||||||
|
return "", errors.New("state param field is empty")
|
||||||
|
}
|
||||||
|
uak, err := oauth.GetStateCookie(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "oauth.GetStateCookie")
|
||||||
|
}
|
||||||
|
data, err := oauth.VerifyState(cfg, state, uak)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "oauth.VerifyState")
|
||||||
|
}
|
||||||
|
oauth.DeleteStateCookie(w)
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
48
internal/handlers/login.go
Normal file
48
internal/handlers/login.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
|
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Login(server *hws.Server, cfg *config.Config) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
|
||||||
|
if err != nil {
|
||||||
|
err = server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Failed to generate state token",
|
||||||
|
Error: err,
|
||||||
|
Level: hws.ErrorLevel("error"),
|
||||||
|
RenderErrorPage: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
|
||||||
|
|
||||||
|
link, err := discord.GetOAuthLink(cfg.Discord, state, cfg.HWSAuth.TrustedHost)
|
||||||
|
if err != nil {
|
||||||
|
err = server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "An error occured trying to generate the login link",
|
||||||
|
Error: err,
|
||||||
|
Level: hws.ErrorLevel("error"),
|
||||||
|
RenderErrorPage: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, link, http.StatusSeeOther)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -38,33 +38,6 @@
|
|||||||
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
--default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
||||||
--default-font-family: var(--font-sans);
|
--default-font-family: var(--font-sans);
|
||||||
--default-mono-font-family: var(--font-mono);
|
--default-mono-font-family: var(--font-mono);
|
||||||
--color-rosewater: var(--rosewater);
|
|
||||||
--color-flamingo: var(--flamingo);
|
|
||||||
--color-pink: var(--pink);
|
|
||||||
--color-mauve: var(--mauve);
|
|
||||||
--color-red: var(--red);
|
|
||||||
--color-dark-red: var(--dark-red);
|
|
||||||
--color-maroon: var(--maroon);
|
|
||||||
--color-peach: var(--peach);
|
|
||||||
--color-yellow: var(--yellow);
|
|
||||||
--color-green: var(--green);
|
|
||||||
--color-teal: var(--teal);
|
|
||||||
--color-sky: var(--sky);
|
|
||||||
--color-sapphire: var(--sapphire);
|
|
||||||
--color-blue: var(--blue);
|
|
||||||
--color-lavender: var(--lavender);
|
|
||||||
--color-text: var(--text);
|
|
||||||
--color-subtext1: var(--subtext1);
|
|
||||||
--color-subtext0: var(--subtext0);
|
|
||||||
--color-overlay2: var(--overlay2);
|
|
||||||
--color-overlay1: var(--overlay1);
|
|
||||||
--color-overlay0: var(--overlay0);
|
|
||||||
--color-surface2: var(--surface2);
|
|
||||||
--color-surface1: var(--surface1);
|
|
||||||
--color-surface0: var(--surface0);
|
|
||||||
--color-base: var(--base);
|
|
||||||
--color-mantle: var(--mantle);
|
|
||||||
--color-crust: var(--crust);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@layer base {
|
@layer base {
|
||||||
|
|||||||
23
pkg/oauth/config.go
Normal file
23
pkg/oauth/config.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.haelnorr.com/h/golib/env"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
PrivateKey string // ENV OAUTH_PRIVATE_KEY: Private key for signing OAuth state tokens (required)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConfigFromEnv() (any, error) {
|
||||||
|
cfg := &Config{
|
||||||
|
PrivateKey: env.String("OAUTH_PRIVATE_KEY", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check required fields
|
||||||
|
if cfg.PrivateKey == "" {
|
||||||
|
return nil, errors.New("Envar not set: OAUTH_PRIVATE_KEY")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
45
pkg/oauth/cookies.go
Normal file
45
pkg/oauth/cookies.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetStateCookie(w http.ResponseWriter, uak []byte, ssl bool) {
|
||||||
|
encodedUak := base64.RawURLEncoding.EncodeToString(uak)
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "oauth_uak",
|
||||||
|
Value: encodedUak,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: 300,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: ssl,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetStateCookie(r *http.Request) ([]byte, error) {
|
||||||
|
if r == nil {
|
||||||
|
return nil, errors.New("Request cannot be nil")
|
||||||
|
}
|
||||||
|
cookie, err := r.Cookie("oauth_uak")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
uak, err := base64.RawURLEncoding.DecodeString(cookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to decode userAgentKey from cookie")
|
||||||
|
}
|
||||||
|
return uak, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteStateCookie(w http.ResponseWriter) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "oauth_uak",
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: -1,
|
||||||
|
})
|
||||||
|
}
|
||||||
41
pkg/oauth/ezconf.go
Normal file
41
pkg/oauth/ezconf.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EZConfIntegration provides integration with ezconf for automatic configuration
|
||||||
|
type EZConfIntegration struct {
|
||||||
|
configFunc func() (any, error)
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// PackagePath returns the path to the config package for source parsing
|
||||||
|
func (e EZConfIntegration) PackagePath() string {
|
||||||
|
_, filename, _, _ := runtime.Caller(0)
|
||||||
|
// Return directory of this file
|
||||||
|
return filename[:len(filename)-len("/ezconf.go")]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||||
|
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||||
|
return func() (any, error) {
|
||||||
|
return e.configFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the name to use when registering with ezconf
|
||||||
|
func (e EZConfIntegration) Name() string {
|
||||||
|
return strings.ToLower(e.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupName returns the display name for grouping environment variables
|
||||||
|
func (e EZConfIntegration) GroupName() string {
|
||||||
|
return e.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEZConfIntegration creates a new EZConf integration helper
|
||||||
|
func NewEZConfIntegration() EZConfIntegration {
|
||||||
|
return EZConfIntegration{name: "OAuth", configFunc: ConfigFromEnv}
|
||||||
|
}
|
||||||
117
pkg/oauth/state.go
Normal file
117
pkg/oauth/state.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// STATE FLOW:
|
||||||
|
// data provided at call time to be retrieved later
|
||||||
|
// random value generated on the spot
|
||||||
|
// userAgentKey - nonce used to prevent MITM, stored as lax cookie on client
|
||||||
|
// privateKey - from config
|
||||||
|
|
||||||
|
func GenerateState(cfg *Config, data string) (state string, userAgentKey []byte, err error) {
|
||||||
|
// signature = BASE64_SHA256(data + "." + random + userAgentKey + privateKey)
|
||||||
|
// state = data + "." + random + "." + signature
|
||||||
|
if cfg == nil {
|
||||||
|
return "", nil, errors.New("cfg cannot be nil")
|
||||||
|
}
|
||||||
|
if cfg.PrivateKey == "" {
|
||||||
|
return "", nil, errors.New("private key cannot be empty")
|
||||||
|
}
|
||||||
|
if data == "" {
|
||||||
|
return "", nil, errors.New("data cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate 32 random bytes for random component
|
||||||
|
randomBytes := make([]byte, 32)
|
||||||
|
_, err = rand.Read(randomBytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, errors.Wrap(err, "failed to generate random bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate 32 random bytes for userAgentKey
|
||||||
|
userAgentKey = make([]byte, 32)
|
||||||
|
_, err = rand.Read(userAgentKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, errors.Wrap(err, "failed to generate userAgentKey bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode random and userAgentKey to base64
|
||||||
|
randomEncoded := base64.RawURLEncoding.EncodeToString(randomBytes)
|
||||||
|
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
|
||||||
|
|
||||||
|
// Create payload for signing: data + "." + random + userAgentKey + privateKey
|
||||||
|
// Note: userAgentKey is concatenated directly with privateKey (no separator)
|
||||||
|
payload := data + "." + randomEncoded + userAgentKeyEncoded + cfg.PrivateKey
|
||||||
|
|
||||||
|
// Generate signature
|
||||||
|
hash := sha256.Sum256([]byte(payload))
|
||||||
|
signature := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
// Construct state: data + "." + random + "." + signature
|
||||||
|
state = data + "." + randomEncoded + "." + signature
|
||||||
|
|
||||||
|
return state, userAgentKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func VerifyState(cfg *Config, state string, userAgentKey []byte) (data string, err error) {
|
||||||
|
// Validate inputs
|
||||||
|
if cfg == nil {
|
||||||
|
return "", errors.New("cfg cannot be nil")
|
||||||
|
}
|
||||||
|
if cfg.PrivateKey == "" {
|
||||||
|
return "", errors.New("private key cannot be empty")
|
||||||
|
}
|
||||||
|
if state == "" {
|
||||||
|
return "", errors.New("state cannot be empty")
|
||||||
|
}
|
||||||
|
if len(userAgentKey) == 0 {
|
||||||
|
return "", errors.New("userAgentKey cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split state into parts
|
||||||
|
parts := strings.Split(state, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return "", errors.Errorf("state must have exactly 3 parts (data.random.signature), got %d parts", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for empty parts
|
||||||
|
if slices.Contains(parts, "") {
|
||||||
|
return "", errors.New("state parts cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
data = parts[0]
|
||||||
|
random := parts[1]
|
||||||
|
receivedSignature := parts[2]
|
||||||
|
|
||||||
|
// Encode userAgentKey to base64 for payload reconstruction
|
||||||
|
userAgentKeyEncoded := base64.RawURLEncoding.EncodeToString(userAgentKey)
|
||||||
|
|
||||||
|
// Reconstruct payload (same as generation): data + "." + random + userAgentKeyEncoded + privateKey
|
||||||
|
payload := data + "." + random + userAgentKeyEncoded + cfg.PrivateKey
|
||||||
|
|
||||||
|
// Generate expected hash
|
||||||
|
hash := sha256.Sum256([]byte(payload))
|
||||||
|
|
||||||
|
// Decode received signature to bytes
|
||||||
|
receivedBytes, err := base64.RawURLEncoding.DecodeString(receivedSignature)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "failed to decode received signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare hash bytes directly with decoded signature using constant-time comparison
|
||||||
|
// This is more efficient than encoding hash and then decoding both for comparison
|
||||||
|
if subtle.ConstantTimeCompare(hash[:], receivedBytes) == 1 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", errors.New("invalid state signature")
|
||||||
|
}
|
||||||
817
pkg/oauth/state_test.go
Normal file
817
pkg/oauth/state_test.go
Normal file
@@ -0,0 +1,817 @@
|
|||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to create a test config
|
||||||
|
func testConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
PrivateKey: "test_private_key_for_testing_12345",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGenerateState_Success tests the happy path of state generation
|
||||||
|
func TestGenerateState_Success(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
data := "test_data_payload"
|
||||||
|
|
||||||
|
state, userAgentKey, err := GenerateState(cfg, data)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state == "" {
|
||||||
|
t.Error("Expected non-empty state")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userAgentKey) != 32 {
|
||||||
|
t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify state format: data.random.signature
|
||||||
|
parts := strings.Split(state, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Errorf("Expected state to have 3 parts, got %d", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify data is preserved
|
||||||
|
if parts[0] != data {
|
||||||
|
t.Errorf("Expected data to be '%s', got '%s'", data, parts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify random part is base64 encoded
|
||||||
|
randomBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected random part to be valid base64: %v", err)
|
||||||
|
}
|
||||||
|
if len(randomBytes) != 32 {
|
||||||
|
t.Errorf("Expected random to be 32 bytes when decoded, got %d", len(randomBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify signature part is base64 encoded
|
||||||
|
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected signature part to be valid base64: %v", err)
|
||||||
|
}
|
||||||
|
if len(sigBytes) != 32 {
|
||||||
|
t.Errorf("Expected signature to be 32 bytes (SHA256), got %d", len(sigBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGenerateState_NilConfig tests that nil config returns error
|
||||||
|
func TestGenerateState_NilConfig(t *testing.T) {
|
||||||
|
_, _, err := GenerateState(nil, "test_data")
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for nil config, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "cfg cannot be nil") {
|
||||||
|
t.Errorf("Expected error message about nil config, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGenerateState_EmptyPrivateKey tests that empty private key returns error
|
||||||
|
func TestGenerateState_EmptyPrivateKey(t *testing.T) {
|
||||||
|
cfg := &Config{PrivateKey: ""}
|
||||||
|
_, _, err := GenerateState(cfg, "test_data")
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for empty private key, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "private key cannot be empty") {
|
||||||
|
t.Errorf("Expected error message about empty private key, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGenerateState_EmptyData tests that empty data returns error
|
||||||
|
func TestGenerateState_EmptyData(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
_, _, err := GenerateState(cfg, "")
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for empty data, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "data cannot be empty") {
|
||||||
|
t.Errorf("Expected error message about empty data, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGenerateState_Randomness tests that multiple calls generate different states
|
||||||
|
func TestGenerateState_Randomness(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
data := "same_data"
|
||||||
|
|
||||||
|
state1, _, err1 := GenerateState(cfg, data)
|
||||||
|
state2, _, err2 := GenerateState(cfg, data)
|
||||||
|
|
||||||
|
if err1 != nil || err2 != nil {
|
||||||
|
t.Fatalf("Unexpected errors: %v, %v", err1, err2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state1 == state2 {
|
||||||
|
t.Error("Expected different states for multiple calls, got identical states")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGenerateState_DifferentData tests states with different data payloads
|
||||||
|
func TestGenerateState_DifferentData(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
testCases := []string{
|
||||||
|
"simple",
|
||||||
|
"with-dashes",
|
||||||
|
"with_underscores",
|
||||||
|
"123456789",
|
||||||
|
"MixedCase123",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, data := range testCases {
|
||||||
|
t.Run(data, func(t *testing.T) {
|
||||||
|
state, userAgentKey, err := GenerateState(cfg, data)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error for data '%s': %v", data, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(state, data+".") {
|
||||||
|
t.Errorf("Expected state to start with '%s.', got: %s", data, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userAgentKey) != 32 {
|
||||||
|
t.Errorf("Expected userAgentKey to be 32 bytes, got %d", len(userAgentKey))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_Success tests the happy path of state verification
|
||||||
|
func TestVerifyState_Success(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
data := "test_data"
|
||||||
|
|
||||||
|
// Generate state
|
||||||
|
state, userAgentKey, err := GenerateState(cfg, data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify state
|
||||||
|
extractedData, err := VerifyState(cfg, state, userAgentKey)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if extractedData != data {
|
||||||
|
t.Errorf("Expected extracted data to be '%s', got '%s'", data, extractedData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_NilConfig tests that nil config returns error
|
||||||
|
func TestVerifyState_NilConfig(t *testing.T) {
|
||||||
|
_, err := VerifyState(nil, "state", []byte("key"))
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for nil config, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "cfg cannot be nil") {
|
||||||
|
t.Errorf("Expected error message about nil config, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_EmptyPrivateKey tests that empty private key returns error
|
||||||
|
func TestVerifyState_EmptyPrivateKey(t *testing.T) {
|
||||||
|
cfg := &Config{PrivateKey: ""}
|
||||||
|
_, err := VerifyState(cfg, "state", []byte("key"))
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for empty private key, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "private key cannot be empty") {
|
||||||
|
t.Errorf("Expected error message about empty private key, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_EmptyState tests that empty state returns error
|
||||||
|
func TestVerifyState_EmptyState(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
_, err := VerifyState(cfg, "", []byte("key"))
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for empty state, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "state cannot be empty") {
|
||||||
|
t.Errorf("Expected error message about empty state, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_EmptyUserAgentKey tests that empty userAgentKey returns error
|
||||||
|
func TestVerifyState_EmptyUserAgentKey(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
_, err := VerifyState(cfg, "data.random.signature", []byte{})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for empty userAgentKey, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "userAgentKey cannot be empty") {
|
||||||
|
t.Errorf("Expected error message about empty userAgentKey, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_WrongUserAgentKey tests MITM protection
|
||||||
|
func TestVerifyState_WrongUserAgentKey(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Generate first state
|
||||||
|
state, _, err := GenerateState(cfg, "test_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a different userAgentKey
|
||||||
|
_, wrongKey, err := GenerateState(cfg, "other_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate second state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to verify with wrong key
|
||||||
|
_, err = VerifyState(cfg, state, wrongKey)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "invalid state signature") {
|
||||||
|
t.Errorf("Expected error about invalid signature, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_TamperedData tests tampering detection
|
||||||
|
func TestVerifyState_TamperedData(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Generate state
|
||||||
|
state, userAgentKey, err := GenerateState(cfg, "original_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tamper with the data portion
|
||||||
|
parts := strings.Split(state, ".")
|
||||||
|
parts[0] = "tampered_data"
|
||||||
|
tamperedState := strings.Join(parts, ".")
|
||||||
|
|
||||||
|
// Try to verify tampered state
|
||||||
|
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for tampered state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_TamperedRandom tests tampering with random portion
|
||||||
|
func TestVerifyState_TamperedRandom(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Generate state
|
||||||
|
state, userAgentKey, err := GenerateState(cfg, "test_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tamper with the random portion
|
||||||
|
parts := strings.Split(state, ".")
|
||||||
|
parts[1] = base64.RawURLEncoding.EncodeToString([]byte("tampered_random_value_here12"))
|
||||||
|
tamperedState := strings.Join(parts, ".")
|
||||||
|
|
||||||
|
// Try to verify tampered state
|
||||||
|
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for tampered state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_TamperedSignature tests tampering with signature
|
||||||
|
func TestVerifyState_TamperedSignature(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Generate state
|
||||||
|
state, userAgentKey, err := GenerateState(cfg, "test_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tamper with the signature portion
|
||||||
|
parts := strings.Split(state, ".")
|
||||||
|
// Create a different valid base64 string
|
||||||
|
parts[2] = base64.RawURLEncoding.EncodeToString(sha256.New().Sum([]byte("tampered")))
|
||||||
|
tamperedState := strings.Join(parts, ".")
|
||||||
|
|
||||||
|
// Try to verify tampered state
|
||||||
|
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for tampered signature")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_MalformedState_TwoParts tests state with only 2 parts
|
||||||
|
func TestVerifyState_MalformedState_TwoParts(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
malformedState := "data.random"
|
||||||
|
|
||||||
|
_, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890"))
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for malformed state")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "must have exactly 3 parts") {
|
||||||
|
t.Errorf("Expected error about incorrect number of parts, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_MalformedState_FourParts tests state with 4 parts
|
||||||
|
func TestVerifyState_MalformedState_FourParts(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
malformedState := "data.random.signature.extra"
|
||||||
|
|
||||||
|
_, err := VerifyState(cfg, malformedState, []byte("key123456789012345678901234567890"))
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for malformed state")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "must have exactly 3 parts") {
|
||||||
|
t.Errorf("Expected error about incorrect number of parts, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_EmptyStateParts tests state with empty parts
|
||||||
|
func TestVerifyState_EmptyStateParts(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
state string
|
||||||
|
}{
|
||||||
|
{"empty data", ".random.signature"},
|
||||||
|
{"empty random", "data..signature"},
|
||||||
|
{"empty signature", "data.random."},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := VerifyState(cfg, tc.state, []byte("key123456789012345678901234567890"))
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for state with empty parts")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "state parts cannot be empty") {
|
||||||
|
t.Errorf("Expected error about empty parts, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_InvalidBase64Signature tests state with invalid base64 in signature
|
||||||
|
func TestVerifyState_InvalidBase64Signature(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
invalidState := "data.random.invalid@base64!"
|
||||||
|
|
||||||
|
_, err := VerifyState(cfg, invalidState, []byte("key123456789012345678901234567890"))
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for invalid base64 signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "failed to decode") {
|
||||||
|
t.Errorf("Expected error about decoding signature, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyState_DifferentPrivateKey tests that different private keys fail verification
|
||||||
|
func TestVerifyState_DifferentPrivateKey(t *testing.T) {
|
||||||
|
cfg1 := &Config{PrivateKey: "private_key_1"}
|
||||||
|
cfg2 := &Config{PrivateKey: "private_key_2"}
|
||||||
|
|
||||||
|
// Generate with first config
|
||||||
|
state, userAgentKey, err := GenerateState(cfg1, "test_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to verify with second config
|
||||||
|
_, err = VerifyState(cfg2, state, userAgentKey)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for mismatched private key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRoundTrip tests complete round trip with various data payloads
|
||||||
|
func TestRoundTrip(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
testCases := []string{
|
||||||
|
"simple",
|
||||||
|
"with-dashes-and-numbers-123",
|
||||||
|
"MixedCaseData",
|
||||||
|
"user_token_abc123",
|
||||||
|
"link_resource_xyz789",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, data := range testCases {
|
||||||
|
t.Run(data, func(t *testing.T) {
|
||||||
|
// Generate
|
||||||
|
state, userAgentKey, err := GenerateState(cfg, data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
extractedData, err := VerifyState(cfg, state, userAgentKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to verify state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if extractedData != data {
|
||||||
|
t.Errorf("Expected extracted data '%s', got '%s'", data, extractedData)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcurrentGeneration tests that concurrent state generation works correctly
|
||||||
|
func TestConcurrentGeneration(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
data := "concurrent_test"
|
||||||
|
|
||||||
|
const numGoroutines = 10
|
||||||
|
results := make(chan string, numGoroutines)
|
||||||
|
errors := make(chan error, numGoroutines)
|
||||||
|
|
||||||
|
// Generate states concurrently
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func() {
|
||||||
|
state, userAgentKey, err := GenerateState(cfg, data)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify immediately
|
||||||
|
_, verifyErr := VerifyState(cfg, state, userAgentKey)
|
||||||
|
if verifyErr != nil {
|
||||||
|
errors <- verifyErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
results <- state
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect results
|
||||||
|
states := make(map[string]bool)
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
select {
|
||||||
|
case state := <-results:
|
||||||
|
if states[state] {
|
||||||
|
t.Errorf("Duplicate state generated: %s", state)
|
||||||
|
}
|
||||||
|
states[state] = true
|
||||||
|
case err := <-errors:
|
||||||
|
t.Errorf("Concurrent generation error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(states) != numGoroutines {
|
||||||
|
t.Errorf("Expected %d unique states, got %d", numGoroutines, len(states))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStateFormatCompatibility ensures state is URL-safe
|
||||||
|
func TestStateFormatCompatibility(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
data := "url_safe_test"
|
||||||
|
|
||||||
|
state, _, err := GenerateState(cfg, data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that state doesn't contain characters that need URL encoding
|
||||||
|
unsafeChars := []string{"+", "/", "=", " ", "&", "?", "#"}
|
||||||
|
for _, char := range unsafeChars {
|
||||||
|
if strings.Contains(state, char) {
|
||||||
|
t.Errorf("State contains URL-unsafe character '%s': %s", char, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMITM_AttackerCannotSubstituteState verifies MITM protection actually works
|
||||||
|
// An attacker obtains their own valid state but tries to use it with victim's session
|
||||||
|
func TestMITM_AttackerCannotSubstituteState(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Victim generates a state for their login
|
||||||
|
victimState, victimKey, err := GenerateState(cfg, "victim_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate victim state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attacker generates their own valid state (they can request this from the server)
|
||||||
|
attackerState, attackerKey, err := GenerateState(cfg, "attacker_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate attacker state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both states should be valid on their own
|
||||||
|
_, err = VerifyState(cfg, victimState, victimKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Victim state should be valid: err=%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = VerifyState(cfg, attackerState, attackerKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Attacker state should be valid: err=%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MITM Attack Scenario 1: Attacker substitutes their state but victim has their cookie
|
||||||
|
// This should FAIL because attackerState was signed with attackerKey, not victimKey
|
||||||
|
_, err = VerifyState(cfg, attackerState, victimKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when attacker substitutes state")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MITM Attack Scenario 2: Attacker uses victim's state but has their own cookie
|
||||||
|
// This should also FAIL
|
||||||
|
_, err = VerifyState(cfg, victimState, attackerKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when attacker uses victim's state")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The key insight: even though both states are "valid", they are bound to their respective cookies
|
||||||
|
// An attacker cannot mix and match states and cookies
|
||||||
|
t.Log("✓ MITM protection verified: States are cryptographically bound to their userAgentKey cookies")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCSRF_AttackerCannotForgeState verifies CSRF protection
|
||||||
|
// An attacker tries to forge a state parameter without knowing the private key
|
||||||
|
func TestCSRF_AttackerCannotForgeState(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Attacker doesn't know the private key, but tries to forge a state
|
||||||
|
// They might try to construct: "malicious_data.random.signature"
|
||||||
|
|
||||||
|
// Attempt 1: Use a random signature
|
||||||
|
randomSig := base64.RawURLEncoding.EncodeToString([]byte("random_signature_attempt_12345678"))
|
||||||
|
forgedState1 := "malicious_data.somefakerandom." + randomSig
|
||||||
|
|
||||||
|
// Generate a real userAgentKey (attacker might try to get this)
|
||||||
|
_, realKey, err := GenerateState(cfg, "legitimate_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to verify forged state
|
||||||
|
_, err = VerifyState(cfg, forgedState1, realKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("CSRF VULNERABILITY: Attacker forged a valid state without private key!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt 2: Attacker tries to compute signature without private key
|
||||||
|
// They use: SHA256(data + "." + random + userAgentKey) - missing privateKey
|
||||||
|
attackerPayload := "malicious_data.fakerandom" + base64.RawURLEncoding.EncodeToString(realKey)
|
||||||
|
hash := sha256.Sum256([]byte(attackerPayload))
|
||||||
|
attackerSig := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
forgedState2 := "malicious_data.fakerandom." + attackerSig
|
||||||
|
|
||||||
|
_, err = VerifyState(cfg, forgedState2, realKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("CSRF VULNERABILITY: Attacker forged valid state without private key!")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("✓ CSRF protection verified: Cannot forge state without private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTampering_SignatureDetectsAllModifications verifies tamper detection
|
||||||
|
func TestTampering_SignatureDetectsAllModifications(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Generate a valid state
|
||||||
|
originalState, userAgentKey, err := GenerateState(cfg, "original_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify original is valid
|
||||||
|
data, err := VerifyState(cfg, originalState, userAgentKey)
|
||||||
|
if err != nil || data != "original_data" {
|
||||||
|
t.Fatalf("Original state should be valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(originalState, ".")
|
||||||
|
|
||||||
|
// Test 1: Attacker modifies data but keeps signature
|
||||||
|
tamperedState := "modified_data." + parts[1] + "." + parts[2]
|
||||||
|
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("TAMPER VULNERABILITY: Modified data not detected!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Attacker modifies random but keeps signature
|
||||||
|
newRandom := base64.RawURLEncoding.EncodeToString([]byte("new_random_value_32bytes_long!!"))
|
||||||
|
tamperedState = parts[0] + "." + newRandom + "." + parts[2]
|
||||||
|
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("TAMPER VULNERABILITY: Modified random not detected!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3: Attacker tries to recompute signature but doesn't have privateKey
|
||||||
|
// They compute: SHA256(modified_data + "." + random + userAgentKey)
|
||||||
|
attackerPayload := "modified_data." + parts[1] + base64.RawURLEncoding.EncodeToString(userAgentKey)
|
||||||
|
hash := sha256.Sum256([]byte(attackerPayload))
|
||||||
|
attackerSig := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
tamperedState = "modified_data." + parts[1] + "." + attackerSig
|
||||||
|
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("TAMPER VULNERABILITY: Attacker recomputed signature without private key!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 4: Single bit flip in signature
|
||||||
|
sigBytes, _ := base64.RawURLEncoding.DecodeString(parts[2])
|
||||||
|
sigBytes[0] ^= 0x01 // Flip one bit
|
||||||
|
flippedSig := base64.RawURLEncoding.EncodeToString(sigBytes)
|
||||||
|
tamperedState = parts[0] + "." + parts[1] + "." + flippedSig
|
||||||
|
_, err = VerifyState(cfg, tamperedState, userAgentKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("TAMPER VULNERABILITY: Single bit flip in signature not detected!")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("✓ Tamper detection verified: All modifications to state are detected")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestReplay_DifferentSessionsCannotReuseState verifies replay protection
|
||||||
|
func TestReplay_DifferentSessionsCannotReuseState(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Session 1: User initiates OAuth flow
|
||||||
|
state1, key1, err := GenerateState(cfg, "session1_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// State is valid for session 1
|
||||||
|
_, err = VerifyState(cfg, state1, key1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("State should be valid for session 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session 2: Same user (or attacker) initiates a new OAuth flow
|
||||||
|
state2, key2, err := GenerateState(cfg, "session1_data") // same data
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replay Attack: Try to use state1 with key2
|
||||||
|
_, err = VerifyState(cfg, state1, key2)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("REPLAY VULNERABILITY: State from session 1 was accepted in session 2!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Even with same data, each session should have unique state+key binding
|
||||||
|
if state1 == state2 {
|
||||||
|
t.Error("REPLAY VULNERABILITY: Same data produces identical states!")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("✓ Replay protection verified: States are bound to specific session cookies")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConstantTimeComparison verifies that signature comparison is timing-safe
|
||||||
|
// This is a behavioral test - we can't easily test timing, but we can verify the function is used
|
||||||
|
func TestConstantTimeComparison_IsUsed(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Generate valid state
|
||||||
|
state, userAgentKey, err := GenerateState(cfg, "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create states with signatures that differ at different positions
|
||||||
|
parts := strings.Split(state, ".")
|
||||||
|
originalSig, _ := base64.RawURLEncoding.DecodeString(parts[2])
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
position int
|
||||||
|
}{
|
||||||
|
{"first byte differs", 0},
|
||||||
|
{"middle byte differs", 16},
|
||||||
|
{"last byte differs", 31},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create signature that differs at specific position
|
||||||
|
tamperedSig := make([]byte, len(originalSig))
|
||||||
|
copy(tamperedSig, originalSig)
|
||||||
|
tamperedSig[tc.position] ^= 0xFF // Flip all bits
|
||||||
|
|
||||||
|
tamperedSigStr := base64.RawURLEncoding.EncodeToString(tamperedSig)
|
||||||
|
tamperedState := parts[0] + "." + parts[1] + "." + tamperedSigStr
|
||||||
|
|
||||||
|
// All should fail verification
|
||||||
|
_, err := VerifyState(cfg, tamperedState, userAgentKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Tampered signature at position %d should be invalid", tc.position)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If constant-time comparison is NOT used, early differences would return faster
|
||||||
|
// While we can't easily test timing here, we verify all positions fail equally
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("✓ Constant-time comparison: All signature positions validated equally")
|
||||||
|
t.Log(" Note: crypto/subtle.ConstantTimeCompare is used in implementation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPrivateKey_IsCriticalToSecurity verifies private key is essential
|
||||||
|
func TestPrivateKey_IsCriticalToSecurity(t *testing.T) {
|
||||||
|
cfg1 := &Config{PrivateKey: "secret_key_1"}
|
||||||
|
cfg2 := &Config{PrivateKey: "secret_key_2"}
|
||||||
|
|
||||||
|
// Generate state with key1
|
||||||
|
state, userAgentKey, err := GenerateState(cfg1, "data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should verify with key1
|
||||||
|
_, err = VerifyState(cfg1, state, userAgentKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("State should be valid with correct private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should NOT verify with key2 (different private key)
|
||||||
|
_, err = VerifyState(cfg2, state, userAgentKey)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("SECURITY VULNERABILITY: State verified with different private key!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// This proves that the private key is cryptographically involved in the signature
|
||||||
|
t.Log("✓ Private key security verified: Different keys produce incompatible signatures")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUserAgentKey_ProperlyIntegratedInSignature verifies userAgentKey is in payload
|
||||||
|
func TestUserAgentKey_ProperlyIntegratedInSignature(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
// Generate two states with same data but different userAgentKeys (implicit)
|
||||||
|
state1, key1, err := GenerateState(cfg, "same_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
state2, key2, err := GenerateState(cfg, "same_data")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate state2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The states should be different even with same data (different random and keys)
|
||||||
|
if state1 == state2 {
|
||||||
|
t.Error("States should differ due to different random values")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each state should only verify with its own key
|
||||||
|
_, err1 := VerifyState(cfg, state1, key1)
|
||||||
|
_, err2 := VerifyState(cfg, state2, key2)
|
||||||
|
|
||||||
|
if err1 != nil || err2 != nil {
|
||||||
|
t.Fatal("States should be valid with their own keys")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cross-verification should fail
|
||||||
|
_, err1 = VerifyState(cfg, state1, key2)
|
||||||
|
_, err2 = VerifyState(cfg, state2, key1)
|
||||||
|
|
||||||
|
if err1 == nil || err2 == nil {
|
||||||
|
t.Error("SECURITY VULNERABILITY: userAgentKey not properly integrated in signature!")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("✓ UserAgentKey integration verified: Each state bound to its specific key")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user