refactor to improve database operability
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -11,8 +9,8 @@ import (
|
||||
// revoke is an internal method that adds a token to the blacklist database.
|
||||
// Once revoked, the token will fail validation checks even if it hasn't expired.
|
||||
// This operation must be performed within a database transaction.
|
||||
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||
if gen.db == nil {
|
||||
func (gen *TokenGenerator) revoke(tx DBTransaction, t Token) error {
|
||||
if gen.beginTx == nil {
|
||||
return errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
|
||||
@@ -22,7 +20,7 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||
sub := t.GetSUB()
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName)
|
||||
_, err := tx.ExecContext(context.Background(), query, jti.String(), exp, sub)
|
||||
_, err := tx.Exec(query, jti.String(), exp, sub)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.ExecContext")
|
||||
}
|
||||
@@ -32,8 +30,8 @@ func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||
// checkNotRevoked is an internal method that queries the blacklist to verify
|
||||
// a token hasn't been revoked. Returns true if the token is valid (not blacklisted),
|
||||
// false if it has been revoked. This operation must be performed within a database transaction.
|
||||
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||
if gen.db == nil {
|
||||
func (gen *TokenGenerator) checkNotRevoked(tx DBTransaction, t Token) (bool, error) {
|
||||
if gen.beginTx == nil {
|
||||
return false, errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
|
||||
@@ -41,7 +39,7 @@ func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||
jti := t.GetJTI()
|
||||
|
||||
query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName)
|
||||
rows, err := tx.QueryContext(context.Background(), query, jti.String())
|
||||
rows, err := tx.Query(query, jti.String())
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.QueryContext")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user