refactor to improve database operability

This commit is contained in:
2026-01-11 22:21:44 +11:00
parent 1b25e2f0a5
commit ae4094d426
13 changed files with 136 additions and 57 deletions

View File

@@ -1,8 +1,6 @@
package jwt
import (
"context"
"database/sql"
"fmt"
"github.com/pkg/errors"
@@ -11,8 +9,8 @@ import (
// revoke is an internal method that adds a token to the blacklist database.
// Once revoked, the token will fail validation checks even if it hasn't expired.
// This operation must be performed within a database transaction.
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
if gen.db == nil {
func (gen *TokenGenerator) revoke(tx DBTransaction, t Token) error {
if gen.beginTx == nil {
return errors.New("No DB provided, unable to use this function")
}
@@ -22,7 +20,7 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
sub := t.GetSUB()
query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName)
_, err := tx.ExecContext(context.Background(), query, jti.String(), exp, sub)
_, err := tx.Exec(query, jti.String(), exp, sub)
if err != nil {
return errors.Wrap(err, "tx.ExecContext")
}
@@ -32,8 +30,8 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
// checkNotRevoked is an internal method that queries the blacklist to verify
// a token hasn't been revoked. Returns true if the token is valid (not blacklisted),
// false if it has been revoked. This operation must be performed within a database transaction.
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
if gen.db == nil {
func (gen *TokenGenerator) checkNotRevoked(tx DBTransaction, t Token) (bool, error) {
if gen.beginTx == nil {
return false, errors.New("No DB provided, unable to use this function")
}
@@ -41,7 +39,7 @@ func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
jti := t.GetJTI()
query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName)
rows, err := tx.QueryContext(context.Background(), query, jti.String())
rows, err := tx.Query(query, jti.String())
if err != nil {
return false, errors.Wrap(err, "tx.QueryContext")
}