diff --git a/server/config.go b/config/config.go similarity index 96% rename from server/config.go rename to config/config.go index 8c52c62..1209fb6 100644 --- a/server/config.go +++ b/config/config.go @@ -1,9 +1,10 @@ -package server +package config import ( "errors" "fmt" "os" + "projectreshoot/logging" "github.com/joho/godotenv" @@ -14,6 +15,7 @@ type Config struct { Host string // Host to listen on Port string // Port to listen on TrustedHost string // Domain/Hostname to accept as trusted + SSL bool // Flag for SSL Mode TursoDBName string // DB Name for Turso DB/Branch TursoToken string // Bearer token for Turso DB/Branch SecretKey string // Secret key for signing tokens @@ -78,6 +80,7 @@ func GetConfig(args map[string]string) (*Config, error) { Host: host, Port: port, TrustedHost: os.Getenv("TRUSTED_HOST"), + SSL: GetEnvBool("SSL_MODE", false), TursoDBName: os.Getenv("TURSO_DB_NAME"), TursoToken: os.Getenv("TURSO_AUTH_TOKEN"), SecretKey: os.Getenv("SECRET_KEY"), diff --git a/config/environment.go b/config/environment.go new file mode 100644 index 0000000..a00cdc4 --- /dev/null +++ b/config/environment.go @@ -0,0 +1,61 @@ +package config + +import ( + "os" + "strconv" + "strings" +) + +// Get an environment variable, specifying a default value if its not set +func GetEnvDefault(key string, defaultValue string) string { + val, exists := os.LookupEnv(key) + if !exists { + return defaultValue + } + return val +} + +// Get an environment variable as an int64, specifying a default value if its +// not set or can't be parsed properly into an int64 +func GetEnvInt64(key string, defaultValue int64) int64 { + val, exists := os.LookupEnv(key) + if !exists { + return defaultValue + } + + intVal, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return defaultValue + } + return intVal + +} + +// Get an environment variable as a boolean, specifying a default value if its +// not set or can't be parsed properly into a bool +func GetEnvBool(key string, defaultValue bool) bool { + val, exists := os.LookupEnv(key) + if !exists { + return defaultValue + } + truthy := map[string]bool{ + "true": true, "t": true, "yes": true, "y": true, "on": true, "1": true, + "enable": true, "enabled": true, "active": true, "affirmative": true, + } + + falsy := map[string]bool{ + "false": false, "f": false, "no": false, "n": false, "off": false, "0": false, + "disable": false, "disabled": false, "inactive": false, "negative": false, + } + + normalized := strings.TrimSpace(strings.ToLower(val)) + + if val, ok := truthy[normalized]; ok { + return val + } + if val, ok := falsy[normalized]; ok { + return val + } + + return defaultValue +} diff --git a/cookies/delete.go b/cookies/delete.go new file mode 100644 index 0000000..d847b65 --- /dev/null +++ b/cookies/delete.go @@ -0,0 +1,18 @@ +package cookies + +import ( + "net/http" + "time" +) + +// Tell the browser to delete the cookie matching the name provided +// Path must match the original set cookie for it to delete +func DeleteCookie(w http.ResponseWriter, name string, path string) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: "", + Path: path, + Expires: time.Unix(0, 0), // Expire in the past + MaxAge: -1, // Immediately expire + }) +} diff --git a/cookies/pagefrom.go b/cookies/pagefrom.go index ace8040..45e4cd4 100644 --- a/cookies/pagefrom.go +++ b/cookies/pagefrom.go @@ -3,7 +3,6 @@ package cookies import ( "net/http" "net/url" - "time" ) // Check the value of "pagefrom" cookie, delete the cookie, and return the value @@ -13,8 +12,7 @@ func CheckPageFrom(w http.ResponseWriter, r *http.Request) string { return "/" } pageFrom := pageFromCookie.Value - deleteCookie := &http.Cookie{Name: "pagefrom", Value: "", Expires: time.Unix(0, 0)} - http.SetCookie(w, deleteCookie) + DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path) return pageFrom } @@ -35,6 +33,6 @@ func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) { } else { pageFrom = parsedURL.Path } - pageFromCookie := &http.Cookie{Name: "pagefrom", Value: pageFrom} + pageFromCookie := &http.Cookie{Name: "pagefrom", Value: pageFrom, Path: "/"} http.SetCookie(w, pageFromCookie) } diff --git a/cookies/tokens.go b/cookies/tokens.go new file mode 100644 index 0000000..e006756 --- /dev/null +++ b/cookies/tokens.go @@ -0,0 +1,48 @@ +package cookies + +import ( + "net/http" + "projectreshoot/config" + "time" +) + +// Get the value of the access and refresh tokens +func GetTokens( + w http.ResponseWriter, + r *http.Request, +) (acc string, ref string) { + accCookie, accErr := r.Cookie("access") + refCookie, refErr := r.Cookie("refresh") + var ( + accStr string = "" + refStr string = "" + ) + if accErr == nil { + accStr = accCookie.Value + } + if refErr == nil { + refStr = refCookie.Value + } + return accStr, refStr +} + +// Set a token with the provided details +func SetToken( + w http.ResponseWriter, + r *http.Request, + config *config.Config, + token string, + scope string, + exp int64, +) { + tokenCookie := &http.Cookie{ + Name: scope, + Value: token, + Path: "/", + Expires: time.Unix(exp, 0), + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: config.SSL, + } + http.SetCookie(w, tokenCookie) +} diff --git a/jwt/create.go b/jwt/create.go index 44aab2c..617abce 100644 --- a/jwt/create.go +++ b/jwt/create.go @@ -3,8 +3,8 @@ package jwt import ( "time" + "projectreshoot/config" "projectreshoot/db" - "projectreshoot/server" "github.com/golang-jwt/jwt" "github.com/google/uuid" @@ -13,11 +13,11 @@ import ( // Generates an access token for the provided user func GenerateAccessToken( - config *server.Config, + config *config.Config, user *db.User, fresh bool, rememberMe bool, -) (string, error) { +) (tokenStr string, exp int64, err error) { issuedAt := time.Now().Unix() expiresAt := issuedAt + (config.AccessTokenExpiry * 60) var freshExpiresAt int64 @@ -37,6 +37,7 @@ func GenerateAccessToken( "iss": config.TrustedHost, "scope": "access", "ttl": ttl, + "jti": uuid.New(), "iat": issuedAt, "exp": expiresAt, "fresh": freshExpiresAt, @@ -45,17 +46,17 @@ func GenerateAccessToken( signedToken, err := token.SignedString([]byte(config.SecretKey)) if err != nil { - return "", errors.Wrap(err, "token.SignedString") + return "", 0, errors.Wrap(err, "token.SignedString") } - return signedToken, nil + return signedToken, expiresAt, nil } // Generates a refresh token for the provided user func GenerateRefreshToken( - config *server.Config, + config *config.Config, user *db.User, rememberMe bool, -) (string, error) { +) (tokenStr string, exp int64, err error) { issuedAt := time.Now().Unix() expiresAt := issuedAt + (config.RefreshTokenExpiry * 60) var ttl string @@ -77,7 +78,7 @@ func GenerateRefreshToken( signedToken, err := token.SignedString([]byte(config.SecretKey)) if err != nil { - return "", errors.Wrap(err, "token.SignedString") + return "", 0, errors.Wrap(err, "token.SignedString") } - return signedToken, nil + return signedToken, expiresAt, nil } diff --git a/jwt/parse.go b/jwt/parse.go index 680d0e1..40b07d8 100644 --- a/jwt/parse.go +++ b/jwt/parse.go @@ -2,9 +2,10 @@ package jwt import ( "fmt" - "projectreshoot/server" "time" + "projectreshoot/config" + "github.com/golang-jwt/jwt" "github.com/google/uuid" "github.com/pkg/errors" @@ -13,7 +14,7 @@ import ( // Parse an access token and return a struct with all the claims. Does validation on // all the claims, including checking if it is expired, has a valid issuer, and // has the correct scope. -func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, error) { +func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, error) { claims, err := parseToken(config.SecretKey, tokenString) if err != nil { return AccessToken{}, errors.Wrap(err, "parseToken") @@ -49,6 +50,10 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e if err != nil { return AccessToken{}, errors.Wrap(err, "getFreshTime") } + jti, err := getTokenJTI(claims["jti"]) + if err != nil { + return AccessToken{}, errors.Wrap(err, "getTokenJTI") + } token := AccessToken{ ISS: issuer, @@ -57,6 +62,8 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e IAT: issuedAt, SUB: subject, Fresh: fresh, + JTI: jti, + Scope: scope, } return token, nil @@ -65,7 +72,7 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e // Parse a refresh token and return a struct with all the claims. Does validation on // all the claims, including checking if it is expired, has a valid issuer, and // has the correct scope. -func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken, error) { +func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken, error) { claims, err := parseToken(config.SecretKey, tokenString) if err != nil { return RefreshToken{}, errors.Wrap(err, "parseToken") @@ -103,12 +110,13 @@ func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken, } token := RefreshToken{ - ISS: issuer, - TTL: ttl, - EXP: expiry, - IAT: issuedAt, - SUB: subject, - JTI: jti, + ISS: issuer, + TTL: ttl, + EXP: expiry, + IAT: issuedAt, + SUB: subject, + JTI: jti, + Scope: scope, } return token, nil diff --git a/jwt/revoke.go b/jwt/revoke.go new file mode 100644 index 0000000..9b22e08 --- /dev/null +++ b/jwt/revoke.go @@ -0,0 +1,31 @@ +package jwt + +import ( + "database/sql" + + "github.com/pkg/errors" +) + +// Revoke a token by adding it to the database +func RevokeToken(conn *sql.DB, t Token) error { + jti := t.GetJTI() + exp := t.GetEXP() + query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)` + _, err := conn.Exec(query, jti, exp) + if err != nil { + return errors.Wrap(err, "conn.Exec") + } + return nil +} + +// Check if a token has been revoked +func CheckRevoked(conn *sql.DB, t Token) (bool, error) { + jti := t.GetJTI() + query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` + rows, err := conn.Query(query, jti) + if err != nil { + return false, errors.Wrap(err, "conn.Exec") + } + revoked := rows.Next() + return revoked, nil +} diff --git a/jwt/tokens.go b/jwt/tokens.go index 707be85..b9fa0d3 100644 --- a/jwt/tokens.go +++ b/jwt/tokens.go @@ -2,22 +2,50 @@ package jwt import "github.com/google/uuid" +type Token interface { + GetJTI() uuid.UUID + GetEXP() int64 + GetScope() string +} + // Access token type AccessToken struct { - ISS string // Issuer, generally TrustedHost - IAT int64 // Time issued at - EXP int64 // Time expiring at - TTL string // Time-to-live: "session" or "exp". Used with 'remember me' - SUB int // Subject (user) ID - Fresh int64 // Time freshness expiring at + ISS string // Issuer, generally TrustedHost + IAT int64 // Time issued at + EXP int64 // Time expiring at + TTL string // Time-to-live: "session" or "exp". Used with 'remember me' + SUB int // Subject (user) ID + JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens + Fresh int64 // Time freshness expiring at + Scope string // Should be "access" } // Refresh token type RefreshToken struct { - ISS string // Issuer, generally TrustedHost - IAT int64 // Time issued at - EXP int64 // Time expiring at - TTL string // Time-to-live: "session" or "exp". Used with 'remember me' - SUB int // Subject (user) ID - JTI uuid.UUID // UUID-4 used for identifying blacklisted refresh tokens + ISS string // Issuer, generally TrustedHost + IAT int64 // Time issued at + EXP int64 // Time expiring at + TTL string // Time-to-live: "session" or "exp". Used with 'remember me' + SUB int // Subject (user) ID + JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens + Scope string // Should be "refresh" +} + +func (a AccessToken) GetJTI() uuid.UUID { + return a.JTI +} +func (r RefreshToken) GetJTI() uuid.UUID { + return r.JTI +} +func (a AccessToken) GetEXP() int64 { + return a.EXP +} +func (r RefreshToken) GetEXP() int64 { + return r.EXP +} +func (a AccessToken) GetScope() string { + return a.Scope +} +func (r RefreshToken) GetScope() string { + return r.Scope } diff --git a/main.go b/main.go index 8639cba..a174c4a 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "projectreshoot/config" "projectreshoot/db" "projectreshoot/logging" "projectreshoot/server" @@ -26,7 +27,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() - config, err := server.GetConfig(args) + config, err := config.GetConfig(args) if err != nil { return errors.Wrap(err, "server.GetConfig") } diff --git a/server/environment.go b/server/environment.go deleted file mode 100644 index f932b53..0000000 --- a/server/environment.go +++ /dev/null @@ -1,31 +0,0 @@ -package server - -import ( - "os" - "strconv" -) - -// Get an environment variable, specifying a default value if its not set -func GetEnvDefault(key string, defaultValue string) string { - val, exists := os.LookupEnv(key) - if !exists { - return defaultValue - } - return val -} - -// Get an environment variable as an int64, specifying a default value if its -// not set or can't be parsed properly into an int64 -func GetEnvInt64(key string, defaultValue int64) int64 { - val, exists := os.LookupEnv(key) - if !exists { - return defaultValue - } - - intVal, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return defaultValue - } - return intVal - -} diff --git a/server/routes.go b/server/routes.go index 446c06f..702bb9e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -4,6 +4,7 @@ import ( "database/sql" "net/http" + "projectreshoot/config" "projectreshoot/handlers" "projectreshoot/view/page" @@ -14,7 +15,7 @@ import ( func addRoutes( mux *http.ServeMux, logger *zerolog.Logger, - config *Config, + config *config.Config, conn *sql.DB, ) { // Static files diff --git a/server/server.go b/server/server.go index c38cd53..d0a1fe2 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "database/sql" "net/http" + "projectreshoot/config" "projectreshoot/middleware" "github.com/rs/zerolog" @@ -11,7 +12,7 @@ import ( // Returns a new http.Handler with all the routes and middleware added func NewServer( - config *Config, + config *config.Config, logger *zerolog.Logger, conn *sql.DB, ) http.Handler {