fixed transaction issues
This commit is contained in:
@@ -1,47 +1,31 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Revoke a token by adding it to the database
|
||||
func revoke(ctx context.Context, t Token) error {
|
||||
db := t.getDB()
|
||||
if db == nil {
|
||||
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||
if gen.dbConn == nil {
|
||||
return errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.BeginTx")
|
||||
}
|
||||
defer tx.Rollback()
|
||||
jti := t.GetJTI()
|
||||
exp := t.GetEXP()
|
||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||
_, err = tx.Exec(query, jti, exp)
|
||||
_, err := tx.Exec(query, jti, exp)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Commit")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if a token has been revoked. Returns true if not revoked.
|
||||
func checkNotRevoked(ctx context.Context, t Token) (bool, error) {
|
||||
db := t.getDB()
|
||||
if db == nil {
|
||||
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||
if gen.dbConn == nil {
|
||||
return false, errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "db.BeginTx")
|
||||
}
|
||||
defer tx.Rollback()
|
||||
jti := t.GetJTI()
|
||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||
rows, err := tx.Query(query, jti)
|
||||
@@ -50,9 +34,5 @@ func checkNotRevoked(ctx context.Context, t Token) (bool, error) {
|
||||
}
|
||||
defer rows.Close()
|
||||
revoked := rows.Next()
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.Commit")
|
||||
}
|
||||
return !revoked, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -31,14 +32,15 @@ func TestNoDBFail(t *testing.T) {
|
||||
token := AccessToken{
|
||||
JTI: jti,
|
||||
EXP: exp,
|
||||
gen: &TokenGenerator{},
|
||||
}
|
||||
|
||||
// Revoke should fail due to no DB
|
||||
err := token.Revoke(context.Background())
|
||||
err := token.Revoke(&sql.Tx{})
|
||||
require.Error(t, err)
|
||||
|
||||
// CheckNotRevoked should fail
|
||||
_, err = token.CheckNotRevoked(context.Background())
|
||||
_, err = token.CheckNotRevoked(&sql.Tx{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -52,7 +54,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
||||
token := AccessToken{
|
||||
JTI: jti,
|
||||
EXP: exp,
|
||||
db: gen.dbConn,
|
||||
gen: gen,
|
||||
}
|
||||
|
||||
// Revoke expectations
|
||||
@@ -60,21 +62,22 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
||||
mock.ExpectExec(`INSERT INTO jwtblacklist`).
|
||||
WithArgs(jti, exp).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
err := token.Revoke(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// CheckNotRevoked expectations (now revoked)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||
WithArgs(jti).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
valid, err := token.CheckNotRevoked(context.Background())
|
||||
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||
defer tx.Rollback()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = token.Revoke(tx)
|
||||
require.NoError(t, err)
|
||||
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, valid)
|
||||
|
||||
require.NoError(t, tx.Commit())
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -11,8 +10,8 @@ type Token interface {
|
||||
GetJTI() uuid.UUID
|
||||
GetEXP() int64
|
||||
GetScope() string
|
||||
getDB() *sql.DB
|
||||
Revoke(context.Context) error
|
||||
Revoke(*sql.Tx) error
|
||||
CheckNotRevoked(*sql.Tx) (bool, error)
|
||||
}
|
||||
|
||||
// Access token
|
||||
@@ -25,7 +24,7 @@ type AccessToken struct {
|
||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||
Fresh int64 // Time freshness expiring at
|
||||
Scope string // Should be "access"
|
||||
db *sql.DB
|
||||
gen *TokenGenerator
|
||||
}
|
||||
|
||||
// Refresh token
|
||||
@@ -37,7 +36,7 @@ type RefreshToken struct {
|
||||
SUB int // Subject (user) ID
|
||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||
Scope string // Should be "refresh"
|
||||
db *sql.DB
|
||||
gen *TokenGenerator
|
||||
}
|
||||
|
||||
func (a AccessToken) GetJTI() uuid.UUID {
|
||||
@@ -58,21 +57,15 @@ func (a AccessToken) GetScope() string {
|
||||
func (r RefreshToken) GetScope() string {
|
||||
return r.Scope
|
||||
}
|
||||
func (a AccessToken) getDB() *sql.DB {
|
||||
return a.db
|
||||
func (a AccessToken) Revoke(tx *sql.Tx) error {
|
||||
return a.gen.revoke(tx, a)
|
||||
}
|
||||
func (r RefreshToken) getDB() *sql.DB {
|
||||
return r.db
|
||||
func (r RefreshToken) Revoke(tx *sql.Tx) error {
|
||||
return r.gen.revoke(tx, r)
|
||||
}
|
||||
func (a AccessToken) Revoke(ctx context.Context) error {
|
||||
return revoke(ctx, a)
|
||||
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||
return a.gen.checkNotRevoked(tx, a)
|
||||
}
|
||||
func (r RefreshToken) Revoke(ctx context.Context) error {
|
||||
return revoke(ctx, r)
|
||||
}
|
||||
func (a AccessToken) CheckNotRevoked(ctx context.Context) (bool, error) {
|
||||
return checkNotRevoked(ctx, a)
|
||||
}
|
||||
func (r RefreshToken) CheckNotRevoked(ctx context.Context) (bool, error) {
|
||||
return checkNotRevoked(ctx, r)
|
||||
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||
return r.gen.checkNotRevoked(tx, r)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
func (gen *TokenGenerator) ValidateAccess(
|
||||
ctx context.Context,
|
||||
tx *sql.Tx,
|
||||
tokenString string,
|
||||
) (*AccessToken, error) {
|
||||
if tokenString == "" {
|
||||
@@ -64,10 +65,10 @@ func (gen *TokenGenerator) ValidateAccess(
|
||||
Fresh: fresh,
|
||||
JTI: jti,
|
||||
Scope: scope,
|
||||
db: gen.dbConn,
|
||||
gen: gen,
|
||||
}
|
||||
|
||||
valid, err := token.CheckNotRevoked(ctx)
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
if err != nil && gen.dbConn != nil {
|
||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func (gen *TokenGenerator) ValidateAccess(
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
func (gen *TokenGenerator) ValidateRefresh(
|
||||
ctx context.Context,
|
||||
tx *sql.Tx,
|
||||
tokenString string,
|
||||
) (*RefreshToken, error) {
|
||||
if tokenString == "" {
|
||||
@@ -131,10 +132,10 @@ func (gen *TokenGenerator) ValidateRefresh(
|
||||
SUB: subject,
|
||||
JTI: jti,
|
||||
Scope: scope,
|
||||
db: gen.dbConn,
|
||||
gen: gen,
|
||||
}
|
||||
|
||||
valid, err := token.CheckNotRevoked(ctx)
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
if err != nil && gen.dbConn != nil {
|
||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
@@ -43,10 +44,15 @@ func TestValidateAccess_Success(t *testing.T) {
|
||||
// We don't know the JTI beforehand; match any arg
|
||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||
|
||||
token, err := gen.ValidateAccess(context.Background(), tokenStr)
|
||||
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
token, err := gen.ValidateAccess(tx, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "access", token.Scope)
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
func TestValidateAccess_NoDB(t *testing.T) {
|
||||
@@ -55,7 +61,7 @@ func TestValidateAccess_NoDB(t *testing.T) {
|
||||
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := gen.ValidateAccess(context.Background(), tokenStr)
|
||||
token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "access", token.Scope)
|
||||
@@ -70,10 +76,15 @@ func TestValidateRefresh_Success(t *testing.T) {
|
||||
|
||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||
|
||||
token, err := gen.ValidateRefresh(context.Background(), tokenStr)
|
||||
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
token, err := gen.ValidateRefresh(tx, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "refresh", token.Scope)
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
func TestValidateRefresh_NoDB(t *testing.T) {
|
||||
@@ -82,7 +93,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
|
||||
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := gen.ValidateRefresh(context.Background(), tokenStr)
|
||||
token, err := gen.ValidateRefresh(nil, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "refresh", token.Scope)
|
||||
@@ -91,7 +102,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
|
||||
func TestValidateAccess_EmptyToken(t *testing.T) {
|
||||
gen := newTestGenerator(t)
|
||||
|
||||
_, err := gen.ValidateAccess(context.Background(), "")
|
||||
_, err := gen.ValidateAccess(nil, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -102,6 +113,6 @@ func TestValidateRefresh_WrongScope(t *testing.T) {
|
||||
tokenStr, _, err := gen.NewAccess(1, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = gen.ValidateRefresh(context.Background(), tokenStr)
|
||||
_, err = gen.ValidateRefresh(nil, tokenStr)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user