Added authentication middleware

This commit is contained in:
2025-02-11 23:46:25 +11:00
parent 97aabcf06f
commit 732f8510ae
12 changed files with 208 additions and 21 deletions

View File

@@ -1,6 +1,7 @@
package jwt
import (
"database/sql"
"fmt"
"time"
@@ -14,7 +15,11 @@ import (
// Parse an access token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, error) {
func ParseAccessToken(
config *config.Config,
conn *sql.DB,
tokenString string,
) (AccessToken, error) {
claims, err := parseToken(config.SecretKey, tokenString)
if err != nil {
return AccessToken{}, errors.Wrap(err, "parseToken")
@@ -66,13 +71,21 @@ func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, e
Scope: scope,
}
valid, err := CheckTokenNotRevoked(conn, token)
if err != nil || !valid {
return AccessToken{}, errors.Wrap(err, "CheckTokenNotRevoked")
}
return token, nil
}
// Parse a refresh token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken, error) {
func ParseRefreshToken(
config *config.Config,
conn *sql.DB,
tokenString string,
) (RefreshToken, error) {
claims, err := parseToken(config.SecretKey, tokenString)
if err != nil {
return RefreshToken{}, errors.Wrap(err, "parseToken")
@@ -119,6 +132,13 @@ func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken,
Scope: scope,
}
valid, err := CheckTokenNotRevoked(conn, token)
if err != nil {
return RefreshToken{}, errors.Wrap(err, "CheckTokenNotRevoked")
}
if !valid {
return RefreshToken{}, errors.New("Token has been revoked")
}
return token, nil
}

View File

@@ -18,8 +18,8 @@ func RevokeToken(conn *sql.DB, t Token) error {
return nil
}
// Check if a token has been revoked
func CheckRevoked(conn *sql.DB, t Token) (bool, error) {
// Check if a token has been revoked. Returns true if not revoked.
func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := conn.Query(query, jti)
@@ -27,5 +27,5 @@ func CheckRevoked(conn *sql.DB, t Token) (bool, error) {
return false, errors.Wrap(err, "conn.Exec")
}
revoked := rows.Next()
return revoked, nil
return !revoked, nil
}

View File

@@ -1,11 +1,18 @@
package jwt
import "github.com/google/uuid"
import (
"database/sql"
"projectreshoot/db"
"github.com/google/uuid"
"github.com/pkg/errors"
)
type Token interface {
GetJTI() uuid.UUID
GetEXP() int64
GetScope() string
GetUser(conn *sql.DB) (*db.User, error)
}
// Access token
@@ -31,6 +38,21 @@ type RefreshToken struct {
Scope string // Should be "refresh"
}
func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) {
user, err := db.GetUserFromID(conn, a.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return &user, nil
}
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) {
user, err := db.GetUserFromID(conn, r.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return &user, nil
}
func (a AccessToken) GetJTI() uuid.UUID {
return a.JTI
}