From 05aad5f11b36c7e2972fcce70bfb5fb0ae26114c Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Thu, 1 Jan 2026 22:44:39 +1100 Subject: [PATCH] fixed transaction issues --- jwt/revoke.go | 32 ++++++-------------------------- jwt/revoke_test.go | 25 ++++++++++++++----------- jwt/tokens.go | 31 ++++++++++++------------------- jwt/validate.go | 15 ++++++++------- jwt/validate_test.go | 23 +++++++++++++++++------ 5 files changed, 57 insertions(+), 69 deletions(-) diff --git a/jwt/revoke.go b/jwt/revoke.go index 424c8fb..3f7db31 100644 --- a/jwt/revoke.go +++ b/jwt/revoke.go @@ -1,47 +1,31 @@ package jwt import ( - "context" + "database/sql" "github.com/pkg/errors" ) // Revoke a token by adding it to the database -func revoke(ctx context.Context, t Token) error { - db := t.getDB() - if db == nil { +func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error { + if gen.dbConn == nil { return errors.New("No DB provided, unable to use this function") } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return errors.Wrap(err, "db.BeginTx") - } - defer tx.Rollback() jti := t.GetJTI() exp := t.GetEXP() query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)` - _, err = tx.Exec(query, jti, exp) + _, err := tx.Exec(query, jti, exp) if err != nil { return errors.Wrap(err, "tx.Exec") } - err = tx.Commit() - if err != nil { - return errors.Wrap(err, "tx.Commit") - } return nil } // Check if a token has been revoked. Returns true if not revoked. -func checkNotRevoked(ctx context.Context, t Token) (bool, error) { - db := t.getDB() - if db == nil { +func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) { + if gen.dbConn == nil { return false, errors.New("No DB provided, unable to use this function") } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return false, errors.Wrap(err, "db.BeginTx") - } - defer tx.Rollback() jti := t.GetJTI() query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` rows, err := tx.Query(query, jti) @@ -50,9 +34,5 @@ func checkNotRevoked(ctx context.Context, t Token) (bool, error) { } defer rows.Close() revoked := rows.Next() - err = tx.Commit() - if err != nil { - return false, errors.Wrap(err, "tx.Commit") - } return !revoked, nil } diff --git a/jwt/revoke_test.go b/jwt/revoke_test.go index 62372e3..be15dc0 100644 --- a/jwt/revoke_test.go +++ b/jwt/revoke_test.go @@ -2,6 +2,7 @@ package jwt import ( "context" + "database/sql" "testing" "time" @@ -31,14 +32,15 @@ func TestNoDBFail(t *testing.T) { token := AccessToken{ JTI: jti, EXP: exp, + gen: &TokenGenerator{}, } // Revoke should fail due to no DB - err := token.Revoke(context.Background()) + err := token.Revoke(&sql.Tx{}) require.Error(t, err) // CheckNotRevoked should fail - _, err = token.CheckNotRevoked(context.Background()) + _, err = token.CheckNotRevoked(&sql.Tx{}) require.Error(t, err) } @@ -52,7 +54,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) { token := AccessToken{ JTI: jti, EXP: exp, - db: gen.dbConn, + gen: gen, } // Revoke expectations @@ -60,21 +62,22 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) { mock.ExpectExec(`INSERT INTO jwtblacklist`). WithArgs(jti, exp). WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - err := token.Revoke(context.Background()) - require.NoError(t, err) - - // CheckNotRevoked expectations (now revoked) - mock.ExpectBegin() mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`). WithArgs(jti). WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1)) mock.ExpectCommit() - valid, err := token.CheckNotRevoked(context.Background()) + tx, err := gen.dbConn.BeginTx(context.Background(), nil) + defer tx.Rollback() + require.NoError(t, err) + + err = token.Revoke(tx) + require.NoError(t, err) + + valid, err := token.CheckNotRevoked(tx) require.NoError(t, err) require.False(t, valid) + require.NoError(t, tx.Commit()) require.NoError(t, mock.ExpectationsWereMet()) } diff --git a/jwt/tokens.go b/jwt/tokens.go index 8754ccf..fbc1cf7 100644 --- a/jwt/tokens.go +++ b/jwt/tokens.go @@ -1,7 +1,6 @@ package jwt import ( - "context" "database/sql" "github.com/google/uuid" @@ -11,8 +10,8 @@ type Token interface { GetJTI() uuid.UUID GetEXP() int64 GetScope() string - getDB() *sql.DB - Revoke(context.Context) error + Revoke(*sql.Tx) error + CheckNotRevoked(*sql.Tx) (bool, error) } // Access token @@ -25,7 +24,7 @@ type AccessToken struct { JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens Fresh int64 // Time freshness expiring at Scope string // Should be "access" - db *sql.DB + gen *TokenGenerator } // Refresh token @@ -37,7 +36,7 @@ type RefreshToken struct { SUB int // Subject (user) ID JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens Scope string // Should be "refresh" - db *sql.DB + gen *TokenGenerator } func (a AccessToken) GetJTI() uuid.UUID { @@ -58,21 +57,15 @@ func (a AccessToken) GetScope() string { func (r RefreshToken) GetScope() string { return r.Scope } -func (a AccessToken) getDB() *sql.DB { - return a.db +func (a AccessToken) Revoke(tx *sql.Tx) error { + return a.gen.revoke(tx, a) } -func (r RefreshToken) getDB() *sql.DB { - return r.db +func (r RefreshToken) Revoke(tx *sql.Tx) error { + return r.gen.revoke(tx, r) } -func (a AccessToken) Revoke(ctx context.Context) error { - return revoke(ctx, a) +func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) { + return a.gen.checkNotRevoked(tx, a) } -func (r RefreshToken) Revoke(ctx context.Context) error { - return revoke(ctx, r) -} -func (a AccessToken) CheckNotRevoked(ctx context.Context) (bool, error) { - return checkNotRevoked(ctx, a) -} -func (r RefreshToken) CheckNotRevoked(ctx context.Context) (bool, error) { - return checkNotRevoked(ctx, r) +func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) { + return r.gen.checkNotRevoked(tx, r) } diff --git a/jwt/validate.go b/jwt/validate.go index c64a6f8..bb0965a 100644 --- a/jwt/validate.go +++ b/jwt/validate.go @@ -1,7 +1,8 @@ package jwt import ( - "context" + "database/sql" + "github.com/pkg/errors" ) @@ -9,7 +10,7 @@ import ( // all the claims, including checking if it is expired, has a valid issuer, and // has the correct scope. func (gen *TokenGenerator) ValidateAccess( - ctx context.Context, + tx *sql.Tx, tokenString string, ) (*AccessToken, error) { if tokenString == "" { @@ -64,10 +65,10 @@ func (gen *TokenGenerator) ValidateAccess( Fresh: fresh, JTI: jti, Scope: scope, - db: gen.dbConn, + gen: gen, } - valid, err := token.CheckNotRevoked(ctx) + valid, err := token.CheckNotRevoked(tx) if err != nil && gen.dbConn != nil { return nil, errors.Wrap(err, "token.CheckNotRevoked") } @@ -81,7 +82,7 @@ func (gen *TokenGenerator) ValidateAccess( // all the claims, including checking if it is expired, has a valid issuer, and // has the correct scope. func (gen *TokenGenerator) ValidateRefresh( - ctx context.Context, + tx *sql.Tx, tokenString string, ) (*RefreshToken, error) { if tokenString == "" { @@ -131,10 +132,10 @@ func (gen *TokenGenerator) ValidateRefresh( SUB: subject, JTI: jti, Scope: scope, - db: gen.dbConn, + gen: gen, } - valid, err := token.CheckNotRevoked(ctx) + valid, err := token.CheckNotRevoked(tx) if err != nil && gen.dbConn != nil { return nil, errors.Wrap(err, "token.CheckNotRevoked") } diff --git a/jwt/validate_test.go b/jwt/validate_test.go index 79d7808..bdedc0e 100644 --- a/jwt/validate_test.go +++ b/jwt/validate_test.go @@ -2,6 +2,7 @@ package jwt import ( "context" + "database/sql" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -43,10 +44,15 @@ func TestValidateAccess_Success(t *testing.T) { // We don't know the JTI beforehand; match any arg expectNotRevoked(mock, sqlmock.AnyArg()) - token, err := gen.ValidateAccess(context.Background(), tokenStr) + tx, err := gen.dbConn.BeginTx(context.Background(), nil) + require.NoError(t, err) + defer tx.Rollback() + + token, err := gen.ValidateAccess(tx, tokenStr) require.NoError(t, err) require.Equal(t, 42, token.SUB) require.Equal(t, "access", token.Scope) + tx.Commit() } func TestValidateAccess_NoDB(t *testing.T) { @@ -55,7 +61,7 @@ func TestValidateAccess_NoDB(t *testing.T) { tokenStr, _, err := gen.NewAccess(42, true, false) require.NoError(t, err) - token, err := gen.ValidateAccess(context.Background(), tokenStr) + token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr) require.NoError(t, err) require.Equal(t, 42, token.SUB) require.Equal(t, "access", token.Scope) @@ -70,10 +76,15 @@ func TestValidateRefresh_Success(t *testing.T) { expectNotRevoked(mock, sqlmock.AnyArg()) - token, err := gen.ValidateRefresh(context.Background(), tokenStr) + tx, err := gen.dbConn.BeginTx(context.Background(), nil) + require.NoError(t, err) + defer tx.Rollback() + + token, err := gen.ValidateRefresh(tx, tokenStr) require.NoError(t, err) require.Equal(t, 42, token.SUB) require.Equal(t, "refresh", token.Scope) + tx.Commit() } func TestValidateRefresh_NoDB(t *testing.T) { @@ -82,7 +93,7 @@ func TestValidateRefresh_NoDB(t *testing.T) { tokenStr, _, err := gen.NewRefresh(42, false) require.NoError(t, err) - token, err := gen.ValidateRefresh(context.Background(), tokenStr) + token, err := gen.ValidateRefresh(nil, tokenStr) require.NoError(t, err) require.Equal(t, 42, token.SUB) require.Equal(t, "refresh", token.Scope) @@ -91,7 +102,7 @@ func TestValidateRefresh_NoDB(t *testing.T) { func TestValidateAccess_EmptyToken(t *testing.T) { gen := newTestGenerator(t) - _, err := gen.ValidateAccess(context.Background(), "") + _, err := gen.ValidateAccess(nil, "") require.Error(t, err) } @@ -102,6 +113,6 @@ func TestValidateRefresh_WrongScope(t *testing.T) { tokenStr, _, err := gen.NewAccess(1, false, false) require.NoError(t, err) - _, err = gen.ValidateRefresh(context.Background(), tokenStr) + _, err = gen.ValidateRefresh(nil, tokenStr) require.Error(t, err) }