Simplified the database layer by removing custom interface wrappers and using standard library *sql.DB and *sql.Tx types directly. Changes: - Removed DBConnection and DBTransaction interfaces from database.go - Removed NewDBConnection() wrapper function - Updated TokenGenerator to use *sql.DB instead of DBConnection - Updated all validation and revocation methods to accept *sql.Tx - Updated TableManager to work with *sql.DB directly - Updated all tests to use db.Begin() instead of custom wrappers - Fixed GeneratorConfig.DB field (was DBConn) - Updated documentation in doc.go with correct API usage Benefits: - Simpler API with fewer abstractions - Works directly with database/sql standard library - Compatible with GORM (via gormDB.DB()) and Bun (share same *sql.DB) - Easier to understand and maintain - No unnecessary wrapper layers Breaking changes: - GeneratorConfig.DBConn renamed to GeneratorConfig.DB - Removed NewDBConnection() function - pass *sql.DB directly - ValidateAccess/ValidateRefresh now accept *sql.Tx instead of DBTransaction - Token.Revoke/CheckNotRevoked now accept *sql.Tx instead of DBTransaction 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
213 lines
5.8 KiB
Go
213 lines
5.8 KiB
Go
package jwt
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// TableManager handles table creation, existence checks, and cleanup configuration.
|
|
type TableManager struct {
|
|
dbType DatabaseType
|
|
tableConfig TableConfig
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewTableManager creates a new TableManager instance.
|
|
func NewTableManager(db *sql.DB, dbType DatabaseType, config TableConfig) *TableManager {
|
|
return &TableManager{
|
|
dbType: dbType,
|
|
tableConfig: config,
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
// CreateTable creates the blacklist table if it doesn't exist.
|
|
func (tm *TableManager) CreateTable(ctx context.Context) error {
|
|
exists, err := tm.tableExists(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed to check if table exists")
|
|
}
|
|
|
|
if exists {
|
|
return nil // Table already exists
|
|
}
|
|
|
|
createSQL, err := tm.getCreateTableSQL()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tm.db.ExecContext(ctx, createSQL)
|
|
if err != nil {
|
|
return errors.Wrapf(err, "failed to create table %s", tm.tableConfig.TableName)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// tableExists checks if the blacklist table exists in the database.
|
|
func (tm *TableManager) tableExists(ctx context.Context) (bool, error) {
|
|
tableName := tm.tableConfig.TableName
|
|
var query string
|
|
var args []interface{}
|
|
|
|
switch tm.dbType.Type {
|
|
case DatabasePostgreSQL:
|
|
query = `
|
|
SELECT 1 FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = $1
|
|
`
|
|
args = []interface{}{tableName}
|
|
case DatabaseMySQL, DatabaseMariaDB:
|
|
query = `
|
|
SELECT 1 FROM information_schema.tables
|
|
WHERE table_schema = DATABASE()
|
|
AND table_name = ?
|
|
`
|
|
args = []interface{}{tableName}
|
|
case DatabaseSQLite:
|
|
query = `
|
|
SELECT 1 FROM sqlite_master
|
|
WHERE type = 'table'
|
|
AND name = ?
|
|
`
|
|
args = []interface{}{tableName}
|
|
default:
|
|
return false, errors.Errorf("unsupported database type: %s", tm.dbType.Type)
|
|
}
|
|
|
|
rows, err := tm.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return false, errors.Wrap(err, "failed to check table existence")
|
|
}
|
|
defer rows.Close()
|
|
|
|
return rows.Next(), nil
|
|
}
|
|
|
|
// getCreateTableSQL returns the CREATE TABLE statement for the given database type.
|
|
func (tm *TableManager) getCreateTableSQL() (string, error) {
|
|
tableName := tm.tableConfig.TableName
|
|
|
|
switch tm.dbType.Type {
|
|
case DatabasePostgreSQL:
|
|
return fmt.Sprintf(`
|
|
CREATE TABLE IF NOT EXISTS %s (
|
|
jti UUID PRIMARY KEY,
|
|
exp BIGINT NOT NULL,
|
|
sub INTEGER NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
|
|
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
|
|
`, tableName, tableName, tableName, tableName, tableName), nil
|
|
|
|
case DatabaseMySQL, DatabaseMariaDB:
|
|
return fmt.Sprintf(`
|
|
CREATE TABLE IF NOT EXISTS %s (
|
|
jti CHAR(36) PRIMARY KEY,
|
|
exp BIGINT NOT NULL,
|
|
sub INT NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
INDEX idx_exp (exp),
|
|
INDEX idx_sub (sub)
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
`, tableName), nil
|
|
|
|
case DatabaseSQLite:
|
|
return fmt.Sprintf(`
|
|
CREATE TABLE IF NOT EXISTS %s (
|
|
jti TEXT PRIMARY KEY,
|
|
exp INTEGER NOT NULL,
|
|
sub INTEGER NOT NULL,
|
|
created_at INTEGER DEFAULT (strftime('%%s', 'now'))
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
|
|
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
|
|
`, tableName, tableName, tableName, tableName, tableName), nil
|
|
|
|
default:
|
|
return "", errors.Errorf("unsupported database type: %s", tm.dbType.Type)
|
|
}
|
|
}
|
|
|
|
// SetupAutoCleanup configures database-native automatic cleanup of expired tokens.
|
|
func (tm *TableManager) SetupAutoCleanup(ctx context.Context) error {
|
|
if !tm.tableConfig.EnableAutoCleanup {
|
|
return nil
|
|
}
|
|
|
|
switch tm.dbType.Type {
|
|
case DatabasePostgreSQL:
|
|
return tm.setupPostgreSQLCleanup(ctx)
|
|
case DatabaseMySQL, DatabaseMariaDB:
|
|
return tm.setupMySQLCleanup(ctx)
|
|
case DatabaseSQLite:
|
|
// SQLite doesn't support automatic cleanup
|
|
return nil
|
|
default:
|
|
return errors.Errorf("unsupported database type: %s", tm.dbType.Type)
|
|
}
|
|
}
|
|
|
|
// setupPostgreSQLCleanup creates a cleanup function for PostgreSQL.
|
|
// Note: This creates a function but does not schedule it. You need to use pg_cron
|
|
// or an external scheduler to call this function periodically.
|
|
func (tm *TableManager) setupPostgreSQLCleanup(ctx context.Context) error {
|
|
tableName := tm.tableConfig.TableName
|
|
functionName := fmt.Sprintf("cleanup_%s", tableName)
|
|
|
|
createFunctionSQL := fmt.Sprintf(`
|
|
CREATE OR REPLACE FUNCTION %s()
|
|
RETURNS void AS $$
|
|
BEGIN
|
|
DELETE FROM %s WHERE exp < EXTRACT(EPOCH FROM NOW());
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
`, functionName, tableName)
|
|
|
|
_, err := tm.db.ExecContext(ctx, createFunctionSQL)
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed to create cleanup function")
|
|
}
|
|
|
|
// Note: Actual scheduling requires pg_cron extension or external tools
|
|
// Users should call this function periodically using:
|
|
// SELECT cleanup_jwtblacklist();
|
|
return nil
|
|
}
|
|
|
|
// setupMySQLCleanup creates a MySQL event for automatic cleanup.
|
|
// Note: Requires event_scheduler to be enabled in MySQL/MariaDB configuration.
|
|
func (tm *TableManager) setupMySQLCleanup(ctx context.Context) error {
|
|
tableName := tm.tableConfig.TableName
|
|
eventName := fmt.Sprintf("cleanup_%s_event", tableName)
|
|
interval := tm.tableConfig.CleanupInterval
|
|
|
|
// Drop existing event if it exists
|
|
dropEventSQL := fmt.Sprintf("DROP EVENT IF EXISTS %s", eventName)
|
|
_, err := tm.db.ExecContext(ctx, dropEventSQL)
|
|
if err != nil {
|
|
return errors.Wrap(err, "failed to drop existing event")
|
|
}
|
|
|
|
// Create new event
|
|
createEventSQL := fmt.Sprintf(`
|
|
CREATE EVENT %s
|
|
ON SCHEDULE EVERY %d HOUR
|
|
DO
|
|
DELETE FROM %s WHERE exp < UNIX_TIMESTAMP()
|
|
`, eventName, interval, tableName)
|
|
|
|
_, err = tm.db.ExecContext(ctx, createEventSQL)
|
|
if err != nil {
|
|
return errors.Wrapf(err, "failed to create cleanup event (ensure event_scheduler is enabled)")
|
|
}
|
|
|
|
return nil
|
|
}
|