Moved config and finished JWT module

This commit is contained in:
2025-02-10 22:10:03 +11:00
parent 04049bb73a
commit e73805a02d
13 changed files with 237 additions and 69 deletions

View File

@@ -1,9 +1,10 @@
package server
package config
import (
"errors"
"fmt"
"os"
"projectreshoot/logging"
"github.com/joho/godotenv"
@@ -14,6 +15,7 @@ type Config struct {
Host string // Host to listen on
Port string // Port to listen on
TrustedHost string // Domain/Hostname to accept as trusted
SSL bool // Flag for SSL Mode
TursoDBName string // DB Name for Turso DB/Branch
TursoToken string // Bearer token for Turso DB/Branch
SecretKey string // Secret key for signing tokens
@@ -78,6 +80,7 @@ func GetConfig(args map[string]string) (*Config, error) {
Host: host,
Port: port,
TrustedHost: os.Getenv("TRUSTED_HOST"),
SSL: GetEnvBool("SSL_MODE", false),
TursoDBName: os.Getenv("TURSO_DB_NAME"),
TursoToken: os.Getenv("TURSO_AUTH_TOKEN"),
SecretKey: os.Getenv("SECRET_KEY"),

61
config/environment.go Normal file
View File

@@ -0,0 +1,61 @@
package config
import (
"os"
"strconv"
"strings"
)
// Get an environment variable, specifying a default value if its not set
func GetEnvDefault(key string, defaultValue string) string {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
return val
}
// Get an environment variable as an int64, specifying a default value if its
// not set or can't be parsed properly into an int64
func GetEnvInt64(key string, defaultValue int64) int64 {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return defaultValue
}
return intVal
}
// Get an environment variable as a boolean, specifying a default value if its
// not set or can't be parsed properly into a bool
func GetEnvBool(key string, defaultValue bool) bool {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
truthy := map[string]bool{
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
"enable": true, "enabled": true, "active": true, "affirmative": true,
}
falsy := map[string]bool{
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
"disable": false, "disabled": false, "inactive": false, "negative": false,
}
normalized := strings.TrimSpace(strings.ToLower(val))
if val, ok := truthy[normalized]; ok {
return val
}
if val, ok := falsy[normalized]; ok {
return val
}
return defaultValue
}

18
cookies/delete.go Normal file
View File

@@ -0,0 +1,18 @@
package cookies
import (
"net/http"
"time"
)
// Tell the browser to delete the cookie matching the name provided
// Path must match the original set cookie for it to delete
func DeleteCookie(w http.ResponseWriter, name string, path string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: path,
Expires: time.Unix(0, 0), // Expire in the past
MaxAge: -1, // Immediately expire
})
}

View File

@@ -3,7 +3,6 @@ package cookies
import (
"net/http"
"net/url"
"time"
)
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
@@ -13,8 +12,7 @@ func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
return "/"
}
pageFrom := pageFromCookie.Value
deleteCookie := &http.Cookie{Name: "pagefrom", Value: "", Expires: time.Unix(0, 0)}
http.SetCookie(w, deleteCookie)
DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
return pageFrom
}
@@ -35,6 +33,6 @@ func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
} else {
pageFrom = parsedURL.Path
}
pageFromCookie := &http.Cookie{Name: "pagefrom", Value: pageFrom}
pageFromCookie := &http.Cookie{Name: "pagefrom", Value: pageFrom, Path: "/"}
http.SetCookie(w, pageFromCookie)
}

48
cookies/tokens.go Normal file
View File

@@ -0,0 +1,48 @@
package cookies
import (
"net/http"
"projectreshoot/config"
"time"
)
// Get the value of the access and refresh tokens
func GetTokens(
w http.ResponseWriter,
r *http.Request,
) (acc string, ref string) {
accCookie, accErr := r.Cookie("access")
refCookie, refErr := r.Cookie("refresh")
var (
accStr string = ""
refStr string = ""
)
if accErr == nil {
accStr = accCookie.Value
}
if refErr == nil {
refStr = refCookie.Value
}
return accStr, refStr
}
// Set a token with the provided details
func SetToken(
w http.ResponseWriter,
r *http.Request,
config *config.Config,
token string,
scope string,
exp int64,
) {
tokenCookie := &http.Cookie{
Name: scope,
Value: token,
Path: "/",
Expires: time.Unix(exp, 0),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: config.SSL,
}
http.SetCookie(w, tokenCookie)
}

View File

@@ -3,8 +3,8 @@ package jwt
import (
"time"
"projectreshoot/config"
"projectreshoot/db"
"projectreshoot/server"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
@@ -13,11 +13,11 @@ import (
// Generates an access token for the provided user
func GenerateAccessToken(
config *server.Config,
config *config.Config,
user *db.User,
fresh bool,
rememberMe bool,
) (string, error) {
) (tokenStr string, exp int64, err error) {
issuedAt := time.Now().Unix()
expiresAt := issuedAt + (config.AccessTokenExpiry * 60)
var freshExpiresAt int64
@@ -37,6 +37,7 @@ func GenerateAccessToken(
"iss": config.TrustedHost,
"scope": "access",
"ttl": ttl,
"jti": uuid.New(),
"iat": issuedAt,
"exp": expiresAt,
"fresh": freshExpiresAt,
@@ -45,17 +46,17 @@ func GenerateAccessToken(
signedToken, err := token.SignedString([]byte(config.SecretKey))
if err != nil {
return "", errors.Wrap(err, "token.SignedString")
return "", 0, errors.Wrap(err, "token.SignedString")
}
return signedToken, nil
return signedToken, expiresAt, nil
}
// Generates a refresh token for the provided user
func GenerateRefreshToken(
config *server.Config,
config *config.Config,
user *db.User,
rememberMe bool,
) (string, error) {
) (tokenStr string, exp int64, err error) {
issuedAt := time.Now().Unix()
expiresAt := issuedAt + (config.RefreshTokenExpiry * 60)
var ttl string
@@ -77,7 +78,7 @@ func GenerateRefreshToken(
signedToken, err := token.SignedString([]byte(config.SecretKey))
if err != nil {
return "", errors.Wrap(err, "token.SignedString")
return "", 0, errors.Wrap(err, "token.SignedString")
}
return signedToken, nil
return signedToken, expiresAt, nil
}

View File

@@ -2,9 +2,10 @@ package jwt
import (
"fmt"
"projectreshoot/server"
"time"
"projectreshoot/config"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/pkg/errors"
@@ -13,7 +14,7 @@ import (
// Parse an access token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, error) {
func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, error) {
claims, err := parseToken(config.SecretKey, tokenString)
if err != nil {
return AccessToken{}, errors.Wrap(err, "parseToken")
@@ -49,6 +50,10 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e
if err != nil {
return AccessToken{}, errors.Wrap(err, "getFreshTime")
}
jti, err := getTokenJTI(claims["jti"])
if err != nil {
return AccessToken{}, errors.Wrap(err, "getTokenJTI")
}
token := AccessToken{
ISS: issuer,
@@ -57,6 +62,8 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e
IAT: issuedAt,
SUB: subject,
Fresh: fresh,
JTI: jti,
Scope: scope,
}
return token, nil
@@ -65,7 +72,7 @@ func ParseAccessToken(config *server.Config, tokenString string) (AccessToken, e
// Parse a refresh token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken, error) {
func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken, error) {
claims, err := parseToken(config.SecretKey, tokenString)
if err != nil {
return RefreshToken{}, errors.Wrap(err, "parseToken")
@@ -109,6 +116,7 @@ func ParseRefreshToken(config *server.Config, tokenString string) (RefreshToken,
IAT: issuedAt,
SUB: subject,
JTI: jti,
Scope: scope,
}
return token, nil

31
jwt/revoke.go Normal file
View File

@@ -0,0 +1,31 @@
package jwt
import (
"database/sql"
"github.com/pkg/errors"
)
// Revoke a token by adding it to the database
func RevokeToken(conn *sql.DB, t Token) error {
jti := t.GetJTI()
exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err := conn.Exec(query, jti, exp)
if err != nil {
return errors.Wrap(err, "conn.Exec")
}
return nil
}
// Check if a token has been revoked
func CheckRevoked(conn *sql.DB, t Token) (bool, error) {
jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := conn.Query(query, jti)
if err != nil {
return false, errors.Wrap(err, "conn.Exec")
}
revoked := rows.Next()
return revoked, nil
}

View File

@@ -2,6 +2,12 @@ package jwt
import "github.com/google/uuid"
type Token interface {
GetJTI() uuid.UUID
GetEXP() int64
GetScope() string
}
// Access token
type AccessToken struct {
ISS string // Issuer, generally TrustedHost
@@ -9,7 +15,9 @@ type AccessToken struct {
EXP int64 // Time expiring at
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Fresh int64 // Time freshness expiring at
Scope string // Should be "access"
}
// Refresh token
@@ -19,5 +27,25 @@ type RefreshToken struct {
EXP int64 // Time expiring at
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted refresh tokens
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Scope string // Should be "refresh"
}
func (a AccessToken) GetJTI() uuid.UUID {
return a.JTI
}
func (r RefreshToken) GetJTI() uuid.UUID {
return r.JTI
}
func (a AccessToken) GetEXP() int64 {
return a.EXP
}
func (r RefreshToken) GetEXP() int64 {
return r.EXP
}
func (a AccessToken) GetScope() string {
return a.Scope
}
func (r RefreshToken) GetScope() string {
return r.Scope
}

View File

@@ -14,6 +14,7 @@ import (
"sync"
"time"
"projectreshoot/config"
"projectreshoot/db"
"projectreshoot/logging"
"projectreshoot/server"
@@ -26,7 +27,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
config, err := server.GetConfig(args)
config, err := config.GetConfig(args)
if err != nil {
return errors.Wrap(err, "server.GetConfig")
}

View File

@@ -1,31 +0,0 @@
package server
import (
"os"
"strconv"
)
// Get an environment variable, specifying a default value if its not set
func GetEnvDefault(key string, defaultValue string) string {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
return val
}
// Get an environment variable as an int64, specifying a default value if its
// not set or can't be parsed properly into an int64
func GetEnvInt64(key string, defaultValue int64) int64 {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return defaultValue
}
return intVal
}

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"net/http"
"projectreshoot/config"
"projectreshoot/handlers"
"projectreshoot/view/page"
@@ -14,7 +15,7 @@ import (
func addRoutes(
mux *http.ServeMux,
logger *zerolog.Logger,
config *Config,
config *config.Config,
conn *sql.DB,
) {
// Static files

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"net/http"
"projectreshoot/config"
"projectreshoot/middleware"
"github.com/rs/zerolog"
@@ -11,7 +12,7 @@ import (
// Returns a new http.Handler with all the routes and middleware added
func NewServer(
config *Config,
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
) http.Handler {