Refactor database interface to use *sql.DB directly
Simplified the database layer by removing custom interface wrappers and using standard library *sql.DB and *sql.Tx types directly. Changes: - Removed DBConnection and DBTransaction interfaces from database.go - Removed NewDBConnection() wrapper function - Updated TokenGenerator to use *sql.DB instead of DBConnection - Updated all validation and revocation methods to accept *sql.Tx - Updated TableManager to work with *sql.DB directly - Updated all tests to use db.Begin() instead of custom wrappers - Fixed GeneratorConfig.DB field (was DBConn) - Updated documentation in doc.go with correct API usage Benefits: - Simpler API with fewer abstractions - Works directly with database/sql standard library - Compatible with GORM (via gormDB.DB()) and Bun (share same *sql.DB) - Easier to understand and maintain - No unnecessary wrapper layers Breaking changes: - GeneratorConfig.DBConn renamed to GeneratorConfig.DB - Removed NewDBConnection() function - pass *sql.DB directly - ValidateAccess/ValidateRefresh now accept *sql.Tx instead of DBTransaction - Token.Revoke/CheckNotRevoked now accept *sql.Tx instead of DBTransaction 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
102
jwt/README.md
Normal file
102
jwt/README.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# JWT Package
|
||||
|
||||
[](https://pkg.go.dev/git.haelnorr.com/h/golib/jwt)
|
||||
|
||||
JWT (JSON Web Token) generation and validation with database-backed token revocation support.
|
||||
|
||||
## Features
|
||||
|
||||
- 🔐 Access and refresh token generation
|
||||
- ✅ Token validation with expiration checking
|
||||
- 🚫 Token revocation via database blacklist
|
||||
- 🗄️ Multi-database support (PostgreSQL, MySQL, SQLite, MariaDB)
|
||||
- 🔧 Compatible with database/sql, GORM, and Bun
|
||||
- 🤖 Automatic table creation and management
|
||||
- 🧹 Database-native automatic cleanup
|
||||
- 🔄 Token freshness tracking
|
||||
- 💾 "Remember me" functionality
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get git.haelnorr.com/h/golib/jwt
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Open database
|
||||
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
|
||||
defer db.Close()
|
||||
|
||||
// Wrap database connection
|
||||
dbConn := jwt.NewDBConnection(db)
|
||||
|
||||
// Create token generator
|
||||
gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||
AccessExpireAfter: 15, // 15 minutes
|
||||
RefreshExpireAfter: 1440, // 24 hours
|
||||
FreshExpireAfter: 5, // 5 minutes
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "your-secret-key",
|
||||
DBConn: dbConn,
|
||||
DBType: jwt.DatabaseType{
|
||||
Type: jwt.DatabasePostgreSQL,
|
||||
Version: "15",
|
||||
},
|
||||
TableConfig: jwt.DefaultTableConfig(),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Generate tokens
|
||||
accessToken, _, _ := gen.NewAccess(42, true, false)
|
||||
refreshToken, _, _ := gen.NewRefresh(42, false)
|
||||
|
||||
// Validate token
|
||||
tx, _ := dbConn.BeginTx(context.Background(), nil)
|
||||
token, _ := gen.ValidateAccess(tx, accessToken)
|
||||
|
||||
// Revoke token
|
||||
token.Revoke(tx)
|
||||
tx.Commit()
|
||||
}
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
Comprehensive documentation is available in the [Wiki](https://git.haelnorr.com/h/golib/wiki/JWT).
|
||||
|
||||
### Key Topics
|
||||
|
||||
- [Configuration](https://git.haelnorr.com/h/golib/wiki/JWT#configuration)
|
||||
- [Token Generation](https://git.haelnorr.com/h/golib/wiki/JWT#token-generation)
|
||||
- [Token Validation](https://git.haelnorr.com/h/golib/wiki/JWT#token-validation)
|
||||
- [Token Revocation](https://git.haelnorr.com/h/golib/wiki/JWT#token-revocation)
|
||||
- [Cleanup](https://git.haelnorr.com/h/golib/wiki/JWT#cleanup)
|
||||
- [Using with ORMs](https://git.haelnorr.com/h/golib/wiki/JWT#using-with-orms)
|
||||
|
||||
## Supported Databases
|
||||
|
||||
- PostgreSQL
|
||||
- MySQL
|
||||
- MariaDB
|
||||
- SQLite
|
||||
|
||||
## License
|
||||
|
||||
See LICENSE file in the repository root.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please open an issue or submit a pull request.
|
||||
@@ -6,7 +6,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Get the value of the access and refresh tokens
|
||||
// GetTokenCookies extracts access and refresh tokens from HTTP request cookies.
|
||||
// Returns empty strings for any cookies that don't exist.
|
||||
//
|
||||
// Returns:
|
||||
// - acc: The access token value from the "access" cookie (empty if not found)
|
||||
// - ref: The refresh token value from the "refresh" cookie (empty if not found)
|
||||
func GetTokenCookies(
|
||||
r *http.Request,
|
||||
) (acc string, ref string) {
|
||||
@@ -25,7 +30,16 @@ func GetTokenCookies(
|
||||
return accStr, refStr
|
||||
}
|
||||
|
||||
// Set a token with the provided details
|
||||
// setToken is an internal helper that sets a token cookie with the specified parameters.
|
||||
// The cookie is HttpOnly for security and uses SameSite=Lax mode.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: HTTP response writer to set the cookie on
|
||||
// - token: The token value to store in the cookie
|
||||
// - scope: The cookie name ("access" or "refresh")
|
||||
// - exp: Unix timestamp when the token expires
|
||||
// - rememberme: If true, sets cookie expiration; if false, cookie is session-only
|
||||
// - useSSL: If true, marks cookie as Secure (HTTPS only)
|
||||
func setToken(
|
||||
w http.ResponseWriter,
|
||||
token string,
|
||||
@@ -48,7 +62,21 @@ func setToken(
|
||||
http.SetCookie(w, tokenCookie)
|
||||
}
|
||||
|
||||
// Generate new tokens for the subject and set them as cookies
|
||||
// SetTokenCookies generates new access and refresh tokens for a user and sets them as HTTP cookies.
|
||||
// This is a convenience function that combines token generation with cookie setting.
|
||||
// Cookies are HttpOnly and use SameSite=Lax for security.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: HTTP response writer to set cookies on
|
||||
// - r: HTTP request (unused but kept for API consistency)
|
||||
// - tokenGen: The TokenGenerator to use for creating tokens
|
||||
// - subject: The user ID to generate tokens for
|
||||
// - fresh: If true, marks the access token as fresh for sensitive operations
|
||||
// - rememberMe: If true, tokens persist beyond browser session
|
||||
// - useSSL: If true, marks cookies as Secure (HTTPS only)
|
||||
//
|
||||
// Returns an error if token generation fails. Cookies are only set if both tokens
|
||||
// are generated successfully.
|
||||
func SetTokenCookies(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
|
||||
48
jwt/database.go
Normal file
48
jwt/database.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package jwt
|
||||
|
||||
// DatabaseType specifies the database system and version being used.
|
||||
type DatabaseType struct {
|
||||
Type string // Database type: "postgres", "mysql", "sqlite", "mariadb"
|
||||
Version string // Version string, e.g., "15.3", "8.0.32", "3.42.0"
|
||||
}
|
||||
|
||||
// Predefined database type constants for easy configuration and validation.
|
||||
const (
|
||||
DatabasePostgreSQL = "postgres"
|
||||
DatabaseMySQL = "mysql"
|
||||
DatabaseSQLite = "sqlite"
|
||||
DatabaseMariaDB = "mariadb"
|
||||
)
|
||||
|
||||
// TableConfig configures the JWT blacklist table.
|
||||
type TableConfig struct {
|
||||
// TableName is the name of the blacklist table.
|
||||
// Default: "jwtblacklist"
|
||||
TableName string
|
||||
|
||||
// AutoCreate determines whether to automatically create the table if it doesn't exist.
|
||||
// Default: true
|
||||
AutoCreate bool
|
||||
|
||||
// EnableAutoCleanup configures database-native automatic cleanup of expired tokens.
|
||||
// For PostgreSQL: Creates a cleanup function (requires external scheduler or pg_cron)
|
||||
// For MySQL/MariaDB: Creates a database event
|
||||
// For SQLite: No automatic cleanup (manual only)
|
||||
// Default: true
|
||||
EnableAutoCleanup bool
|
||||
|
||||
// CleanupInterval specifies how often automatic cleanup should run (in hours).
|
||||
// Only used if EnableAutoCleanup is true.
|
||||
// Default: 24 (daily cleanup)
|
||||
CleanupInterval int
|
||||
}
|
||||
|
||||
// DefaultTableConfig returns a TableConfig with sensible defaults.
|
||||
func DefaultTableConfig() TableConfig {
|
||||
return TableConfig{
|
||||
TableName: "jwtblacklist",
|
||||
AutoCreate: true,
|
||||
EnableAutoCleanup: true,
|
||||
CleanupInterval: 24,
|
||||
}
|
||||
}
|
||||
135
jwt/doc.go
Normal file
135
jwt/doc.go
Normal file
@@ -0,0 +1,135 @@
|
||||
// Package jwt provides JWT (JSON Web Token) generation and validation with token revocation support.
|
||||
//
|
||||
// This package implements JWT access and refresh tokens with the ability to revoke tokens
|
||||
// using a database-backed blacklist. It supports multiple database backends including
|
||||
// PostgreSQL, MySQL, SQLite, and MariaDB, and works with both standard library database/sql
|
||||
// and popular ORMs like GORM and Bun.
|
||||
//
|
||||
// # Features
|
||||
//
|
||||
// - Access and refresh token generation
|
||||
// - Token validation with expiration checking
|
||||
// - Token revocation via database blacklist
|
||||
// - Support for multiple database types (PostgreSQL, MySQL, SQLite, MariaDB)
|
||||
// - Compatible with database/sql, GORM, and Bun ORMs
|
||||
// - Automatic table creation and management
|
||||
// - Database-native automatic cleanup (PostgreSQL functions, MySQL events)
|
||||
// - Manual cleanup method for on-demand token cleanup
|
||||
// - Token freshness tracking for sensitive operations
|
||||
// - "Remember me" functionality with session vs persistent tokens
|
||||
//
|
||||
// # Basic Usage
|
||||
//
|
||||
// Create a token generator with database support:
|
||||
//
|
||||
// db, _ := sql.Open("postgres", "connection_string")
|
||||
// gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||
// AccessExpireAfter: 15, // 15 minutes
|
||||
// RefreshExpireAfter: 1440, // 24 hours
|
||||
// FreshExpireAfter: 5, // 5 minutes
|
||||
// TrustedHost: "example.com",
|
||||
// SecretKey: "your-secret-key",
|
||||
// DB: db,
|
||||
// DBType: jwt.DatabaseType{Type: jwt.DatabasePostgreSQL, Version: "15"},
|
||||
// TableConfig: jwt.DefaultTableConfig(),
|
||||
// })
|
||||
//
|
||||
// Generate tokens:
|
||||
//
|
||||
// accessToken, accessExp, err := gen.NewAccess(userID, true, false)
|
||||
// refreshToken, refreshExp, err := gen.NewRefresh(userID, false)
|
||||
//
|
||||
// Validate tokens:
|
||||
//
|
||||
// tx, _ := db.Begin()
|
||||
// token, err := gen.ValidateAccess(tx, accessToken)
|
||||
// if err != nil {
|
||||
// // Token is invalid or revoked
|
||||
// }
|
||||
// tx.Commit()
|
||||
//
|
||||
// Revoke tokens:
|
||||
//
|
||||
// tx, _ := db.Begin()
|
||||
// err := token.Revoke(tx)
|
||||
// tx.Commit()
|
||||
//
|
||||
// # Database Configuration
|
||||
//
|
||||
// The package automatically creates a blacklist table with the following schema:
|
||||
//
|
||||
// CREATE TABLE jwtblacklist (
|
||||
// jti UUID PRIMARY KEY, -- Token unique identifier
|
||||
// exp BIGINT NOT NULL, -- Expiration timestamp
|
||||
// sub INT NOT NULL, -- Subject (user) ID
|
||||
// created_at TIMESTAMP -- When token was blacklisted
|
||||
// );
|
||||
//
|
||||
// # Cleanup
|
||||
//
|
||||
// For PostgreSQL, the package creates a cleanup function that can be called manually
|
||||
// or scheduled with pg_cron:
|
||||
//
|
||||
// SELECT cleanup_jwtblacklist();
|
||||
//
|
||||
// For MySQL/MariaDB, the package creates a database event that runs automatically
|
||||
// (requires event_scheduler to be enabled).
|
||||
//
|
||||
// Manual cleanup can be performed at any time:
|
||||
//
|
||||
// err := gen.Cleanup(context.Background())
|
||||
//
|
||||
// # Using with ORMs
|
||||
//
|
||||
// The package works with popular ORMs by using raw SQL queries. For GORM and Bun,
|
||||
// wrap the underlying *sql.DB with NewDBConnection() when creating the generator:
|
||||
//
|
||||
// // GORM example
|
||||
// gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
// sqlDB, _ := gormDB.DB()
|
||||
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||
// // ... config ...
|
||||
// DB: sqlDB,
|
||||
// })
|
||||
//
|
||||
// // Bun example
|
||||
// sqlDB, _ := sql.Open("postgres", dsn)
|
||||
// bunDB := bun.NewDB(sqlDB, pgdialect.New())
|
||||
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||
// // ... config ...
|
||||
// DB: sqlDB,
|
||||
// })
|
||||
//
|
||||
// # Token Freshness
|
||||
//
|
||||
// Tokens can be marked as "fresh" for sensitive operations. Fresh tokens are typically
|
||||
// required for actions like changing passwords or email addresses:
|
||||
//
|
||||
// token, err := gen.ValidateAccess(exec, tokenString)
|
||||
// if time.Now().Unix() > token.Fresh {
|
||||
// // Token is not fresh, require re-authentication
|
||||
// }
|
||||
//
|
||||
// # Custom Table Names
|
||||
//
|
||||
// You can customize the blacklist table name:
|
||||
//
|
||||
// config := jwt.DefaultTableConfig()
|
||||
// config.TableName = "my_token_blacklist"
|
||||
//
|
||||
// # Disabling Database Features
|
||||
//
|
||||
// To use JWT without revocation support (no database):
|
||||
//
|
||||
// gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
|
||||
// AccessExpireAfter: 15,
|
||||
// RefreshExpireAfter: 1440,
|
||||
// FreshExpireAfter: 5,
|
||||
// TrustedHost: "example.com",
|
||||
// SecretKey: "your-secret-key",
|
||||
// DB: nil, // No database
|
||||
// })
|
||||
//
|
||||
// When DB is nil, revocation features are disabled and token validation
|
||||
// will not check the blacklist.
|
||||
package jwt
|
||||
136
jwt/generator.go
136
jwt/generator.go
@@ -1,62 +1,130 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type TokenGenerator struct {
|
||||
accessExpireAfter int64 // Access Token expiry time in minutes
|
||||
refreshExpireAfter int64 // Refresh Token expiry time in minutes
|
||||
freshExpireAfter int64 // Token freshness expiry time in minutes
|
||||
trustedHost string // Trusted hostname to use for the tokens
|
||||
secretKey string // Secret key to use for token hashing
|
||||
dbConn *sql.DB // Database handle for token blacklisting
|
||||
accessExpireAfter int64 // Access Token expiry time in minutes
|
||||
refreshExpireAfter int64 // Refresh Token expiry time in minutes
|
||||
freshExpireAfter int64 // Token freshness expiry time in minutes
|
||||
trustedHost string // Trusted hostname to use for the tokens
|
||||
secretKey string // Secret key to use for token hashing
|
||||
db *sql.DB // Database connection for token blacklisting
|
||||
tableConfig TableConfig // Table configuration
|
||||
tableManager *TableManager // Table lifecycle manager
|
||||
}
|
||||
|
||||
// GeneratorConfig holds configuration for creating a TokenGenerator.
|
||||
type GeneratorConfig struct {
|
||||
// AccessExpireAfter is the access token expiry time in minutes.
|
||||
AccessExpireAfter int64
|
||||
|
||||
// RefreshExpireAfter is the refresh token expiry time in minutes.
|
||||
RefreshExpireAfter int64
|
||||
|
||||
// FreshExpireAfter is the token freshness expiry time in minutes.
|
||||
FreshExpireAfter int64
|
||||
|
||||
// TrustedHost is the trusted hostname to use for the tokens.
|
||||
TrustedHost string
|
||||
|
||||
// SecretKey is the secret key to use for token hashing.
|
||||
SecretKey string
|
||||
|
||||
// DB is the database connection. Can be nil to disable token revocation.
|
||||
// When using ORMs like GORM or Bun, pass the underlying *sql.DB.
|
||||
DB *sql.DB
|
||||
|
||||
// DBType specifies the database type and version for proper table management.
|
||||
// Only required if DB is not nil.
|
||||
DBType DatabaseType
|
||||
|
||||
// TableConfig configures the blacklist table name and behavior.
|
||||
// Only required if DB is not nil.
|
||||
TableConfig TableConfig
|
||||
}
|
||||
|
||||
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
|
||||
// All expiry times should be provided in minutes.
|
||||
// trustedHost and secretKey strings must be provided.
|
||||
// dbConn can be nil, but doing this will disable token revocation
|
||||
func CreateGenerator(
|
||||
accessExpireAfter int64,
|
||||
refreshExpireAfter int64,
|
||||
freshExpireAfter int64,
|
||||
trustedHost string,
|
||||
secretKey string,
|
||||
dbConn *sql.DB,
|
||||
) (gen *TokenGenerator, err error) {
|
||||
if accessExpireAfter <= 0 {
|
||||
func CreateGenerator(config GeneratorConfig) (gen *TokenGenerator, err error) {
|
||||
if config.AccessExpireAfter <= 0 {
|
||||
return nil, errors.New("accessExpireAfter must be greater than 0")
|
||||
}
|
||||
if refreshExpireAfter <= 0 {
|
||||
if config.RefreshExpireAfter <= 0 {
|
||||
return nil, errors.New("refreshExpireAfter must be greater than 0")
|
||||
}
|
||||
if freshExpireAfter <= 0 {
|
||||
if config.FreshExpireAfter <= 0 {
|
||||
return nil, errors.New("freshExpireAfter must be greater than 0")
|
||||
}
|
||||
if trustedHost == "" {
|
||||
if config.TrustedHost == "" {
|
||||
return nil, errors.New("trustedHost cannot be an empty string")
|
||||
}
|
||||
if secretKey == "" {
|
||||
if config.SecretKey == "" {
|
||||
return nil, errors.New("secretKey cannot be an empty string")
|
||||
}
|
||||
|
||||
if dbConn != nil {
|
||||
err := dbConn.Ping()
|
||||
if err != nil {
|
||||
return nil, errors.New("Failed to ping database")
|
||||
var tableManager *TableManager
|
||||
if config.DB != nil {
|
||||
// Create table manager
|
||||
tableManager = NewTableManager(config.DB, config.DBType, config.TableConfig)
|
||||
|
||||
// Create table if AutoCreate is enabled
|
||||
if config.TableConfig.AutoCreate {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = tableManager.CreateTable(ctx)
|
||||
if err != nil {
|
||||
return nil, pkgerrors.Wrap(err, "failed to create blacklist table")
|
||||
}
|
||||
}
|
||||
|
||||
// Setup automatic cleanup if enabled
|
||||
if config.TableConfig.EnableAutoCleanup {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = tableManager.SetupAutoCleanup(ctx)
|
||||
if err != nil {
|
||||
return nil, pkgerrors.Wrap(err, "failed to setup automatic cleanup")
|
||||
}
|
||||
}
|
||||
// TODO: check if jwtblacklist table exists
|
||||
// TODO: create jwtblacklist table if not existing
|
||||
}
|
||||
|
||||
return &TokenGenerator{
|
||||
accessExpireAfter: accessExpireAfter,
|
||||
refreshExpireAfter: refreshExpireAfter,
|
||||
freshExpireAfter: freshExpireAfter,
|
||||
trustedHost: trustedHost,
|
||||
secretKey: secretKey,
|
||||
dbConn: dbConn,
|
||||
accessExpireAfter: config.AccessExpireAfter,
|
||||
refreshExpireAfter: config.RefreshExpireAfter,
|
||||
freshExpireAfter: config.FreshExpireAfter,
|
||||
trustedHost: config.TrustedHost,
|
||||
secretKey: config.SecretKey,
|
||||
db: config.DB,
|
||||
tableConfig: config.TableConfig,
|
||||
tableManager: tableManager,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Cleanup manually removes expired tokens from the blacklist table.
|
||||
// This method should be called periodically if automatic cleanup is not enabled,
|
||||
// or can be called on-demand regardless of automatic cleanup settings.
|
||||
func (gen *TokenGenerator) Cleanup(ctx context.Context) error {
|
||||
if gen.db == nil {
|
||||
return errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
|
||||
tableName := gen.tableConfig.TableName
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
query := "DELETE FROM " + tableName + " WHERE exp < ?"
|
||||
|
||||
_, err := gen.db.ExecContext(ctx, query, currentTime)
|
||||
if err != nil {
|
||||
return pkgerrors.Wrap(err, "failed to cleanup expired tokens")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
@@ -8,14 +9,16 @@ import (
|
||||
)
|
||||
|
||||
func TestCreateGenerator_Success_NoDB(t *testing.T) {
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"secret",
|
||||
nil,
|
||||
)
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
FreshExpireAfter: 5,
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "secret",
|
||||
DB: nil,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen)
|
||||
@@ -26,14 +29,54 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"secret",
|
||||
db,
|
||||
)
|
||||
config := DefaultTableConfig()
|
||||
config.AutoCreate = false
|
||||
config.EnableAutoCleanup = false
|
||||
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
FreshExpireAfter: 5,
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "secret",
|
||||
DB: db,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: config,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Mock table doesn't exist
|
||||
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
|
||||
WithArgs("jwtblacklist").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"1"}))
|
||||
|
||||
// Mock CREATE TABLE
|
||||
mock.ExpectExec("CREATE TABLE IF NOT EXISTS jwtblacklist").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
// Mock cleanup function creation
|
||||
mock.ExpectExec("CREATE OR REPLACE FUNCTION cleanup_jwtblacklist").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
FreshExpireAfter: 5,
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "secret",
|
||||
DB: db,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen)
|
||||
@@ -42,49 +85,113 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
|
||||
|
||||
func TestCreateGenerator_InvalidInputs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
name string
|
||||
config GeneratorConfig
|
||||
}{
|
||||
{
|
||||
"access expiry <= 0",
|
||||
func() error {
|
||||
_, err := CreateGenerator(0, 1, 1, "h", "s", nil)
|
||||
return err
|
||||
GeneratorConfig{
|
||||
AccessExpireAfter: 0,
|
||||
RefreshExpireAfter: 1,
|
||||
FreshExpireAfter: 1,
|
||||
TrustedHost: "h",
|
||||
SecretKey: "s",
|
||||
},
|
||||
},
|
||||
{
|
||||
"refresh expiry <= 0",
|
||||
func() error {
|
||||
_, err := CreateGenerator(1, 0, 1, "h", "s", nil)
|
||||
return err
|
||||
GeneratorConfig{
|
||||
AccessExpireAfter: 1,
|
||||
RefreshExpireAfter: 0,
|
||||
FreshExpireAfter: 1,
|
||||
TrustedHost: "h",
|
||||
SecretKey: "s",
|
||||
},
|
||||
},
|
||||
{
|
||||
"fresh expiry <= 0",
|
||||
func() error {
|
||||
_, err := CreateGenerator(1, 1, 0, "h", "s", nil)
|
||||
return err
|
||||
GeneratorConfig{
|
||||
AccessExpireAfter: 1,
|
||||
RefreshExpireAfter: 1,
|
||||
FreshExpireAfter: 0,
|
||||
TrustedHost: "h",
|
||||
SecretKey: "s",
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty trustedHost",
|
||||
func() error {
|
||||
_, err := CreateGenerator(1, 1, 1, "", "s", nil)
|
||||
return err
|
||||
GeneratorConfig{
|
||||
AccessExpireAfter: 1,
|
||||
RefreshExpireAfter: 1,
|
||||
FreshExpireAfter: 1,
|
||||
TrustedHost: "",
|
||||
SecretKey: "s",
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty secretKey",
|
||||
func() error {
|
||||
_, err := CreateGenerator(1, 1, 1, "h", "", nil)
|
||||
return err
|
||||
GeneratorConfig{
|
||||
AccessExpireAfter: 1,
|
||||
RefreshExpireAfter: 1,
|
||||
FreshExpireAfter: 1,
|
||||
TrustedHost: "h",
|
||||
SecretKey: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Error(t, tt.fn())
|
||||
_, err := CreateGenerator(tt.config)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanup_NoDB(t *testing.T) {
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
FreshExpireAfter: 5,
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "secret",
|
||||
DB: nil,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = gen.Cleanup(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "No DB provided")
|
||||
}
|
||||
|
||||
func TestCleanup_Success(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
config := DefaultTableConfig()
|
||||
config.AutoCreate = false
|
||||
config.EnableAutoCleanup = false
|
||||
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
FreshExpireAfter: 5,
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "secret",
|
||||
DB: db,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: config,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mock DELETE query
|
||||
mock.ExpectExec("DELETE FROM jwtblacklist WHERE exp").
|
||||
WillReturnResult(sqlmock.NewResult(0, 5))
|
||||
|
||||
err = gen.Cleanup(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
@@ -1,38 +1,56 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Revoke a token by adding it to the database
|
||||
// revoke is an internal method that adds a token to the blacklist database.
|
||||
// Once revoked, the token will fail validation checks even if it hasn't expired.
|
||||
// This operation must be performed within a database transaction.
|
||||
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||
if gen.dbConn == nil {
|
||||
if gen.db == nil {
|
||||
return errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
|
||||
tableName := gen.tableConfig.TableName
|
||||
jti := t.GetJTI()
|
||||
exp := t.GetEXP()
|
||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||
_, err := tx.Exec(query, jti, exp)
|
||||
sub := t.GetSUB()
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName)
|
||||
_, err := tx.ExecContext(context.Background(), query, jti.String(), exp, sub)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
return errors.Wrap(err, "tx.ExecContext")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if a token has been revoked. Returns true if not revoked.
|
||||
// checkNotRevoked is an internal method that queries the blacklist to verify
|
||||
// a token hasn't been revoked. Returns true if the token is valid (not blacklisted),
|
||||
// false if it has been revoked. This operation must be performed within a database transaction.
|
||||
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||
if gen.dbConn == nil {
|
||||
if gen.db == nil {
|
||||
return false, errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
|
||||
tableName := gen.tableConfig.TableName
|
||||
jti := t.GetJTI()
|
||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||
rows, err := tx.Query(query, jti)
|
||||
|
||||
query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName)
|
||||
rows, err := tx.QueryContext(context.Background(), query, jti.String())
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.Query")
|
||||
return false, errors.Wrap(err, "tx.QueryContext")
|
||||
}
|
||||
defer rows.Close()
|
||||
revoked := rows.Next()
|
||||
return !revoked, nil
|
||||
|
||||
exists := rows.Next()
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, errors.Wrap(err, "rows iteration")
|
||||
}
|
||||
|
||||
return !exists, nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -12,19 +11,44 @@ import (
|
||||
)
|
||||
|
||||
func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"supersecret",
|
||||
nil,
|
||||
)
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
FreshExpireAfter: 5,
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "supersecret",
|
||||
DB: nil,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return gen
|
||||
}
|
||||
|
||||
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
config := DefaultTableConfig()
|
||||
config.AutoCreate = false
|
||||
config.EnableAutoCleanup = false
|
||||
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
FreshExpireAfter: 5,
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "supersecret",
|
||||
DB: db,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: config,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return gen, mock, func() { db.Close() }
|
||||
}
|
||||
|
||||
func TestNoDBFail(t *testing.T) {
|
||||
jti := uuid.New()
|
||||
exp := time.Now().Add(time.Hour).Unix()
|
||||
@@ -32,15 +56,19 @@ func TestNoDBFail(t *testing.T) {
|
||||
token := AccessToken{
|
||||
JTI: jti,
|
||||
EXP: exp,
|
||||
SUB: 42,
|
||||
gen: &TokenGenerator{},
|
||||
}
|
||||
|
||||
// Create a nil transaction (can't revoke without DB)
|
||||
var tx *sql.Tx = nil
|
||||
|
||||
// Revoke should fail due to no DB
|
||||
err := token.Revoke(&sql.Tx{})
|
||||
err := token.Revoke(tx)
|
||||
require.Error(t, err)
|
||||
|
||||
// CheckNotRevoked should fail
|
||||
_, err = token.CheckNotRevoked(&sql.Tx{})
|
||||
_, err = token.CheckNotRevoked(tx)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -50,24 +78,26 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
||||
|
||||
jti := uuid.New()
|
||||
exp := time.Now().Add(time.Hour).Unix()
|
||||
sub := 42
|
||||
|
||||
token := AccessToken{
|
||||
JTI: jti,
|
||||
EXP: exp,
|
||||
SUB: sub,
|
||||
gen: gen,
|
||||
}
|
||||
|
||||
// Revoke expectations
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec(`INSERT INTO jwtblacklist`).
|
||||
WithArgs(jti, exp).
|
||||
WithArgs(jti.String(), exp, sub).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||
WithArgs(jti).
|
||||
WithArgs(jti.String()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||
tx, err := gen.db.Begin()
|
||||
defer tx.Rollback()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
212
jwt/tablemanager.go
Normal file
212
jwt/tablemanager.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// TableManager handles table creation, existence checks, and cleanup configuration.
|
||||
type TableManager struct {
|
||||
dbType DatabaseType
|
||||
tableConfig TableConfig
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewTableManager creates a new TableManager instance.
|
||||
func NewTableManager(db *sql.DB, dbType DatabaseType, config TableConfig) *TableManager {
|
||||
return &TableManager{
|
||||
dbType: dbType,
|
||||
tableConfig: config,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTable creates the blacklist table if it doesn't exist.
|
||||
func (tm *TableManager) CreateTable(ctx context.Context) error {
|
||||
exists, err := tm.tableExists(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to check if table exists")
|
||||
}
|
||||
|
||||
if exists {
|
||||
return nil // Table already exists
|
||||
}
|
||||
|
||||
createSQL, err := tm.getCreateTableSQL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tm.db.ExecContext(ctx, createSQL)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to create table %s", tm.tableConfig.TableName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tableExists checks if the blacklist table exists in the database.
|
||||
func (tm *TableManager) tableExists(ctx context.Context) (bool, error) {
|
||||
tableName := tm.tableConfig.TableName
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
switch tm.dbType.Type {
|
||||
case DatabasePostgreSQL:
|
||||
query = `
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = $1
|
||||
`
|
||||
args = []interface{}{tableName}
|
||||
case DatabaseMySQL, DatabaseMariaDB:
|
||||
query = `
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = ?
|
||||
`
|
||||
args = []interface{}{tableName}
|
||||
case DatabaseSQLite:
|
||||
query = `
|
||||
SELECT 1 FROM sqlite_master
|
||||
WHERE type = 'table'
|
||||
AND name = ?
|
||||
`
|
||||
args = []interface{}{tableName}
|
||||
default:
|
||||
return false, errors.Errorf("unsupported database type: %s", tm.dbType.Type)
|
||||
}
|
||||
|
||||
rows, err := tm.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check table existence")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return rows.Next(), nil
|
||||
}
|
||||
|
||||
// getCreateTableSQL returns the CREATE TABLE statement for the given database type.
|
||||
func (tm *TableManager) getCreateTableSQL() (string, error) {
|
||||
tableName := tm.tableConfig.TableName
|
||||
|
||||
switch tm.dbType.Type {
|
||||
case DatabasePostgreSQL:
|
||||
return fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
jti UUID PRIMARY KEY,
|
||||
exp BIGINT NOT NULL,
|
||||
sub INTEGER NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
|
||||
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
|
||||
`, tableName, tableName, tableName, tableName, tableName), nil
|
||||
|
||||
case DatabaseMySQL, DatabaseMariaDB:
|
||||
return fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
jti CHAR(36) PRIMARY KEY,
|
||||
exp BIGINT NOT NULL,
|
||||
sub INT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
INDEX idx_exp (exp),
|
||||
INDEX idx_sub (sub)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
`, tableName), nil
|
||||
|
||||
case DatabaseSQLite:
|
||||
return fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
jti TEXT PRIMARY KEY,
|
||||
exp INTEGER NOT NULL,
|
||||
sub INTEGER NOT NULL,
|
||||
created_at INTEGER DEFAULT (strftime('%%s', 'now'))
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
|
||||
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
|
||||
`, tableName, tableName, tableName, tableName, tableName), nil
|
||||
|
||||
default:
|
||||
return "", errors.Errorf("unsupported database type: %s", tm.dbType.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupAutoCleanup configures database-native automatic cleanup of expired tokens.
|
||||
func (tm *TableManager) SetupAutoCleanup(ctx context.Context) error {
|
||||
if !tm.tableConfig.EnableAutoCleanup {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch tm.dbType.Type {
|
||||
case DatabasePostgreSQL:
|
||||
return tm.setupPostgreSQLCleanup(ctx)
|
||||
case DatabaseMySQL, DatabaseMariaDB:
|
||||
return tm.setupMySQLCleanup(ctx)
|
||||
case DatabaseSQLite:
|
||||
// SQLite doesn't support automatic cleanup
|
||||
return nil
|
||||
default:
|
||||
return errors.Errorf("unsupported database type: %s", tm.dbType.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// setupPostgreSQLCleanup creates a cleanup function for PostgreSQL.
|
||||
// Note: This creates a function but does not schedule it. You need to use pg_cron
|
||||
// or an external scheduler to call this function periodically.
|
||||
func (tm *TableManager) setupPostgreSQLCleanup(ctx context.Context) error {
|
||||
tableName := tm.tableConfig.TableName
|
||||
functionName := fmt.Sprintf("cleanup_%s", tableName)
|
||||
|
||||
createFunctionSQL := fmt.Sprintf(`
|
||||
CREATE OR REPLACE FUNCTION %s()
|
||||
RETURNS void AS $$
|
||||
BEGIN
|
||||
DELETE FROM %s WHERE exp < EXTRACT(EPOCH FROM NOW());
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
`, functionName, tableName)
|
||||
|
||||
_, err := tm.db.ExecContext(ctx, createFunctionSQL)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create cleanup function")
|
||||
}
|
||||
|
||||
// Note: Actual scheduling requires pg_cron extension or external tools
|
||||
// Users should call this function periodically using:
|
||||
// SELECT cleanup_jwtblacklist();
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupMySQLCleanup creates a MySQL event for automatic cleanup.
|
||||
// Note: Requires event_scheduler to be enabled in MySQL/MariaDB configuration.
|
||||
func (tm *TableManager) setupMySQLCleanup(ctx context.Context) error {
|
||||
tableName := tm.tableConfig.TableName
|
||||
eventName := fmt.Sprintf("cleanup_%s_event", tableName)
|
||||
interval := tm.tableConfig.CleanupInterval
|
||||
|
||||
// Drop existing event if it exists
|
||||
dropEventSQL := fmt.Sprintf("DROP EVENT IF EXISTS %s", eventName)
|
||||
_, err := tm.db.ExecContext(ctx, dropEventSQL)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to drop existing event")
|
||||
}
|
||||
|
||||
// Create new event
|
||||
createEventSQL := fmt.Sprintf(`
|
||||
CREATE EVENT %s
|
||||
ON SCHEDULE EVERY %d HOUR
|
||||
DO
|
||||
DELETE FROM %s WHERE exp < UNIX_TIMESTAMP()
|
||||
`, eventName, interval, tableName)
|
||||
|
||||
_, err = tm.db.ExecContext(ctx, createEventSQL)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to create cleanup event (ensure event_scheduler is enabled)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
221
jwt/tablemanager_test.go
Normal file
221
jwt/tablemanager_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTableManager(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
|
||||
config := DefaultTableConfig()
|
||||
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
require.NotNil(t, tm)
|
||||
}
|
||||
|
||||
func TestGetCreateTableSQL_PostgreSQL(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
|
||||
config := DefaultTableConfig()
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
sql, err := tm.getCreateTableSQL()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, sql, "CREATE TABLE IF NOT EXISTS jwtblacklist")
|
||||
require.Contains(t, sql, "jti UUID PRIMARY KEY")
|
||||
require.Contains(t, sql, "exp BIGINT NOT NULL")
|
||||
require.Contains(t, sql, "sub INTEGER NOT NULL")
|
||||
require.Contains(t, sql, "CREATE INDEX IF NOT EXISTS idx_jwtblacklist_exp")
|
||||
require.Contains(t, sql, "CREATE INDEX IF NOT EXISTS idx_jwtblacklist_sub")
|
||||
}
|
||||
|
||||
func TestGetCreateTableSQL_MySQL(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabaseMySQL, Version: "8.0"}
|
||||
config := DefaultTableConfig()
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
sql, err := tm.getCreateTableSQL()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, sql, "CREATE TABLE IF NOT EXISTS jwtblacklist")
|
||||
require.Contains(t, sql, "jti CHAR(36) PRIMARY KEY")
|
||||
require.Contains(t, sql, "exp BIGINT NOT NULL")
|
||||
require.Contains(t, sql, "sub INT NOT NULL")
|
||||
require.Contains(t, sql, "INDEX idx_exp")
|
||||
require.Contains(t, sql, "ENGINE=InnoDB")
|
||||
}
|
||||
|
||||
func TestGetCreateTableSQL_SQLite(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabaseSQLite, Version: "3.42"}
|
||||
config := DefaultTableConfig()
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
sql, err := tm.getCreateTableSQL()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, sql, "CREATE TABLE IF NOT EXISTS jwtblacklist")
|
||||
require.Contains(t, sql, "jti TEXT PRIMARY KEY")
|
||||
require.Contains(t, sql, "exp INTEGER NOT NULL")
|
||||
require.Contains(t, sql, "sub INTEGER NOT NULL")
|
||||
}
|
||||
|
||||
func TestGetCreateTableSQL_CustomTableName(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
|
||||
config := TableConfig{
|
||||
TableName: "custom_blacklist",
|
||||
AutoCreate: true,
|
||||
EnableAutoCleanup: false,
|
||||
CleanupInterval: 24,
|
||||
}
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
sql, err := tm.getCreateTableSQL()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, sql, "CREATE TABLE IF NOT EXISTS custom_blacklist")
|
||||
require.Contains(t, sql, "CREATE INDEX IF NOT EXISTS idx_custom_blacklist_exp")
|
||||
}
|
||||
|
||||
func TestGetCreateTableSQL_UnsupportedDB(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: "unsupported", Version: "1.0"}
|
||||
config := DefaultTableConfig()
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
sql, err := tm.getCreateTableSQL()
|
||||
require.Error(t, err)
|
||||
require.Empty(t, sql)
|
||||
require.Contains(t, err.Error(), "unsupported database type")
|
||||
}
|
||||
|
||||
func TestTableExists_PostgreSQL(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
|
||||
config := DefaultTableConfig()
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
// Test table exists
|
||||
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
|
||||
WithArgs("jwtblacklist").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
|
||||
exists, err := tm.tableExists(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
|
||||
// Test table doesn't exist
|
||||
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
|
||||
WithArgs("jwtblacklist").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"1"}))
|
||||
|
||||
exists, err = tm.tableExists(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.False(t, exists)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestCreateTable_AlreadyExists(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
|
||||
config := DefaultTableConfig()
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
// Mock table exists check
|
||||
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
|
||||
WithArgs("jwtblacklist").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
|
||||
err = tm.CreateTable(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestCreateTable_Success(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
|
||||
config := DefaultTableConfig()
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
// Mock table doesn't exist
|
||||
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
|
||||
WithArgs("jwtblacklist").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"1"}))
|
||||
|
||||
// Mock CREATE TABLE
|
||||
mock.ExpectExec("CREATE TABLE IF NOT EXISTS jwtblacklist").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
err = tm.CreateTable(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestSetupAutoCleanup_Disabled(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
|
||||
config := TableConfig{
|
||||
TableName: "jwtblacklist",
|
||||
AutoCreate: true,
|
||||
EnableAutoCleanup: false,
|
||||
CleanupInterval: 24,
|
||||
}
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
err = tm.SetupAutoCleanup(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestSetupAutoCleanup_SQLite(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
dbType := DatabaseType{Type: DatabaseSQLite, Version: "3.42"}
|
||||
config := DefaultTableConfig()
|
||||
tm := NewTableManager(db, dbType, config)
|
||||
|
||||
// SQLite doesn't support auto-cleanup, should return nil
|
||||
err = tm.SetupAutoCleanup(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
@@ -8,7 +8,21 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Generates an access token for the provided subject
|
||||
// NewAccess generates a new JWT access token for the specified subject (user).
|
||||
//
|
||||
// Parameters:
|
||||
// - subjectID: The user ID or subject identifier to associate with the token
|
||||
// - fresh: If true, marks the token as "fresh" for sensitive operations.
|
||||
// Fresh tokens are typically required for actions like changing passwords
|
||||
// or email addresses. The token remains fresh until FreshExpireAfter minutes.
|
||||
// - rememberMe: If true, the token is persistent (TTL="exp") and will be stored
|
||||
// with an expiration date. If false, it's session-only (TTL="session") and
|
||||
// expires when the browser closes.
|
||||
//
|
||||
// Returns:
|
||||
// - tokenString: The signed JWT token string
|
||||
// - expiresIn: Unix timestamp when the token expires
|
||||
// - err: Any error encountered during token generation
|
||||
func (gen *TokenGenerator) NewAccess(
|
||||
subjectID int,
|
||||
fresh bool,
|
||||
@@ -47,7 +61,19 @@ func (gen *TokenGenerator) NewAccess(
|
||||
return signedToken, expiresAt, nil
|
||||
}
|
||||
|
||||
// Generates a refresh token for the provided user
|
||||
// NewRefresh generates a new JWT refresh token for the specified subject (user).
|
||||
// Refresh tokens are used to obtain new access tokens without re-authentication.
|
||||
//
|
||||
// Parameters:
|
||||
// - subjectID: The user ID or subject identifier to associate with the token
|
||||
// - rememberMe: If true, the token is persistent (TTL="exp") and will be stored
|
||||
// with an expiration date. If false, it's session-only (TTL="session") and
|
||||
// expires when the browser closes.
|
||||
//
|
||||
// Returns:
|
||||
// - tokenStr: The signed JWT token string
|
||||
// - exp: Unix timestamp when the token expires
|
||||
// - err: Any error encountered during token generation
|
||||
func (gen *TokenGenerator) NewRefresh(
|
||||
subjectID int,
|
||||
rememberMe bool,
|
||||
|
||||
@@ -7,14 +7,16 @@ import (
|
||||
)
|
||||
|
||||
func newTestGenerator(t *testing.T) *TokenGenerator {
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"supersecret",
|
||||
nil,
|
||||
)
|
||||
gen, err := CreateGenerator(GeneratorConfig{
|
||||
AccessExpireAfter: 15,
|
||||
RefreshExpireAfter: 60,
|
||||
FreshExpireAfter: 5,
|
||||
TrustedHost: "example.com",
|
||||
SecretKey: "supersecret",
|
||||
DB: nil,
|
||||
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
|
||||
TableConfig: DefaultTableConfig(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return gen
|
||||
}
|
||||
|
||||
@@ -6,15 +6,34 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Token is the common interface implemented by both AccessToken and RefreshToken.
|
||||
// It provides methods to access token claims and manage token revocation.
|
||||
type Token interface {
|
||||
// GetJTI returns the unique token identifier (JTI claim)
|
||||
GetJTI() uuid.UUID
|
||||
|
||||
// GetEXP returns the expiration timestamp (EXP claim)
|
||||
GetEXP() int64
|
||||
|
||||
// GetSUB returns the subject/user ID (SUB claim)
|
||||
GetSUB() int
|
||||
|
||||
// GetScope returns the token scope ("access" or "refresh")
|
||||
GetScope() string
|
||||
|
||||
// Revoke adds this token to the blacklist, preventing future use.
|
||||
// Must be called within a database transaction context.
|
||||
Revoke(*sql.Tx) error
|
||||
|
||||
// CheckNotRevoked verifies that this token has not been blacklisted.
|
||||
// Returns true if the token is valid, false if revoked.
|
||||
// Must be called within a database transaction context.
|
||||
CheckNotRevoked(*sql.Tx) (bool, error)
|
||||
}
|
||||
|
||||
// Access token
|
||||
// AccessToken represents a JWT access token with all its claims.
|
||||
// Access tokens are short-lived and used for authenticating API requests.
|
||||
// They can be marked as "fresh" for sensitive operations like password changes.
|
||||
type AccessToken struct {
|
||||
ISS string // Issuer, generally TrustedHost
|
||||
IAT int64 // Time issued at
|
||||
@@ -27,7 +46,9 @@ type AccessToken struct {
|
||||
gen *TokenGenerator
|
||||
}
|
||||
|
||||
// Refresh token
|
||||
// RefreshToken represents a JWT refresh token with all its claims.
|
||||
// Refresh tokens are longer-lived and used to obtain new access tokens
|
||||
// without requiring the user to re-authenticate.
|
||||
type RefreshToken struct {
|
||||
ISS string // Issuer, generally TrustedHost
|
||||
IAT int64 // Time issued at
|
||||
@@ -51,6 +72,12 @@ func (a AccessToken) GetEXP() int64 {
|
||||
func (r RefreshToken) GetEXP() int64 {
|
||||
return r.EXP
|
||||
}
|
||||
func (a AccessToken) GetSUB() int {
|
||||
return a.SUB
|
||||
}
|
||||
func (r RefreshToken) GetSUB() int {
|
||||
return r.SUB
|
||||
}
|
||||
func (a AccessToken) GetScope() string {
|
||||
return a.Scope
|
||||
}
|
||||
|
||||
@@ -6,9 +6,26 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Parse an access token and return a struct with all the claims. Does validation on
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
// ValidateAccess parses and validates a JWT access token string.
|
||||
//
|
||||
// This method performs comprehensive validation including:
|
||||
// - Signature verification using the secret key
|
||||
// - Expiration time checking (token must not be expired)
|
||||
// - Issuer verification (must match trusted host)
|
||||
// - Scope verification (must be "access" token)
|
||||
// - Revocation status check (if database is configured)
|
||||
//
|
||||
// The validation must be performed within a database transaction context to ensure
|
||||
// consistency when checking the blacklist. If no database is configured, the
|
||||
// revocation check is skipped.
|
||||
//
|
||||
// Parameters:
|
||||
// - tx: Database transaction for checking token revocation status
|
||||
// - tokenString: The JWT token string to validate
|
||||
//
|
||||
// Returns:
|
||||
// - *AccessToken: The validated token with all claims, or nil if validation fails
|
||||
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
|
||||
func (gen *TokenGenerator) ValidateAccess(
|
||||
tx *sql.Tx,
|
||||
tokenString string,
|
||||
@@ -69,18 +86,35 @@ func (gen *TokenGenerator) ValidateAccess(
|
||||
}
|
||||
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
if err != nil && gen.dbConn != nil {
|
||||
if err != nil && gen.db != nil {
|
||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||
}
|
||||
if !valid && gen.dbConn != nil {
|
||||
if !valid && gen.db != nil {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Parse a refresh token and return a struct with all the claims. Does validation on
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
// ValidateRefresh parses and validates a JWT refresh token string.
|
||||
//
|
||||
// This method performs comprehensive validation including:
|
||||
// - Signature verification using the secret key
|
||||
// - Expiration time checking (token must not be expired)
|
||||
// - Issuer verification (must match trusted host)
|
||||
// - Scope verification (must be "refresh" token)
|
||||
// - Revocation status check (if database is configured)
|
||||
//
|
||||
// The validation must be performed within a database transaction context to ensure
|
||||
// consistency when checking the blacklist. If no database is configured, the
|
||||
// revocation check is skipped.
|
||||
//
|
||||
// Parameters:
|
||||
// - tx: Database transaction for checking token revocation status
|
||||
// - tokenString: The JWT token string to validate
|
||||
//
|
||||
// Returns:
|
||||
// - *RefreshToken: The validated token with all claims, or nil if validation fails
|
||||
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
|
||||
func (gen *TokenGenerator) ValidateRefresh(
|
||||
tx *sql.Tx,
|
||||
tokenString string,
|
||||
@@ -136,10 +170,10 @@ func (gen *TokenGenerator) ValidateRefresh(
|
||||
}
|
||||
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
if err != nil && gen.dbConn != nil {
|
||||
if err != nil && gen.db != nil {
|
||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||
}
|
||||
if !valid && gen.dbConn != nil {
|
||||
if !valid && gen.db != nil {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
@@ -9,23 +8,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"supersecret",
|
||||
db,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return gen, mock, func() { db.Close() }
|
||||
}
|
||||
|
||||
func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||
@@ -44,7 +26,7 @@ func TestValidateAccess_Success(t *testing.T) {
|
||||
// We don't know the JTI beforehand; match any arg
|
||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||
|
||||
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||
tx, err := gen.db.Begin()
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -61,7 +43,10 @@ func TestValidateAccess_NoDB(t *testing.T) {
|
||||
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
|
||||
// Use nil transaction for no-db case
|
||||
var tx *sql.Tx = nil
|
||||
|
||||
token, err := gen.ValidateAccess(tx, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "access", token.Scope)
|
||||
@@ -76,7 +61,7 @@ func TestValidateRefresh_Success(t *testing.T) {
|
||||
|
||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||
|
||||
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||
tx, err := gen.db.Begin()
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -93,7 +78,10 @@ func TestValidateRefresh_NoDB(t *testing.T) {
|
||||
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := gen.ValidateRefresh(nil, tokenStr)
|
||||
// Use nil transaction for no-db case
|
||||
var tx *sql.Tx = nil
|
||||
|
||||
token, err := gen.ValidateRefresh(tx, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "refresh", token.Scope)
|
||||
@@ -102,7 +90,10 @@ func TestValidateRefresh_NoDB(t *testing.T) {
|
||||
func TestValidateAccess_EmptyToken(t *testing.T) {
|
||||
gen := newTestGenerator(t)
|
||||
|
||||
_, err := gen.ValidateAccess(nil, "")
|
||||
// Use nil transaction
|
||||
var tx *sql.Tx = nil
|
||||
|
||||
_, err := gen.ValidateAccess(tx, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -113,6 +104,9 @@ func TestValidateRefresh_WrongScope(t *testing.T) {
|
||||
tokenStr, _, err := gen.NewAccess(1, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = gen.ValidateRefresh(nil, tokenStr)
|
||||
// Use nil transaction
|
||||
var tx *sql.Tx = nil
|
||||
|
||||
_, err = gen.ValidateRefresh(tx, tokenStr)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user