diff --git a/jwt/.gitignore b/jwt/.gitignore new file mode 100644 index 0000000..4c5f206 --- /dev/null +++ b/jwt/.gitignore @@ -0,0 +1 @@ +.claude/ diff --git a/jwt/LICENSE b/jwt/LICENSE new file mode 100644 index 0000000..fbf1733 --- /dev/null +++ b/jwt/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 haelnorr + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/jwt/README.md b/jwt/README.md index 7ccdda0..184140f 100644 --- a/jwt/README.md +++ b/jwt/README.md @@ -28,6 +28,7 @@ go get git.haelnorr.com/h/golib/jwt package main import ( + "context" "database/sql" "git.haelnorr.com/h/golib/jwt" _ "github.com/lib/pq" @@ -38,8 +39,10 @@ func main() { db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db") defer db.Close() - // Wrap database connection - dbConn := jwt.NewDBConnection(db) + // Create a transaction getter function + txGetter := func(ctx context.Context) (jwt.DBTransaction, error) { + return db.Begin() + } // Create token generator gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{ @@ -48,13 +51,13 @@ func main() { FreshExpireAfter: 5, // 5 minutes TrustedHost: "example.com", SecretKey: "your-secret-key", - DBConn: dbConn, + DB: db, DBType: jwt.DatabaseType{ Type: jwt.DatabasePostgreSQL, Version: "15", }, TableConfig: jwt.DefaultTableConfig(), - }) + }, txGetter) if err != nil { panic(err) } @@ -64,7 +67,7 @@ func main() { refreshToken, _, _ := gen.NewRefresh(42, false) // Validate token - tx, _ := dbConn.BeginTx(context.Background(), nil) + tx, _ := db.Begin() token, _ := gen.ValidateAccess(tx, accessToken) // Revoke token diff --git a/jwt/database.go b/jwt/database.go index 6ea387f..d921a10 100644 --- a/jwt/database.go +++ b/jwt/database.go @@ -1,5 +1,23 @@ package jwt +import ( + "context" + "database/sql" +) + +// DBTransaction represents a database transaction that can execute queries. +// This interface is compatible with *sql.Tx and can be implemented by ORM transactions +// from libraries like GORM (gormDB.Begin()), Bun (bunDB.Begin()), etc. +type DBTransaction interface { + Exec(query string, args ...any) (sql.Result, error) + Query(query string, args ...any) (*sql.Rows, error) + Commit() error + Rollback() error +} + +// BeginTX represents a wrapper function that is used to start a transaction with any dependencies injected +type BeginTX func(ctx context.Context) (DBTransaction, error) + // DatabaseType specifies the database system and version being used. type DatabaseType struct { Type string // Database type: "postgres", "mysql", "sqlite", "mariadb" diff --git a/jwt/doc.go b/jwt/doc.go index 5e93dee..f4c0050 100644 --- a/jwt/doc.go +++ b/jwt/doc.go @@ -39,7 +39,7 @@ // accessToken, accessExp, err := gen.NewAccess(userID, true, false) // refreshToken, refreshExp, err := gen.NewRefresh(userID, false) // -// Validate tokens: +// Validate tokens (using standard library): // // tx, _ := db.Begin() // token, err := gen.ValidateAccess(tx, accessToken) @@ -48,6 +48,13 @@ // } // tx.Commit() // +// Validate tokens (using ORM like GORM): +// +// tx := gormDB.Begin() +// token, err := gen.ValidateAccess(tx.Statement.ConnPool, accessToken) +// // or with Bun: gen.ValidateAccess(bunDB.BeginTx(ctx, nil), accessToken) +// tx.Commit() +// // Revoke tokens: // // tx, _ := db.Begin() @@ -84,21 +91,29 @@ // The package works with popular ORMs by using raw SQL queries. For GORM and Bun, // wrap the underlying *sql.DB with NewDBConnection() when creating the generator: // -// // GORM example +// // GORM example - can use GORM transactions directly // gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{}) // sqlDB, _ := gormDB.DB() // gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{ // // ... config ... // DB: sqlDB, // }) +// // Use GORM transaction +// tx := gormDB.Begin() +// token, _ := gen.ValidateAccess(tx.Statement.ConnPool, tokenString) +// tx.Commit() // -// // Bun example +// // Bun example - can use Bun transactions directly // sqlDB, _ := sql.Open("postgres", dsn) // bunDB := bun.NewDB(sqlDB, pgdialect.New()) // gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{ // // ... config ... // DB: sqlDB, // }) +// // Use Bun transaction +// tx, _ := bunDB.BeginTx(context.Background(), nil) +// token, _ := gen.ValidateAccess(tx, tokenString) +// tx.Commit() // // # Token Freshness // diff --git a/jwt/generator.go b/jwt/generator.go index a350c1e..3f382a0 100644 --- a/jwt/generator.go +++ b/jwt/generator.go @@ -15,7 +15,7 @@ type TokenGenerator struct { freshExpireAfter int64 // Token freshness expiry time in minutes trustedHost string // Trusted hostname to use for the tokens secretKey string // Secret key to use for token hashing - db *sql.DB // Database connection for token blacklisting + beginTx BeginTX // Database transaction getter for token blacklisting tableConfig TableConfig // Table configuration tableManager *TableManager // Table lifecycle manager } @@ -51,7 +51,7 @@ type GeneratorConfig struct { } // CreateGenerator creates and returns a new TokenGenerator using the provided configuration. -func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) { +func CreateGenerator(config GeneratorConfig, txGetter BeginTX) (gen *TokenGenerator, err error) { if config.AccessExpireAfter <= 0 { return nil, errors.New("accessExpireAfter must be greater than 0") } @@ -102,7 +102,7 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) { freshExpireAfter: config.FreshExpireAfter, trustedHost: config.TrustedHost, secretKey: config.SecretKey, - db: config.DB, + beginTx: txGetter, tableConfig: config.TableConfig, tableManager: tableManager, }, nil @@ -112,16 +112,21 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) { // This method should be called periodically if automatic cleanup is not enabled, // or can be called on-demand regardless of automatic cleanup settings. func (gen *TokenGenerator) Cleanup(ctx context.Context) error { - if gen.db == nil { + if gen.beginTx == nil { return errors.New("No DB provided, unable to use this function") } + tx, err := gen.beginTx(ctx) + if err != nil { + return pkgerrors.Wrap(err, "failed to begin transaction") + } + tableName := gen.tableConfig.TableName currentTime := time.Now().Unix() query := "DELETE FROM " + tableName + " WHERE exp < ?" - _, err := gen.db.ExecContext(ctx, query, currentTime) + _, err = tx.Exec(query, currentTime) if err != nil { return pkgerrors.Wrap(err, "failed to cleanup expired tokens") } diff --git a/jwt/generator_test.go b/jwt/generator_test.go index d963414..e209e5b 100644 --- a/jwt/generator_test.go +++ b/jwt/generator_test.go @@ -18,7 +18,7 @@ func TestCreateGenerator_Success_NoDB(t *testing.T) { DB: nil, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, TableConfig: DefaultTableConfig(), - }) + }, nil) require.NoError(t, err) require.NotNil(t, gen) @@ -33,6 +33,10 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) { config.AutoCreate = false config.EnableAutoCleanup = false + txGetter := func(ctx context.Context) (DBTransaction, error) { + return db.Begin() + } + gen, err := CreateGenerator(GeneratorConfig{ AccessExpireAfter: 15, RefreshExpireAfter: 60, @@ -42,7 +46,7 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) { DB: db, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, TableConfig: config, - }) + }, txGetter) require.NoError(t, err) require.NotNil(t, gen) @@ -67,6 +71,10 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) { mock.ExpectExec("CREATE OR REPLACE FUNCTION cleanup_jwtblacklist"). WillReturnResult(sqlmock.NewResult(0, 0)) + txGetter := func(ctx context.Context) (DBTransaction, error) { + return db.Begin() + } + gen, err := CreateGenerator(GeneratorConfig{ AccessExpireAfter: 15, RefreshExpireAfter: 60, @@ -76,7 +84,7 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) { DB: db, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, TableConfig: DefaultTableConfig(), - }) + }, txGetter) require.NoError(t, err) require.NotNil(t, gen) @@ -142,7 +150,7 @@ func TestCreateGenerator_InvalidInputs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := CreateGenerator(tt.config) + _, err := CreateGenerator(tt.config, nil) require.Error(t, err) }) } @@ -158,7 +166,7 @@ func TestCleanup_NoDB(t *testing.T) { DB: nil, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, TableConfig: DefaultTableConfig(), - }) + }, nil) require.NoError(t, err) err = gen.Cleanup(context.Background()) @@ -175,6 +183,10 @@ func TestCleanup_Success(t *testing.T) { config.AutoCreate = false config.EnableAutoCleanup = false + txGetter := func(ctx context.Context) (DBTransaction, error) { + return db.Begin() + } + gen, err := CreateGenerator(GeneratorConfig{ AccessExpireAfter: 15, RefreshExpireAfter: 60, @@ -184,10 +196,11 @@ func TestCleanup_Success(t *testing.T) { DB: db, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, TableConfig: config, - }) + }, txGetter) require.NoError(t, err) - // Mock DELETE query + // Mock transaction begin and DELETE query + mock.ExpectBegin() mock.ExpectExec("DELETE FROM jwtblacklist WHERE exp"). WillReturnResult(sqlmock.NewResult(0, 5)) diff --git a/jwt/revoke.go b/jwt/revoke.go index 1c9534d..d44a7f4 100644 --- a/jwt/revoke.go +++ b/jwt/revoke.go @@ -1,8 +1,6 @@ package jwt import ( - "context" - "database/sql" "fmt" "github.com/pkg/errors" @@ -11,8 +9,8 @@ import ( // revoke is an internal method that adds a token to the blacklist database. // Once revoked, the token will fail validation checks even if it hasn't expired. // This operation must be performed within a database transaction. -func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error { - if gen.db == nil { +func (gen *TokenGenerator) revoke(tx DBTransaction, t Token) error { + if gen.beginTx == nil { return errors.New("No DB provided, unable to use this function") } @@ -22,7 +20,7 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error { sub := t.GetSUB() query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName) - _, err := tx.ExecContext(context.Background(), query, jti.String(), exp, sub) + _, err := tx.Exec(query, jti.String(), exp, sub) if err != nil { return errors.Wrap(err, "tx.ExecContext") } @@ -32,8 +30,8 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error { // checkNotRevoked is an internal method that queries the blacklist to verify // a token hasn't been revoked. Returns true if the token is valid (not blacklisted), // false if it has been revoked. This operation must be performed within a database transaction. -func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) { - if gen.db == nil { +func (gen *TokenGenerator) checkNotRevoked(tx DBTransaction, t Token) (bool, error) { + if gen.beginTx == nil { return false, errors.New("No DB provided, unable to use this function") } @@ -41,7 +39,7 @@ func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) { jti := t.GetJTI() query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName) - rows, err := tx.QueryContext(context.Background(), query, jti.String()) + rows, err := tx.Query(query, jti.String()) if err != nil { return false, errors.Wrap(err, "tx.QueryContext") } diff --git a/jwt/revoke_test.go b/jwt/revoke_test.go index 1a32c25..273af5b 100644 --- a/jwt/revoke_test.go +++ b/jwt/revoke_test.go @@ -1,6 +1,7 @@ package jwt import ( + "context" "database/sql" "testing" "time" @@ -20,13 +21,13 @@ func newGeneratorWithNoDB(t *testing.T) *TokenGenerator { DB: nil, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, TableConfig: DefaultTableConfig(), - }) + }, nil) require.NoError(t, err) return gen } -func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) { +func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, *sql.DB, sqlmock.Sqlmock, func()) { db, mock, err := sqlmock.New() require.NoError(t, err) @@ -34,6 +35,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun config.AutoCreate = false config.EnableAutoCleanup = false + txGetter := func(ctx context.Context) (DBTransaction, error) { + return db.Begin() + } + gen, err := CreateGenerator(GeneratorConfig{ AccessExpireAfter: 15, RefreshExpireAfter: 60, @@ -43,10 +48,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun DB: db, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, TableConfig: config, - }) + }, txGetter) require.NoError(t, err) - return gen, mock, func() { db.Close() } + return gen, db, mock, func() { db.Close() } } func TestNoDBFail(t *testing.T) { @@ -73,7 +78,7 @@ func TestNoDBFail(t *testing.T) { } func TestRevokeAndCheckNotRevoked(t *testing.T) { - gen, mock, cleanup := newGeneratorWithMockDB(t) + gen, db, mock, cleanup := newGeneratorWithMockDB(t) defer cleanup() jti := uuid.New() @@ -97,7 +102,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1)) mock.ExpectCommit() - tx, err := gen.db.Begin() + tx, err := db.Begin() defer tx.Rollback() require.NoError(t, err) diff --git a/jwt/tokengen_test.go b/jwt/tokengen_test.go index 3b3aec8..209a32b 100644 --- a/jwt/tokengen_test.go +++ b/jwt/tokengen_test.go @@ -16,7 +16,7 @@ func newTestGenerator(t *testing.T) *TokenGenerator { DB: nil, DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"}, TableConfig: DefaultTableConfig(), - }) + }, nil) require.NoError(t, err) return gen } diff --git a/jwt/tokens.go b/jwt/tokens.go index 69898ce..d99199f 100644 --- a/jwt/tokens.go +++ b/jwt/tokens.go @@ -1,8 +1,6 @@ package jwt import ( - "database/sql" - "github.com/google/uuid" ) @@ -23,12 +21,14 @@ type Token interface { // Revoke adds this token to the blacklist, preventing future use. // Must be called within a database transaction context. - Revoke(*sql.Tx) error + // Accepts any transaction type that implements DBTransaction interface. + Revoke(DBTransaction) error // CheckNotRevoked verifies that this token has not been blacklisted. // Returns true if the token is valid, false if revoked. // Must be called within a database transaction context. - CheckNotRevoked(*sql.Tx) (bool, error) + // Accepts any transaction type that implements DBTransaction interface. + CheckNotRevoked(DBTransaction) (bool, error) } // AccessToken represents a JWT access token with all its claims. @@ -84,15 +84,15 @@ func (a AccessToken) GetScope() string { func (r RefreshToken) GetScope() string { return r.Scope } -func (a AccessToken) Revoke(tx *sql.Tx) error { +func (a AccessToken) Revoke(tx DBTransaction) error { return a.gen.revoke(tx, a) } -func (r RefreshToken) Revoke(tx *sql.Tx) error { +func (r RefreshToken) Revoke(tx DBTransaction) error { return r.gen.revoke(tx, r) } -func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) { +func (a AccessToken) CheckNotRevoked(tx DBTransaction) (bool, error) { return a.gen.checkNotRevoked(tx, a) } -func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) { +func (r RefreshToken) CheckNotRevoked(tx DBTransaction) (bool, error) { return r.gen.checkNotRevoked(tx, r) } diff --git a/jwt/validate.go b/jwt/validate.go index 56b18b0..178cec7 100644 --- a/jwt/validate.go +++ b/jwt/validate.go @@ -1,8 +1,6 @@ package jwt import ( - "database/sql" - "github.com/pkg/errors" ) @@ -20,14 +18,15 @@ import ( // revocation check is skipped. // // Parameters: -// - tx: Database transaction for checking token revocation status +// - tx: Database transaction for checking token revocation status. +// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface. // - tokenString: The JWT token string to validate // // Returns: // - *AccessToken: The validated token with all claims, or nil if validation fails // - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.) func (gen *TokenGenerator) ValidateAccess( - tx *sql.Tx, + tx DBTransaction, tokenString string, ) (*AccessToken, error) { if tokenString == "" { @@ -86,10 +85,10 @@ func (gen *TokenGenerator) ValidateAccess( } valid, err := token.CheckNotRevoked(tx) - if err != nil && gen.db != nil { + if err != nil && gen.beginTx != nil { return nil, errors.Wrap(err, "token.CheckNotRevoked") } - if !valid && gen.db != nil { + if !valid && gen.beginTx != nil { return nil, errors.New("Token has been revoked") } return token, nil @@ -109,14 +108,15 @@ func (gen *TokenGenerator) ValidateAccess( // revocation check is skipped. // // Parameters: -// - tx: Database transaction for checking token revocation status +// - tx: Database transaction for checking token revocation status. +// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface. // - tokenString: The JWT token string to validate // // Returns: // - *RefreshToken: The validated token with all claims, or nil if validation fails // - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.) func (gen *TokenGenerator) ValidateRefresh( - tx *sql.Tx, + tx DBTransaction, tokenString string, ) (*RefreshToken, error) { if tokenString == "" { @@ -170,10 +170,10 @@ func (gen *TokenGenerator) ValidateRefresh( } valid, err := token.CheckNotRevoked(tx) - if err != nil && gen.db != nil { + if err != nil && gen.beginTx != nil { return nil, errors.Wrap(err, "token.CheckNotRevoked") } - if !valid && gen.db != nil { + if !valid && gen.beginTx != nil { return nil, errors.New("Token has been revoked") } return token, nil diff --git a/jwt/validate_test.go b/jwt/validate_test.go index 5eab343..338e78b 100644 --- a/jwt/validate_test.go +++ b/jwt/validate_test.go @@ -17,7 +17,7 @@ func expectNotRevoked(mock sqlmock.Sqlmock, jti any) { } func TestValidateAccess_Success(t *testing.T) { - gen, mock, cleanup := newGeneratorWithMockDB(t) + gen, db, mock, cleanup := newGeneratorWithMockDB(t) defer cleanup() tokenStr, _, err := gen.NewAccess(42, true, false) @@ -26,7 +26,7 @@ func TestValidateAccess_Success(t *testing.T) { // We don't know the JTI beforehand; match any arg expectNotRevoked(mock, sqlmock.AnyArg()) - tx, err := gen.db.Begin() + tx, err := db.Begin() require.NoError(t, err) defer tx.Rollback() @@ -53,7 +53,7 @@ func TestValidateAccess_NoDB(t *testing.T) { } func TestValidateRefresh_Success(t *testing.T) { - gen, mock, cleanup := newGeneratorWithMockDB(t) + gen, db, mock, cleanup := newGeneratorWithMockDB(t) defer cleanup() tokenStr, _, err := gen.NewRefresh(42, false) @@ -61,7 +61,7 @@ func TestValidateRefresh_Success(t *testing.T) { expectNotRevoked(mock, sqlmock.AnyArg()) - tx, err := gen.db.Begin() + tx, err := db.Begin() require.NoError(t, err) defer tx.Rollback()