Moved config and finished JWT module

This commit is contained in:
2025-02-10 22:10:03 +11:00
parent 04049bb73a
commit e73805a02d
13 changed files with 237 additions and 69 deletions

View File

@@ -1,9 +1,10 @@
package server package config
import ( import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"projectreshoot/logging" "projectreshoot/logging"
"github.com/joho/godotenv" "github.com/joho/godotenv"
@@ -14,6 +15,7 @@ type Config struct {
Host string // Host to listen on Host string // Host to listen on
Port string // Port to listen on Port string // Port to listen on
TrustedHost string // Domain/Hostname to accept as trusted TrustedHost string // Domain/Hostname to accept as trusted
SSL bool // Flag for SSL Mode
TursoDBName string // DB Name for Turso DB/Branch TursoDBName string // DB Name for Turso DB/Branch
TursoToken string // Bearer token for Turso DB/Branch TursoToken string // Bearer token for Turso DB/Branch
SecretKey string // Secret key for signing tokens SecretKey string // Secret key for signing tokens
@@ -78,6 +80,7 @@ func GetConfig(args map[string]string) (*Config, error) {
Host: host, Host: host,
Port: port, Port: port,
TrustedHost: os.Getenv("TRUSTED_HOST"), TrustedHost: os.Getenv("TRUSTED_HOST"),
SSL: GetEnvBool("SSL_MODE", false),
TursoDBName: os.Getenv("TURSO_DB_NAME"), TursoDBName: os.Getenv("TURSO_DB_NAME"),
TursoToken: os.Getenv("TURSO_AUTH_TOKEN"), TursoToken: os.Getenv("TURSO_AUTH_TOKEN"),
SecretKey: os.Getenv("SECRET_KEY"), SecretKey: os.Getenv("SECRET_KEY"),

61
config/environment.go Normal file
View File

@@ -0,0 +1,61 @@
package config
import (
"os"
"strconv"
"strings"
)
// Get an environment variable, specifying a default value if its not set
func GetEnvDefault(key string, defaultValue string) string {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
return val
}
// Get an environment variable as an int64, specifying a default value if its
// not set or can't be parsed properly into an int64
func GetEnvInt64(key string, defaultValue int64) int64 {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return defaultValue
}
return intVal
}
// Get an environment variable as a boolean, specifying a default value if its
// not set or can't be parsed properly into a bool
func GetEnvBool(key string, defaultValue bool) bool {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
truthy := map[string]bool{
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
"enable": true, "enabled": true, "active": true, "affirmative": true,
}
falsy := map[string]bool{
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
"disable": false, "disabled": false, "inactive": false, "negative": false,
}
normalized := strings.TrimSpace(strings.ToLower(val))
if val, ok := truthy[normalized]; ok {
return val
}
if val, ok := falsy[normalized]; ok {
return val
}
return defaultValue
}

18
cookies/delete.go Normal file
View File

@@ -0,0 +1,18 @@
package cookies
import (
"net/http"
"time"
)
// Tell the browser to delete the cookie matching the name provided
// Path must match the original set cookie for it to delete
func DeleteCookie(w http.ResponseWriter, name string, path string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: path,
Expires: time.Unix(0, 0), // Expire in the past
MaxAge: -1, // Immediately expire
})
}

View File

@@ -3,7 +3,6 @@ package cookies
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"time"
) )
// Check the value of "pagefrom" cookie, delete the cookie, and return the value // Check the value of "pagefrom" cookie, delete the cookie, and return the value
@@ -13,8 +12,7 @@ func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
return "/" return "/"
} }
pageFrom := pageFromCookie.Value pageFrom := pageFromCookie.Value
deleteCookie := &http.Cookie{Name: "pagefrom", Value: "", Expires: time.Unix(0, 0)} DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
http.SetCookie(w, deleteCookie)
return pageFrom return pageFrom
} }
@@ -35,6 +33,6 @@ func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
} else { } else {
pageFrom = parsedURL.Path pageFrom = parsedURL.Path
} }
pageFromCookie := &http.Cookie{Name: "pagefrom", Value: pageFrom} pageFromCookie := &http.Cookie{Name: "pagefrom", Value: pageFrom, Path: "/"}
http.SetCookie(w, pageFromCookie) http.SetCookie(w, pageFromCookie)
} }

48
cookies/tokens.go Normal file
View File

@@ -0,0 +1,48 @@
package cookies
import (
"net/http"
"projectreshoot/config"
"time"
)
// Get the value of the access and refresh tokens
func GetTokens(
w http.ResponseWriter,
r *http.Request,
) (acc string, ref string) {
accCookie, accErr := r.Cookie("access")
refCookie, refErr := r.Cookie("refresh")
var (
accStr string = ""
refStr string = ""
)
if accErr == nil {
accStr = accCookie.Value
}
if refErr == nil {
refStr = refCookie.Value
}
return accStr, refStr
}
// Set a token with the provided details
func SetToken(
w http.ResponseWriter,
r *http.Request,
config *config.Config,
token string,
scope string,
exp int64,
) {
tokenCookie := &http.Cookie{
Name: scope,
Value: token,
Path: "/",
Expires: time.Unix(exp, 0),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: config.SSL,
}
http.SetCookie(w, tokenCookie)
}

View File

@@ -3,8 +3,8 @@ package jwt
import ( import (
"time" "time"
"projectreshoot/config"
"projectreshoot/db" "projectreshoot/db"
"projectreshoot/server"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/google/uuid" "github.com/google/uuid"
@@ -13,11 +13,11 @@ import (
// Generates an access token for the provided user // Generates an access token for the provided user
func GenerateAccessToken( func GenerateAccessToken(
config *server.Config, config *config.Config,
user *db.User, user *db.User,
fresh bool, fresh bool,
rememberMe bool, rememberMe bool,
) (string, error) { ) (tokenStr string, exp int64, err error) {
issuedAt := time.Now().Unix() issuedAt := time.Now().Unix()
expiresAt := issuedAt + (config.AccessTokenExpiry * 60) expiresAt := issuedAt + (config.AccessTokenExpiry * 60)
var freshExpiresAt int64 var freshExpiresAt int64
@@ -37,6 +37,7 @@ func GenerateAccessToken(
"iss": config.TrustedHost, "iss": config.TrustedHost,
"scope": "access", "scope": "access",
"ttl": ttl, "ttl": ttl,
"jti": uuid.New(),
"iat": issuedAt, "iat": issuedAt,
"exp": expiresAt, "exp": expiresAt,
"fresh": freshExpiresAt, "fresh": freshExpiresAt,
@@ -45,17 +46,17 @@ func GenerateAccessToken(
signedToken, err := token.SignedString([]byte(config.SecretKey)) signedToken, err := token.SignedString([]byte(config.SecretKey))
if err != nil { if err != nil {
return "", errors.Wrap(err, "token.SignedString") return "", 0, errors.Wrap(err, "token.SignedString")
} }
return signedToken, nil return signedToken, expiresAt, nil
} }
// Generates a refresh token for the provided user // Generates a refresh token for the provided user
func GenerateRefreshToken( func GenerateRefreshToken(
config *server.Config, config *config.Config,
user *db.User, user *db.User,
rememberMe bool, rememberMe bool,
) (string, error) { ) (tokenStr string, exp int64, err error) {
issuedAt := time.Now().Unix() issuedAt := time.Now().Unix()
expiresAt := issuedAt + (config.RefreshTokenExpiry * 60) expiresAt := issuedAt + (config.RefreshTokenExpiry * 60)
var ttl string var ttl string
@@ -77,7 +78,7 @@ func GenerateRefreshToken(
signedToken, err := token.SignedString([]byte(config.SecretKey)) signedToken, err := token.SignedString([]byte(config.SecretKey))
if err != nil { if err != nil {
return "", errors.Wrap(err, "token.SignedString") return "", 0, errors.Wrap(err, "token.SignedString")
} }
return signedToken, nil return signedToken, expiresAt, nil
} }

View File

@@ -2,9 +2,10 @@ package jwt
import ( import (
"fmt" "fmt"
"projectreshoot/server"
"time" "time"
"projectreshoot/config"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -13,7 +14,7 @@ import (
// Parse an access token and return a struct with all the claims. Does validation on // Parse an access token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and // all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope. // has the correct scope.
func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, error) { func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, error) {
claims, err := parseToken(config.SecretKey, tokenString) claims, err := parseToken(config.SecretKey, tokenString)
if err != nil { if err != nil {
return AccessToken{}, errors.Wrap(err, "parseToken") return AccessToken{}, errors.Wrap(err, "parseToken")
@@ -49,6 +50,10 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e
if err != nil { if err != nil {
return AccessToken{}, errors.Wrap(err, "getFreshTime") return AccessToken{}, errors.Wrap(err, "getFreshTime")
} }
jti, err := getTokenJTI(claims["jti"])
if err != nil {
return AccessToken{}, errors.Wrap(err, "getTokenJTI")
}
token := AccessToken{ token := AccessToken{
ISS: issuer, ISS: issuer,
@@ -57,6 +62,8 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e
IAT: issuedAt, IAT: issuedAt,
SUB: subject, SUB: subject,
Fresh: fresh, Fresh: fresh,
JTI: jti,
Scope: scope,
} }
return token, nil return token, nil
@@ -65,7 +72,7 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e
// Parse a refresh token and return a struct with all the claims. Does validation on // Parse a refresh token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and // all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope. // has the correct scope.
func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken, error) { func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken, error) {
claims, err := parseToken(config.SecretKey, tokenString) claims, err := parseToken(config.SecretKey, tokenString)
if err != nil { if err != nil {
return RefreshToken{}, errors.Wrap(err, "parseToken") return RefreshToken{}, errors.Wrap(err, "parseToken")
@@ -103,12 +110,13 @@ func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken,
} }
token := RefreshToken{ token := RefreshToken{
ISS: issuer, ISS: issuer,
TTL: ttl, TTL: ttl,
EXP: expiry, EXP: expiry,
IAT: issuedAt, IAT: issuedAt,
SUB: subject, SUB: subject,
JTI: jti, JTI: jti,
Scope: scope,
} }
return token, nil return token, nil

31
jwt/revoke.go Normal file
View File

@@ -0,0 +1,31 @@
package jwt
import (
"database/sql"
"github.com/pkg/errors"
)
// Revoke a token by adding it to the database
func RevokeToken(conn *sql.DB, t Token) error {
jti := t.GetJTI()
exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err := conn.Exec(query, jti, exp)
if err != nil {
return errors.Wrap(err, "conn.Exec")
}
return nil
}
// Check if a token has been revoked
func CheckRevoked(conn *sql.DB, t Token) (bool, error) {
jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := conn.Query(query, jti)
if err != nil {
return false, errors.Wrap(err, "conn.Exec")
}
revoked := rows.Next()
return revoked, nil
}

View File

@@ -2,22 +2,50 @@ package jwt
import "github.com/google/uuid" import "github.com/google/uuid"
type Token interface {
GetJTI() uuid.UUID
GetEXP() int64
GetScope() string
}
// Access token // Access token
type AccessToken struct { type AccessToken struct {
ISS string // Issuer, generally TrustedHost ISS string // Issuer, generally TrustedHost
IAT int64 // Time issued at IAT int64 // Time issued at
EXP int64 // Time expiring at EXP int64 // Time expiring at
TTL string // Time-to-live: "session" or "exp". Used with 'remember me' TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
SUB int // Subject (user) ID SUB int // Subject (user) ID
Fresh int64 // Time freshness expiring at JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Fresh int64 // Time freshness expiring at
Scope string // Should be "access"
} }
// Refresh token // Refresh token
type RefreshToken struct { type RefreshToken struct {
ISS string // Issuer, generally TrustedHost ISS string // Issuer, generally TrustedHost
IAT int64 // Time issued at IAT int64 // Time issued at
EXP int64 // Time expiring at EXP int64 // Time expiring at
TTL string // Time-to-live: "session" or "exp". Used with 'remember me' TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
SUB int // Subject (user) ID SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted refresh tokens JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Scope string // Should be "refresh"
}
func (a AccessToken) GetJTI() uuid.UUID {
return a.JTI
}
func (r RefreshToken) GetJTI() uuid.UUID {
return r.JTI
}
func (a AccessToken) GetEXP() int64 {
return a.EXP
}
func (r RefreshToken) GetEXP() int64 {
return r.EXP
}
func (a AccessToken) GetScope() string {
return a.Scope
}
func (r RefreshToken) GetScope() string {
return r.Scope
} }

View File

@@ -14,6 +14,7 @@ import (
"sync" "sync"
"time" "time"
"projectreshoot/config"
"projectreshoot/db" "projectreshoot/db"
"projectreshoot/logging" "projectreshoot/logging"
"projectreshoot/server" "projectreshoot/server"
@@ -26,7 +27,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel() defer cancel()
config, err := server.GetConfig(args) config, err := config.GetConfig(args)
if err != nil { if err != nil {
return errors.Wrap(err, "server.GetConfig") return errors.Wrap(err, "server.GetConfig")
} }

View File

@@ -1,31 +0,0 @@
package server
import (
"os"
"strconv"
)
// Get an environment variable, specifying a default value if its not set
func GetEnvDefault(key string, defaultValue string) string {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
return val
}
// Get an environment variable as an int64, specifying a default value if its
// not set or can't be parsed properly into an int64
func GetEnvInt64(key string, defaultValue int64) int64 {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return defaultValue
}
return intVal
}

View File

@@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"net/http" "net/http"
"projectreshoot/config"
"projectreshoot/handlers" "projectreshoot/handlers"
"projectreshoot/view/page" "projectreshoot/view/page"
@@ -14,7 +15,7 @@ import (
func addRoutes( func addRoutes(
mux *http.ServeMux, mux *http.ServeMux,
logger *zerolog.Logger, logger *zerolog.Logger,
config *Config, config *config.Config,
conn *sql.DB, conn *sql.DB,
) { ) {
// Static files // Static files

View File

@@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"net/http" "net/http"
"projectreshoot/config"
"projectreshoot/middleware" "projectreshoot/middleware"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@@ -11,7 +12,7 @@ import (
// Returns a new http.Handler with all the routes and middleware added // Returns a new http.Handler with all the routes and middleware added
func NewServer( func NewServer(
config *Config, config *config.Config,
logger *zerolog.Logger, logger *zerolog.Logger,
conn *sql.DB, conn *sql.DB,
) http.Handler { ) http.Handler {