Refactor database interface to use *sql.DB directly
Simplified the database layer by removing custom interface wrappers and using standard library *sql.DB and *sql.Tx types directly. Changes: - Removed DBConnection and DBTransaction interfaces from database.go - Removed NewDBConnection() wrapper function - Updated TokenGenerator to use *sql.DB instead of DBConnection - Updated all validation and revocation methods to accept *sql.Tx - Updated TableManager to work with *sql.DB directly - Updated all tests to use db.Begin() instead of custom wrappers - Fixed GeneratorConfig.DB field (was DBConn) - Updated documentation in doc.go with correct API usage Benefits: - Simpler API with fewer abstractions - Works directly with database/sql standard library - Compatible with GORM (via gormDB.DB()) and Bun (share same *sql.DB) - Easier to understand and maintain - No unnecessary wrapper layers Breaking changes: - GeneratorConfig.DBConn renamed to GeneratorConfig.DB - Removed NewDBConnection() function - pass *sql.DB directly - ValidateAccess/ValidateRefresh now accept *sql.Tx instead of DBTransaction - Token.Revoke/CheckNotRevoked now accept *sql.Tx instead of DBTransaction 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
212
jwt/tablemanager.go
Normal file
212
jwt/tablemanager.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// TableManager handles table creation, existence checks, and cleanup configuration.
|
||||
type TableManager struct {
|
||||
dbType DatabaseType
|
||||
tableConfig TableConfig
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewTableManager creates a new TableManager instance.
|
||||
func NewTableManager(db *sql.DB, dbType DatabaseType, config TableConfig) *TableManager {
|
||||
return &TableManager{
|
||||
dbType: dbType,
|
||||
tableConfig: config,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTable creates the blacklist table if it doesn't exist.
|
||||
func (tm *TableManager) CreateTable(ctx context.Context) error {
|
||||
exists, err := tm.tableExists(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to check if table exists")
|
||||
}
|
||||
|
||||
if exists {
|
||||
return nil // Table already exists
|
||||
}
|
||||
|
||||
createSQL, err := tm.getCreateTableSQL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tm.db.ExecContext(ctx, createSQL)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to create table %s", tm.tableConfig.TableName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tableExists checks if the blacklist table exists in the database.
|
||||
func (tm *TableManager) tableExists(ctx context.Context) (bool, error) {
|
||||
tableName := tm.tableConfig.TableName
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
switch tm.dbType.Type {
|
||||
case DatabasePostgreSQL:
|
||||
query = `
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = $1
|
||||
`
|
||||
args = []interface{}{tableName}
|
||||
case DatabaseMySQL, DatabaseMariaDB:
|
||||
query = `
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = ?
|
||||
`
|
||||
args = []interface{}{tableName}
|
||||
case DatabaseSQLite:
|
||||
query = `
|
||||
SELECT 1 FROM sqlite_master
|
||||
WHERE type = 'table'
|
||||
AND name = ?
|
||||
`
|
||||
args = []interface{}{tableName}
|
||||
default:
|
||||
return false, errors.Errorf("unsupported database type: %s", tm.dbType.Type)
|
||||
}
|
||||
|
||||
rows, err := tm.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "failed to check table existence")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return rows.Next(), nil
|
||||
}
|
||||
|
||||
// getCreateTableSQL returns the CREATE TABLE statement for the given database type.
|
||||
func (tm *TableManager) getCreateTableSQL() (string, error) {
|
||||
tableName := tm.tableConfig.TableName
|
||||
|
||||
switch tm.dbType.Type {
|
||||
case DatabasePostgreSQL:
|
||||
return fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
jti UUID PRIMARY KEY,
|
||||
exp BIGINT NOT NULL,
|
||||
sub INTEGER NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
|
||||
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
|
||||
`, tableName, tableName, tableName, tableName, tableName), nil
|
||||
|
||||
case DatabaseMySQL, DatabaseMariaDB:
|
||||
return fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
jti CHAR(36) PRIMARY KEY,
|
||||
exp BIGINT NOT NULL,
|
||||
sub INT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
INDEX idx_exp (exp),
|
||||
INDEX idx_sub (sub)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||
`, tableName), nil
|
||||
|
||||
case DatabaseSQLite:
|
||||
return fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
jti TEXT PRIMARY KEY,
|
||||
exp INTEGER NOT NULL,
|
||||
sub INTEGER NOT NULL,
|
||||
created_at INTEGER DEFAULT (strftime('%%s', 'now'))
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
|
||||
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
|
||||
`, tableName, tableName, tableName, tableName, tableName), nil
|
||||
|
||||
default:
|
||||
return "", errors.Errorf("unsupported database type: %s", tm.dbType.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupAutoCleanup configures database-native automatic cleanup of expired tokens.
|
||||
func (tm *TableManager) SetupAutoCleanup(ctx context.Context) error {
|
||||
if !tm.tableConfig.EnableAutoCleanup {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch tm.dbType.Type {
|
||||
case DatabasePostgreSQL:
|
||||
return tm.setupPostgreSQLCleanup(ctx)
|
||||
case DatabaseMySQL, DatabaseMariaDB:
|
||||
return tm.setupMySQLCleanup(ctx)
|
||||
case DatabaseSQLite:
|
||||
// SQLite doesn't support automatic cleanup
|
||||
return nil
|
||||
default:
|
||||
return errors.Errorf("unsupported database type: %s", tm.dbType.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// setupPostgreSQLCleanup creates a cleanup function for PostgreSQL.
|
||||
// Note: This creates a function but does not schedule it. You need to use pg_cron
|
||||
// or an external scheduler to call this function periodically.
|
||||
func (tm *TableManager) setupPostgreSQLCleanup(ctx context.Context) error {
|
||||
tableName := tm.tableConfig.TableName
|
||||
functionName := fmt.Sprintf("cleanup_%s", tableName)
|
||||
|
||||
createFunctionSQL := fmt.Sprintf(`
|
||||
CREATE OR REPLACE FUNCTION %s()
|
||||
RETURNS void AS $$
|
||||
BEGIN
|
||||
DELETE FROM %s WHERE exp < EXTRACT(EPOCH FROM NOW());
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
`, functionName, tableName)
|
||||
|
||||
_, err := tm.db.ExecContext(ctx, createFunctionSQL)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create cleanup function")
|
||||
}
|
||||
|
||||
// Note: Actual scheduling requires pg_cron extension or external tools
|
||||
// Users should call this function periodically using:
|
||||
// SELECT cleanup_jwtblacklist();
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupMySQLCleanup creates a MySQL event for automatic cleanup.
|
||||
// Note: Requires event_scheduler to be enabled in MySQL/MariaDB configuration.
|
||||
func (tm *TableManager) setupMySQLCleanup(ctx context.Context) error {
|
||||
tableName := tm.tableConfig.TableName
|
||||
eventName := fmt.Sprintf("cleanup_%s_event", tableName)
|
||||
interval := tm.tableConfig.CleanupInterval
|
||||
|
||||
// Drop existing event if it exists
|
||||
dropEventSQL := fmt.Sprintf("DROP EVENT IF EXISTS %s", eventName)
|
||||
_, err := tm.db.ExecContext(ctx, dropEventSQL)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to drop existing event")
|
||||
}
|
||||
|
||||
// Create new event
|
||||
createEventSQL := fmt.Sprintf(`
|
||||
CREATE EVENT %s
|
||||
ON SCHEDULE EVERY %d HOUR
|
||||
DO
|
||||
DELETE FROM %s WHERE exp < UNIX_TIMESTAMP()
|
||||
`, eventName, interval, tableName)
|
||||
|
||||
_, err = tm.db.ExecContext(ctx, createEventSQL)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to create cleanup event (ensure event_scheduler is enabled)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user