diff --git a/jwt/create.go b/jwt/create.go index 7ba7688..44aab2c 100644 --- a/jwt/create.go +++ b/jwt/create.go @@ -16,6 +16,7 @@ func GenerateAccessToken( config *server.Config, user *db.User, fresh bool, + rememberMe bool, ) (string, error) { issuedAt := time.Now().Unix() expiresAt := issuedAt + (config.AccessTokenExpiry * 60) @@ -25,15 +26,21 @@ func GenerateAccessToken( } else { freshExpiresAt = issuedAt } + var ttl string + if rememberMe { + ttl = "exp" + } else { + ttl = "session" + } token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "iss": config.TrustedHost, "scope": "access", + "ttl": ttl, "iat": issuedAt, "exp": expiresAt, "fresh": freshExpiresAt, "sub": user.ID, - "roles": []string{"user", "admin"}, // TODO: add user roles }) signedToken, err := token.SignedString([]byte(config.SecretKey)) @@ -47,13 +54,21 @@ func GenerateAccessToken( func GenerateRefreshToken( config *server.Config, user *db.User, + rememberMe bool, ) (string, error) { issuedAt := time.Now().Unix() expiresAt := issuedAt + (config.RefreshTokenExpiry * 60) + var ttl string + if rememberMe { + ttl = "exp" + } else { + ttl = "session" + } token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "iss": config.TrustedHost, "scope": "refresh", + "ttl": ttl, "jti": uuid.New(), "iat": issuedAt, "exp": expiresAt, diff --git a/jwt/parse.go b/jwt/parse.go index 7ee0ea5..680d0e1 100644 --- a/jwt/parse.go +++ b/jwt/parse.go @@ -26,6 +26,10 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e if err != nil { return AccessToken{}, errors.Wrap(err, "checkTokenIssuer") } + ttl, err := getTokenTTL(claims["ttl"]) + if err != nil { + return AccessToken{}, errors.Wrap(err, "getTokenTTL") + } scope, err := getTokenScope(claims["scope"]) if err != nil { return AccessToken{}, errors.Wrap(err, "getTokenScope") @@ -45,19 +49,14 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e if err != nil { return AccessToken{}, errors.Wrap(err, "getFreshTime") } - roles, err := getTokenRoles(claims["roles"]) - if err != nil { - return AccessToken{}, errors.Wrap(err, "getTokenRoles") - } token := AccessToken{ ISS: issuer, - Scope: scope, + TTL: ttl, EXP: expiry, IAT: issuedAt, SUB: subject, Fresh: fresh, - Roles: roles, } return token, nil @@ -79,6 +78,10 @@ func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken, if err != nil { return RefreshToken{}, errors.Wrap(err, "checkTokenIssuer") } + ttl, err := getTokenTTL(claims["ttl"]) + if err != nil { + return RefreshToken{}, errors.Wrap(err, "getTokenTTL") + } scope, err := getTokenScope(claims["scope"]) if err != nil { return RefreshToken{}, errors.Wrap(err, "getTokenScope") @@ -100,12 +103,12 @@ func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken, } token := RefreshToken{ - ISS: issuer, - Scope: scope, - EXP: expiry, - IAT: issuedAt, - SUB: subject, - JTI: jti, + ISS: issuer, + TTL: ttl, + EXP: expiry, + IAT: issuedAt, + SUB: subject, + JTI: jti, } return token, nil @@ -151,7 +154,6 @@ func checkTokenExpired(expiry interface{}) (int64, error) { // Check if a token has a valid issuer. Returns the issuer if valid func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) { - // Check issuer issuerVal, ok := issuer.(string) if !ok { return "", errors.New("Missing or invalid 'iss' claim") @@ -171,6 +173,18 @@ func getTokenScope(scope interface{}) (string, error) { return scopeStr, nil } +// Get the TTL of the token, either "session" or "exp" +func getTokenTTL(ttl interface{}) (string, error) { + ttlStr, ok := ttl.(string) + if !ok { + return "", errors.New("Missing or invalid 'ttl' claim") + } + if ttlStr != "exp" && ttlStr != "session" { + return "", errors.New("TTL value is not recognised") + } + return ttlStr, nil +} + // Get the time the token was issued at func getIssuedTime(issued interface{}) (int64, error) { // Same float64 -> int64 trick as expiry @@ -200,23 +214,6 @@ func getTokenSubject(sub interface{}) (int, error) { return int(subject), nil } -// Get the roles of the token subject -func getTokenRoles(roles interface{}) ([]string, error) { - rolesIfSlice, ok := roles.([]interface{}) - if !ok { - return nil, errors.New("Missing or invalid 'roles' claim") - } - rolesSlice := []string{} - for _, role := range rolesIfSlice { - if str, ok := role.(string); ok { - rolesSlice = append(rolesSlice, str) - } else { - return nil, errors.New("Malformed 'roles' claim") - } - } - return rolesSlice, nil -} - // Get the JTI of the token func getTokenJTI(jti interface{}) (uuid.UUID, error) { jtiStr, ok := jti.(string) diff --git a/jwt/tokens.go b/jwt/tokens.go index 9674cad..707be85 100644 --- a/jwt/tokens.go +++ b/jwt/tokens.go @@ -4,21 +4,20 @@ import "github.com/google/uuid" // Access token type AccessToken struct { - ISS string - Scope string - IAT int64 - EXP int64 - SUB int - Fresh int64 - Roles []string + ISS string // Issuer, generally TrustedHost + IAT int64 // Time issued at + EXP int64 // Time expiring at + TTL string // Time-to-live: "session" or "exp". Used with 'remember me' + SUB int // Subject (user) ID + Fresh int64 // Time freshness expiring at } // Refresh token type RefreshToken struct { - ISS string - Scope string - IAT int64 - EXP int64 - SUB int - JTI uuid.UUID + ISS string // Issuer, generally TrustedHost + IAT int64 // Time issued at + EXP int64 // Time expiring at + TTL string // Time-to-live: "session" or "exp". Used with 'remember me' + SUB int // Subject (user) ID + JTI uuid.UUID // UUID-4 used for identifying blacklisted refresh tokens }