refactor to improve database operability

This commit is contained in:
2026-01-11 22:21:44 +11:00
parent 1b25e2f0a5
commit ae4094d426
13 changed files with 136 additions and 57 deletions

1
jwt/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
.claude/

21
jwt/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 haelnorr
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -28,6 +28,7 @@ go get git.haelnorr.com/h/golib/jwt
package main
import (
"context"
"database/sql"
"git.haelnorr.com/h/golib/jwt"
_ "github.com/lib/pq"
@@ -38,8 +39,10 @@ func main() {
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
defer db.Close()
// Wrap database connection
dbConn := jwt.NewDBConnection(db)
// Create a transaction getter function
txGetter := func(ctx context.Context) (jwt.DBTransaction, error) {
return db.Begin()
}
// Create token generator
gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
@@ -48,13 +51,13 @@ func main() {
FreshExpireAfter: 5, // 5 minutes
TrustedHost: "example.com",
SecretKey: "your-secret-key",
DBConn: dbConn,
DB: db,
DBType: jwt.DatabaseType{
Type: jwt.DatabasePostgreSQL,
Version: "15",
},
TableConfig: jwt.DefaultTableConfig(),
})
}, txGetter)
if err != nil {
panic(err)
}
@@ -64,7 +67,7 @@ func main() {
refreshToken, _, _ := gen.NewRefresh(42, false)
// Validate token
tx, _ := dbConn.BeginTx(context.Background(), nil)
tx, _ := db.Begin()
token, _ := gen.ValidateAccess(tx, accessToken)
// Revoke token

View File

@@ -1,5 +1,23 @@
package jwt
import (
"context"
"database/sql"
)
// DBTransaction represents a database transaction that can execute queries.
// This interface is compatible with *sql.Tx and can be implemented by ORM transactions
// from libraries like GORM (gormDB.Begin()), Bun (bunDB.Begin()), etc.
type DBTransaction interface {
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
Commit() error
Rollback() error
}
// BeginTX represents a wrapper function that is used to start a transaction with any dependencies injected
type BeginTX func(ctx context.Context) (DBTransaction, error)
// DatabaseType specifies the database system and version being used.
type DatabaseType struct {
Type string // Database type: "postgres", "mysql", "sqlite", "mariadb"

View File

@@ -39,7 +39,7 @@
// accessToken, accessExp, err := gen.NewAccess(userID, true, false)
// refreshToken, refreshExp, err := gen.NewRefresh(userID, false)
//
// Validate tokens:
// Validate tokens (using standard library):
//
// tx, _ := db.Begin()
// token, err := gen.ValidateAccess(tx, accessToken)
@@ -48,6 +48,13 @@
// }
// tx.Commit()
//
// Validate tokens (using ORM like GORM):
//
// tx := gormDB.Begin()
// token, err := gen.ValidateAccess(tx.Statement.ConnPool, accessToken)
// // or with Bun: gen.ValidateAccess(bunDB.BeginTx(ctx, nil), accessToken)
// tx.Commit()
//
// Revoke tokens:
//
// tx, _ := db.Begin()
@@ -84,21 +91,29 @@
// The package works with popular ORMs by using raw SQL queries. For GORM and Bun,
// wrap the underlying *sql.DB with NewDBConnection() when creating the generator:
//
// // GORM example
// // GORM example - can use GORM transactions directly
// gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{})
// sqlDB, _ := gormDB.DB()
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
// // ... config ...
// DB: sqlDB,
// })
// // Use GORM transaction
// tx := gormDB.Begin()
// token, _ := gen.ValidateAccess(tx.Statement.ConnPool, tokenString)
// tx.Commit()
//
// // Bun example
// // Bun example - can use Bun transactions directly
// sqlDB, _ := sql.Open("postgres", dsn)
// bunDB := bun.NewDB(sqlDB, pgdialect.New())
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
// // ... config ...
// DB: sqlDB,
// })
// // Use Bun transaction
// tx, _ := bunDB.BeginTx(context.Background(), nil)
// token, _ := gen.ValidateAccess(tx, tokenString)
// tx.Commit()
//
// # Token Freshness
//

View File

@@ -15,7 +15,7 @@ type TokenGenerator struct {
freshExpireAfter int64 // Token freshness expiry time in minutes
trustedHost string // Trusted hostname to use for the tokens
secretKey string // Secret key to use for token hashing
db *sql.DB // Database connection for token blacklisting
beginTx BeginTX // Database transaction getter for token blacklisting
tableConfig TableConfig // Table configuration
tableManager *TableManager // Table lifecycle manager
}
@@ -51,7 +51,7 @@ type GeneratorConfig struct {
}
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
func CreateGenerator(config GeneratorConfig, txGetter BeginTX) (gen *TokenGenerator, err error) {
if config.AccessExpireAfter <= 0 {
return nil, errors.New("accessExpireAfter must be greater than 0")
}
@@ -102,7 +102,7 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
freshExpireAfter: config.FreshExpireAfter,
trustedHost: config.TrustedHost,
secretKey: config.SecretKey,
db: config.DB,
beginTx: txGetter,
tableConfig: config.TableConfig,
tableManager: tableManager,
}, nil
@@ -112,16 +112,21 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
// This method should be called periodically if automatic cleanup is not enabled,
// or can be called on-demand regardless of automatic cleanup settings.
func (gen *TokenGenerator) Cleanup(ctx context.Context) error {
if gen.db == nil {
if gen.beginTx == nil {
return errors.New("No DB provided, unable to use this function")
}
tx, err := gen.beginTx(ctx)
if err != nil {
return pkgerrors.Wrap(err, "failed to begin transaction")
}
tableName := gen.tableConfig.TableName
currentTime := time.Now().Unix()
query := "DELETE FROM " + tableName + " WHERE exp < ?"
_, err := gen.db.ExecContext(ctx, query, currentTime)
_, err = tx.Exec(query, currentTime)
if err != nil {
return pkgerrors.Wrap(err, "failed to cleanup expired tokens")
}

View File

@@ -18,7 +18,7 @@ func TestCreateGenerator_Success_NoDB(t *testing.T) {
DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
})
}, nil)
require.NoError(t, err)
require.NotNil(t, gen)
@@ -33,6 +33,10 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
config.AutoCreate = false
config.EnableAutoCleanup = false
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
@@ -42,7 +46,7 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: config,
})
}, txGetter)
require.NoError(t, err)
require.NotNil(t, gen)
@@ -67,6 +71,10 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
mock.ExpectExec("CREATE OR REPLACE FUNCTION cleanup_jwtblacklist").
WillReturnResult(sqlmock.NewResult(0, 0))
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
@@ -76,7 +84,7 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
})
}, txGetter)
require.NoError(t, err)
require.NotNil(t, gen)
@@ -142,7 +150,7 @@ func TestCreateGenerator_InvalidInputs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := CreateGenerator(tt.config)
_, err := CreateGenerator(tt.config, nil)
require.Error(t, err)
})
}
@@ -158,7 +166,7 @@ func TestCleanup_NoDB(t *testing.T) {
DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
})
}, nil)
require.NoError(t, err)
err = gen.Cleanup(context.Background())
@@ -175,6 +183,10 @@ func TestCleanup_Success(t *testing.T) {
config.AutoCreate = false
config.EnableAutoCleanup = false
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
@@ -184,10 +196,11 @@ func TestCleanup_Success(t *testing.T) {
DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: config,
})
}, txGetter)
require.NoError(t, err)
// Mock DELETE query
// Mock transaction begin and DELETE query
mock.ExpectBegin()
mock.ExpectExec("DELETE FROM jwtblacklist WHERE exp").
WillReturnResult(sqlmock.NewResult(0, 5))

View File

@@ -1,8 +1,6 @@
package jwt
import (
"context"
"database/sql"
"fmt"
"github.com/pkg/errors"
@@ -11,8 +9,8 @@ import (
// revoke is an internal method that adds a token to the blacklist database.
// Once revoked, the token will fail validation checks even if it hasn't expired.
// This operation must be performed within a database transaction.
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
if gen.db == nil {
func (gen *TokenGenerator) revoke(tx DBTransaction, t Token) error {
if gen.beginTx == nil {
return errors.New("No DB provided, unable to use this function")
}
@@ -22,7 +20,7 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
sub := t.GetSUB()
query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName)
_, err := tx.ExecContext(context.Background(), query, jti.String(), exp, sub)
_, err := tx.Exec(query, jti.String(), exp, sub)
if err != nil {
return errors.Wrap(err, "tx.ExecContext")
}
@@ -32,8 +30,8 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
// checkNotRevoked is an internal method that queries the blacklist to verify
// a token hasn't been revoked. Returns true if the token is valid (not blacklisted),
// false if it has been revoked. This operation must be performed within a database transaction.
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
if gen.db == nil {
func (gen *TokenGenerator) checkNotRevoked(tx DBTransaction, t Token) (bool, error) {
if gen.beginTx == nil {
return false, errors.New("No DB provided, unable to use this function")
}
@@ -41,7 +39,7 @@ func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
jti := t.GetJTI()
query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName)
rows, err := tx.QueryContext(context.Background(), query, jti.String())
rows, err := tx.Query(query, jti.String())
if err != nil {
return false, errors.Wrap(err, "tx.QueryContext")
}

View File

@@ -1,6 +1,7 @@
package jwt
import (
"context"
"database/sql"
"testing"
"time"
@@ -20,13 +21,13 @@ func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
})
}, nil)
require.NoError(t, err)
return gen
}
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, *sql.DB, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
@@ -34,6 +35,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun
config.AutoCreate = false
config.EnableAutoCleanup = false
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
@@ -43,10 +48,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun
DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: config,
})
}, txGetter)
require.NoError(t, err)
return gen, mock, func() { db.Close() }
return gen, db, mock, func() { db.Close() }
}
func TestNoDBFail(t *testing.T) {
@@ -73,7 +78,7 @@ func TestNoDBFail(t *testing.T) {
}
func TestRevokeAndCheckNotRevoked(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t)
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup()
jti := uuid.New()
@@ -97,7 +102,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectCommit()
tx, err := gen.db.Begin()
tx, err := db.Begin()
defer tx.Rollback()
require.NoError(t, err)

View File

@@ -16,7 +16,7 @@ func newTestGenerator(t *testing.T) *TokenGenerator {
DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
})
}, nil)
require.NoError(t, err)
return gen
}

View File

@@ -1,8 +1,6 @@
package jwt
import (
"database/sql"
"github.com/google/uuid"
)
@@ -23,12 +21,14 @@ type Token interface {
// Revoke adds this token to the blacklist, preventing future use.
// Must be called within a database transaction context.
Revoke(*sql.Tx) error
// Accepts any transaction type that implements DBTransaction interface.
Revoke(DBTransaction) error
// CheckNotRevoked verifies that this token has not been blacklisted.
// Returns true if the token is valid, false if revoked.
// Must be called within a database transaction context.
CheckNotRevoked(*sql.Tx) (bool, error)
// Accepts any transaction type that implements DBTransaction interface.
CheckNotRevoked(DBTransaction) (bool, error)
}
// AccessToken represents a JWT access token with all its claims.
@@ -84,15 +84,15 @@ func (a AccessToken) GetScope() string {
func (r RefreshToken) GetScope() string {
return r.Scope
}
func (a AccessToken) Revoke(tx *sql.Tx) error {
func (a AccessToken) Revoke(tx DBTransaction) error {
return a.gen.revoke(tx, a)
}
func (r RefreshToken) Revoke(tx *sql.Tx) error {
func (r RefreshToken) Revoke(tx DBTransaction) error {
return r.gen.revoke(tx, r)
}
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
func (a AccessToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
return a.gen.checkNotRevoked(tx, a)
}
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
func (r RefreshToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
return r.gen.checkNotRevoked(tx, r)
}

View File

@@ -1,8 +1,6 @@
package jwt
import (
"database/sql"
"github.com/pkg/errors"
)
@@ -20,14 +18,15 @@ import (
// revocation check is skipped.
//
// Parameters:
// - tx: Database transaction for checking token revocation status
// - tx: Database transaction for checking token revocation status.
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
// - tokenString: The JWT token string to validate
//
// Returns:
// - *AccessToken: The validated token with all claims, or nil if validation fails
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
func (gen *TokenGenerator) ValidateAccess(
tx *sql.Tx,
tx DBTransaction,
tokenString string,
) (*AccessToken, error) {
if tokenString == "" {
@@ -86,10 +85,10 @@ func (gen *TokenGenerator) ValidateAccess(
}
valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.db != nil {
if err != nil && gen.beginTx != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked")
}
if !valid && gen.db != nil {
if !valid && gen.beginTx != nil {
return nil, errors.New("Token has been revoked")
}
return token, nil
@@ -109,14 +108,15 @@ func (gen *TokenGenerator) ValidateAccess(
// revocation check is skipped.
//
// Parameters:
// - tx: Database transaction for checking token revocation status
// - tx: Database transaction for checking token revocation status.
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
// - tokenString: The JWT token string to validate
//
// Returns:
// - *RefreshToken: The validated token with all claims, or nil if validation fails
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
func (gen *TokenGenerator) ValidateRefresh(
tx *sql.Tx,
tx DBTransaction,
tokenString string,
) (*RefreshToken, error) {
if tokenString == "" {
@@ -170,10 +170,10 @@ func (gen *TokenGenerator) ValidateRefresh(
}
valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.db != nil {
if err != nil && gen.beginTx != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked")
}
if !valid && gen.db != nil {
if !valid && gen.beginTx != nil {
return nil, errors.New("Token has been revoked")
}
return token, nil

View File

@@ -17,7 +17,7 @@ func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
}
func TestValidateAccess_Success(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t)
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup()
tokenStr, _, err := gen.NewAccess(42, true, false)
@@ -26,7 +26,7 @@ func TestValidateAccess_Success(t *testing.T) {
// We don't know the JTI beforehand; match any arg
expectNotRevoked(mock, sqlmock.AnyArg())
tx, err := gen.db.Begin()
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
@@ -53,7 +53,7 @@ func TestValidateAccess_NoDB(t *testing.T) {
}
func TestValidateRefresh_Success(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t)
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup()
tokenStr, _, err := gen.NewRefresh(42, false)
@@ -61,7 +61,7 @@ func TestValidateRefresh_Success(t *testing.T) {
expectNotRevoked(mock, sqlmock.AnyArg())
tx, err := gen.db.Begin()
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()