refactor to improve database operability
This commit is contained in:
1
jwt/.gitignore
vendored
Normal file
1
jwt/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.claude/
|
||||
21
jwt/LICENSE
Normal file
21
jwt/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 haelnorr
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -28,6 +28,7 @@ go get git.haelnorr.com/h/golib/jwt
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
_ "github.com/lib/pq"
|
||||
@@ -38,8 +39,10 @@ func main() {
|
||||
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
|
||||
defer db.Close()
|
||||
|
||||
// Wrap database connection
|
||||
dbConn := jwt.NewDBConnection(db)
|
||||
// Create a transaction getter function
|
||||
txGetter := func(ctx context.Context) (jwt.DBTransaction, error) {
|
||||
return db.Begin()
|
||||
}
|
||||
|
||||
// Create token generator
|
||||
gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||
@@ -48,13 +51,13 @@ func main() {
|
||||
FreshExpireAfter: 5, // 5 minutes
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "your-secret-key",
|
||||
DBConn: dbConn,
|
||||
DB: db,
|
||||
DBType: jwt.DatabaseType{
|
||||
Type: jwt.DatabasePostgreSQL,
|
||||
Version: "15",
|
||||
},
|
||||
TableConfig: jwt.DefaultTableConfig(),
|
||||
})
|
||||
}, txGetter)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -64,7 +67,7 @@ func main() {
|
||||
refreshToken, _, _ := gen.NewRefresh(42, false)
|
||||
|
||||
// Validate token
|
||||
tx, _ := dbConn.BeginTx(context.Background(), nil)
|
||||
tx, _ := db.Begin()
|
||||
token, _ := gen.ValidateAccess(tx, accessToken)
|
||||
|
||||
// Revoke token
|
||||
|
||||
@@ -1,5 +1,23 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// DBTransaction represents a database transaction that can execute queries.
|
||||
// This interface is compatible with *sql.Tx and can be implemented by ORM transactions
|
||||
// from libraries like GORM (gormDB.Begin()), Bun (bunDB.Begin()), etc.
|
||||
type DBTransaction interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// BeginTX represents a wrapper function that is used to start a transaction with any dependencies injected
|
||||
type BeginTX func(ctx context.Context) (DBTransaction, error)
|
||||
|
||||
// DatabaseType specifies the database system and version being used.
|
||||
type DatabaseType struct {
|
||||
Type string // Database type: "postgres", "mysql", "sqlite", "mariadb"
|
||||
|
||||
21
jwt/doc.go
21
jwt/doc.go
@@ -39,7 +39,7 @@
|
||||
// accessToken, accessExp, err := gen.NewAccess(userID, true, false)
|
||||
// refreshToken, refreshExp, err := gen.NewRefresh(userID, false)
|
||||
//
|
||||
// Validate tokens:
|
||||
// Validate tokens (using standard library):
|
||||
//
|
||||
// tx, _ := db.Begin()
|
||||
// token, err := gen.ValidateAccess(tx, accessToken)
|
||||
@@ -48,6 +48,13 @@
|
||||
// }
|
||||
// tx.Commit()
|
||||
//
|
||||
// Validate tokens (using ORM like GORM):
|
||||
//
|
||||
// tx := gormDB.Begin()
|
||||
// token, err := gen.ValidateAccess(tx.Statement.ConnPool, accessToken)
|
||||
// // or with Bun: gen.ValidateAccess(bunDB.BeginTx(ctx, nil), accessToken)
|
||||
// tx.Commit()
|
||||
//
|
||||
// Revoke tokens:
|
||||
//
|
||||
// tx, _ := db.Begin()
|
||||
@@ -84,21 +91,29 @@
|
||||
// The package works with popular ORMs by using raw SQL queries. For GORM and Bun,
|
||||
// wrap the underlying *sql.DB with NewDBConnection() when creating the generator:
|
||||
//
|
||||
// // GORM example
|
||||
// // GORM example - can use GORM transactions directly
|
||||
// gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
// sqlDB, _ := gormDB.DB()
|
||||
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||
// // ... config ...
|
||||
// DB: sqlDB,
|
||||
// })
|
||||
// // Use GORM transaction
|
||||
// tx := gormDB.Begin()
|
||||
// token, _ := gen.ValidateAccess(tx.Statement.ConnPool, tokenString)
|
||||
// tx.Commit()
|
||||
//
|
||||
// // Bun example
|
||||
// // Bun example - can use Bun transactions directly
|
||||
// sqlDB, _ := sql.Open("postgres", dsn)
|
||||
// bunDB := bun.NewDB(sqlDB, pgdialect.New())
|
||||
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||
// // ... config ...
|
||||
// DB: sqlDB,
|
||||
// })
|
||||
// // Use Bun transaction
|
||||
// tx, _ := bunDB.BeginTx(context.Background(), nil)
|
||||
// token, _ := gen.ValidateAccess(tx, tokenString)
|
||||
// tx.Commit()
|
||||
//
|
||||
// # Token Freshness
|
||||
//
|
||||
|
||||
@@ -15,7 +15,7 @@ type TokenGenerator struct {
|
||||
freshExpireAfter int64 // Token freshness expiry time in minutes
|
||||
trustedHost string // Trusted hostname to use for the tokens
|
||||
secretKey string // Secret key to use for token hashing
|
||||
db *sql.DB // Database connection for token blacklisting
|
||||
beginTx BeginTX // Database transaction getter for token blacklisting
|
||||
tableConfig TableConfig // Table configuration
|
||||
tableManager *TableManager // Table lifecycle manager
|
||||
}
|
||||
@@ -51,7 +51,7 @@ type GeneratorConfig struct {
|
||||
}
|
||||
|
||||
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
|
||||
func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
|
||||
func CreateGenerator(config GeneratorConfig, txGetter BeginTX) (gen *TokenGenerator, err error) {
|
||||
if config.AccessExpireAfter <= 0 {
|
||||
return nil, errors.New("accessExpireAfter must be greater than 0")
|
||||
}
|
||||
@@ -102,7 +102,7 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
|
||||
freshExpireAfter: config.FreshExpireAfter,
|
||||
trustedHost: config.TrustedHost,
|
||||
secretKey: config.SecretKey,
|
||||
db: config.DB,
|
||||
beginTx: txGetter,
|
||||
tableConfig: config.TableConfig,
|
||||
tableManager: tableManager,
|
||||
}, nil
|
||||
@@ -112,16 +112,21 @@ func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
|
||||
// This method should be called periodically if automatic cleanup is not enabled,
|
||||
// or can be called on-demand regardless of automatic cleanup settings.
|
||||
func (gen *TokenGenerator) Cleanup(ctx context.Context) error {
|
||||
if gen.db == nil {
|
||||
if gen.beginTx == nil {
|
||||
return errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
|
||||
tx, err := gen.beginTx(ctx)
|
||||
if err != nil {
|
||||
return pkgerrors.Wrap(err, "failed to begin transaction")
|
||||
}
|
||||
|
||||
tableName := gen.tableConfig.TableName
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
query := "DELETE FROM " + tableName + " WHERE exp < ?"
|
||||
|
||||
_, err := gen.db.ExecContext(ctx, query, currentTime)
|
||||
_, err = tx.Exec(query, currentTime)
|
||||
if err != nil {
|
||||
return pkgerrors.Wrap(err, "failed to cleanup expired tokens")
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestCreateGenerator_Success_NoDB(t *testing.T) {
|
||||
DB: nil,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
}, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen)
|
||||
@@ -33,6 +33,10 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
|
||||
config.AutoCreate = false
|
||||
config.EnableAutoCleanup = false
|
||||
|
||||
txGetter := func(ctx context.Context) (DBTransaction, error) {
|
||||
return db.Begin()
|
||||
}
|
||||
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
@@ -42,7 +46,7 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
|
||||
DB: db,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: config,
|
||||
})
|
||||
}, txGetter)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen)
|
||||
@@ -67,6 +71,10 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
|
||||
mock.ExpectExec("CREATE OR REPLACE FUNCTION cleanup_jwtblacklist").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
txGetter := func(ctx context.Context) (DBTransaction, error) {
|
||||
return db.Begin()
|
||||
}
|
||||
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
@@ -76,7 +84,7 @@ func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
|
||||
DB: db,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
}, txGetter)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen)
|
||||
@@ -142,7 +150,7 @@ func TestCreateGenerator_InvalidInputs(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := CreateGenerator(tt.config)
|
||||
_, err := CreateGenerator(tt.config, nil)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -158,7 +166,7 @@ func TestCleanup_NoDB(t *testing.T) {
|
||||
DB: nil,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = gen.Cleanup(context.Background())
|
||||
@@ -175,6 +183,10 @@ func TestCleanup_Success(t *testing.T) {
|
||||
config.AutoCreate = false
|
||||
config.EnableAutoCleanup = false
|
||||
|
||||
txGetter := func(ctx context.Context) (DBTransaction, error) {
|
||||
return db.Begin()
|
||||
}
|
||||
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
@@ -184,10 +196,11 @@ func TestCleanup_Success(t *testing.T) {
|
||||
DB: db,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: config,
|
||||
})
|
||||
}, txGetter)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mock DELETE query
|
||||
// Mock transaction begin and DELETE query
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("DELETE FROM jwtblacklist WHERE exp").
|
||||
WillReturnResult(sqlmock.NewResult(0, 5))
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -11,8 +9,8 @@ import (
|
||||
// revoke is an internal method that adds a token to the blacklist database.
|
||||
// Once revoked, the token will fail validation checks even if it hasn't expired.
|
||||
// This operation must be performed within a database transaction.
|
||||
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||
if gen.db == nil {
|
||||
func (gen *TokenGenerator) revoke(tx DBTransaction, t Token) error {
|
||||
if gen.beginTx == nil {
|
||||
return errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
|
||||
@@ -22,7 +20,7 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||
sub := t.GetSUB()
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName)
|
||||
_, err := tx.ExecContext(context.Background(), query, jti.String(), exp, sub)
|
||||
_, err := tx.Exec(query, jti.String(), exp, sub)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.ExecContext")
|
||||
}
|
||||
@@ -32,8 +30,8 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||
// checkNotRevoked is an internal method that queries the blacklist to verify
|
||||
// a token hasn't been revoked. Returns true if the token is valid (not blacklisted),
|
||||
// false if it has been revoked. This operation must be performed within a database transaction.
|
||||
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||
if gen.db == nil {
|
||||
func (gen *TokenGenerator) checkNotRevoked(tx DBTransaction, t Token) (bool, error) {
|
||||
if gen.beginTx == nil {
|
||||
return false, errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
|
||||
@@ -41,7 +39,7 @@ func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||
jti := t.GetJTI()
|
||||
|
||||
query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName)
|
||||
rows, err := tx.QueryContext(context.Background(), query, jti.String())
|
||||
rows, err := tx.Query(query, jti.String())
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.QueryContext")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -20,13 +21,13 @@ func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
|
||||
DB: nil,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
return gen
|
||||
}
|
||||
|
||||
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
|
||||
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, *sql.DB, sqlmock.Sqlmock, func()) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -34,6 +35,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun
|
||||
config.AutoCreate = false
|
||||
config.EnableAutoCleanup = false
|
||||
|
||||
txGetter := func(ctx context.Context) (DBTransaction, error) {
|
||||
return db.Begin()
|
||||
}
|
||||
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
@@ -43,10 +48,10 @@ func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, fun
|
||||
DB: db,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: config,
|
||||
})
|
||||
}, txGetter)
|
||||
require.NoError(t, err)
|
||||
|
||||
return gen, mock, func() { db.Close() }
|
||||
return gen, db, mock, func() { db.Close() }
|
||||
}
|
||||
|
||||
func TestNoDBFail(t *testing.T) {
|
||||
@@ -73,7 +78,7 @@ func TestNoDBFail(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
||||
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
jti := uuid.New()
|
||||
@@ -97,7 +102,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
||||
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
tx, err := gen.db.Begin()
|
||||
tx, err := db.Begin()
|
||||
defer tx.Rollback()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ func newTestGenerator(t *testing.T) *TokenGenerator {
|
||||
DB: nil,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
return gen
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -23,12 +21,14 @@ type Token interface {
|
||||
|
||||
// Revoke adds this token to the blacklist, preventing future use.
|
||||
// Must be called within a database transaction context.
|
||||
Revoke(*sql.Tx) error
|
||||
// Accepts any transaction type that implements DBTransaction interface.
|
||||
Revoke(DBTransaction) error
|
||||
|
||||
// CheckNotRevoked verifies that this token has not been blacklisted.
|
||||
// Returns true if the token is valid, false if revoked.
|
||||
// Must be called within a database transaction context.
|
||||
CheckNotRevoked(*sql.Tx) (bool, error)
|
||||
// Accepts any transaction type that implements DBTransaction interface.
|
||||
CheckNotRevoked(DBTransaction) (bool, error)
|
||||
}
|
||||
|
||||
// AccessToken represents a JWT access token with all its claims.
|
||||
@@ -84,15 +84,15 @@ func (a AccessToken) GetScope() string {
|
||||
func (r RefreshToken) GetScope() string {
|
||||
return r.Scope
|
||||
}
|
||||
func (a AccessToken) Revoke(tx *sql.Tx) error {
|
||||
func (a AccessToken) Revoke(tx DBTransaction) error {
|
||||
return a.gen.revoke(tx, a)
|
||||
}
|
||||
func (r RefreshToken) Revoke(tx *sql.Tx) error {
|
||||
func (r RefreshToken) Revoke(tx DBTransaction) error {
|
||||
return r.gen.revoke(tx, r)
|
||||
}
|
||||
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||
func (a AccessToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
|
||||
return a.gen.checkNotRevoked(tx, a)
|
||||
}
|
||||
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||
func (r RefreshToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
|
||||
return r.gen.checkNotRevoked(tx, r)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -20,14 +18,15 @@ import (
|
||||
// revocation check is skipped.
|
||||
//
|
||||
// Parameters:
|
||||
// - tx: Database transaction for checking token revocation status
|
||||
// - tx: Database transaction for checking token revocation status.
|
||||
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
|
||||
// - tokenString: The JWT token string to validate
|
||||
//
|
||||
// Returns:
|
||||
// - *AccessToken: The validated token with all claims, or nil if validation fails
|
||||
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
|
||||
func (gen *TokenGenerator) ValidateAccess(
|
||||
tx *sql.Tx,
|
||||
tx DBTransaction,
|
||||
tokenString string,
|
||||
) (*AccessToken, error) {
|
||||
if tokenString == "" {
|
||||
@@ -86,10 +85,10 @@ func (gen *TokenGenerator) ValidateAccess(
|
||||
}
|
||||
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
if err != nil && gen.db != nil {
|
||||
if err != nil && gen.beginTx != nil {
|
||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||
}
|
||||
if !valid && gen.db != nil {
|
||||
if !valid && gen.beginTx != nil {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
@@ -109,14 +108,15 @@ func (gen *TokenGenerator) ValidateAccess(
|
||||
// revocation check is skipped.
|
||||
//
|
||||
// Parameters:
|
||||
// - tx: Database transaction for checking token revocation status
|
||||
// - tx: Database transaction for checking token revocation status.
|
||||
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
|
||||
// - tokenString: The JWT token string to validate
|
||||
//
|
||||
// Returns:
|
||||
// - *RefreshToken: The validated token with all claims, or nil if validation fails
|
||||
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
|
||||
func (gen *TokenGenerator) ValidateRefresh(
|
||||
tx *sql.Tx,
|
||||
tx DBTransaction,
|
||||
tokenString string,
|
||||
) (*RefreshToken, error) {
|
||||
if tokenString == "" {
|
||||
@@ -170,10 +170,10 @@ func (gen *TokenGenerator) ValidateRefresh(
|
||||
}
|
||||
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
if err != nil && gen.db != nil {
|
||||
if err != nil && gen.beginTx != nil {
|
||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||
}
|
||||
if !valid && gen.db != nil {
|
||||
if !valid && gen.beginTx != nil {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
|
||||
@@ -17,7 +17,7 @@ func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
|
||||
}
|
||||
|
||||
func TestValidateAccess_Success(t *testing.T) {
|
||||
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||
@@ -26,7 +26,7 @@ func TestValidateAccess_Success(t *testing.T) {
|
||||
// We don't know the JTI beforehand; match any arg
|
||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||
|
||||
tx, err := gen.db.Begin()
|
||||
tx, err := db.Begin()
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestValidateAccess_NoDB(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateRefresh_Success(t *testing.T) {
|
||||
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||
@@ -61,7 +61,7 @@ func TestValidateRefresh_Success(t *testing.T) {
|
||||
|
||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||
|
||||
tx, err := gen.db.Begin()
|
||||
tx, err := db.Begin()
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user