Update authentication, reauth, logout to use new transactions

This commit is contained in:
2025-02-17 18:58:34 +11:00
parent 417daf0028
commit 2c61cec55c
17 changed files with 306 additions and 121 deletions

View File

@@ -1,11 +1,12 @@
package jwt
import (
"database/sql"
"context"
"fmt"
"time"
"projectreshoot/config"
"projectreshoot/db"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
@@ -17,7 +18,8 @@ import (
// has the correct scope.
func ParseAccessToken(
config *config.Config,
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
tokenString string,
) (*AccessToken, error) {
if tokenString == "" {
@@ -74,7 +76,7 @@ func ParseAccessToken(
Scope: scope,
}
valid, err := CheckTokenNotRevoked(conn, token)
valid, err := CheckTokenNotRevoked(ctx, tx, token)
if err != nil {
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
}
@@ -89,7 +91,8 @@ func ParseAccessToken(
// has the correct scope.
func ParseRefreshToken(
config *config.Config,
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
tokenString string,
) (*RefreshToken, error) {
if tokenString == "" {
@@ -141,7 +144,7 @@ func ParseRefreshToken(
Scope: scope,
}
valid, err := CheckTokenNotRevoked(conn, token)
valid, err := CheckTokenNotRevoked(ctx, tx, token)
if err != nil {
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
}

View File

@@ -1,32 +1,34 @@
package jwt
import (
"database/sql"
"context"
"projectreshoot/db"
"github.com/pkg/errors"
)
// Revoke a token by adding it to the database
func RevokeToken(conn *sql.DB, t Token) error {
func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error {
jti := t.GetJTI()
exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err := conn.Exec(query, jti, exp)
_, err := tx.Exec(ctx, query, jti, exp)
if err != nil {
return errors.Wrap(err, "conn.Exec")
return errors.Wrap(err, "tx.Exec")
}
return nil
}
// Check if a token has been revoked. Returns true if not revoked.
func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) {
jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := conn.Query(query, jti)
defer rows.Close()
rows, err := tx.Query(ctx, query, jti)
if err != nil {
return false, errors.Wrap(err, "conn.Exec")
return false, errors.Wrap(err, "tx.Query")
}
// NOTE: rows.Close()
defer rows.Close()
revoked := rows.Next()
return !revoked, nil
}

View File

@@ -1,7 +1,7 @@
package jwt
import (
"database/sql"
"context"
"projectreshoot/db"
"github.com/google/uuid"
@@ -12,7 +12,7 @@ type Token interface {
GetJTI() uuid.UUID
GetEXP() int64
GetScope() string
GetUser(conn *sql.DB) (*db.User, error)
GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error)
}
// Access token
@@ -38,15 +38,15 @@ type RefreshToken struct {
Scope string // Should be "refresh"
}
func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) {
user, err := db.GetUserFromID(conn, a.SUB)
func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
user, err := db.GetUserFromID(ctx, tx, a.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return user, nil
}
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) {
user, err := db.GetUserFromID(conn, r.SUB)
func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
user, err := db.GetUserFromID(ctx, tx, r.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}