diff --git a/go.mod b/go.mod index 7e6b658..380a9c2 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.5 require ( github.com/a-h/templ v0.3.833 github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.33.0 diff --git a/go.sum b/go.sum index aa1501a..1ee3159 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= diff --git a/jwt/createtoken.go b/jwt/create.go similarity index 51% rename from jwt/createtoken.go rename to jwt/create.go index ed43146..7ba7688 100644 --- a/jwt/createtoken.go +++ b/jwt/create.go @@ -7,11 +7,11 @@ import ( "projectreshoot/server" "github.com/golang-jwt/jwt" + "github.com/google/uuid" "github.com/pkg/errors" ) -// Generates an access token for the provided user, using the variables set -// in the config object +// Generates an access token for the provided user func GenerateAccessToken( config *server.Config, user *db.User, @@ -28,11 +28,36 @@ func GenerateAccessToken( token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "iss": config.TrustedHost, - "sub": user.ID, - "aud": config.TrustedHost, + "scope": "access", "iat": issuedAt, "exp": expiresAt, "fresh": freshExpiresAt, + "sub": user.ID, + "roles": []string{"user", "admin"}, // TODO: add user roles + }) + + signedToken, err := token.SignedString([]byte(config.SecretKey)) + if err != nil { + return "", errors.Wrap(err, "token.SignedString") + } + return signedToken, nil +} + +// Generates a refresh token for the provided user +func GenerateRefreshToken( + config *server.Config, + user *db.User, +) (string, error) { + issuedAt := time.Now().Unix() + expiresAt := issuedAt + (config.RefreshTokenExpiry * 60) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, + jwt.MapClaims{ + "iss": config.TrustedHost, + "scope": "refresh", + "jti": uuid.New(), + "iat": issuedAt, + "exp": expiresAt, + "sub": user.ID, }) signedToken, err := token.SignedString([]byte(config.SecretKey)) diff --git a/jwt/parse.go b/jwt/parse.go new file mode 100644 index 0000000..7ee0ea5 --- /dev/null +++ b/jwt/parse.go @@ -0,0 +1,231 @@ +package jwt + +import ( + "fmt" + "projectreshoot/server" + "time" + + "github.com/golang-jwt/jwt" + "github.com/google/uuid" + "github.com/pkg/errors" +) + +// Parse an access token and return a struct with all the claims. Does validation on +// all the claims, including checking if it is expired, has a valid issuer, and +// has the correct scope. +func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, error) { + claims, err := parseToken(config.SecretKey, tokenString) + if err != nil { + return AccessToken{}, errors.Wrap(err, "parseToken") + } + expiry, err := checkTokenExpired(claims["exp"]) + if err != nil { + return AccessToken{}, errors.Wrap(err, "checkTokenExpired") + } + issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"]) + if err != nil { + return AccessToken{}, errors.Wrap(err, "checkTokenIssuer") + } + scope, err := getTokenScope(claims["scope"]) + if err != nil { + return AccessToken{}, errors.Wrap(err, "getTokenScope") + } + if scope != "access" { + return AccessToken{}, errors.New("Token is not an Access token") + } + issuedAt, err := getIssuedTime(claims["iat"]) + if err != nil { + return AccessToken{}, errors.Wrap(err, "getIssuedTime") + } + subject, err := getTokenSubject(claims["sub"]) + if err != nil { + return AccessToken{}, errors.Wrap(err, "getTokenSubject") + } + fresh, err := getFreshTime(claims["fresh"]) + if err != nil { + return AccessToken{}, errors.Wrap(err, "getFreshTime") + } + roles, err := getTokenRoles(claims["roles"]) + if err != nil { + return AccessToken{}, errors.Wrap(err, "getTokenRoles") + } + + token := AccessToken{ + ISS: issuer, + Scope: scope, + EXP: expiry, + IAT: issuedAt, + SUB: subject, + Fresh: fresh, + Roles: roles, + } + + return token, nil +} + +// Parse a refresh token and return a struct with all the claims. Does validation on +// all the claims, including checking if it is expired, has a valid issuer, and +// has the correct scope. +func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken, error) { + claims, err := parseToken(config.SecretKey, tokenString) + if err != nil { + return RefreshToken{}, errors.Wrap(err, "parseToken") + } + expiry, err := checkTokenExpired(claims["exp"]) + if err != nil { + return RefreshToken{}, errors.Wrap(err, "checkTokenExpired") + } + issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"]) + if err != nil { + return RefreshToken{}, errors.Wrap(err, "checkTokenIssuer") + } + scope, err := getTokenScope(claims["scope"]) + if err != nil { + return RefreshToken{}, errors.Wrap(err, "getTokenScope") + } + if scope != "refresh" { + return RefreshToken{}, errors.New("Token is not an Refresh token") + } + issuedAt, err := getIssuedTime(claims["iat"]) + if err != nil { + return RefreshToken{}, errors.Wrap(err, "getIssuedTime") + } + subject, err := getTokenSubject(claims["sub"]) + if err != nil { + return RefreshToken{}, errors.Wrap(err, "getTokenSubject") + } + jti, err := getTokenJTI(claims["jti"]) + if err != nil { + return RefreshToken{}, errors.Wrap(err, "getTokenJTI") + } + + token := RefreshToken{ + ISS: issuer, + Scope: scope, + EXP: expiry, + IAT: issuedAt, + SUB: subject, + JTI: jti, + } + + return token, nil +} + +// Parse a token, validating its signing sigature and returning the claims +func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(secretKey), nil + }) + if err != nil { + return nil, errors.Wrap(err, "jwt.Parse") + } + // Token decoded, parse the claims + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("Failed to parse claims") + } + return claims, nil +} + +// Check if a token is expired. Returns the expiry if not expired +func checkTokenExpired(expiry interface{}) (int64, error) { + // Coerce the expiry to a float64 to avoid scientific notation + expFloat, ok := expiry.(float64) + if !ok { + return 0, errors.New("Missing or invalid 'exp' claim") + } + // Convert to the int64 time we expect :) + expiryTime := int64(expFloat) + + // Check if its expired + isExpired := time.Now().After(time.Unix(expiryTime, 0)) + if isExpired { + return 0, errors.New("Token has expired") + } + return expiryTime, nil +} + +// Check if a token has a valid issuer. Returns the issuer if valid +func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) { + // Check issuer + issuerVal, ok := issuer.(string) + if !ok { + return "", errors.New("Missing or invalid 'iss' claim") + } + if issuer != trustedHost { + return "", errors.New("Issuer does not matched trusted host") + } + return issuerVal, nil +} + +// Check the scope matches the expected scope. Returns scope if true +func getTokenScope(scope interface{}) (string, error) { + scopeStr, ok := scope.(string) + if !ok { + return "", errors.New("Missing or invalid 'scope' claim") + } + return scopeStr, nil +} + +// Get the time the token was issued at +func getIssuedTime(issued interface{}) (int64, error) { + // Same float64 -> int64 trick as expiry + issuedFloat, ok := issued.(float64) + if !ok { + return 0, errors.New("Missing or invalid 'iat' claim") + } + issuedAt := int64(issuedFloat) + return issuedAt, nil +} + +// Get the freshness expiry timestamp +func getFreshTime(fresh interface{}) (int64, error) { + freshUntil, ok := fresh.(float64) + if !ok { + return 0, errors.New("Missing or invalid 'fresh' claim") + } + return int64(freshUntil), nil +} + +// Get the subject of the token +func getTokenSubject(sub interface{}) (int, error) { + subject, ok := sub.(float64) + if !ok { + return 0, errors.New("Missing or invalid 'sub' claim") + } + return int(subject), nil +} + +// Get the roles of the token subject +func getTokenRoles(roles interface{}) ([]string, error) { + rolesIfSlice, ok := roles.([]interface{}) + if !ok { + return nil, errors.New("Missing or invalid 'roles' claim") + } + rolesSlice := []string{} + for _, role := range rolesIfSlice { + if str, ok := role.(string); ok { + rolesSlice = append(rolesSlice, str) + } else { + return nil, errors.New("Malformed 'roles' claim") + } + } + return rolesSlice, nil +} + +// Get the JTI of the token +func getTokenJTI(jti interface{}) (uuid.UUID, error) { + jtiStr, ok := jti.(string) + if !ok { + return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim") + } + jtiUUID, err := uuid.Parse(jtiStr) + if err != nil { + return uuid.UUID{}, errors.New("JTI is not a valid UUID") + } + return jtiUUID, nil +} diff --git a/jwt/tokens.go b/jwt/tokens.go new file mode 100644 index 0000000..9674cad --- /dev/null +++ b/jwt/tokens.go @@ -0,0 +1,24 @@ +package jwt + +import "github.com/google/uuid" + +// Access token +type AccessToken struct { + ISS string + Scope string + IAT int64 + EXP int64 + SUB int + Fresh int64 + Roles []string +} + +// Refresh token +type RefreshToken struct { + ISS string + Scope string + IAT int64 + EXP int64 + SUB int + JTI uuid.UUID +}