refactor to improve database operability
This commit is contained in:
1
jwt/.gitignore
vendored
Normal file
1
jwt/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
.claude/
|
||||||
21
jwt/LICENSE
Normal file
21
jwt/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2026 haelnorr
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -28,6 +28,7 @@ go get git.haelnorr.com/h/golib/jwt
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"git.haelnorr.com/h/golib/jwt"
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
@@ -38,8 +39,10 @@ func main() {
|
|||||||
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
|
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Wrap database connection
|
// Create a transaction getter function
|
||||||
dbConn := jwt.NewDBConnection(db)
|
txGetter := func(ctx context.Context) (jwt.DBTransaction, error) {
|
||||||
|
return db.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
// Create token generator
|
// Create token generator
|
||||||
gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
|
gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||||
@@ -48,13 +51,13 @@ func main() {
|
|||||||
FreshExpireAfter: 5, // 5 minutes
|
FreshExpireAfter: 5, // 5 minutes
|
||||||
TrustedHost: "example.com",
|
TrustedHost: "example.com",
|
||||||
SecretKey: "your-secret-key",
|
SecretKey: "your-secret-key",
|
||||||
DBConn: dbConn,
|
DB: db,
|
||||||
DBType: jwt.DatabaseType{
|
DBType: jwt.DatabaseType{
|
||||||
Type: jwt.DatabasePostgreSQL,
|
Type: jwt.DatabasePostgreSQL,
|
||||||
Version: "15",
|
Version: "15",
|
||||||
},
|
},
|
||||||
TableConfig: jwt.DefaultTableConfig(),
|
TableConfig: jwt.DefaultTableConfig(),
|
||||||
})
|
}, txGetter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -64,7 +67,7 @@ func main() {
|
|||||||
refreshToken, _, _ := gen.NewRefresh(42, false)
|
refreshToken, _, _ := gen.NewRefresh(42, false)
|
||||||
|
|
||||||
// Validate token
|
// Validate token
|
||||||
tx, _ := dbConn.BeginTx(context.Background(), nil)
|
tx, _ := db.Begin()
|
||||||
token, _ := gen.ValidateAccess(tx, accessToken)
|
token, _ := gen.ValidateAccess(tx, accessToken)
|
||||||
|
|
||||||
// Revoke token
|
// Revoke token
|
||||||
|
|||||||
@@ -1,5 +1,23 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DBTransaction represents a database transaction that can execute queries.
|
||||||
|
// This interface is compatible with *sql.Tx and can be implemented by ORM transactions
|
||||||
|
// from libraries like GORM (gormDB.Begin()), Bun (bunDB.Begin()), etc.
|
||||||
|
type DBTransaction interface {
|
||||||
|
Exec(query string, args ...any) (sql.Result, error)
|
||||||
|
Query(query string, args ...any) (*sql.Rows, error)
|
||||||
|
Commit() error
|
||||||
|
Rollback() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginTX represents a wrapper function that is used to start a transaction with any dependencies injected
|
||||||
|
type BeginTX func(ctx context.Context) (DBTransaction, error)
|
||||||
|
|
||||||
// DatabaseType specifies the database system and version being used.
|
// DatabaseType specifies the database system and version being used.
|
||||||
type DatabaseType struct {
|
type DatabaseType struct {
|
||||||
Type string // Database type: "postgres", "mysql", "sqlite", "mariadb"
|
Type string // Database type: "postgres", "mysql", "sqlite", "mariadb"
|
||||||
|
|||||||
21
jwt/doc.go
21
jwt/doc.go
@@ -39,7 +39,7 @@
|
|||||||
// accessToken, accessExp, err := gen.NewAccess(userID, true, false)
|
// accessToken, accessExp, err := gen.NewAccess(userID, true, false)
|
||||||
// refreshToken, refreshExp, err := gen.NewRefresh(userID, false)
|
// refreshToken, refreshExp, err := gen.NewRefresh(userID, false)
|
||||||
//
|
//
|
||||||
// Validate tokens:
|
// Validate tokens (using standard library):
|
||||||
//
|
//
|
||||||
// tx, _ := db.Begin()
|
// tx, _ := db.Begin()
|
||||||
// token, err := gen.ValidateAccess(tx, accessToken)
|
// token, err := gen.ValidateAccess(tx, accessToken)
|
||||||
@@ -48,6 +48,13 @@
|
|||||||
// }
|
// }
|
||||||
// tx.Commit()
|
// tx.Commit()
|
||||||
//
|
//
|
||||||
|
// Validate tokens (using ORM like GORM):
|
||||||
|
//
|
||||||
|
// tx := gormDB.Begin()
|
||||||
|
// token, err := gen.ValidateAccess(tx.Statement.ConnPool, accessToken)
|
||||||
|
// // or with Bun: gen.ValidateAccess(bunDB.BeginTx(ctx, nil), accessToken)
|
||||||
|
// tx.Commit()
|
||||||
|
//
|
||||||
// Revoke tokens:
|
// Revoke tokens:
|
||||||
//
|
//
|
||||||
// tx, _ := db.Begin()
|
// tx, _ := db.Begin()
|
||||||
@@ -84,21 +91,29 @@
|
|||||||
// The package works with popular ORMs by using raw SQL queries. For GORM and Bun,
|
// The package works with popular ORMs by using raw SQL queries. For GORM and Bun,
|
||||||
// wrap the underlying *sql.DB with NewDBConnection() when creating the generator:
|
// wrap the underlying *sql.DB with NewDBConnection() when creating the generator:
|
||||||
//
|
//
|
||||||
// // GORM example
|
// // GORM example - can use GORM transactions directly
|
||||||
// gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
// gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||||
// sqlDB, _ := gormDB.DB()
|
// sqlDB, _ := gormDB.DB()
|
||||||
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
|
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||||
// // ... config ...
|
// // ... config ...
|
||||||
// DB: sqlDB,
|
// DB: sqlDB,
|
||||||
// })
|
// })
|
||||||
|
// // Use GORM transaction
|
||||||
|
// tx := gormDB.Begin()
|
||||||
|
// token, _ := gen.ValidateAccess(tx.Statement.ConnPool, tokenString)
|
||||||
|
// tx.Commit()
|
||||||
//
|
//
|
||||||
// // Bun example
|
// // Bun example - can use Bun transactions directly
|
||||||
// sqlDB, _ := sql.Open("postgres", dsn)
|
// sqlDB, _ := sql.Open("postgres", dsn)
|
||||||
// bunDB := bun.NewDB(sqlDB, pgdialect.New())
|
// bunDB := bun.NewDB(sqlDB, pgdialect.New())
|
||||||
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
|
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||||
// // ... config ...
|
// // ... config ...
|
||||||
// DB: sqlDB,
|
// DB: sqlDB,
|
||||||
// })
|
// })
|
||||||
|
// // Use Bun transaction
|
||||||
|
// tx, _ := bunDB.BeginTx(context.Background(), nil)
|
||||||
|
// token, _ := gen.ValidateAccess(tx, tokenString)
|
||||||
|
// tx.Commit()
|
||||||
//
|
//
|
||||||
// # Token Freshness
|
// # Token Freshness
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ type TokenGenerator struct {
|
|||||||
freshExpireAfter int64 // Token freshness expiry time in minutes
|
freshExpireAfter int64 // Token freshness expiry time in minutes
|
||||||
trustedHost string // Trusted hostname to use for the tokens
|
trustedHost string // Trusted hostname to use for the tokens
|
||||||
secretKey string // Secret key to use for token hashing
|
secretKey string // Secret key to use for token hashing
|
||||||
db *sql.DB // Database connection for token blacklisting
|
beginTx BeginTX // Database transaction getter for token blacklisting
|
||||||
tableConfig TableConfig // Table configuration
|
tableConfig TableConfig // Table configuration
|
||||||
tableManager *TableManager // Table lifecycle manager
|
tableManager *TableManager // Table lifecycle manager
|
||||||
}
|
}
|
||||||
@@ -51,7 +51,7 @@ type GeneratorConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
|
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
|
||||||
func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
|
func CreateGenerator(config GeneratorConfig, txGetter BeginTX) (gen *TokenGenerator, err error) {
|
||||||
if config.AccessExpireAfter <= 0 {
|
if config.AccessExpireAfter <= 0 {
|
||||||
return nil, errors.New("accessExpireAfter must be greater than 0")
|
return nil, errors.New("accessExpireAfter must be greater than 0")
|
||||||
}
|
}
|
||||||
@@ -102,7 +102,7 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
|
|||||||
freshExpireAfter: config.FreshExpireAfter,
|
freshExpireAfter: config.FreshExpireAfter,
|
||||||
trustedHost: config.TrustedHost,
|
trustedHost: config.TrustedHost,
|
||||||
secretKey: config.SecretKey,
|
secretKey: config.SecretKey,
|
||||||
db: config.DB,
|
beginTx: txGetter,
|
||||||
tableConfig: config.TableConfig,
|
tableConfig: config.TableConfig,
|
||||||
tableManager: tableManager,
|
tableManager: tableManager,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -112,16 +112,21 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
|
|||||||
// This method should be called periodically if automatic cleanup is not enabled,
|
// This method should be called periodically if automatic cleanup is not enabled,
|
||||||
// or can be called on-demand regardless of automatic cleanup settings.
|
// or can be called on-demand regardless of automatic cleanup settings.
|
||||||
func (gen *TokenGenerator) Cleanup(ctx context.Context) error {
|
func (gen *TokenGenerator) Cleanup(ctx context.Context) error {
|
||||||
if gen.db == nil {
|
if gen.beginTx == nil {
|
||||||
return errors.New("No DB provided, unable to use this function")
|
return errors.New("No DB provided, unable to use this function")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tx, err := gen.beginTx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return pkgerrors.Wrap(err, "failed to begin transaction")
|
||||||
|
}
|
||||||
|
|
||||||
tableName := gen.tableConfig.TableName
|
tableName := gen.tableConfig.TableName
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
query := "DELETE FROM " + tableName + " WHERE exp < ?"
|
query := "DELETE FROM " + tableName + " WHERE exp < ?"
|
||||||
|
|
||||||
_, err := gen.db.ExecContext(ctx, query, currentTime)
|
_, err = tx.Exec(query, currentTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pkgerrors.Wrap(err, "failed to cleanup expired tokens")
|
return pkgerrors.Wrap(err, "failed to cleanup expired tokens")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func TestCreateGenerator_Success_NoDB(t *testing.T) {
|
|||||||
DB: nil,
|
DB: nil,
|
||||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||||
TableConfig: DefaultTableConfig(),
|
TableConfig: DefaultTableConfig(),
|
||||||
})
|
}, nil)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, gen)
|
require.NotNil(t, gen)
|
||||||
@@ -33,6 +33,10 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
|
|||||||
config.AutoCreate = false
|
config.AutoCreate = false
|
||||||
config.EnableAutoCleanup = false
|
config.EnableAutoCleanup = false
|
||||||
|
|
||||||
|
txGetter := func(ctx context.Context) (DBTransaction, error) {
|
||||||
|
return db.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
gen, err := CreateGenerator(GeneratorConfig{
|
gen, err := CreateGenerator(GeneratorConfig{
|
||||||
AccessExpireAfter: 15,
|
AccessExpireAfter: 15,
|
||||||
RefreshExpireAfter: 60,
|
RefreshExpireAfter: 60,
|
||||||
@@ -42,7 +46,7 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
|
|||||||
DB: db,
|
DB: db,
|
||||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||||
TableConfig: config,
|
TableConfig: config,
|
||||||
})
|
}, txGetter)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, gen)
|
require.NotNil(t, gen)
|
||||||
@@ -67,6 +71,10 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
|
|||||||
mock.ExpectExec("CREATE OR REPLACE FUNCTION cleanup_jwtblacklist").
|
mock.ExpectExec("CREATE OR REPLACE FUNCTION cleanup_jwtblacklist").
|
||||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||||
|
|
||||||
|
txGetter := func(ctx context.Context) (DBTransaction, error) {
|
||||||
|
return db.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
gen, err := CreateGenerator(GeneratorConfig{
|
gen, err := CreateGenerator(GeneratorConfig{
|
||||||
AccessExpireAfter: 15,
|
AccessExpireAfter: 15,
|
||||||
RefreshExpireAfter: 60,
|
RefreshExpireAfter: 60,
|
||||||
@@ -76,7 +84,7 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
|
|||||||
DB: db,
|
DB: db,
|
||||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||||
TableConfig: DefaultTableConfig(),
|
TableConfig: DefaultTableConfig(),
|
||||||
})
|
}, txGetter)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, gen)
|
require.NotNil(t, gen)
|
||||||
@@ -142,7 +150,7 @@ func TestCreateGenerator_InvalidInputs(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := CreateGenerator(tt.config)
|
_, err := CreateGenerator(tt.config, nil)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -158,7 +166,7 @@ func TestCleanup_NoDB(t *testing.T) {
|
|||||||
DB: nil,
|
DB: nil,
|
||||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||||
TableConfig: DefaultTableConfig(),
|
TableConfig: DefaultTableConfig(),
|
||||||
})
|
}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = gen.Cleanup(context.Background())
|
err = gen.Cleanup(context.Background())
|
||||||
@@ -175,6 +183,10 @@ func TestCleanup_Success(t *testing.T) {
|
|||||||
config.AutoCreate = false
|
config.AutoCreate = false
|
||||||
config.EnableAutoCleanup = false
|
config.EnableAutoCleanup = false
|
||||||
|
|
||||||
|
txGetter := func(ctx context.Context) (DBTransaction, error) {
|
||||||
|
return db.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
gen, err := CreateGenerator(GeneratorConfig{
|
gen, err := CreateGenerator(GeneratorConfig{
|
||||||
AccessExpireAfter: 15,
|
AccessExpireAfter: 15,
|
||||||
RefreshExpireAfter: 60,
|
RefreshExpireAfter: 60,
|
||||||
@@ -184,10 +196,11 @@ func TestCleanup_Success(t *testing.T) {
|
|||||||
DB: db,
|
DB: db,
|
||||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||||
TableConfig: config,
|
TableConfig: config,
|
||||||
})
|
}, txGetter)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Mock DELETE query
|
// Mock transaction begin and DELETE query
|
||||||
|
mock.ExpectBegin()
|
||||||
mock.ExpectExec("DELETE FROM jwtblacklist WHERE exp").
|
mock.ExpectExec("DELETE FROM jwtblacklist WHERE exp").
|
||||||
WillReturnResult(sqlmock.NewResult(0, 5))
|
WillReturnResult(sqlmock.NewResult(0, 5))
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -11,8 +9,8 @@ import (
|
|||||||
// revoke is an internal method that adds a token to the blacklist database.
|
// revoke is an internal method that adds a token to the blacklist database.
|
||||||
// Once revoked, the token will fail validation checks even if it hasn't expired.
|
// Once revoked, the token will fail validation checks even if it hasn't expired.
|
||||||
// This operation must be performed within a database transaction.
|
// This operation must be performed within a database transaction.
|
||||||
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
func (gen *TokenGenerator) revoke(tx DBTransaction, t Token) error {
|
||||||
if gen.db == nil {
|
if gen.beginTx == nil {
|
||||||
return errors.New("No DB provided, unable to use this function")
|
return errors.New("No DB provided, unable to use this function")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,7 +20,7 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
|||||||
sub := t.GetSUB()
|
sub := t.GetSUB()
|
||||||
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName)
|
query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName)
|
||||||
_, err := tx.ExecContext(context.Background(), query, jti.String(), exp, sub)
|
_, err := tx.Exec(query, jti.String(), exp, sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "tx.ExecContext")
|
return errors.Wrap(err, "tx.ExecContext")
|
||||||
}
|
}
|
||||||
@@ -32,8 +30,8 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
|||||||
// checkNotRevoked is an internal method that queries the blacklist to verify
|
// checkNotRevoked is an internal method that queries the blacklist to verify
|
||||||
// a token hasn't been revoked. Returns true if the token is valid (not blacklisted),
|
// a token hasn't been revoked. Returns true if the token is valid (not blacklisted),
|
||||||
// false if it has been revoked. This operation must be performed within a database transaction.
|
// false if it has been revoked. This operation must be performed within a database transaction.
|
||||||
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
func (gen *TokenGenerator) checkNotRevoked(tx DBTransaction, t Token) (bool, error) {
|
||||||
if gen.db == nil {
|
if gen.beginTx == nil {
|
||||||
return false, errors.New("No DB provided, unable to use this function")
|
return false, errors.New("No DB provided, unable to use this function")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +39,7 @@ func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
|||||||
jti := t.GetJTI()
|
jti := t.GetJTI()
|
||||||
|
|
||||||
query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName)
|
query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName)
|
||||||
rows, err := tx.QueryContext(context.Background(), query, jti.String())
|
rows, err := tx.Query(query, jti.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "tx.QueryContext")
|
return false, errors.Wrap(err, "tx.QueryContext")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -20,13 +21,13 @@ func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
|
|||||||
DB: nil,
|
DB: nil,
|
||||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||||
TableConfig: DefaultTableConfig(),
|
TableConfig: DefaultTableConfig(),
|
||||||
})
|
}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return gen
|
return gen
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
|
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, *sql.DB, sqlmock.Sqlmock, func()) {
|
||||||
db, mock, err := sqlmock.New()
|
db, mock, err := sqlmock.New()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -34,6 +35,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun
|
|||||||
config.AutoCreate = false
|
config.AutoCreate = false
|
||||||
config.EnableAutoCleanup = false
|
config.EnableAutoCleanup = false
|
||||||
|
|
||||||
|
txGetter := func(ctx context.Context) (DBTransaction, error) {
|
||||||
|
return db.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
gen, err := CreateGenerator(GeneratorConfig{
|
gen, err := CreateGenerator(GeneratorConfig{
|
||||||
AccessExpireAfter: 15,
|
AccessExpireAfter: 15,
|
||||||
RefreshExpireAfter: 60,
|
RefreshExpireAfter: 60,
|
||||||
@@ -43,10 +48,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun
|
|||||||
DB: db,
|
DB: db,
|
||||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||||
TableConfig: config,
|
TableConfig: config,
|
||||||
})
|
}, txGetter)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return gen, mock, func() { db.Close() }
|
return gen, db, mock, func() { db.Close() }
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNoDBFail(t *testing.T) {
|
func TestNoDBFail(t *testing.T) {
|
||||||
@@ -73,7 +78,7 @@ func TestNoDBFail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
||||||
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
jti := uuid.New()
|
jti := uuid.New()
|
||||||
@@ -97,7 +102,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
|||||||
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||||
mock.ExpectCommit()
|
mock.ExpectCommit()
|
||||||
|
|
||||||
tx, err := gen.db.Begin()
|
tx, err := db.Begin()
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func newTestGenerator(t *testing.T) *TokenGenerator {
|
|||||||
DB: nil,
|
DB: nil,
|
||||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||||
TableConfig: DefaultTableConfig(),
|
TableConfig: DefaultTableConfig(),
|
||||||
})
|
}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return gen
|
return gen
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,12 +21,14 @@ type Token interface {
|
|||||||
|
|
||||||
// Revoke adds this token to the blacklist, preventing future use.
|
// Revoke adds this token to the blacklist, preventing future use.
|
||||||
// Must be called within a database transaction context.
|
// Must be called within a database transaction context.
|
||||||
Revoke(*sql.Tx) error
|
// Accepts any transaction type that implements DBTransaction interface.
|
||||||
|
Revoke(DBTransaction) error
|
||||||
|
|
||||||
// CheckNotRevoked verifies that this token has not been blacklisted.
|
// CheckNotRevoked verifies that this token has not been blacklisted.
|
||||||
// Returns true if the token is valid, false if revoked.
|
// Returns true if the token is valid, false if revoked.
|
||||||
// Must be called within a database transaction context.
|
// Must be called within a database transaction context.
|
||||||
CheckNotRevoked(*sql.Tx) (bool, error)
|
// Accepts any transaction type that implements DBTransaction interface.
|
||||||
|
CheckNotRevoked(DBTransaction) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccessToken represents a JWT access token with all its claims.
|
// AccessToken represents a JWT access token with all its claims.
|
||||||
@@ -84,15 +84,15 @@ func (a AccessToken) GetScope() string {
|
|||||||
func (r RefreshToken) GetScope() string {
|
func (r RefreshToken) GetScope() string {
|
||||||
return r.Scope
|
return r.Scope
|
||||||
}
|
}
|
||||||
func (a AccessToken) Revoke(tx *sql.Tx) error {
|
func (a AccessToken) Revoke(tx DBTransaction) error {
|
||||||
return a.gen.revoke(tx, a)
|
return a.gen.revoke(tx, a)
|
||||||
}
|
}
|
||||||
func (r RefreshToken) Revoke(tx *sql.Tx) error {
|
func (r RefreshToken) Revoke(tx DBTransaction) error {
|
||||||
return r.gen.revoke(tx, r)
|
return r.gen.revoke(tx, r)
|
||||||
}
|
}
|
||||||
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
func (a AccessToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
|
||||||
return a.gen.checkNotRevoked(tx, a)
|
return a.gen.checkNotRevoked(tx, a)
|
||||||
}
|
}
|
||||||
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
func (r RefreshToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
|
||||||
return r.gen.checkNotRevoked(tx, r)
|
return r.gen.checkNotRevoked(tx, r)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,14 +18,15 @@ import (
|
|||||||
// revocation check is skipped.
|
// revocation check is skipped.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - tx: Database transaction for checking token revocation status
|
// - tx: Database transaction for checking token revocation status.
|
||||||
|
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
|
||||||
// - tokenString: The JWT token string to validate
|
// - tokenString: The JWT token string to validate
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *AccessToken: The validated token with all claims, or nil if validation fails
|
// - *AccessToken: The validated token with all claims, or nil if validation fails
|
||||||
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
|
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
|
||||||
func (gen *TokenGenerator) ValidateAccess(
|
func (gen *TokenGenerator) ValidateAccess(
|
||||||
tx *sql.Tx,
|
tx DBTransaction,
|
||||||
tokenString string,
|
tokenString string,
|
||||||
) (*AccessToken, error) {
|
) (*AccessToken, error) {
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
@@ -86,10 +85,10 @@ func (gen *TokenGenerator) ValidateAccess(
|
|||||||
}
|
}
|
||||||
|
|
||||||
valid, err := token.CheckNotRevoked(tx)
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
if err != nil && gen.db != nil {
|
if err != nil && gen.beginTx != nil {
|
||||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||||
}
|
}
|
||||||
if !valid && gen.db != nil {
|
if !valid && gen.beginTx != nil {
|
||||||
return nil, errors.New("Token has been revoked")
|
return nil, errors.New("Token has been revoked")
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
@@ -109,14 +108,15 @@ func (gen *TokenGenerator) ValidateAccess(
|
|||||||
// revocation check is skipped.
|
// revocation check is skipped.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - tx: Database transaction for checking token revocation status
|
// - tx: Database transaction for checking token revocation status.
|
||||||
|
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
|
||||||
// - tokenString: The JWT token string to validate
|
// - tokenString: The JWT token string to validate
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *RefreshToken: The validated token with all claims, or nil if validation fails
|
// - *RefreshToken: The validated token with all claims, or nil if validation fails
|
||||||
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
|
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
|
||||||
func (gen *TokenGenerator) ValidateRefresh(
|
func (gen *TokenGenerator) ValidateRefresh(
|
||||||
tx *sql.Tx,
|
tx DBTransaction,
|
||||||
tokenString string,
|
tokenString string,
|
||||||
) (*RefreshToken, error) {
|
) (*RefreshToken, error) {
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
@@ -170,10 +170,10 @@ func (gen *TokenGenerator) ValidateRefresh(
|
|||||||
}
|
}
|
||||||
|
|
||||||
valid, err := token.CheckNotRevoked(tx)
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
if err != nil && gen.db != nil {
|
if err != nil && gen.beginTx != nil {
|
||||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||||
}
|
}
|
||||||
if !valid && gen.db != nil {
|
if !valid && gen.beginTx != nil {
|
||||||
return nil, errors.New("Token has been revoked")
|
return nil, errors.New("Token has been revoked")
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateAccess_Success(t *testing.T) {
|
func TestValidateAccess_Success(t *testing.T) {
|
||||||
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
tokenStr, _, err := gen.NewAccess(42, true, false)
|
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||||
@@ -26,7 +26,7 @@ func TestValidateAccess_Success(t *testing.T) {
|
|||||||
// We don't know the JTI beforehand; match any arg
|
// We don't know the JTI beforehand; match any arg
|
||||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||||
|
|
||||||
tx, err := gen.db.Begin()
|
tx, err := db.Begin()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ func TestValidateAccess_NoDB(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateRefresh_Success(t *testing.T) {
|
func TestValidateRefresh_Success(t *testing.T) {
|
||||||
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
tokenStr, _, err := gen.NewRefresh(42, false)
|
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||||
@@ -61,7 +61,7 @@ func TestValidateRefresh_Success(t *testing.T) {
|
|||||||
|
|
||||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||||
|
|
||||||
tx, err := gen.db.Begin()
|
tx, err := db.Begin()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user