Added parsing of tokens
This commit is contained in:
1
go.mod
1
go.mod
@@ -5,6 +5,7 @@ go 1.23.5
|
|||||||
require (
|
require (
|
||||||
github.com/a-h/templ v0.3.833
|
github.com/a-h/templ v0.3.833
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/rs/zerolog v1.33.0
|
github.com/rs/zerolog v1.33.0
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -10,6 +10,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
|
|||||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ import (
|
|||||||
"projectreshoot/server"
|
"projectreshoot/server"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Generates an access token for the provided user, using the variables set
|
// Generates an access token for the provided user
|
||||||
// in the config object
|
|
||||||
func GenerateAccessToken(
|
func GenerateAccessToken(
|
||||||
config *server.Config,
|
config *server.Config,
|
||||||
user *db.User,
|
user *db.User,
|
||||||
@@ -28,11 +28,36 @@ func GenerateAccessToken(
|
|||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||||
jwt.MapClaims{
|
jwt.MapClaims{
|
||||||
"iss": config.TrustedHost,
|
"iss": config.TrustedHost,
|
||||||
"sub": user.ID,
|
"scope": "access",
|
||||||
"aud": config.TrustedHost,
|
|
||||||
"iat": issuedAt,
|
"iat": issuedAt,
|
||||||
"exp": expiresAt,
|
"exp": expiresAt,
|
||||||
"fresh": freshExpiresAt,
|
"fresh": freshExpiresAt,
|
||||||
|
"sub": user.ID,
|
||||||
|
"roles": []string{"user", "admin"}, // TODO: add user roles
|
||||||
|
})
|
||||||
|
|
||||||
|
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "token.SignedString")
|
||||||
|
}
|
||||||
|
return signedToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates a refresh token for the provided user
|
||||||
|
func GenerateRefreshToken(
|
||||||
|
config *server.Config,
|
||||||
|
user *db.User,
|
||||||
|
) (string, error) {
|
||||||
|
issuedAt := time.Now().Unix()
|
||||||
|
expiresAt := issuedAt + (config.RefreshTokenExpiry * 60)
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||||
|
jwt.MapClaims{
|
||||||
|
"iss": config.TrustedHost,
|
||||||
|
"scope": "refresh",
|
||||||
|
"jti": uuid.New(),
|
||||||
|
"iat": issuedAt,
|
||||||
|
"exp": expiresAt,
|
||||||
|
"sub": user.ID,
|
||||||
})
|
})
|
||||||
|
|
||||||
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
||||||
231
jwt/parse.go
Normal file
231
jwt/parse.go
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"projectreshoot/server"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse an access token and return a struct with all the claims. Does validation on
|
||||||
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
|
// has the correct scope.
|
||||||
|
func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, error) {
|
||||||
|
claims, err := parseToken(config.SecretKey, tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Wrap(err, "parseToken")
|
||||||
|
}
|
||||||
|
expiry, err := checkTokenExpired(claims["exp"])
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Wrap(err, "checkTokenExpired")
|
||||||
|
}
|
||||||
|
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Wrap(err, "checkTokenIssuer")
|
||||||
|
}
|
||||||
|
scope, err := getTokenScope(claims["scope"])
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Wrap(err, "getTokenScope")
|
||||||
|
}
|
||||||
|
if scope != "access" {
|
||||||
|
return AccessToken{}, errors.New("Token is not an Access token")
|
||||||
|
}
|
||||||
|
issuedAt, err := getIssuedTime(claims["iat"])
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Wrap(err, "getIssuedTime")
|
||||||
|
}
|
||||||
|
subject, err := getTokenSubject(claims["sub"])
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Wrap(err, "getTokenSubject")
|
||||||
|
}
|
||||||
|
fresh, err := getFreshTime(claims["fresh"])
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Wrap(err, "getFreshTime")
|
||||||
|
}
|
||||||
|
roles, err := getTokenRoles(claims["roles"])
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Wrap(err, "getTokenRoles")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := AccessToken{
|
||||||
|
ISS: issuer,
|
||||||
|
Scope: scope,
|
||||||
|
EXP: expiry,
|
||||||
|
IAT: issuedAt,
|
||||||
|
SUB: subject,
|
||||||
|
Fresh: fresh,
|
||||||
|
Roles: roles,
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse a refresh token and return a struct with all the claims. Does validation on
|
||||||
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
|
// has the correct scope.
|
||||||
|
func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken, error) {
|
||||||
|
claims, err := parseToken(config.SecretKey, tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return RefreshToken{}, errors.Wrap(err, "parseToken")
|
||||||
|
}
|
||||||
|
expiry, err := checkTokenExpired(claims["exp"])
|
||||||
|
if err != nil {
|
||||||
|
return RefreshToken{}, errors.Wrap(err, "checkTokenExpired")
|
||||||
|
}
|
||||||
|
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
||||||
|
if err != nil {
|
||||||
|
return RefreshToken{}, errors.Wrap(err, "checkTokenIssuer")
|
||||||
|
}
|
||||||
|
scope, err := getTokenScope(claims["scope"])
|
||||||
|
if err != nil {
|
||||||
|
return RefreshToken{}, errors.Wrap(err, "getTokenScope")
|
||||||
|
}
|
||||||
|
if scope != "refresh" {
|
||||||
|
return RefreshToken{}, errors.New("Token is not an Refresh token")
|
||||||
|
}
|
||||||
|
issuedAt, err := getIssuedTime(claims["iat"])
|
||||||
|
if err != nil {
|
||||||
|
return RefreshToken{}, errors.Wrap(err, "getIssuedTime")
|
||||||
|
}
|
||||||
|
subject, err := getTokenSubject(claims["sub"])
|
||||||
|
if err != nil {
|
||||||
|
return RefreshToken{}, errors.Wrap(err, "getTokenSubject")
|
||||||
|
}
|
||||||
|
jti, err := getTokenJTI(claims["jti"])
|
||||||
|
if err != nil {
|
||||||
|
return RefreshToken{}, errors.Wrap(err, "getTokenJTI")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := RefreshToken{
|
||||||
|
ISS: issuer,
|
||||||
|
Scope: scope,
|
||||||
|
EXP: expiry,
|
||||||
|
IAT: issuedAt,
|
||||||
|
SUB: subject,
|
||||||
|
JTI: jti,
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse a token, validating its signing sigature and returning the claims
|
||||||
|
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
|
||||||
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(secretKey), nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "jwt.Parse")
|
||||||
|
}
|
||||||
|
// Token decoded, parse the claims
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("Failed to parse claims")
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if a token is expired. Returns the expiry if not expired
|
||||||
|
func checkTokenExpired(expiry interface{}) (int64, error) {
|
||||||
|
// Coerce the expiry to a float64 to avoid scientific notation
|
||||||
|
expFloat, ok := expiry.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'exp' claim")
|
||||||
|
}
|
||||||
|
// Convert to the int64 time we expect :)
|
||||||
|
expiryTime := int64(expFloat)
|
||||||
|
|
||||||
|
// Check if its expired
|
||||||
|
isExpired := time.Now().After(time.Unix(expiryTime, 0))
|
||||||
|
if isExpired {
|
||||||
|
return 0, errors.New("Token has expired")
|
||||||
|
}
|
||||||
|
return expiryTime, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if a token has a valid issuer. Returns the issuer if valid
|
||||||
|
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
|
||||||
|
// Check issuer
|
||||||
|
issuerVal, ok := issuer.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("Missing or invalid 'iss' claim")
|
||||||
|
}
|
||||||
|
if issuer != trustedHost {
|
||||||
|
return "", errors.New("Issuer does not matched trusted host")
|
||||||
|
}
|
||||||
|
return issuerVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the scope matches the expected scope. Returns scope if true
|
||||||
|
func getTokenScope(scope interface{}) (string, error) {
|
||||||
|
scopeStr, ok := scope.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("Missing or invalid 'scope' claim")
|
||||||
|
}
|
||||||
|
return scopeStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the time the token was issued at
|
||||||
|
func getIssuedTime(issued interface{}) (int64, error) {
|
||||||
|
// Same float64 -> int64 trick as expiry
|
||||||
|
issuedFloat, ok := issued.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'iat' claim")
|
||||||
|
}
|
||||||
|
issuedAt := int64(issuedFloat)
|
||||||
|
return issuedAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the freshness expiry timestamp
|
||||||
|
func getFreshTime(fresh interface{}) (int64, error) {
|
||||||
|
freshUntil, ok := fresh.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'fresh' claim")
|
||||||
|
}
|
||||||
|
return int64(freshUntil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the subject of the token
|
||||||
|
func getTokenSubject(sub interface{}) (int, error) {
|
||||||
|
subject, ok := sub.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'sub' claim")
|
||||||
|
}
|
||||||
|
return int(subject), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the roles of the token subject
|
||||||
|
func getTokenRoles(roles interface{}) ([]string, error) {
|
||||||
|
rolesIfSlice, ok := roles.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("Missing or invalid 'roles' claim")
|
||||||
|
}
|
||||||
|
rolesSlice := []string{}
|
||||||
|
for _, role := range rolesIfSlice {
|
||||||
|
if str, ok := role.(string); ok {
|
||||||
|
rolesSlice = append(rolesSlice, str)
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("Malformed 'roles' claim")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rolesSlice, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the JTI of the token
|
||||||
|
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
|
||||||
|
jtiStr, ok := jti.(string)
|
||||||
|
if !ok {
|
||||||
|
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
|
||||||
|
}
|
||||||
|
jtiUUID, err := uuid.Parse(jtiStr)
|
||||||
|
if err != nil {
|
||||||
|
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
|
||||||
|
}
|
||||||
|
return jtiUUID, nil
|
||||||
|
}
|
||||||
24
jwt/tokens.go
Normal file
24
jwt/tokens.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import "github.com/google/uuid"
|
||||||
|
|
||||||
|
// Access token
|
||||||
|
type AccessToken struct {
|
||||||
|
ISS string
|
||||||
|
Scope string
|
||||||
|
IAT int64
|
||||||
|
EXP int64
|
||||||
|
SUB int
|
||||||
|
Fresh int64
|
||||||
|
Roles []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh token
|
||||||
|
type RefreshToken struct {
|
||||||
|
ISS string
|
||||||
|
Scope string
|
||||||
|
IAT int64
|
||||||
|
EXP int64
|
||||||
|
SUB int
|
||||||
|
JTI uuid.UUID
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user