Compare commits

..

4 Commits

Author SHA1 Message Date
ade3fa0454 imported env module 2026-01-02 19:03:07 +11:00
516be905a9 imported cookies module 2026-01-02 18:25:38 +11:00
6e632267ea added cookie control to jwt 2026-01-02 18:15:49 +11:00
05aad5f11b fixed transaction issues 2026-01-01 22:44:39 +11:00
15 changed files with 323 additions and 69 deletions

19
cookies/delete.go Normal file
View File

@@ -0,0 +1,19 @@
package cookies
import (
"net/http"
"time"
)
// Tell the browser to delete the cookie matching the name provided
// Path must match the original set cookie for it to delete
func DeleteCookie(w http.ResponseWriter, name string, path string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: path,
Expires: time.Unix(0, 0), // Expire in the past
MaxAge: -1, // Immediately expire
HttpOnly: true,
})
}

3
cookies/go.mod Normal file
View File

@@ -0,0 +1,3 @@
module git.haelnorr.com/h/golib/cookies
go 1.25.5

36
cookies/pagefrom.go Normal file
View File

@@ -0,0 +1,36 @@
package cookies
import (
"net/http"
"net/url"
)
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
pageFromCookie, err := r.Cookie("pagefrom")
if err != nil {
return "/"
}
pageFrom := pageFromCookie.Value
DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
return pageFrom
}
// Check the referer of the request, and if it matches the trustedHost, set
// the "pagefrom" cookie as the Path of the referer
func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
referer := r.Referer()
parsedURL, err := url.Parse(referer)
if err != nil {
return
}
var pageFrom string
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
pageFrom = "/"
} else if parsedURL.Path == "/login" || parsedURL.Path == "/register" {
return
} else {
pageFrom = parsedURL.Path
}
SetCookie(w, "pagefrom", "/", pageFrom, 0)
}

23
cookies/set.go Normal file
View File

@@ -0,0 +1,23 @@
package cookies
import (
"net/http"
)
// Set a cookie with the given name, path and value. maxAge directly relates
// to cookie MaxAge (0 for no max age, >0 for TTL in seconds)
func SetCookie(
w http.ResponseWriter,
name string,
path string,
value string,
maxAge int,
) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
Path: path,
HttpOnly: true,
MaxAge: maxAge,
})
}

35
env/boolean.go vendored Normal file
View File

@@ -0,0 +1,35 @@
package env
import (
"os"
"strings"
)
// Get an environment variable as a boolean, specifying a default value if its
// not set or can't be parsed properly into a bool
func Bool(key string, defaultValue bool) bool {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
truthy := map[string]bool{
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
"enable": true, "enabled": true, "active": true, "affirmative": true,
}
falsy := map[string]bool{
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
"disable": false, "disabled": false, "inactive": false, "negative": false,
}
normalized := strings.TrimSpace(strings.ToLower(val))
if val, ok := truthy[normalized]; ok {
return val
}
if val, ok := falsy[normalized]; ok {
return val
}
return defaultValue
}

23
env/duration.go vendored Normal file
View File

@@ -0,0 +1,23 @@
package env
import (
"os"
"strconv"
"time"
)
// Get an environment variable as a time.Duration, specifying a default value if its
// not set or can't be parsed properly
func Duration(key string, defaultValue time.Duration) time.Duration {
val, exists := os.LookupEnv(key)
if !exists {
return time.Duration(defaultValue)
}
intVal, err := strconv.Atoi(val)
if err != nil {
return time.Duration(defaultValue)
}
return time.Duration(intVal)
}

3
env/go.mod vendored Normal file
View File

@@ -0,0 +1,3 @@
module git.haelnorr.com/h/golib/env
go 1.25.5

37
env/int.go vendored Normal file
View File

@@ -0,0 +1,37 @@
package env
import (
"os"
"strconv"
)
// Get an environment variable as an int, specifying a default value if its
// not set or can't be parsed properly into an int
func Int(key string, defaultValue int) int {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.Atoi(val)
if err != nil {
return defaultValue
}
return intVal
}
// Get an environment variable as an int64, specifying a default value if its
// not set or can't be parsed properly into an int64
func Int64(key string, defaultValue int64) int64 {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return defaultValue
}
return intVal
}

14
env/string.go vendored Normal file
View File

@@ -0,0 +1,14 @@
package env
import (
"os"
)
// Get an environment variable, specifying a default value if its not set
func String(key string, defaultValue string) string {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
return val
}

73
jwt/cookies.go Normal file
View File

@@ -0,0 +1,73 @@
package jwt
import (
"github.com/pkg/errors"
"net/http"
"time"
)
// Get the value of the access and refresh tokens
func GetTokenCookies(
r *http.Request,
) (acc string, ref string) {
accCookie, accErr := r.Cookie("access")
refCookie, refErr := r.Cookie("refresh")
var (
accStr string = ""
refStr string = ""
)
if accErr == nil {
accStr = accCookie.Value
}
if refErr == nil {
refStr = refCookie.Value
}
return accStr, refStr
}
// Set a token with the provided details
func setToken(
w http.ResponseWriter,
token string,
scope string,
exp int64,
rememberme bool,
useSSL bool,
) {
tokenCookie := &http.Cookie{
Name: scope,
Value: token,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: useSSL,
}
if rememberme {
tokenCookie.Expires = time.Unix(exp, 0)
}
http.SetCookie(w, tokenCookie)
}
// Generate new tokens for the subject and set them as cookies
func SetTokenCookies(
w http.ResponseWriter,
r *http.Request,
tokenGen *TokenGenerator,
subject int,
fresh bool,
rememberMe bool,
useSSL bool,
) error {
at, atexp, err := tokenGen.NewAccess(subject, fresh, rememberMe)
if err != nil {
return errors.Wrap(err, "jwt.GenerateAccessToken")
}
rt, rtexp, err := tokenGen.NewRefresh(subject, rememberMe)
if err != nil {
return errors.Wrap(err, "jwt.GenerateRefreshToken")
}
// Don't set the cookies until we know no errors occured
setToken(w, at, "access", atexp, rememberMe, useSSL)
setToken(w, rt, "refresh", rtexp, rememberMe, useSSL)
return nil
}

View File

@@ -1,47 +1,31 @@
package jwt package jwt
import ( import (
"context" "database/sql"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// Revoke a token by adding it to the database // Revoke a token by adding it to the database
func revoke(ctx context.Context, t Token) error { func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
db := t.getDB() if gen.dbConn == nil {
if db == nil {
return errors.New("No DB provided, unable to use this function") return errors.New("No DB provided, unable to use this function")
} }
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "db.BeginTx")
}
defer tx.Rollback()
jti := t.GetJTI() jti := t.GetJTI()
exp := t.GetEXP() exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)` query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err = tx.Exec(query, jti, exp) _, err := tx.Exec(query, jti, exp)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.Exec") return errors.Wrap(err, "tx.Exec")
} }
err = tx.Commit()
if err != nil {
return errors.Wrap(err, "tx.Commit")
}
return nil return nil
} }
// Check if a token has been revoked. Returns true if not revoked. // Check if a token has been revoked. Returns true if not revoked.
func checkNotRevoked(ctx context.Context, t Token) (bool, error) { func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
db := t.getDB() if gen.dbConn == nil {
if db == nil {
return false, errors.New("No DB provided, unable to use this function") return false, errors.New("No DB provided, unable to use this function")
} }
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return false, errors.Wrap(err, "db.BeginTx")
}
defer tx.Rollback()
jti := t.GetJTI() jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := tx.Query(query, jti) rows, err := tx.Query(query, jti)
@@ -50,9 +34,5 @@ func checkNotRevoked(ctx context.Context, t Token) (bool, error) {
} }
defer rows.Close() defer rows.Close()
revoked := rows.Next() revoked := rows.Next()
err = tx.Commit()
if err != nil {
return false, errors.Wrap(err, "tx.Commit")
}
return !revoked, nil return !revoked, nil
} }

View File

@@ -2,6 +2,7 @@ package jwt
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"
@@ -31,14 +32,15 @@ func TestNoDBFail(t *testing.T) {
token := AccessToken{ token := AccessToken{
JTI: jti, JTI: jti,
EXP: exp, EXP: exp,
gen: &TokenGenerator{},
} }
// Revoke should fail due to no DB // Revoke should fail due to no DB
err := token.Revoke(context.Background()) err := token.Revoke(&sql.Tx{})
require.Error(t, err) require.Error(t, err)
// CheckNotRevoked should fail // CheckNotRevoked should fail
_, err = token.CheckNotRevoked(context.Background()) _, err = token.CheckNotRevoked(&sql.Tx{})
require.Error(t, err) require.Error(t, err)
} }
@@ -52,7 +54,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
token := AccessToken{ token := AccessToken{
JTI: jti, JTI: jti,
EXP: exp, EXP: exp,
db: gen.dbConn, gen: gen,
} }
// Revoke expectations // Revoke expectations
@@ -60,21 +62,22 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
mock.ExpectExec(`INSERT INTO jwtblacklist`). mock.ExpectExec(`INSERT INTO jwtblacklist`).
WithArgs(jti, exp). WithArgs(jti, exp).
WillReturnResult(sqlmock.NewResult(1, 1)) WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := token.Revoke(context.Background())
require.NoError(t, err)
// CheckNotRevoked expectations (now revoked)
mock.ExpectBegin()
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`). mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
WithArgs(jti). WithArgs(jti).
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1)) WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectCommit() mock.ExpectCommit()
valid, err := token.CheckNotRevoked(context.Background()) tx, err := gen.dbConn.BeginTx(context.Background(), nil)
defer tx.Rollback()
require.NoError(t, err)
err = token.Revoke(tx)
require.NoError(t, err)
valid, err := token.CheckNotRevoked(tx)
require.NoError(t, err) require.NoError(t, err)
require.False(t, valid) require.False(t, valid)
require.NoError(t, tx.Commit())
require.NoError(t, mock.ExpectationsWereMet()) require.NoError(t, mock.ExpectationsWereMet())
} }

View File

@@ -1,7 +1,6 @@
package jwt package jwt
import ( import (
"context"
"database/sql" "database/sql"
"github.com/google/uuid" "github.com/google/uuid"
@@ -11,8 +10,8 @@ type Token interface {
GetJTI() uuid.UUID GetJTI() uuid.UUID
GetEXP() int64 GetEXP() int64
GetScope() string GetScope() string
getDB() *sql.DB Revoke(*sql.Tx) error
Revoke(context.Context) error CheckNotRevoked(*sql.Tx) (bool, error)
} }
// Access token // Access token
@@ -25,7 +24,7 @@ type AccessToken struct {
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Fresh int64 // Time freshness expiring at Fresh int64 // Time freshness expiring at
Scope string // Should be "access" Scope string // Should be "access"
db *sql.DB gen *TokenGenerator
} }
// Refresh token // Refresh token
@@ -37,7 +36,7 @@ type RefreshToken struct {
SUB int // Subject (user) ID SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Scope string // Should be "refresh" Scope string // Should be "refresh"
db *sql.DB gen *TokenGenerator
} }
func (a AccessToken) GetJTI() uuid.UUID { func (a AccessToken) GetJTI() uuid.UUID {
@@ -58,21 +57,15 @@ func (a AccessToken) GetScope() string {
func (r RefreshToken) GetScope() string { func (r RefreshToken) GetScope() string {
return r.Scope return r.Scope
} }
func (a AccessToken) getDB() *sql.DB { func (a AccessToken) Revoke(tx *sql.Tx) error {
return a.db return a.gen.revoke(tx, a)
} }
func (r RefreshToken) getDB() *sql.DB { func (r RefreshToken) Revoke(tx *sql.Tx) error {
return r.db return r.gen.revoke(tx, r)
} }
func (a AccessToken) Revoke(ctx context.Context) error { func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
return revoke(ctx, a) return a.gen.checkNotRevoked(tx, a)
} }
func (r RefreshToken) Revoke(ctx context.Context) error { func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
return revoke(ctx, r) return r.gen.checkNotRevoked(tx, r)
}
func (a AccessToken) CheckNotRevoked(ctx context.Context) (bool, error) {
return checkNotRevoked(ctx, a)
}
func (r RefreshToken) CheckNotRevoked(ctx context.Context) (bool, error) {
return checkNotRevoked(ctx, r)
} }

View File

@@ -1,7 +1,8 @@
package jwt package jwt
import ( import (
"context" "database/sql"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -9,7 +10,7 @@ import (
// all the claims, including checking if it is expired, has a valid issuer, and // all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope. // has the correct scope.
func (gen *TokenGenerator) ValidateAccess( func (gen *TokenGenerator) ValidateAccess(
ctx context.Context, tx *sql.Tx,
tokenString string, tokenString string,
) (*AccessToken, error) { ) (*AccessToken, error) {
if tokenString == "" { if tokenString == "" {
@@ -64,10 +65,10 @@ func (gen *TokenGenerator) ValidateAccess(
Fresh: fresh, Fresh: fresh,
JTI: jti, JTI: jti,
Scope: scope, Scope: scope,
db: gen.dbConn, gen: gen,
} }
valid, err := token.CheckNotRevoked(ctx) valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.dbConn != nil { if err != nil && gen.dbConn != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked") return nil, errors.Wrap(err, "token.CheckNotRevoked")
} }
@@ -81,7 +82,7 @@ func (gen *TokenGenerator) ValidateAccess(
// all the claims, including checking if it is expired, has a valid issuer, and // all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope. // has the correct scope.
func (gen *TokenGenerator) ValidateRefresh( func (gen *TokenGenerator) ValidateRefresh(
ctx context.Context, tx *sql.Tx,
tokenString string, tokenString string,
) (*RefreshToken, error) { ) (*RefreshToken, error) {
if tokenString == "" { if tokenString == "" {
@@ -131,10 +132,10 @@ func (gen *TokenGenerator) ValidateRefresh(
SUB: subject, SUB: subject,
JTI: jti, JTI: jti,
Scope: scope, Scope: scope,
db: gen.dbConn, gen: gen,
} }
valid, err := token.CheckNotRevoked(ctx) valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.dbConn != nil { if err != nil && gen.dbConn != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked") return nil, errors.Wrap(err, "token.CheckNotRevoked")
} }

View File

@@ -2,6 +2,7 @@ package jwt
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
@@ -43,10 +44,15 @@ func TestValidateAccess_Success(t *testing.T) {
// We don't know the JTI beforehand; match any arg // We don't know the JTI beforehand; match any arg
expectNotRevoked(mock, sqlmock.AnyArg()) expectNotRevoked(mock, sqlmock.AnyArg())
token, err := gen.ValidateAccess(context.Background(), tokenStr) tx, err := gen.dbConn.BeginTx(context.Background(), nil)
require.NoError(t, err)
defer tx.Rollback()
token, err := gen.ValidateAccess(tx, tokenStr)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 42, token.SUB) require.Equal(t, 42, token.SUB)
require.Equal(t, "access", token.Scope) require.Equal(t, "access", token.Scope)
tx.Commit()
} }
func TestValidateAccess_NoDB(t *testing.T) { func TestValidateAccess_NoDB(t *testing.T) {
@@ -55,7 +61,7 @@ func TestValidateAccess_NoDB(t *testing.T) {
tokenStr, _, err := gen.NewAccess(42, true, false) tokenStr, _, err := gen.NewAccess(42, true, false)
require.NoError(t, err) require.NoError(t, err)
token, err := gen.ValidateAccess(context.Background(), tokenStr) token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 42, token.SUB) require.Equal(t, 42, token.SUB)
require.Equal(t, "access", token.Scope) require.Equal(t, "access", token.Scope)
@@ -70,10 +76,15 @@ func TestValidateRefresh_Success(t *testing.T) {
expectNotRevoked(mock, sqlmock.AnyArg()) expectNotRevoked(mock, sqlmock.AnyArg())
token, err := gen.ValidateRefresh(context.Background(), tokenStr) tx, err := gen.dbConn.BeginTx(context.Background(), nil)
require.NoError(t, err)
defer tx.Rollback()
token, err := gen.ValidateRefresh(tx, tokenStr)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 42, token.SUB) require.Equal(t, 42, token.SUB)
require.Equal(t, "refresh", token.Scope) require.Equal(t, "refresh", token.Scope)
tx.Commit()
} }
func TestValidateRefresh_NoDB(t *testing.T) { func TestValidateRefresh_NoDB(t *testing.T) {
@@ -82,7 +93,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
tokenStr, _, err := gen.NewRefresh(42, false) tokenStr, _, err := gen.NewRefresh(42, false)
require.NoError(t, err) require.NoError(t, err)
token, err := gen.ValidateRefresh(context.Background(), tokenStr) token, err := gen.ValidateRefresh(nil, tokenStr)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 42, token.SUB) require.Equal(t, 42, token.SUB)
require.Equal(t, "refresh", token.Scope) require.Equal(t, "refresh", token.Scope)
@@ -91,7 +102,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
func TestValidateAccess_EmptyToken(t *testing.T) { func TestValidateAccess_EmptyToken(t *testing.T) {
gen := newTestGenerator(t) gen := newTestGenerator(t)
_, err := gen.ValidateAccess(context.Background(), "") _, err := gen.ValidateAccess(nil, "")
require.Error(t, err) require.Error(t, err)
} }
@@ -102,6 +113,6 @@ func TestValidateRefresh_WrongScope(t *testing.T) {
tokenStr, _, err := gen.NewAccess(1, false, false) tokenStr, _, err := gen.NewAccess(1, false, false)
require.NoError(t, err) require.NoError(t, err)
_, err = gen.ValidateRefresh(context.Background(), tokenStr) _, err = gen.ValidateRefresh(nil, tokenStr)
require.Error(t, err) require.Error(t, err)
} }