refactor to improve database operability

This commit is contained in:
2026-01-11 22:21:44 +11:00
parent 1b25e2f0a5
commit ae4094d426
13 changed files with 136 additions and 57 deletions

1
jwt/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
.claude/

21
jwt/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 haelnorr
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -28,6 +28,7 @@ go get git.haelnorr.com/h/golib/jwt
package main package main
import ( import (
"context"
"database/sql" "database/sql"
"git.haelnorr.com/h/golib/jwt" "git.haelnorr.com/h/golib/jwt"
_ "github.com/lib/pq" _ "github.com/lib/pq"
@@ -38,8 +39,10 @@ func main() {
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db") db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
defer db.Close() defer db.Close()
// Wrap database connection // Create a transaction getter function
dbConn := jwt.NewDBConnection(db) txGetter := func(ctx context.Context) (jwt.DBTransaction, error) {
return db.Begin()
}
// Create token generator // Create token generator
gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{ gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
@@ -48,13 +51,13 @@ func main() {
FreshExpireAfter: 5, // 5 minutes FreshExpireAfter: 5, // 5 minutes
TrustedHost: "example.com", TrustedHost: "example.com",
SecretKey: "your-secret-key", SecretKey: "your-secret-key",
DBConn: dbConn, DB: db,
DBType: jwt.DatabaseType{ DBType: jwt.DatabaseType{
Type: jwt.DatabasePostgreSQL, Type: jwt.DatabasePostgreSQL,
Version: "15", Version: "15",
}, },
TableConfig: jwt.DefaultTableConfig(), TableConfig: jwt.DefaultTableConfig(),
}) }, txGetter)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -64,7 +67,7 @@ func main() {
refreshToken, _, _ := gen.NewRefresh(42, false) refreshToken, _, _ := gen.NewRefresh(42, false)
// Validate token // Validate token
tx, _ := dbConn.BeginTx(context.Background(), nil) tx, _ := db.Begin()
token, _ := gen.ValidateAccess(tx, accessToken) token, _ := gen.ValidateAccess(tx, accessToken)
// Revoke token // Revoke token

View File

@@ -1,5 +1,23 @@
package jwt package jwt
import (
"context"
"database/sql"
)
// DBTransaction represents a database transaction that can execute queries.
// This interface is compatible with *sql.Tx and can be implemented by ORM transactions
// from libraries like GORM (gormDB.Begin()), Bun (bunDB.Begin()), etc.
type DBTransaction interface {
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
Commit() error
Rollback() error
}
// BeginTX represents a wrapper function that is used to start a transaction with any dependencies injected
type BeginTX func(ctx context.Context) (DBTransaction, error)
// DatabaseType specifies the database system and version being used. // DatabaseType specifies the database system and version being used.
type DatabaseType struct { type DatabaseType struct {
Type string // Database type: "postgres", "mysql", "sqlite", "mariadb" Type string // Database type: "postgres", "mysql", "sqlite", "mariadb"

View File

@@ -39,7 +39,7 @@
// accessToken, accessExp, err := gen.NewAccess(userID, true, false) // accessToken, accessExp, err := gen.NewAccess(userID, true, false)
// refreshToken, refreshExp, err := gen.NewRefresh(userID, false) // refreshToken, refreshExp, err := gen.NewRefresh(userID, false)
// //
// Validate tokens: // Validate tokens (using standard library):
// //
// tx, _ := db.Begin() // tx, _ := db.Begin()
// token, err := gen.ValidateAccess(tx, accessToken) // token, err := gen.ValidateAccess(tx, accessToken)
@@ -48,6 +48,13 @@
// } // }
// tx.Commit() // tx.Commit()
// //
// Validate tokens (using ORM like GORM):
//
// tx := gormDB.Begin()
// token, err := gen.ValidateAccess(tx.Statement.ConnPool, accessToken)
// // or with Bun: gen.ValidateAccess(bunDB.BeginTx(ctx, nil), accessToken)
// tx.Commit()
//
// Revoke tokens: // Revoke tokens:
// //
// tx, _ := db.Begin() // tx, _ := db.Begin()
@@ -84,21 +91,29 @@
// The package works with popular ORMs by using raw SQL queries. For GORM and Bun, // The package works with popular ORMs by using raw SQL queries. For GORM and Bun,
// wrap the underlying *sql.DB with NewDBConnection() when creating the generator: // wrap the underlying *sql.DB with NewDBConnection() when creating the generator:
// //
// // GORM example // // GORM example - can use GORM transactions directly
// gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{}) // gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{})
// sqlDB, _ := gormDB.DB() // sqlDB, _ := gormDB.DB()
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{ // gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
// // ... config ... // // ... config ...
// DB: sqlDB, // DB: sqlDB,
// }) // })
// // Use GORM transaction
// tx := gormDB.Begin()
// token, _ := gen.ValidateAccess(tx.Statement.ConnPool, tokenString)
// tx.Commit()
// //
// // Bun example // // Bun example - can use Bun transactions directly
// sqlDB, _ := sql.Open("postgres", dsn) // sqlDB, _ := sql.Open("postgres", dsn)
// bunDB := bun.NewDB(sqlDB, pgdialect.New()) // bunDB := bun.NewDB(sqlDB, pgdialect.New())
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{ // gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
// // ... config ... // // ... config ...
// DB: sqlDB, // DB: sqlDB,
// }) // })
// // Use Bun transaction
// tx, _ := bunDB.BeginTx(context.Background(), nil)
// token, _ := gen.ValidateAccess(tx, tokenString)
// tx.Commit()
// //
// # Token Freshness // # Token Freshness
// //

View File

@@ -15,7 +15,7 @@ type TokenGenerator struct {
freshExpireAfter int64 // Token freshness expiry time in minutes freshExpireAfter int64 // Token freshness expiry time in minutes
trustedHost string // Trusted hostname to use for the tokens trustedHost string // Trusted hostname to use for the tokens
secretKey string // Secret key to use for token hashing secretKey string // Secret key to use for token hashing
db *sql.DB // Database connection for token blacklisting beginTx BeginTX // Database transaction getter for token blacklisting
tableConfig TableConfig // Table configuration tableConfig TableConfig // Table configuration
tableManager *TableManager // Table lifecycle manager tableManager *TableManager // Table lifecycle manager
} }
@@ -51,7 +51,7 @@ type GeneratorConfig struct {
} }
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration. // CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) { func CreateGenerator(config GeneratorConfig, txGetter BeginTX) (gen *TokenGenerator, err error) {
if config.AccessExpireAfter <= 0 { if config.AccessExpireAfter <= 0 {
return nil, errors.New("accessExpireAfter must be greater than 0") return nil, errors.New("accessExpireAfter must be greater than 0")
} }
@@ -102,7 +102,7 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
freshExpireAfter: config.FreshExpireAfter, freshExpireAfter: config.FreshExpireAfter,
trustedHost: config.TrustedHost, trustedHost: config.TrustedHost,
secretKey: config.SecretKey, secretKey: config.SecretKey,
db: config.DB, beginTx: txGetter,
tableConfig: config.TableConfig, tableConfig: config.TableConfig,
tableManager: tableManager, tableManager: tableManager,
}, nil }, nil
@@ -112,16 +112,21 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
// This method should be called periodically if automatic cleanup is not enabled, // This method should be called periodically if automatic cleanup is not enabled,
// or can be called on-demand regardless of automatic cleanup settings. // or can be called on-demand regardless of automatic cleanup settings.
func (gen *TokenGenerator) Cleanup(ctx context.Context) error { func (gen *TokenGenerator) Cleanup(ctx context.Context) error {
if gen.db == nil { if gen.beginTx == nil {
return errors.New("No DB provided, unable to use this function") return errors.New("No DB provided, unable to use this function")
} }
tx, err := gen.beginTx(ctx)
if err != nil {
return pkgerrors.Wrap(err, "failed to begin transaction")
}
tableName := gen.tableConfig.TableName tableName := gen.tableConfig.TableName
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
query := "DELETE FROM " + tableName + " WHERE exp < ?" query := "DELETE FROM " + tableName + " WHERE exp < ?"
_, err := gen.db.ExecContext(ctx, query, currentTime) _, err = tx.Exec(query, currentTime)
if err != nil { if err != nil {
return pkgerrors.Wrap(err, "failed to cleanup expired tokens") return pkgerrors.Wrap(err, "failed to cleanup expired tokens")
} }

View File

@@ -18,7 +18,7 @@ func TestCreateGenerator_Success_NoDB(t *testing.T) {
DB: nil, DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(), TableConfig: DefaultTableConfig(),
}) }, nil)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, gen) require.NotNil(t, gen)
@@ -33,6 +33,10 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
config.AutoCreate = false config.AutoCreate = false
config.EnableAutoCleanup = false config.EnableAutoCleanup = false
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{ gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15, AccessExpireAfter: 15,
RefreshExpireAfter: 60, RefreshExpireAfter: 60,
@@ -42,7 +46,7 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
DB: db, DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: config, TableConfig: config,
}) }, txGetter)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, gen) require.NotNil(t, gen)
@@ -67,6 +71,10 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
mock.ExpectExec("CREATE OR REPLACE FUNCTION cleanup_jwtblacklist"). mock.ExpectExec("CREATE OR REPLACE FUNCTION cleanup_jwtblacklist").
WillReturnResult(sqlmock.NewResult(0, 0)) WillReturnResult(sqlmock.NewResult(0, 0))
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{ gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15, AccessExpireAfter: 15,
RefreshExpireAfter: 60, RefreshExpireAfter: 60,
@@ -76,7 +84,7 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
DB: db, DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(), TableConfig: DefaultTableConfig(),
}) }, txGetter)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, gen) require.NotNil(t, gen)
@@ -142,7 +150,7 @@ func TestCreateGenerator_InvalidInputs(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := CreateGenerator(tt.config) _, err := CreateGenerator(tt.config, nil)
require.Error(t, err) require.Error(t, err)
}) })
} }
@@ -158,7 +166,7 @@ func TestCleanup_NoDB(t *testing.T) {
DB: nil, DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(), TableConfig: DefaultTableConfig(),
}) }, nil)
require.NoError(t, err) require.NoError(t, err)
err = gen.Cleanup(context.Background()) err = gen.Cleanup(context.Background())
@@ -175,6 +183,10 @@ func TestCleanup_Success(t *testing.T) {
config.AutoCreate = false config.AutoCreate = false
config.EnableAutoCleanup = false config.EnableAutoCleanup = false
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{ gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15, AccessExpireAfter: 15,
RefreshExpireAfter: 60, RefreshExpireAfter: 60,
@@ -184,10 +196,11 @@ func TestCleanup_Success(t *testing.T) {
DB: db, DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: config, TableConfig: config,
}) }, txGetter)
require.NoError(t, err) require.NoError(t, err)
// Mock DELETE query // Mock transaction begin and DELETE query
mock.ExpectBegin()
mock.ExpectExec("DELETE FROM jwtblacklist WHERE exp"). mock.ExpectExec("DELETE FROM jwtblacklist WHERE exp").
WillReturnResult(sqlmock.NewResult(0, 5)) WillReturnResult(sqlmock.NewResult(0, 5))

View File

@@ -1,8 +1,6 @@
package jwt package jwt
import ( import (
"context"
"database/sql"
"fmt" "fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -11,8 +9,8 @@ import (
// revoke is an internal method that adds a token to the blacklist database. // revoke is an internal method that adds a token to the blacklist database.
// Once revoked, the token will fail validation checks even if it hasn't expired. // Once revoked, the token will fail validation checks even if it hasn't expired.
// This operation must be performed within a database transaction. // This operation must be performed within a database transaction.
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error { func (gen *TokenGenerator) revoke(tx DBTransaction, t Token) error {
if gen.db == nil { if gen.beginTx == nil {
return errors.New("No DB provided, unable to use this function") return errors.New("No DB provided, unable to use this function")
} }
@@ -22,7 +20,7 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
sub := t.GetSUB() sub := t.GetSUB()
query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName) query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName)
_, err := tx.ExecContext(context.Background(), query, jti.String(), exp, sub) _, err := tx.Exec(query, jti.String(), exp, sub)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.ExecContext") return errors.Wrap(err, "tx.ExecContext")
} }
@@ -32,8 +30,8 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
// checkNotRevoked is an internal method that queries the blacklist to verify // checkNotRevoked is an internal method that queries the blacklist to verify
// a token hasn't been revoked. Returns true if the token is valid (not blacklisted), // a token hasn't been revoked. Returns true if the token is valid (not blacklisted),
// false if it has been revoked. This operation must be performed within a database transaction. // false if it has been revoked. This operation must be performed within a database transaction.
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) { func (gen *TokenGenerator) checkNotRevoked(tx DBTransaction, t Token) (bool, error) {
if gen.db == nil { if gen.beginTx == nil {
return false, errors.New("No DB provided, unable to use this function") return false, errors.New("No DB provided, unable to use this function")
} }
@@ -41,7 +39,7 @@ func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
jti := t.GetJTI() jti := t.GetJTI()
query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName) query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName)
rows, err := tx.QueryContext(context.Background(), query, jti.String()) rows, err := tx.Query(query, jti.String())
if err != nil { if err != nil {
return false, errors.Wrap(err, "tx.QueryContext") return false, errors.Wrap(err, "tx.QueryContext")
} }

View File

@@ -1,6 +1,7 @@
package jwt package jwt
import ( import (
"context"
"database/sql" "database/sql"
"testing" "testing"
"time" "time"
@@ -20,13 +21,13 @@ func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
DB: nil, DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(), TableConfig: DefaultTableConfig(),
}) }, nil)
require.NoError(t, err) require.NoError(t, err)
return gen return gen
} }
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) { func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, *sql.DB, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New() db, mock, err := sqlmock.New()
require.NoError(t, err) require.NoError(t, err)
@@ -34,6 +35,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun
config.AutoCreate = false config.AutoCreate = false
config.EnableAutoCleanup = false config.EnableAutoCleanup = false
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{ gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15, AccessExpireAfter: 15,
RefreshExpireAfter: 60, RefreshExpireAfter: 60,
@@ -43,10 +48,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun
DB: db, DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: config, TableConfig: config,
}) }, txGetter)
require.NoError(t, err) require.NoError(t, err)
return gen, mock, func() { db.Close() } return gen, db, mock, func() { db.Close() }
} }
func TestNoDBFail(t *testing.T) { func TestNoDBFail(t *testing.T) {
@@ -73,7 +78,7 @@ func TestNoDBFail(t *testing.T) {
} }
func TestRevokeAndCheckNotRevoked(t *testing.T) { func TestRevokeAndCheckNotRevoked(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t) gen, db, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup() defer cleanup()
jti := uuid.New() jti := uuid.New()
@@ -97,7 +102,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1)) WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectCommit() mock.ExpectCommit()
tx, err := gen.db.Begin() tx, err := db.Begin()
defer tx.Rollback() defer tx.Rollback()
require.NoError(t, err) require.NoError(t, err)

View File

@@ -16,7 +16,7 @@ func newTestGenerator(t *testing.T) *TokenGenerator {
DB: nil, DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(), TableConfig: DefaultTableConfig(),
}) }, nil)
require.NoError(t, err) require.NoError(t, err)
return gen return gen
} }

View File

@@ -1,8 +1,6 @@
package jwt package jwt
import ( import (
"database/sql"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -23,12 +21,14 @@ type Token interface {
// Revoke adds this token to the blacklist, preventing future use. // Revoke adds this token to the blacklist, preventing future use.
// Must be called within a database transaction context. // Must be called within a database transaction context.
Revoke(*sql.Tx) error // Accepts any transaction type that implements DBTransaction interface.
Revoke(DBTransaction) error
// CheckNotRevoked verifies that this token has not been blacklisted. // CheckNotRevoked verifies that this token has not been blacklisted.
// Returns true if the token is valid, false if revoked. // Returns true if the token is valid, false if revoked.
// Must be called within a database transaction context. // Must be called within a database transaction context.
CheckNotRevoked(*sql.Tx) (bool, error) // Accepts any transaction type that implements DBTransaction interface.
CheckNotRevoked(DBTransaction) (bool, error)
} }
// AccessToken represents a JWT access token with all its claims. // AccessToken represents a JWT access token with all its claims.
@@ -84,15 +84,15 @@ func (a AccessToken) GetScope() string {
func (r RefreshToken) GetScope() string { func (r RefreshToken) GetScope() string {
return r.Scope return r.Scope
} }
func (a AccessToken) Revoke(tx *sql.Tx) error { func (a AccessToken) Revoke(tx DBTransaction) error {
return a.gen.revoke(tx, a) return a.gen.revoke(tx, a)
} }
func (r RefreshToken) Revoke(tx *sql.Tx) error { func (r RefreshToken) Revoke(tx DBTransaction) error {
return r.gen.revoke(tx, r) return r.gen.revoke(tx, r)
} }
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) { func (a AccessToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
return a.gen.checkNotRevoked(tx, a) return a.gen.checkNotRevoked(tx, a)
} }
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) { func (r RefreshToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
return r.gen.checkNotRevoked(tx, r) return r.gen.checkNotRevoked(tx, r)
} }

View File

@@ -1,8 +1,6 @@
package jwt package jwt
import ( import (
"database/sql"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -20,14 +18,15 @@ import (
// revocation check is skipped. // revocation check is skipped.
// //
// Parameters: // Parameters:
// - tx: Database transaction for checking token revocation status // - tx: Database transaction for checking token revocation status.
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
// - tokenString: The JWT token string to validate // - tokenString: The JWT token string to validate
// //
// Returns: // Returns:
// - *AccessToken: The validated token with all claims, or nil if validation fails // - *AccessToken: The validated token with all claims, or nil if validation fails
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.) // - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
func (gen *TokenGenerator) ValidateAccess( func (gen *TokenGenerator) ValidateAccess(
tx *sql.Tx, tx DBTransaction,
tokenString string, tokenString string,
) (*AccessToken, error) { ) (*AccessToken, error) {
if tokenString == "" { if tokenString == "" {
@@ -86,10 +85,10 @@ func (gen *TokenGenerator) ValidateAccess(
} }
valid, err := token.CheckNotRevoked(tx) valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.db != nil { if err != nil && gen.beginTx != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked") return nil, errors.Wrap(err, "token.CheckNotRevoked")
} }
if !valid && gen.db != nil { if !valid && gen.beginTx != nil {
return nil, errors.New("Token has been revoked") return nil, errors.New("Token has been revoked")
} }
return token, nil return token, nil
@@ -109,14 +108,15 @@ func (gen *TokenGenerator) ValidateAccess(
// revocation check is skipped. // revocation check is skipped.
// //
// Parameters: // Parameters:
// - tx: Database transaction for checking token revocation status // - tx: Database transaction for checking token revocation status.
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
// - tokenString: The JWT token string to validate // - tokenString: The JWT token string to validate
// //
// Returns: // Returns:
// - *RefreshToken: The validated token with all claims, or nil if validation fails // - *RefreshToken: The validated token with all claims, or nil if validation fails
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.) // - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
func (gen *TokenGenerator) ValidateRefresh( func (gen *TokenGenerator) ValidateRefresh(
tx *sql.Tx, tx DBTransaction,
tokenString string, tokenString string,
) (*RefreshToken, error) { ) (*RefreshToken, error) {
if tokenString == "" { if tokenString == "" {
@@ -170,10 +170,10 @@ func (gen *TokenGenerator) ValidateRefresh(
} }
valid, err := token.CheckNotRevoked(tx) valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.db != nil { if err != nil && gen.beginTx != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked") return nil, errors.Wrap(err, "token.CheckNotRevoked")
} }
if !valid && gen.db != nil { if !valid && gen.beginTx != nil {
return nil, errors.New("Token has been revoked") return nil, errors.New("Token has been revoked")
} }
return token, nil return token, nil

View File

@@ -17,7 +17,7 @@ func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
} }
func TestValidateAccess_Success(t *testing.T) { func TestValidateAccess_Success(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t) gen, db, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup() defer cleanup()
tokenStr, _, err := gen.NewAccess(42, true, false) tokenStr, _, err := gen.NewAccess(42, true, false)
@@ -26,7 +26,7 @@ func TestValidateAccess_Success(t *testing.T) {
// We don't know the JTI beforehand; match any arg // We don't know the JTI beforehand; match any arg
expectNotRevoked(mock, sqlmock.AnyArg()) expectNotRevoked(mock, sqlmock.AnyArg())
tx, err := gen.db.Begin() tx, err := db.Begin()
require.NoError(t, err) require.NoError(t, err)
defer tx.Rollback() defer tx.Rollback()
@@ -53,7 +53,7 @@ func TestValidateAccess_NoDB(t *testing.T) {
} }
func TestValidateRefresh_Success(t *testing.T) { func TestValidateRefresh_Success(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t) gen, db, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup() defer cleanup()
tokenStr, _, err := gen.NewRefresh(42, false) tokenStr, _, err := gen.NewRefresh(42, false)
@@ -61,7 +61,7 @@ func TestValidateRefresh_Success(t *testing.T) {
expectNotRevoked(mock, sqlmock.AnyArg()) expectNotRevoked(mock, sqlmock.AnyArg())
tx, err := gen.db.Begin() tx, err := db.Begin()
require.NoError(t, err) require.NoError(t, err)
defer tx.Rollback() defer tx.Rollback()