From c466cd3163106a28c5f33f34646b522bb67d5983 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Thu, 1 Jan 2026 20:33:16 +1100 Subject: [PATCH] imported jwt module --- jwt/generator.go | 62 ++++++++++++++++++ jwt/generator_test.go | 90 ++++++++++++++++++++++++++ jwt/go.mod | 17 +++++ jwt/go.sum | 19 ++++++ jwt/revoke.go | 58 +++++++++++++++++ jwt/revoke_test.go | 80 +++++++++++++++++++++++ jwt/tokengen.go | 79 +++++++++++++++++++++++ jwt/tokengen_test.go | 38 +++++++++++ jwt/tokens.go | 78 +++++++++++++++++++++++ jwt/util.go | 123 +++++++++++++++++++++++++++++++++++ jwt/validate.go | 145 ++++++++++++++++++++++++++++++++++++++++++ jwt/validate_test.go | 107 +++++++++++++++++++++++++++++++ 12 files changed, 896 insertions(+) create mode 100644 jwt/generator.go create mode 100644 jwt/generator_test.go create mode 100644 jwt/go.mod create mode 100644 jwt/go.sum create mode 100644 jwt/revoke.go create mode 100644 jwt/revoke_test.go create mode 100644 jwt/tokengen.go create mode 100644 jwt/tokengen_test.go create mode 100644 jwt/tokens.go create mode 100644 jwt/util.go create mode 100644 jwt/validate.go create mode 100644 jwt/validate_test.go diff --git a/jwt/generator.go b/jwt/generator.go new file mode 100644 index 0000000..6f79574 --- /dev/null +++ b/jwt/generator.go @@ -0,0 +1,62 @@ +package jwt + +import ( + "database/sql" + "errors" +) + +type TokenGenerator struct { + accessExpireAfter int64 // Access Token expiry time in minutes + refreshExpireAfter int64 // Refresh Token expiry time in minutes + freshExpireAfter int64 // Token freshness expiry time in minutes + trustedHost string // Trusted hostname to use for the tokens + secretKey string // Secret key to use for token hashing + dbConn *sql.DB // Database handle for token blacklisting +} + +// CreateGenerator creates and returns a new TokenGenerator using the provided configuration. +// All expiry times should be provided in minutes. +// trustedHost and secretKey strings must be provided. +// dbConn can be nil, but doing this will disable token revocation +func CreateGenerator( + accessExpireAfter int64, + refreshExpireAfter int64, + freshExpireAfter int64, + trustedHost string, + secretKey string, + dbConn *sql.DB, +) (gen *TokenGenerator, err error) { + if accessExpireAfter <= 0 { + return nil, errors.New("accessExpireAfter must be greater than 0") + } + if refreshExpireAfter <= 0 { + return nil, errors.New("refreshExpireAfter must be greater than 0") + } + if freshExpireAfter <= 0 { + return nil, errors.New("freshExpireAfter must be greater than 0") + } + if trustedHost == "" { + return nil, errors.New("trustedHost cannot be an empty string") + } + if secretKey == "" { + return nil, errors.New("secretKey cannot be an empty string") + } + + if dbConn != nil { + err := dbConn.Ping() + if err != nil { + return nil, errors.New("Failed to ping database") + } + // TODO: check if jwtblacklist table exists + // TODO: create jwtblacklist table if not existing + } + + return &TokenGenerator{ + accessExpireAfter: accessExpireAfter, + refreshExpireAfter: refreshExpireAfter, + freshExpireAfter: freshExpireAfter, + trustedHost: trustedHost, + secretKey: secretKey, + dbConn: dbConn, + }, nil +} diff --git a/jwt/generator_test.go b/jwt/generator_test.go new file mode 100644 index 0000000..42d75da --- /dev/null +++ b/jwt/generator_test.go @@ -0,0 +1,90 @@ +package jwt + +import ( + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestCreateGenerator_Success_NoDB(t *testing.T) { + gen, err := CreateGenerator( + 15, + 60, + 5, + "example.com", + "secret", + nil, + ) + + require.NoError(t, err) + require.NotNil(t, gen) +} + +func TestCreateGenerator_Success_WithDB(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + gen, err := CreateGenerator( + 15, + 60, + 5, + "example.com", + "secret", + db, + ) + + require.NoError(t, err) + require.NotNil(t, gen) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestCreateGenerator_InvalidInputs(t *testing.T) { + tests := []struct { + name string + fn func() error + }{ + { + "access expiry <= 0", + func() error { + _, err := CreateGenerator(0, 1, 1, "h", "s", nil) + return err + }, + }, + { + "refresh expiry <= 0", + func() error { + _, err := CreateGenerator(1, 0, 1, "h", "s", nil) + return err + }, + }, + { + "fresh expiry <= 0", + func() error { + _, err := CreateGenerator(1, 1, 0, "h", "s", nil) + return err + }, + }, + { + "empty trustedHost", + func() error { + _, err := CreateGenerator(1, 1, 1, "", "s", nil) + return err + }, + }, + { + "empty secretKey", + func() error { + _, err := CreateGenerator(1, 1, 1, "h", "", nil) + return err + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Error(t, tt.fn()) + }) + } +} diff --git a/jwt/go.mod b/jwt/go.mod new file mode 100644 index 0000000..c0a4fb5 --- /dev/null +++ b/jwt/go.mod @@ -0,0 +1,17 @@ +module git.haelnorr.com/h/golib/jwt + +go 1.25.5 + +require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.6.0 + github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/jwt/go.sum b/jwt/go.sum new file mode 100644 index 0000000..8f4beef --- /dev/null +++ b/jwt/go.sum @@ -0,0 +1,19 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/jwt/revoke.go b/jwt/revoke.go new file mode 100644 index 0000000..424c8fb --- /dev/null +++ b/jwt/revoke.go @@ -0,0 +1,58 @@ +package jwt + +import ( + "context" + + "github.com/pkg/errors" +) + +// Revoke a token by adding it to the database +func revoke(ctx context.Context, t Token) error { + db := t.getDB() + if db == nil { + return errors.New("No DB provided, unable to use this function") + } + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, "db.BeginTx") + } + defer tx.Rollback() + jti := t.GetJTI() + exp := t.GetEXP() + query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)` + _, err = tx.Exec(query, jti, exp) + if err != nil { + return errors.Wrap(err, "tx.Exec") + } + err = tx.Commit() + if err != nil { + return errors.Wrap(err, "tx.Commit") + } + return nil +} + +// Check if a token has been revoked. Returns true if not revoked. +func checkNotRevoked(ctx context.Context, t Token) (bool, error) { + db := t.getDB() + if db == nil { + return false, errors.New("No DB provided, unable to use this function") + } + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return false, errors.Wrap(err, "db.BeginTx") + } + defer tx.Rollback() + jti := t.GetJTI() + query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` + rows, err := tx.Query(query, jti) + if err != nil { + return false, errors.Wrap(err, "tx.Query") + } + defer rows.Close() + revoked := rows.Next() + err = tx.Commit() + if err != nil { + return false, errors.Wrap(err, "tx.Commit") + } + return !revoked, nil +} diff --git a/jwt/revoke_test.go b/jwt/revoke_test.go new file mode 100644 index 0000000..62372e3 --- /dev/null +++ b/jwt/revoke_test.go @@ -0,0 +1,80 @@ +package jwt + +import ( + "context" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func newGeneratorWithNoDB(t *testing.T) *TokenGenerator { + gen, err := CreateGenerator( + 15, + 60, + 5, + "example.com", + "supersecret", + nil, + ) + require.NoError(t, err) + + return gen +} + +func TestNoDBFail(t *testing.T) { + jti := uuid.New() + exp := time.Now().Add(time.Hour).Unix() + + token := AccessToken{ + JTI: jti, + EXP: exp, + } + + // Revoke should fail due to no DB + err := token.Revoke(context.Background()) + require.Error(t, err) + + // CheckNotRevoked should fail + _, err = token.CheckNotRevoked(context.Background()) + require.Error(t, err) +} + +func TestRevokeAndCheckNotRevoked(t *testing.T) { + gen, mock, cleanup := newGeneratorWithMockDB(t) + defer cleanup() + + jti := uuid.New() + exp := time.Now().Add(time.Hour).Unix() + + token := AccessToken{ + JTI: jti, + EXP: exp, + db: gen.dbConn, + } + + // Revoke expectations + mock.ExpectBegin() + mock.ExpectExec(`INSERT INTO jwtblacklist`). + WithArgs(jti, exp). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := token.Revoke(context.Background()) + require.NoError(t, err) + + // CheckNotRevoked expectations (now revoked) + mock.ExpectBegin() + mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`). + WithArgs(jti). + WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1)) + mock.ExpectCommit() + + valid, err := token.CheckNotRevoked(context.Background()) + require.NoError(t, err) + require.False(t, valid) + + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/jwt/tokengen.go b/jwt/tokengen.go new file mode 100644 index 0000000..a65aa19 --- /dev/null +++ b/jwt/tokengen.go @@ -0,0 +1,79 @@ +package jwt + +import ( + "time" + + "github.com/golang-jwt/jwt" + "github.com/google/uuid" + "github.com/pkg/errors" +) + +// Generates an access token for the provided subject +func (gen *TokenGenerator) NewAccess( + subjectID int, + fresh bool, + rememberMe bool, +) (tokenString string, expiresIn int64, err error) { + issuedAt := time.Now().Unix() + expiresAt := issuedAt + (gen.accessExpireAfter * 60) + var freshExpiresAt int64 + if fresh { + freshExpiresAt = issuedAt + (gen.freshExpireAfter * 60) + } else { + freshExpiresAt = issuedAt + } + var ttl string + if rememberMe { + ttl = "exp" + } else { + ttl = "session" + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, + jwt.MapClaims{ + "iss": gen.trustedHost, + "scope": "access", + "ttl": ttl, + "jti": uuid.New(), + "iat": issuedAt, + "exp": expiresAt, + "fresh": freshExpiresAt, + "sub": subjectID, + }) + + signedToken, err := token.SignedString([]byte(gen.secretKey)) + if err != nil { + return "", 0, errors.Wrap(err, "token.SignedString") + } + return signedToken, expiresAt, nil +} + +// Generates a refresh token for the provided user +func (gen *TokenGenerator) NewRefresh( + subjectID int, + rememberMe bool, +) (tokenStr string, exp int64, err error) { + issuedAt := time.Now().Unix() + expiresAt := issuedAt + (gen.refreshExpireAfter * 60) + var ttl string + if rememberMe { + ttl = "exp" + } else { + ttl = "session" + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, + jwt.MapClaims{ + "iss": gen.trustedHost, + "scope": "refresh", + "ttl": ttl, + "jti": uuid.New(), + "iat": issuedAt, + "exp": expiresAt, + "sub": subjectID, + }) + + signedToken, err := token.SignedString([]byte(gen.secretKey)) + if err != nil { + return "", 0, errors.Wrap(err, "token.SignedString") + } + return signedToken, expiresAt, nil +} diff --git a/jwt/tokengen_test.go b/jwt/tokengen_test.go new file mode 100644 index 0000000..2c9f80c --- /dev/null +++ b/jwt/tokengen_test.go @@ -0,0 +1,38 @@ +package jwt + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func newTestGenerator(t *testing.T) *TokenGenerator { + gen, err := CreateGenerator( + 15, + 60, + 5, + "example.com", + "supersecret", + nil, + ) + require.NoError(t, err) + return gen +} + +func TestNewAccessToken(t *testing.T) { + gen := newTestGenerator(t) + + tokenStr, exp, err := gen.NewAccess(123, true, false) + require.NoError(t, err) + require.NotEmpty(t, tokenStr) + require.Greater(t, exp, int64(0)) +} + +func TestNewRefreshToken(t *testing.T) { + gen := newTestGenerator(t) + + tokenStr, exp, err := gen.NewRefresh(123, true) + require.NoError(t, err) + require.NotEmpty(t, tokenStr) + require.Greater(t, exp, int64(0)) +} diff --git a/jwt/tokens.go b/jwt/tokens.go new file mode 100644 index 0000000..8754ccf --- /dev/null +++ b/jwt/tokens.go @@ -0,0 +1,78 @@ +package jwt + +import ( + "context" + "database/sql" + + "github.com/google/uuid" +) + +type Token interface { + GetJTI() uuid.UUID + GetEXP() int64 + GetScope() string + getDB() *sql.DB + Revoke(context.Context) error +} + +// Access token +type AccessToken struct { + ISS string // Issuer, generally TrustedHost + IAT int64 // Time issued at + EXP int64 // Time expiring at + TTL string // Time-to-live: "session" or "exp". Used with 'remember me' + SUB int // Subject (user) ID + JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens + Fresh int64 // Time freshness expiring at + Scope string // Should be "access" + db *sql.DB +} + +// Refresh token +type RefreshToken struct { + ISS string // Issuer, generally TrustedHost + IAT int64 // Time issued at + EXP int64 // Time expiring at + TTL string // Time-to-live: "session" or "exp". Used with 'remember me' + SUB int // Subject (user) ID + JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens + Scope string // Should be "refresh" + db *sql.DB +} + +func (a AccessToken) GetJTI() uuid.UUID { + return a.JTI +} +func (r RefreshToken) GetJTI() uuid.UUID { + return r.JTI +} +func (a AccessToken) GetEXP() int64 { + return a.EXP +} +func (r RefreshToken) GetEXP() int64 { + return r.EXP +} +func (a AccessToken) GetScope() string { + return a.Scope +} +func (r RefreshToken) GetScope() string { + return r.Scope +} +func (a AccessToken) getDB() *sql.DB { + return a.db +} +func (r RefreshToken) getDB() *sql.DB { + return r.db +} +func (a AccessToken) Revoke(ctx context.Context) error { + return revoke(ctx, a) +} +func (r RefreshToken) Revoke(ctx context.Context) error { + return revoke(ctx, r) +} +func (a AccessToken) CheckNotRevoked(ctx context.Context) (bool, error) { + return checkNotRevoked(ctx, a) +} +func (r RefreshToken) CheckNotRevoked(ctx context.Context) (bool, error) { + return checkNotRevoked(ctx, r) +} diff --git a/jwt/util.go b/jwt/util.go new file mode 100644 index 0000000..7a4fd8d --- /dev/null +++ b/jwt/util.go @@ -0,0 +1,123 @@ +package jwt + +import ( + "fmt" + "time" + + "github.com/golang-jwt/jwt" + "github.com/google/uuid" + "github.com/pkg/errors" +) + +// Parse a token, validating its signing sigature and returning the claims +func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(secretKey), nil + }) + if err != nil { + return nil, errors.Wrap(err, "jwt.Parse") + } + // Token decoded, parse the claims + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("Failed to parse claims") + } + return claims, nil +} + +// Check if a token is expired. Returns the expiry if not expired +func checkTokenExpired(expiry interface{}) (int64, error) { + // Coerce the expiry to a float64 to avoid scientific notation + expFloat, ok := expiry.(float64) + if !ok { + return 0, errors.New("Missing or invalid 'exp' claim") + } + // Convert to the int64 time we expect :) + expiryTime := int64(expFloat) + + // Check if its expired + isExpired := time.Now().After(time.Unix(expiryTime, 0)) + if isExpired { + return 0, errors.New("Token has expired") + } + return expiryTime, nil +} + +// Check if a token has a valid issuer. Returns the issuer if valid +func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) { + issuerVal, ok := issuer.(string) + if !ok { + return "", errors.New("Missing or invalid 'iss' claim") + } + if issuer != trustedHost { + return "", errors.New("Issuer does not matched trusted host") + } + return issuerVal, nil +} + +// Check the scope matches the expected scope. Returns scope if true +func getTokenScope(scope interface{}) (string, error) { + scopeStr, ok := scope.(string) + if !ok { + return "", errors.New("Missing or invalid 'scope' claim") + } + return scopeStr, nil +} + +// Get the TTL of the token, either "session" or "exp" +func getTokenTTL(ttl interface{}) (string, error) { + ttlStr, ok := ttl.(string) + if !ok { + return "", errors.New("Missing or invalid 'ttl' claim") + } + if ttlStr != "exp" && ttlStr != "session" { + return "", errors.New("TTL value is not recognised") + } + return ttlStr, nil +} + +// Get the time the token was issued at +func getIssuedTime(issued interface{}) (int64, error) { + // Same float64 -> int64 trick as expiry + issuedFloat, ok := issued.(float64) + if !ok { + return 0, errors.New("Missing or invalid 'iat' claim") + } + issuedAt := int64(issuedFloat) + return issuedAt, nil +} + +// Get the freshness expiry timestamp +func getFreshTime(fresh interface{}) (int64, error) { + freshUntil, ok := fresh.(float64) + if !ok { + return 0, errors.New("Missing or invalid 'fresh' claim") + } + return int64(freshUntil), nil +} + +// Get the subject of the token +func getTokenSubject(sub interface{}) (int, error) { + subject, ok := sub.(float64) + if !ok { + return 0, errors.New("Missing or invalid 'sub' claim") + } + return int(subject), nil +} + +// Get the JTI of the token +func getTokenJTI(jti interface{}) (uuid.UUID, error) { + jtiStr, ok := jti.(string) + if !ok { + return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim") + } + jtiUUID, err := uuid.Parse(jtiStr) + if err != nil { + return uuid.UUID{}, errors.New("JTI is not a valid UUID") + } + return jtiUUID, nil +} diff --git a/jwt/validate.go b/jwt/validate.go new file mode 100644 index 0000000..c64a6f8 --- /dev/null +++ b/jwt/validate.go @@ -0,0 +1,145 @@ +package jwt + +import ( + "context" + "github.com/pkg/errors" +) + +// Parse an access token and return a struct with all the claims. Does validation on +// all the claims, including checking if it is expired, has a valid issuer, and +// has the correct scope. +func (gen *TokenGenerator) ValidateAccess( + ctx context.Context, + tokenString string, +) (*AccessToken, error) { + if tokenString == "" { + return nil, errors.New("Access token string not provided") + } + claims, err := parseToken(gen.secretKey, tokenString) + if err != nil { + return nil, errors.Wrap(err, "parseToken") + } + expiry, err := checkTokenExpired(claims["exp"]) + if err != nil { + return nil, errors.Wrap(err, "checkTokenExpired") + } + issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"]) + if err != nil { + return nil, errors.Wrap(err, "checkTokenIssuer") + } + ttl, err := getTokenTTL(claims["ttl"]) + if err != nil { + return nil, errors.Wrap(err, "getTokenTTL") + } + scope, err := getTokenScope(claims["scope"]) + if err != nil { + return nil, errors.Wrap(err, "getTokenScope") + } + if scope != "access" { + return nil, errors.New("Token is not an Access token") + } + issuedAt, err := getIssuedTime(claims["iat"]) + if err != nil { + return nil, errors.Wrap(err, "getIssuedTime") + } + subject, err := getTokenSubject(claims["sub"]) + if err != nil { + return nil, errors.Wrap(err, "getTokenSubject") + } + fresh, err := getFreshTime(claims["fresh"]) + if err != nil { + return nil, errors.Wrap(err, "getFreshTime") + } + jti, err := getTokenJTI(claims["jti"]) + if err != nil { + return nil, errors.Wrap(err, "getTokenJTI") + } + + token := &AccessToken{ + ISS: issuer, + TTL: ttl, + EXP: expiry, + IAT: issuedAt, + SUB: subject, + Fresh: fresh, + JTI: jti, + Scope: scope, + db: gen.dbConn, + } + + valid, err := token.CheckNotRevoked(ctx) + if err != nil && gen.dbConn != nil { + return nil, errors.Wrap(err, "token.CheckNotRevoked") + } + if !valid && gen.dbConn != nil { + return nil, errors.New("Token has been revoked") + } + return token, nil +} + +// Parse a refresh token and return a struct with all the claims. Does validation on +// all the claims, including checking if it is expired, has a valid issuer, and +// has the correct scope. +func (gen *TokenGenerator) ValidateRefresh( + ctx context.Context, + tokenString string, +) (*RefreshToken, error) { + if tokenString == "" { + return nil, errors.New("Refresh token string not provided") + } + claims, err := parseToken(gen.secretKey, tokenString) + if err != nil { + return nil, errors.Wrap(err, "parseToken") + } + expiry, err := checkTokenExpired(claims["exp"]) + if err != nil { + return nil, errors.Wrap(err, "checkTokenExpired") + } + issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"]) + if err != nil { + return nil, errors.Wrap(err, "checkTokenIssuer") + } + ttl, err := getTokenTTL(claims["ttl"]) + if err != nil { + return nil, errors.Wrap(err, "getTokenTTL") + } + scope, err := getTokenScope(claims["scope"]) + if err != nil { + return nil, errors.Wrap(err, "getTokenScope") + } + if scope != "refresh" { + return nil, errors.New("Token is not an Refresh token") + } + issuedAt, err := getIssuedTime(claims["iat"]) + if err != nil { + return nil, errors.Wrap(err, "getIssuedTime") + } + subject, err := getTokenSubject(claims["sub"]) + if err != nil { + return nil, errors.Wrap(err, "getTokenSubject") + } + jti, err := getTokenJTI(claims["jti"]) + if err != nil { + return nil, errors.Wrap(err, "getTokenJTI") + } + + token := &RefreshToken{ + ISS: issuer, + TTL: ttl, + EXP: expiry, + IAT: issuedAt, + SUB: subject, + JTI: jti, + Scope: scope, + db: gen.dbConn, + } + + valid, err := token.CheckNotRevoked(ctx) + if err != nil && gen.dbConn != nil { + return nil, errors.Wrap(err, "token.CheckNotRevoked") + } + if !valid && gen.dbConn != nil { + return nil, errors.New("Token has been revoked") + } + return token, nil +} diff --git a/jwt/validate_test.go b/jwt/validate_test.go new file mode 100644 index 0000000..79d7808 --- /dev/null +++ b/jwt/validate_test.go @@ -0,0 +1,107 @@ +package jwt + +import ( + "context" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + + gen, err := CreateGenerator( + 15, + 60, + 5, + "example.com", + "supersecret", + db, + ) + require.NoError(t, err) + + return gen, mock, func() { db.Close() } +} + +func expectNotRevoked(mock sqlmock.Sqlmock, jti any) { + mock.ExpectBegin() + mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`). + WithArgs(jti). + WillReturnRows(sqlmock.NewRows([]string{})) + mock.ExpectCommit() +} + +func TestValidateAccess_Success(t *testing.T) { + gen, mock, cleanup := newGeneratorWithMockDB(t) + defer cleanup() + + tokenStr, _, err := gen.NewAccess(42, true, false) + require.NoError(t, err) + + // We don't know the JTI beforehand; match any arg + expectNotRevoked(mock, sqlmock.AnyArg()) + + token, err := gen.ValidateAccess(context.Background(), tokenStr) + require.NoError(t, err) + require.Equal(t, 42, token.SUB) + require.Equal(t, "access", token.Scope) +} + +func TestValidateAccess_NoDB(t *testing.T) { + gen := newGeneratorWithNoDB(t) + + tokenStr, _, err := gen.NewAccess(42, true, false) + require.NoError(t, err) + + token, err := gen.ValidateAccess(context.Background(), tokenStr) + require.NoError(t, err) + require.Equal(t, 42, token.SUB) + require.Equal(t, "access", token.Scope) +} + +func TestValidateRefresh_Success(t *testing.T) { + gen, mock, cleanup := newGeneratorWithMockDB(t) + defer cleanup() + + tokenStr, _, err := gen.NewRefresh(42, false) + require.NoError(t, err) + + expectNotRevoked(mock, sqlmock.AnyArg()) + + token, err := gen.ValidateRefresh(context.Background(), tokenStr) + require.NoError(t, err) + require.Equal(t, 42, token.SUB) + require.Equal(t, "refresh", token.Scope) +} + +func TestValidateRefresh_NoDB(t *testing.T) { + gen := newGeneratorWithNoDB(t) + + tokenStr, _, err := gen.NewRefresh(42, false) + require.NoError(t, err) + + token, err := gen.ValidateRefresh(context.Background(), tokenStr) + require.NoError(t, err) + require.Equal(t, 42, token.SUB) + require.Equal(t, "refresh", token.Scope) +} + +func TestValidateAccess_EmptyToken(t *testing.T) { + gen := newTestGenerator(t) + + _, err := gen.ValidateAccess(context.Background(), "") + require.Error(t, err) +} + +func TestValidateRefresh_WrongScope(t *testing.T) { + gen := newTestGenerator(t) + + // Create access token but validate as refresh + tokenStr, _, err := gen.NewAccess(1, false, false) + require.NoError(t, err) + + _, err = gen.ValidateRefresh(context.Background(), tokenStr) + require.Error(t, err) +}