refactor: changed file structure
This commit is contained in:
84
pkg/jwt/create.go
Normal file
84
pkg/jwt/create.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/pkg/config"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Generates an access token for the provided user
|
||||
func GenerateAccessToken(
|
||||
config *config.Config,
|
||||
user *models.User,
|
||||
fresh bool,
|
||||
rememberMe bool,
|
||||
) (tokenStr string, exp int64, err error) {
|
||||
issuedAt := time.Now().Unix()
|
||||
expiresAt := issuedAt + (config.AccessTokenExpiry * 60)
|
||||
var freshExpiresAt int64
|
||||
if fresh {
|
||||
freshExpiresAt = issuedAt + (config.TokenFreshTime * 60)
|
||||
} else {
|
||||
freshExpiresAt = issuedAt
|
||||
}
|
||||
var ttl string
|
||||
if rememberMe {
|
||||
ttl = "exp"
|
||||
} else {
|
||||
ttl = "session"
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||
jwt.MapClaims{
|
||||
"iss": config.TrustedHost,
|
||||
"scope": "access",
|
||||
"ttl": ttl,
|
||||
"jti": uuid.New(),
|
||||
"iat": issuedAt,
|
||||
"exp": expiresAt,
|
||||
"fresh": freshExpiresAt,
|
||||
"sub": user.ID,
|
||||
})
|
||||
|
||||
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||
}
|
||||
return signedToken, expiresAt, nil
|
||||
}
|
||||
|
||||
// Generates a refresh token for the provided user
|
||||
func GenerateRefreshToken(
|
||||
config *config.Config,
|
||||
user *models.User,
|
||||
rememberMe bool,
|
||||
) (tokenStr string, exp int64, err error) {
|
||||
issuedAt := time.Now().Unix()
|
||||
expiresAt := issuedAt + (config.RefreshTokenExpiry * 60)
|
||||
var ttl string
|
||||
if rememberMe {
|
||||
ttl = "exp"
|
||||
} else {
|
||||
ttl = "session"
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||
jwt.MapClaims{
|
||||
"iss": config.TrustedHost,
|
||||
"scope": "refresh",
|
||||
"ttl": ttl,
|
||||
"jti": uuid.New(),
|
||||
"iat": issuedAt,
|
||||
"exp": expiresAt,
|
||||
"sub": user.ID,
|
||||
})
|
||||
|
||||
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||
}
|
||||
return signedToken, expiresAt, nil
|
||||
}
|
||||
268
pkg/jwt/parse.go
Normal file
268
pkg/jwt/parse.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/db"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Parse an access token and return a struct with all the claims. Does validation on
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
func ParseAccessToken(
|
||||
config *config.Config,
|
||||
ctx context.Context,
|
||||
tx db.SafeTX,
|
||||
tokenString string,
|
||||
) (*AccessToken, error) {
|
||||
if tokenString == "" {
|
||||
return nil, errors.New("Access token string not provided")
|
||||
}
|
||||
claims, err := parseToken(config.SecretKey, tokenString)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parseToken")
|
||||
}
|
||||
expiry, err := checkTokenExpired(claims["exp"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||
}
|
||||
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||
}
|
||||
ttl, err := getTokenTTL(claims["ttl"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenTTL")
|
||||
}
|
||||
scope, err := getTokenScope(claims["scope"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenScope")
|
||||
}
|
||||
if scope != "access" {
|
||||
return nil, errors.New("Token is not an Access token")
|
||||
}
|
||||
issuedAt, err := getIssuedTime(claims["iat"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getIssuedTime")
|
||||
}
|
||||
subject, err := getTokenSubject(claims["sub"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenSubject")
|
||||
}
|
||||
fresh, err := getFreshTime(claims["fresh"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getFreshTime")
|
||||
}
|
||||
jti, err := getTokenJTI(claims["jti"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenJTI")
|
||||
}
|
||||
|
||||
token := &AccessToken{
|
||||
ISS: issuer,
|
||||
TTL: ttl,
|
||||
EXP: expiry,
|
||||
IAT: issuedAt,
|
||||
SUB: subject,
|
||||
Fresh: fresh,
|
||||
JTI: jti,
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Parse a refresh token and return a struct with all the claims. Does validation on
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
func ParseRefreshToken(
|
||||
config *config.Config,
|
||||
ctx context.Context,
|
||||
tx db.SafeTX,
|
||||
tokenString string,
|
||||
) (*RefreshToken, error) {
|
||||
if tokenString == "" {
|
||||
return nil, errors.New("Refresh token string not provided")
|
||||
}
|
||||
claims, err := parseToken(config.SecretKey, tokenString)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parseToken")
|
||||
}
|
||||
expiry, err := checkTokenExpired(claims["exp"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||
}
|
||||
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||
}
|
||||
ttl, err := getTokenTTL(claims["ttl"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenTTL")
|
||||
}
|
||||
scope, err := getTokenScope(claims["scope"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenScope")
|
||||
}
|
||||
if scope != "refresh" {
|
||||
return nil, errors.New("Token is not an Refresh token")
|
||||
}
|
||||
issuedAt, err := getIssuedTime(claims["iat"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getIssuedTime")
|
||||
}
|
||||
subject, err := getTokenSubject(claims["sub"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenSubject")
|
||||
}
|
||||
jti, err := getTokenJTI(claims["jti"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenJTI")
|
||||
}
|
||||
|
||||
token := &RefreshToken{
|
||||
ISS: issuer,
|
||||
TTL: ttl,
|
||||
EXP: expiry,
|
||||
IAT: issuedAt,
|
||||
SUB: subject,
|
||||
JTI: jti,
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Parse a token, validating its signing sigature and returning the claims
|
||||
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "jwt.Parse")
|
||||
}
|
||||
// Token decoded, parse the claims
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("Failed to parse claims")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Check if a token is expired. Returns the expiry if not expired
|
||||
func checkTokenExpired(expiry interface{}) (int64, error) {
|
||||
// Coerce the expiry to a float64 to avoid scientific notation
|
||||
expFloat, ok := expiry.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'exp' claim")
|
||||
}
|
||||
// Convert to the int64 time we expect :)
|
||||
expiryTime := int64(expFloat)
|
||||
|
||||
// Check if its expired
|
||||
isExpired := time.Now().After(time.Unix(expiryTime, 0))
|
||||
if isExpired {
|
||||
return 0, errors.New("Token has expired")
|
||||
}
|
||||
return expiryTime, nil
|
||||
}
|
||||
|
||||
// Check if a token has a valid issuer. Returns the issuer if valid
|
||||
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
|
||||
issuerVal, ok := issuer.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing or invalid 'iss' claim")
|
||||
}
|
||||
if issuer != trustedHost {
|
||||
return "", errors.New("Issuer does not matched trusted host")
|
||||
}
|
||||
return issuerVal, nil
|
||||
}
|
||||
|
||||
// Check the scope matches the expected scope. Returns scope if true
|
||||
func getTokenScope(scope interface{}) (string, error) {
|
||||
scopeStr, ok := scope.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing or invalid 'scope' claim")
|
||||
}
|
||||
return scopeStr, nil
|
||||
}
|
||||
|
||||
// Get the TTL of the token, either "session" or "exp"
|
||||
func getTokenTTL(ttl interface{}) (string, error) {
|
||||
ttlStr, ok := ttl.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing or invalid 'ttl' claim")
|
||||
}
|
||||
if ttlStr != "exp" && ttlStr != "session" {
|
||||
return "", errors.New("TTL value is not recognised")
|
||||
}
|
||||
return ttlStr, nil
|
||||
}
|
||||
|
||||
// Get the time the token was issued at
|
||||
func getIssuedTime(issued interface{}) (int64, error) {
|
||||
// Same float64 -> int64 trick as expiry
|
||||
issuedFloat, ok := issued.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'iat' claim")
|
||||
}
|
||||
issuedAt := int64(issuedFloat)
|
||||
return issuedAt, nil
|
||||
}
|
||||
|
||||
// Get the freshness expiry timestamp
|
||||
func getFreshTime(fresh interface{}) (int64, error) {
|
||||
freshUntil, ok := fresh.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'fresh' claim")
|
||||
}
|
||||
return int64(freshUntil), nil
|
||||
}
|
||||
|
||||
// Get the subject of the token
|
||||
func getTokenSubject(sub interface{}) (int, error) {
|
||||
subject, ok := sub.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'sub' claim")
|
||||
}
|
||||
return int(subject), nil
|
||||
}
|
||||
|
||||
// Get the JTI of the token
|
||||
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
|
||||
jtiStr, ok := jti.(string)
|
||||
if !ok {
|
||||
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
|
||||
}
|
||||
jtiUUID, err := uuid.Parse(jtiStr)
|
||||
if err != nil {
|
||||
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
|
||||
}
|
||||
return jtiUUID, nil
|
||||
}
|
||||
33
pkg/jwt/revoke.go
Normal file
33
pkg/jwt/revoke.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"projectreshoot/pkg/db"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Revoke a token by adding it to the database
|
||||
func RevokeToken(ctx context.Context, tx *db.SafeWTX, t Token) error {
|
||||
jti := t.GetJTI()
|
||||
exp := t.GetEXP()
|
||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||
_, err := tx.Exec(ctx, query, jti, exp)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if a token has been revoked. Returns true if not revoked.
|
||||
func CheckTokenNotRevoked(ctx context.Context, tx db.SafeTX, t Token) (bool, error) {
|
||||
jti := t.GetJTI()
|
||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||
rows, err := tx.Query(ctx, query, jti)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
revoked := rows.Next()
|
||||
return !revoked, nil
|
||||
}
|
||||
74
pkg/jwt/tokens.go
Normal file
74
pkg/jwt/tokens.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/pkg/db"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Token interface {
|
||||
GetJTI() uuid.UUID
|
||||
GetEXP() int64
|
||||
GetScope() string
|
||||
GetUser(ctx context.Context, tx db.SafeTX) (*models.User, error)
|
||||
}
|
||||
|
||||
// Access token
|
||||
type AccessToken struct {
|
||||
ISS string // Issuer, generally TrustedHost
|
||||
IAT int64 // Time issued at
|
||||
EXP int64 // Time expiring at
|
||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||
SUB int // Subject (user) ID
|
||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||
Fresh int64 // Time freshness expiring at
|
||||
Scope string // Should be "access"
|
||||
}
|
||||
|
||||
// Refresh token
|
||||
type RefreshToken struct {
|
||||
ISS string // Issuer, generally TrustedHost
|
||||
IAT int64 // Time issued at
|
||||
EXP int64 // Time expiring at
|
||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||
SUB int // Subject (user) ID
|
||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||
Scope string // Should be "refresh"
|
||||
}
|
||||
|
||||
func (a AccessToken) GetUser(ctx context.Context, tx db.SafeTX) (*models.User, error) {
|
||||
user, err := models.GetUserFromID(ctx, tx, a.SUB)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
func (r RefreshToken) GetUser(ctx context.Context, tx db.SafeTX) (*models.User, error) {
|
||||
user, err := models.GetUserFromID(ctx, tx, r.SUB)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (a AccessToken) GetJTI() uuid.UUID {
|
||||
return a.JTI
|
||||
}
|
||||
func (r RefreshToken) GetJTI() uuid.UUID {
|
||||
return r.JTI
|
||||
}
|
||||
func (a AccessToken) GetEXP() int64 {
|
||||
return a.EXP
|
||||
}
|
||||
func (r RefreshToken) GetEXP() int64 {
|
||||
return r.EXP
|
||||
}
|
||||
func (a AccessToken) GetScope() string {
|
||||
return a.Scope
|
||||
}
|
||||
func (r RefreshToken) GetScope() string {
|
||||
return r.Scope
|
||||
}
|
||||
Reference in New Issue
Block a user