imported jwt module

This commit is contained in:
2026-01-01 20:33:16 +11:00
parent 72e1513fae
commit c466cd3163
12 changed files with 896 additions and 0 deletions

62
jwt/generator.go Normal file
View File

@@ -0,0 +1,62 @@
package jwt
import (
"database/sql"
"errors"
)
type TokenGenerator struct {
accessExpireAfter int64 // Access Token expiry time in minutes
refreshExpireAfter int64 // Refresh Token expiry time in minutes
freshExpireAfter int64 // Token freshness expiry time in minutes
trustedHost string // Trusted hostname to use for the tokens
secretKey string // Secret key to use for token hashing
dbConn *sql.DB // Database handle for token blacklisting
}
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
// All expiry times should be provided in minutes.
// trustedHost and secretKey strings must be provided.
// dbConn can be nil, but doing this will disable token revocation
func CreateGenerator(
accessExpireAfter int64,
refreshExpireAfter int64,
freshExpireAfter int64,
trustedHost string,
secretKey string,
dbConn *sql.DB,
) (gen *TokenGenerator, err error) {
if accessExpireAfter <= 0 {
return nil, errors.New("accessExpireAfter must be greater than 0")
}
if refreshExpireAfter <= 0 {
return nil, errors.New("refreshExpireAfter must be greater than 0")
}
if freshExpireAfter <= 0 {
return nil, errors.New("freshExpireAfter must be greater than 0")
}
if trustedHost == "" {
return nil, errors.New("trustedHost cannot be an empty string")
}
if secretKey == "" {
return nil, errors.New("secretKey cannot be an empty string")
}
if dbConn != nil {
err := dbConn.Ping()
if err != nil {
return nil, errors.New("Failed to ping database")
}
// TODO: check if jwtblacklist table exists
// TODO: create jwtblacklist table if not existing
}
return &TokenGenerator{
accessExpireAfter: accessExpireAfter,
refreshExpireAfter: refreshExpireAfter,
freshExpireAfter: freshExpireAfter,
trustedHost: trustedHost,
secretKey: secretKey,
dbConn: dbConn,
}, nil
}

90
jwt/generator_test.go Normal file
View File

@@ -0,0 +1,90 @@
package jwt
import (
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestCreateGenerator_Success_NoDB(t *testing.T) {
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"secret",
nil,
)
require.NoError(t, err)
require.NotNil(t, gen)
}
func TestCreateGenerator_Success_WithDB(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"secret",
db,
)
require.NoError(t, err)
require.NotNil(t, gen)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestCreateGenerator_InvalidInputs(t *testing.T) {
tests := []struct {
name string
fn func() error
}{
{
"access expiry <= 0",
func() error {
_, err := CreateGenerator(0, 1, 1, "h", "s", nil)
return err
},
},
{
"refresh expiry <= 0",
func() error {
_, err := CreateGenerator(1, 0, 1, "h", "s", nil)
return err
},
},
{
"fresh expiry <= 0",
func() error {
_, err := CreateGenerator(1, 1, 0, "h", "s", nil)
return err
},
},
{
"empty trustedHost",
func() error {
_, err := CreateGenerator(1, 1, 1, "", "s", nil)
return err
},
},
{
"empty secretKey",
func() error {
_, err := CreateGenerator(1, 1, 1, "h", "", nil)
return err
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Error(t, tt.fn())
})
}
}

17
jwt/go.mod Normal file
View File

@@ -0,0 +1,17 @@
module git.haelnorr.com/h/golib/jwt
go 1.25.5
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.11.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

19
jwt/go.sum Normal file
View File

@@ -0,0 +1,19 @@
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

58
jwt/revoke.go Normal file
View File

@@ -0,0 +1,58 @@
package jwt
import (
"context"
"github.com/pkg/errors"
)
// Revoke a token by adding it to the database
func revoke(ctx context.Context, t Token) error {
db := t.getDB()
if db == nil {
return errors.New("No DB provided, unable to use this function")
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "db.BeginTx")
}
defer tx.Rollback()
jti := t.GetJTI()
exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err = tx.Exec(query, jti, exp)
if err != nil {
return errors.Wrap(err, "tx.Exec")
}
err = tx.Commit()
if err != nil {
return errors.Wrap(err, "tx.Commit")
}
return nil
}
// Check if a token has been revoked. Returns true if not revoked.
func checkNotRevoked(ctx context.Context, t Token) (bool, error) {
db := t.getDB()
if db == nil {
return false, errors.New("No DB provided, unable to use this function")
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return false, errors.Wrap(err, "db.BeginTx")
}
defer tx.Rollback()
jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := tx.Query(query, jti)
if err != nil {
return false, errors.Wrap(err, "tx.Query")
}
defer rows.Close()
revoked := rows.Next()
err = tx.Commit()
if err != nil {
return false, errors.Wrap(err, "tx.Commit")
}
return !revoked, nil
}

80
jwt/revoke_test.go Normal file
View File

@@ -0,0 +1,80 @@
package jwt
import (
"context"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"supersecret",
nil,
)
require.NoError(t, err)
return gen
}
func TestNoDBFail(t *testing.T) {
jti := uuid.New()
exp := time.Now().Add(time.Hour).Unix()
token := AccessToken{
JTI: jti,
EXP: exp,
}
// Revoke should fail due to no DB
err := token.Revoke(context.Background())
require.Error(t, err)
// CheckNotRevoked should fail
_, err = token.CheckNotRevoked(context.Background())
require.Error(t, err)
}
func TestRevokeAndCheckNotRevoked(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup()
jti := uuid.New()
exp := time.Now().Add(time.Hour).Unix()
token := AccessToken{
JTI: jti,
EXP: exp,
db: gen.dbConn,
}
// Revoke expectations
mock.ExpectBegin()
mock.ExpectExec(`INSERT INTO jwtblacklist`).
WithArgs(jti, exp).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := token.Revoke(context.Background())
require.NoError(t, err)
// CheckNotRevoked expectations (now revoked)
mock.ExpectBegin()
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
WithArgs(jti).
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectCommit()
valid, err := token.CheckNotRevoked(context.Background())
require.NoError(t, err)
require.False(t, valid)
require.NoError(t, mock.ExpectationsWereMet())
}

79
jwt/tokengen.go Normal file
View File

@@ -0,0 +1,79 @@
package jwt
import (
"time"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/pkg/errors"
)
// Generates an access token for the provided subject
func (gen *TokenGenerator) NewAccess(
subjectID int,
fresh bool,
rememberMe bool,
) (tokenString string, expiresIn int64, err error) {
issuedAt := time.Now().Unix()
expiresAt := issuedAt + (gen.accessExpireAfter * 60)
var freshExpiresAt int64
if fresh {
freshExpiresAt = issuedAt + (gen.freshExpireAfter * 60)
} else {
freshExpiresAt = issuedAt
}
var ttl string
if rememberMe {
ttl = "exp"
} else {
ttl = "session"
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
jwt.MapClaims{
"iss": gen.trustedHost,
"scope": "access",
"ttl": ttl,
"jti": uuid.New(),
"iat": issuedAt,
"exp": expiresAt,
"fresh": freshExpiresAt,
"sub": subjectID,
})
signedToken, err := token.SignedString([]byte(gen.secretKey))
if err != nil {
return "", 0, errors.Wrap(err, "token.SignedString")
}
return signedToken, expiresAt, nil
}
// Generates a refresh token for the provided user
func (gen *TokenGenerator) NewRefresh(
subjectID int,
rememberMe bool,
) (tokenStr string, exp int64, err error) {
issuedAt := time.Now().Unix()
expiresAt := issuedAt + (gen.refreshExpireAfter * 60)
var ttl string
if rememberMe {
ttl = "exp"
} else {
ttl = "session"
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
jwt.MapClaims{
"iss": gen.trustedHost,
"scope": "refresh",
"ttl": ttl,
"jti": uuid.New(),
"iat": issuedAt,
"exp": expiresAt,
"sub": subjectID,
})
signedToken, err := token.SignedString([]byte(gen.secretKey))
if err != nil {
return "", 0, errors.Wrap(err, "token.SignedString")
}
return signedToken, expiresAt, nil
}

38
jwt/tokengen_test.go Normal file
View File

@@ -0,0 +1,38 @@
package jwt
import (
"testing"
"github.com/stretchr/testify/require"
)
func newTestGenerator(t *testing.T) *TokenGenerator {
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"supersecret",
nil,
)
require.NoError(t, err)
return gen
}
func TestNewAccessToken(t *testing.T) {
gen := newTestGenerator(t)
tokenStr, exp, err := gen.NewAccess(123, true, false)
require.NoError(t, err)
require.NotEmpty(t, tokenStr)
require.Greater(t, exp, int64(0))
}
func TestNewRefreshToken(t *testing.T) {
gen := newTestGenerator(t)
tokenStr, exp, err := gen.NewRefresh(123, true)
require.NoError(t, err)
require.NotEmpty(t, tokenStr)
require.Greater(t, exp, int64(0))
}

78
jwt/tokens.go Normal file
View File

@@ -0,0 +1,78 @@
package jwt
import (
"context"
"database/sql"
"github.com/google/uuid"
)
type Token interface {
GetJTI() uuid.UUID
GetEXP() int64
GetScope() string
getDB() *sql.DB
Revoke(context.Context) error
}
// Access token
type AccessToken struct {
ISS string // Issuer, generally TrustedHost
IAT int64 // Time issued at
EXP int64 // Time expiring at
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Fresh int64 // Time freshness expiring at
Scope string // Should be "access"
db *sql.DB
}
// Refresh token
type RefreshToken struct {
ISS string // Issuer, generally TrustedHost
IAT int64 // Time issued at
EXP int64 // Time expiring at
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Scope string // Should be "refresh"
db *sql.DB
}
func (a AccessToken) GetJTI() uuid.UUID {
return a.JTI
}
func (r RefreshToken) GetJTI() uuid.UUID {
return r.JTI
}
func (a AccessToken) GetEXP() int64 {
return a.EXP
}
func (r RefreshToken) GetEXP() int64 {
return r.EXP
}
func (a AccessToken) GetScope() string {
return a.Scope
}
func (r RefreshToken) GetScope() string {
return r.Scope
}
func (a AccessToken) getDB() *sql.DB {
return a.db
}
func (r RefreshToken) getDB() *sql.DB {
return r.db
}
func (a AccessToken) Revoke(ctx context.Context) error {
return revoke(ctx, a)
}
func (r RefreshToken) Revoke(ctx context.Context) error {
return revoke(ctx, r)
}
func (a AccessToken) CheckNotRevoked(ctx context.Context) (bool, error) {
return checkNotRevoked(ctx, a)
}
func (r RefreshToken) CheckNotRevoked(ctx context.Context) (bool, error) {
return checkNotRevoked(ctx, r)
}

123
jwt/util.go Normal file
View File

@@ -0,0 +1,123 @@
package jwt
import (
"fmt"
"time"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/pkg/errors"
)
// Parse a token, validating its signing sigature and returning the claims
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
return nil, errors.Wrap(err, "jwt.Parse")
}
// Token decoded, parse the claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("Failed to parse claims")
}
return claims, nil
}
// Check if a token is expired. Returns the expiry if not expired
func checkTokenExpired(expiry interface{}) (int64, error) {
// Coerce the expiry to a float64 to avoid scientific notation
expFloat, ok := expiry.(float64)
if !ok {
return 0, errors.New("Missing or invalid 'exp' claim")
}
// Convert to the int64 time we expect :)
expiryTime := int64(expFloat)
// Check if its expired
isExpired := time.Now().After(time.Unix(expiryTime, 0))
if isExpired {
return 0, errors.New("Token has expired")
}
return expiryTime, nil
}
// Check if a token has a valid issuer. Returns the issuer if valid
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
issuerVal, ok := issuer.(string)
if !ok {
return "", errors.New("Missing or invalid 'iss' claim")
}
if issuer != trustedHost {
return "", errors.New("Issuer does not matched trusted host")
}
return issuerVal, nil
}
// Check the scope matches the expected scope. Returns scope if true
func getTokenScope(scope interface{}) (string, error) {
scopeStr, ok := scope.(string)
if !ok {
return "", errors.New("Missing or invalid 'scope' claim")
}
return scopeStr, nil
}
// Get the TTL of the token, either "session" or "exp"
func getTokenTTL(ttl interface{}) (string, error) {
ttlStr, ok := ttl.(string)
if !ok {
return "", errors.New("Missing or invalid 'ttl' claim")
}
if ttlStr != "exp" && ttlStr != "session" {
return "", errors.New("TTL value is not recognised")
}
return ttlStr, nil
}
// Get the time the token was issued at
func getIssuedTime(issued interface{}) (int64, error) {
// Same float64 -> int64 trick as expiry
issuedFloat, ok := issued.(float64)
if !ok {
return 0, errors.New("Missing or invalid 'iat' claim")
}
issuedAt := int64(issuedFloat)
return issuedAt, nil
}
// Get the freshness expiry timestamp
func getFreshTime(fresh interface{}) (int64, error) {
freshUntil, ok := fresh.(float64)
if !ok {
return 0, errors.New("Missing or invalid 'fresh' claim")
}
return int64(freshUntil), nil
}
// Get the subject of the token
func getTokenSubject(sub interface{}) (int, error) {
subject, ok := sub.(float64)
if !ok {
return 0, errors.New("Missing or invalid 'sub' claim")
}
return int(subject), nil
}
// Get the JTI of the token
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
jtiStr, ok := jti.(string)
if !ok {
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
}
jtiUUID, err := uuid.Parse(jtiStr)
if err != nil {
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
}
return jtiUUID, nil
}

145
jwt/validate.go Normal file
View File

@@ -0,0 +1,145 @@
package jwt
import (
"context"
"github.com/pkg/errors"
)
// Parse an access token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func (gen *TokenGenerator) ValidateAccess(
ctx context.Context,
tokenString string,
) (*AccessToken, error) {
if tokenString == "" {
return nil, errors.New("Access token string not provided")
}
claims, err := parseToken(gen.secretKey, tokenString)
if err != nil {
return nil, errors.Wrap(err, "parseToken")
}
expiry, err := checkTokenExpired(claims["exp"])
if err != nil {
return nil, errors.Wrap(err, "checkTokenExpired")
}
issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"])
if err != nil {
return nil, errors.Wrap(err, "checkTokenIssuer")
}
ttl, err := getTokenTTL(claims["ttl"])
if err != nil {
return nil, errors.Wrap(err, "getTokenTTL")
}
scope, err := getTokenScope(claims["scope"])
if err != nil {
return nil, errors.Wrap(err, "getTokenScope")
}
if scope != "access" {
return nil, errors.New("Token is not an Access token")
}
issuedAt, err := getIssuedTime(claims["iat"])
if err != nil {
return nil, errors.Wrap(err, "getIssuedTime")
}
subject, err := getTokenSubject(claims["sub"])
if err != nil {
return nil, errors.Wrap(err, "getTokenSubject")
}
fresh, err := getFreshTime(claims["fresh"])
if err != nil {
return nil, errors.Wrap(err, "getFreshTime")
}
jti, err := getTokenJTI(claims["jti"])
if err != nil {
return nil, errors.Wrap(err, "getTokenJTI")
}
token := &AccessToken{
ISS: issuer,
TTL: ttl,
EXP: expiry,
IAT: issuedAt,
SUB: subject,
Fresh: fresh,
JTI: jti,
Scope: scope,
db: gen.dbConn,
}
valid, err := token.CheckNotRevoked(ctx)
if err != nil && gen.dbConn != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked")
}
if !valid && gen.dbConn != nil {
return nil, errors.New("Token has been revoked")
}
return token, nil
}
// Parse a refresh token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func (gen *TokenGenerator) ValidateRefresh(
ctx context.Context,
tokenString string,
) (*RefreshToken, error) {
if tokenString == "" {
return nil, errors.New("Refresh token string not provided")
}
claims, err := parseToken(gen.secretKey, tokenString)
if err != nil {
return nil, errors.Wrap(err, "parseToken")
}
expiry, err := checkTokenExpired(claims["exp"])
if err != nil {
return nil, errors.Wrap(err, "checkTokenExpired")
}
issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"])
if err != nil {
return nil, errors.Wrap(err, "checkTokenIssuer")
}
ttl, err := getTokenTTL(claims["ttl"])
if err != nil {
return nil, errors.Wrap(err, "getTokenTTL")
}
scope, err := getTokenScope(claims["scope"])
if err != nil {
return nil, errors.Wrap(err, "getTokenScope")
}
if scope != "refresh" {
return nil, errors.New("Token is not an Refresh token")
}
issuedAt, err := getIssuedTime(claims["iat"])
if err != nil {
return nil, errors.Wrap(err, "getIssuedTime")
}
subject, err := getTokenSubject(claims["sub"])
if err != nil {
return nil, errors.Wrap(err, "getTokenSubject")
}
jti, err := getTokenJTI(claims["jti"])
if err != nil {
return nil, errors.Wrap(err, "getTokenJTI")
}
token := &RefreshToken{
ISS: issuer,
TTL: ttl,
EXP: expiry,
IAT: issuedAt,
SUB: subject,
JTI: jti,
Scope: scope,
db: gen.dbConn,
}
valid, err := token.CheckNotRevoked(ctx)
if err != nil && gen.dbConn != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked")
}
if !valid && gen.dbConn != nil {
return nil, errors.New("Token has been revoked")
}
return token, nil
}

107
jwt/validate_test.go Normal file
View File

@@ -0,0 +1,107 @@
package jwt
import (
"context"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"supersecret",
db,
)
require.NoError(t, err)
return gen, mock, func() { db.Close() }
}
func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
mock.ExpectBegin()
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
WithArgs(jti).
WillReturnRows(sqlmock.NewRows([]string{}))
mock.ExpectCommit()
}
func TestValidateAccess_Success(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup()
tokenStr, _, err := gen.NewAccess(42, true, false)
require.NoError(t, err)
// We don't know the JTI beforehand; match any arg
expectNotRevoked(mock, sqlmock.AnyArg())
token, err := gen.ValidateAccess(context.Background(), tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "access", token.Scope)
}
func TestValidateAccess_NoDB(t *testing.T) {
gen := newGeneratorWithNoDB(t)
tokenStr, _, err := gen.NewAccess(42, true, false)
require.NoError(t, err)
token, err := gen.ValidateAccess(context.Background(), tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "access", token.Scope)
}
func TestValidateRefresh_Success(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup()
tokenStr, _, err := gen.NewRefresh(42, false)
require.NoError(t, err)
expectNotRevoked(mock, sqlmock.AnyArg())
token, err := gen.ValidateRefresh(context.Background(), tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "refresh", token.Scope)
}
func TestValidateRefresh_NoDB(t *testing.T) {
gen := newGeneratorWithNoDB(t)
tokenStr, _, err := gen.NewRefresh(42, false)
require.NoError(t, err)
token, err := gen.ValidateRefresh(context.Background(), tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "refresh", token.Scope)
}
func TestValidateAccess_EmptyToken(t *testing.T) {
gen := newTestGenerator(t)
_, err := gen.ValidateAccess(context.Background(), "")
require.Error(t, err)
}
func TestValidateRefresh_WrongScope(t *testing.T) {
gen := newTestGenerator(t)
// Create access token but validate as refresh
tokenStr, _, err := gen.NewAccess(1, false, false)
require.NoError(t, err)
_, err = gen.ValidateRefresh(context.Background(), tokenStr)
require.Error(t, err)
}