Files
golib/jwt/tablemanager.go
Haelnorr 1b25e2f0a5 Refactor database interface to use *sql.DB directly
Simplified the database layer by removing custom interface wrappers
and using standard library *sql.DB and *sql.Tx types directly.

Changes:
- Removed DBConnection and DBTransaction interfaces from database.go
- Removed NewDBConnection() wrapper function
- Updated TokenGenerator to use *sql.DB instead of DBConnection
- Updated all validation and revocation methods to accept *sql.Tx
- Updated TableManager to work with *sql.DB directly
- Updated all tests to use db.Begin() instead of custom wrappers
- Fixed GeneratorConfig.DB field (was DBConn)
- Updated documentation in doc.go with correct API usage

Benefits:
- Simpler API with fewer abstractions
- Works directly with database/sql standard library
- Compatible with GORM (via gormDB.DB()) and Bun (share same *sql.DB)
- Easier to understand and maintain
- No unnecessary wrapper layers

Breaking changes:
- GeneratorConfig.DBConn renamed to GeneratorConfig.DB
- Removed NewDBConnection() function - pass *sql.DB directly
- ValidateAccess/ValidateRefresh now accept *sql.Tx instead of DBTransaction
- Token.Revoke/CheckNotRevoked now accept *sql.Tx instead of DBTransaction

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-11 17:39:30 +11:00

213 lines
5.8 KiB
Go

package jwt
import (
"context"
"database/sql"
"fmt"
"github.com/pkg/errors"
)
// TableManager handles table creation, existence checks, and cleanup configuration.
type TableManager struct {
dbType DatabaseType
tableConfig TableConfig
db *sql.DB
}
// NewTableManager creates a new TableManager instance.
func NewTableManager(db *sql.DB, dbType DatabaseType, config TableConfig) *TableManager {
return &TableManager{
dbType: dbType,
tableConfig: config,
db: db,
}
}
// CreateTable creates the blacklist table if it doesn't exist.
func (tm *TableManager) CreateTable(ctx context.Context) error {
exists, err := tm.tableExists(ctx)
if err != nil {
return errors.Wrap(err, "failed to check if table exists")
}
if exists {
return nil // Table already exists
}
createSQL, err := tm.getCreateTableSQL()
if err != nil {
return err
}
_, err = tm.db.ExecContext(ctx, createSQL)
if err != nil {
return errors.Wrapf(err, "failed to create table %s", tm.tableConfig.TableName)
}
return nil
}
// tableExists checks if the blacklist table exists in the database.
func (tm *TableManager) tableExists(ctx context.Context) (bool, error) {
tableName := tm.tableConfig.TableName
var query string
var args []interface{}
switch tm.dbType.Type {
case DatabasePostgreSQL:
query = `
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = $1
`
args = []interface{}{tableName}
case DatabaseMySQL, DatabaseMariaDB:
query = `
SELECT 1 FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = ?
`
args = []interface{}{tableName}
case DatabaseSQLite:
query = `
SELECT 1 FROM sqlite_master
WHERE type = 'table'
AND name = ?
`
args = []interface{}{tableName}
default:
return false, errors.Errorf("unsupported database type: %s", tm.dbType.Type)
}
rows, err := tm.db.QueryContext(ctx, query, args...)
if err != nil {
return false, errors.Wrap(err, "failed to check table existence")
}
defer rows.Close()
return rows.Next(), nil
}
// getCreateTableSQL returns the CREATE TABLE statement for the given database type.
func (tm *TableManager) getCreateTableSQL() (string, error) {
tableName := tm.tableConfig.TableName
switch tm.dbType.Type {
case DatabasePostgreSQL:
return fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
jti UUID PRIMARY KEY,
exp BIGINT NOT NULL,
sub INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
`, tableName, tableName, tableName, tableName, tableName), nil
case DatabaseMySQL, DatabaseMariaDB:
return fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
jti CHAR(36) PRIMARY KEY,
exp BIGINT NOT NULL,
sub INT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_exp (exp),
INDEX idx_sub (sub)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
`, tableName), nil
case DatabaseSQLite:
return fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
jti TEXT PRIMARY KEY,
exp INTEGER NOT NULL,
sub INTEGER NOT NULL,
created_at INTEGER DEFAULT (strftime('%%s', 'now'))
);
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
`, tableName, tableName, tableName, tableName, tableName), nil
default:
return "", errors.Errorf("unsupported database type: %s", tm.dbType.Type)
}
}
// SetupAutoCleanup configures database-native automatic cleanup of expired tokens.
func (tm *TableManager) SetupAutoCleanup(ctx context.Context) error {
if !tm.tableConfig.EnableAutoCleanup {
return nil
}
switch tm.dbType.Type {
case DatabasePostgreSQL:
return tm.setupPostgreSQLCleanup(ctx)
case DatabaseMySQL, DatabaseMariaDB:
return tm.setupMySQLCleanup(ctx)
case DatabaseSQLite:
// SQLite doesn't support automatic cleanup
return nil
default:
return errors.Errorf("unsupported database type: %s", tm.dbType.Type)
}
}
// setupPostgreSQLCleanup creates a cleanup function for PostgreSQL.
// Note: This creates a function but does not schedule it. You need to use pg_cron
// or an external scheduler to call this function periodically.
func (tm *TableManager) setupPostgreSQLCleanup(ctx context.Context) error {
tableName := tm.tableConfig.TableName
functionName := fmt.Sprintf("cleanup_%s", tableName)
createFunctionSQL := fmt.Sprintf(`
CREATE OR REPLACE FUNCTION %s()
RETURNS void AS $$
BEGIN
DELETE FROM %s WHERE exp < EXTRACT(EPOCH FROM NOW());
END;
$$ LANGUAGE plpgsql;
`, functionName, tableName)
_, err := tm.db.ExecContext(ctx, createFunctionSQL)
if err != nil {
return errors.Wrap(err, "failed to create cleanup function")
}
// Note: Actual scheduling requires pg_cron extension or external tools
// Users should call this function periodically using:
// SELECT cleanup_jwtblacklist();
return nil
}
// setupMySQLCleanup creates a MySQL event for automatic cleanup.
// Note: Requires event_scheduler to be enabled in MySQL/MariaDB configuration.
func (tm *TableManager) setupMySQLCleanup(ctx context.Context) error {
tableName := tm.tableConfig.TableName
eventName := fmt.Sprintf("cleanup_%s_event", tableName)
interval := tm.tableConfig.CleanupInterval
// Drop existing event if it exists
dropEventSQL := fmt.Sprintf("DROP EVENT IF EXISTS %s", eventName)
_, err := tm.db.ExecContext(ctx, dropEventSQL)
if err != nil {
return errors.Wrap(err, "failed to drop existing event")
}
// Create new event
createEventSQL := fmt.Sprintf(`
CREATE EVENT %s
ON SCHEDULE EVERY %d HOUR
DO
DELETE FROM %s WHERE exp < UNIX_TIMESTAMP()
`, eventName, interval, tableName)
_, err = tm.db.ExecContext(ctx, createEventSQL)
if err != nil {
return errors.Wrapf(err, "failed to create cleanup event (ensure event_scheduler is enabled)")
}
return nil
}