fixed transaction issues
This commit is contained in:
@@ -1,47 +1,31 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Revoke a token by adding it to the database
|
// Revoke a token by adding it to the database
|
||||||
func revoke(ctx context.Context, t Token) error {
|
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||||
db := t.getDB()
|
if gen.dbConn == nil {
|
||||||
if db == nil {
|
|
||||||
return errors.New("No DB provided, unable to use this function")
|
return errors.New("No DB provided, unable to use this function")
|
||||||
}
|
}
|
||||||
tx, err := db.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "db.BeginTx")
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
jti := t.GetJTI()
|
jti := t.GetJTI()
|
||||||
exp := t.GetEXP()
|
exp := t.GetEXP()
|
||||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||||
_, err = tx.Exec(query, jti, exp)
|
_, err := tx.Exec(query, jti, exp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "tx.Exec")
|
return errors.Wrap(err, "tx.Exec")
|
||||||
}
|
}
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tx.Commit")
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if a token has been revoked. Returns true if not revoked.
|
// Check if a token has been revoked. Returns true if not revoked.
|
||||||
func checkNotRevoked(ctx context.Context, t Token) (bool, error) {
|
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||||
db := t.getDB()
|
if gen.dbConn == nil {
|
||||||
if db == nil {
|
|
||||||
return false, errors.New("No DB provided, unable to use this function")
|
return false, errors.New("No DB provided, unable to use this function")
|
||||||
}
|
}
|
||||||
tx, err := db.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "db.BeginTx")
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
jti := t.GetJTI()
|
jti := t.GetJTI()
|
||||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||||
rows, err := tx.Query(query, jti)
|
rows, err := tx.Query(query, jti)
|
||||||
@@ -50,9 +34,5 @@ func checkNotRevoked(ctx context.Context, t Token) (bool, error) {
|
|||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
revoked := rows.Next()
|
revoked := rows.Next()
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "tx.Commit")
|
|
||||||
}
|
|
||||||
return !revoked, nil
|
return !revoked, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package jwt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -31,14 +32,15 @@ func TestNoDBFail(t *testing.T) {
|
|||||||
token := AccessToken{
|
token := AccessToken{
|
||||||
JTI: jti,
|
JTI: jti,
|
||||||
EXP: exp,
|
EXP: exp,
|
||||||
|
gen: &TokenGenerator{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Revoke should fail due to no DB
|
// Revoke should fail due to no DB
|
||||||
err := token.Revoke(context.Background())
|
err := token.Revoke(&sql.Tx{})
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
// CheckNotRevoked should fail
|
// CheckNotRevoked should fail
|
||||||
_, err = token.CheckNotRevoked(context.Background())
|
_, err = token.CheckNotRevoked(&sql.Tx{})
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +54,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
|||||||
token := AccessToken{
|
token := AccessToken{
|
||||||
JTI: jti,
|
JTI: jti,
|
||||||
EXP: exp,
|
EXP: exp,
|
||||||
db: gen.dbConn,
|
gen: gen,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Revoke expectations
|
// Revoke expectations
|
||||||
@@ -60,21 +62,22 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
|||||||
mock.ExpectExec(`INSERT INTO jwtblacklist`).
|
mock.ExpectExec(`INSERT INTO jwtblacklist`).
|
||||||
WithArgs(jti, exp).
|
WithArgs(jti, exp).
|
||||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
|
||||||
|
|
||||||
err := token.Revoke(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// CheckNotRevoked expectations (now revoked)
|
|
||||||
mock.ExpectBegin()
|
|
||||||
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||||
WithArgs(jti).
|
WithArgs(jti).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||||
mock.ExpectCommit()
|
mock.ExpectCommit()
|
||||||
|
|
||||||
valid, err := token.CheckNotRevoked(context.Background())
|
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||||
|
defer tx.Rollback()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = token.Revoke(tx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.False(t, valid)
|
require.False(t, valid)
|
||||||
|
|
||||||
|
require.NoError(t, tx.Commit())
|
||||||
require.NoError(t, mock.ExpectationsWereMet())
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -11,8 +10,8 @@ type Token interface {
|
|||||||
GetJTI() uuid.UUID
|
GetJTI() uuid.UUID
|
||||||
GetEXP() int64
|
GetEXP() int64
|
||||||
GetScope() string
|
GetScope() string
|
||||||
getDB() *sql.DB
|
Revoke(*sql.Tx) error
|
||||||
Revoke(context.Context) error
|
CheckNotRevoked(*sql.Tx) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Access token
|
// Access token
|
||||||
@@ -25,7 +24,7 @@ type AccessToken struct {
|
|||||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||||
Fresh int64 // Time freshness expiring at
|
Fresh int64 // Time freshness expiring at
|
||||||
Scope string // Should be "access"
|
Scope string // Should be "access"
|
||||||
db *sql.DB
|
gen *TokenGenerator
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh token
|
// Refresh token
|
||||||
@@ -37,7 +36,7 @@ type RefreshToken struct {
|
|||||||
SUB int // Subject (user) ID
|
SUB int // Subject (user) ID
|
||||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||||
Scope string // Should be "refresh"
|
Scope string // Should be "refresh"
|
||||||
db *sql.DB
|
gen *TokenGenerator
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AccessToken) GetJTI() uuid.UUID {
|
func (a AccessToken) GetJTI() uuid.UUID {
|
||||||
@@ -58,21 +57,15 @@ func (a AccessToken) GetScope() string {
|
|||||||
func (r RefreshToken) GetScope() string {
|
func (r RefreshToken) GetScope() string {
|
||||||
return r.Scope
|
return r.Scope
|
||||||
}
|
}
|
||||||
func (a AccessToken) getDB() *sql.DB {
|
func (a AccessToken) Revoke(tx *sql.Tx) error {
|
||||||
return a.db
|
return a.gen.revoke(tx, a)
|
||||||
}
|
}
|
||||||
func (r RefreshToken) getDB() *sql.DB {
|
func (r RefreshToken) Revoke(tx *sql.Tx) error {
|
||||||
return r.db
|
return r.gen.revoke(tx, r)
|
||||||
}
|
}
|
||||||
func (a AccessToken) Revoke(ctx context.Context) error {
|
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||||
return revoke(ctx, a)
|
return a.gen.checkNotRevoked(tx, a)
|
||||||
}
|
}
|
||||||
func (r RefreshToken) Revoke(ctx context.Context) error {
|
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||||
return revoke(ctx, r)
|
return r.gen.checkNotRevoked(tx, r)
|
||||||
}
|
|
||||||
func (a AccessToken) CheckNotRevoked(ctx context.Context) (bool, error) {
|
|
||||||
return checkNotRevoked(ctx, a)
|
|
||||||
}
|
|
||||||
func (r RefreshToken) CheckNotRevoked(ctx context.Context) (bool, error) {
|
|
||||||
return checkNotRevoked(ctx, r)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -9,7 +10,7 @@ import (
|
|||||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
// has the correct scope.
|
// has the correct scope.
|
||||||
func (gen *TokenGenerator) ValidateAccess(
|
func (gen *TokenGenerator) ValidateAccess(
|
||||||
ctx context.Context,
|
tx *sql.Tx,
|
||||||
tokenString string,
|
tokenString string,
|
||||||
) (*AccessToken, error) {
|
) (*AccessToken, error) {
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
@@ -64,10 +65,10 @@ func (gen *TokenGenerator) ValidateAccess(
|
|||||||
Fresh: fresh,
|
Fresh: fresh,
|
||||||
JTI: jti,
|
JTI: jti,
|
||||||
Scope: scope,
|
Scope: scope,
|
||||||
db: gen.dbConn,
|
gen: gen,
|
||||||
}
|
}
|
||||||
|
|
||||||
valid, err := token.CheckNotRevoked(ctx)
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
if err != nil && gen.dbConn != nil {
|
if err != nil && gen.dbConn != nil {
|
||||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||||
}
|
}
|
||||||
@@ -81,7 +82,7 @@ func (gen *TokenGenerator) ValidateAccess(
|
|||||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
// has the correct scope.
|
// has the correct scope.
|
||||||
func (gen *TokenGenerator) ValidateRefresh(
|
func (gen *TokenGenerator) ValidateRefresh(
|
||||||
ctx context.Context,
|
tx *sql.Tx,
|
||||||
tokenString string,
|
tokenString string,
|
||||||
) (*RefreshToken, error) {
|
) (*RefreshToken, error) {
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
@@ -131,10 +132,10 @@ func (gen *TokenGenerator) ValidateRefresh(
|
|||||||
SUB: subject,
|
SUB: subject,
|
||||||
JTI: jti,
|
JTI: jti,
|
||||||
Scope: scope,
|
Scope: scope,
|
||||||
db: gen.dbConn,
|
gen: gen,
|
||||||
}
|
}
|
||||||
|
|
||||||
valid, err := token.CheckNotRevoked(ctx)
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
if err != nil && gen.dbConn != nil {
|
if err != nil && gen.dbConn != nil {
|
||||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package jwt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
@@ -43,10 +44,15 @@ func TestValidateAccess_Success(t *testing.T) {
|
|||||||
// We don't know the JTI beforehand; match any arg
|
// We don't know the JTI beforehand; match any arg
|
||||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||||
|
|
||||||
token, err := gen.ValidateAccess(context.Background(), tokenStr)
|
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
token, err := gen.ValidateAccess(tx, tokenStr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 42, token.SUB)
|
require.Equal(t, 42, token.SUB)
|
||||||
require.Equal(t, "access", token.Scope)
|
require.Equal(t, "access", token.Scope)
|
||||||
|
tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateAccess_NoDB(t *testing.T) {
|
func TestValidateAccess_NoDB(t *testing.T) {
|
||||||
@@ -55,7 +61,7 @@ func TestValidateAccess_NoDB(t *testing.T) {
|
|||||||
tokenStr, _, err := gen.NewAccess(42, true, false)
|
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
token, err := gen.ValidateAccess(context.Background(), tokenStr)
|
token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 42, token.SUB)
|
require.Equal(t, 42, token.SUB)
|
||||||
require.Equal(t, "access", token.Scope)
|
require.Equal(t, "access", token.Scope)
|
||||||
@@ -70,10 +76,15 @@ func TestValidateRefresh_Success(t *testing.T) {
|
|||||||
|
|
||||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||||
|
|
||||||
token, err := gen.ValidateRefresh(context.Background(), tokenStr)
|
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
token, err := gen.ValidateRefresh(tx, tokenStr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 42, token.SUB)
|
require.Equal(t, 42, token.SUB)
|
||||||
require.Equal(t, "refresh", token.Scope)
|
require.Equal(t, "refresh", token.Scope)
|
||||||
|
tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateRefresh_NoDB(t *testing.T) {
|
func TestValidateRefresh_NoDB(t *testing.T) {
|
||||||
@@ -82,7 +93,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
|
|||||||
tokenStr, _, err := gen.NewRefresh(42, false)
|
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
token, err := gen.ValidateRefresh(context.Background(), tokenStr)
|
token, err := gen.ValidateRefresh(nil, tokenStr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 42, token.SUB)
|
require.Equal(t, 42, token.SUB)
|
||||||
require.Equal(t, "refresh", token.Scope)
|
require.Equal(t, "refresh", token.Scope)
|
||||||
@@ -91,7 +102,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
|
|||||||
func TestValidateAccess_EmptyToken(t *testing.T) {
|
func TestValidateAccess_EmptyToken(t *testing.T) {
|
||||||
gen := newTestGenerator(t)
|
gen := newTestGenerator(t)
|
||||||
|
|
||||||
_, err := gen.ValidateAccess(context.Background(), "")
|
_, err := gen.ValidateAccess(nil, "")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,6 +113,6 @@ func TestValidateRefresh_WrongScope(t *testing.T) {
|
|||||||
tokenStr, _, err := gen.NewAccess(1, false, false)
|
tokenStr, _, err := gen.NewAccess(1, false, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = gen.ValidateRefresh(context.Background(), tokenStr)
|
_, err = gen.ValidateRefresh(nil, tokenStr)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user