Compare commits

...

2 Commits

Author SHA1 Message Date
ae4094d426 refactor to improve database operability 2026-01-11 22:21:44 +11:00
1b25e2f0a5 Refactor database interface to use *sql.DB directly
Simplified the database layer by removing custom interface wrappers
and using standard library *sql.DB and *sql.Tx types directly.

Changes:
- Removed DBConnection and DBTransaction interfaces from database.go
- Removed NewDBConnection() wrapper function
- Updated TokenGenerator to use *sql.DB instead of DBConnection
- Updated all validation and revocation methods to accept *sql.Tx
- Updated TableManager to work with *sql.DB directly
- Updated all tests to use db.Begin() instead of custom wrappers
- Fixed GeneratorConfig.DB field (was DBConn)
- Updated documentation in doc.go with correct API usage

Benefits:
- Simpler API with fewer abstractions
- Works directly with database/sql standard library
- Compatible with GORM (via gormDB.DB()) and Bun (share same *sql.DB)
- Easier to understand and maintain
- No unnecessary wrapper layers

Breaking changes:
- GeneratorConfig.DBConn renamed to GeneratorConfig.DB
- Removed NewDBConnection() function - pass *sql.DB directly
- ValidateAccess/ValidateRefresh now accept *sql.Tx instead of DBTransaction
- Token.Revoke/CheckNotRevoked now accept *sql.Tx instead of DBTransaction

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-11 17:39:30 +11:00
46 changed files with 3824 additions and 311 deletions

19
hws/.gitignore vendored Normal file
View File

@@ -0,0 +1,19 @@
# Test coverage files
coverage.out
coverage.html
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool
*.out
# Go workspace file
go.work

35
hws/config.go Normal file
View File

@@ -0,0 +1,35 @@
package hws
import (
"time"
"git.haelnorr.com/h/golib/env"
)
type Config struct {
Host string // ENV HWS_HOST: Host to listen on (default: 127.0.0.1)
Port uint64 // ENV HWS_PORT: Port to listen on (default: 3000)
TrustedHost string // ENV HWS_TRUSTED_HOST: Domain/Hostname to accept as trusted (default: same as Host)
GZIP bool // ENV HWS_GZIP: Flag for GZIP compression on requests (default: false)
ReadHeaderTimeout time.Duration // ENV HWS_READ_HEADER_TIMEOUT: Timeout for reading request headers in seconds (default: 2)
WriteTimeout time.Duration // ENV HWS_WRITE_TIMEOUT: Timeout for writing requests in seconds (default: 10)
IdleTimeout time.Duration // ENV HWS_IDLE_TIMEOUT: Timeout for idle connections in seconds (default: 120)
}
// ConfigFromEnv returns a Config struct loaded from the environment variables
func ConfigFromEnv() (*Config, error) {
host := env.String("HWS_HOST", "127.0.0.1")
trustedHost := env.String("HWS_TRUSTED_HOST", host)
cfg := &Config{
Host: host,
Port: env.UInt64("HWS_PORT", 3000),
TrustedHost: trustedHost,
GZIP: env.Bool("HWS_GZIP", false),
ReadHeaderTimeout: time.Duration(env.Int("HWS_READ_HEADER_TIMEOUT", 2)) * time.Second,
WriteTimeout: time.Duration(env.Int("HWS_WRITE_TIMEOUT", 10)) * time.Second,
IdleTimeout: time.Duration(env.Int("HWS_IDLE_TIMEOUT", 120)) * time.Second,
}
return cfg, nil
}

120
hws/config_test.go Normal file
View File

@@ -0,0 +1,120 @@
package hws_test
import (
"os"
"testing"
"time"
"git.haelnorr.com/h/golib/hws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ConfigFromEnv(t *testing.T) {
t.Run("Default values when no env vars set", func(t *testing.T) {
// Clear any existing env vars
os.Unsetenv("HWS_HOST")
os.Unsetenv("HWS_PORT")
os.Unsetenv("HWS_TRUSTED_HOST")
os.Unsetenv("HWS_GZIP")
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
os.Unsetenv("HWS_WRITE_TIMEOUT")
os.Unsetenv("HWS_IDLE_TIMEOUT")
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
require.NotNil(t, config)
assert.Equal(t, "127.0.0.1", config.Host)
assert.Equal(t, uint64(3000), config.Port)
assert.Equal(t, "127.0.0.1", config.TrustedHost)
assert.Equal(t, false, config.GZIP)
assert.Equal(t, 2*time.Second, config.ReadHeaderTimeout)
assert.Equal(t, 10*time.Second, config.WriteTimeout)
assert.Equal(t, 120*time.Second, config.IdleTimeout)
})
t.Run("Custom host", func(t *testing.T) {
os.Setenv("HWS_HOST", "192.168.1.1")
defer os.Unsetenv("HWS_HOST")
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
assert.Equal(t, "192.168.1.1", config.Host)
assert.Equal(t, "192.168.1.1", config.TrustedHost) // Should match host by default
})
t.Run("Custom port", func(t *testing.T) {
os.Setenv("HWS_PORT", "8080")
defer os.Unsetenv("HWS_PORT")
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
assert.Equal(t, uint64(8080), config.Port)
})
t.Run("Custom trusted host", func(t *testing.T) {
os.Setenv("HWS_HOST", "127.0.0.1")
os.Setenv("HWS_TRUSTED_HOST", "example.com")
defer os.Unsetenv("HWS_HOST")
defer os.Unsetenv("HWS_TRUSTED_HOST")
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
assert.Equal(t, "127.0.0.1", config.Host)
assert.Equal(t, "example.com", config.TrustedHost)
})
t.Run("GZIP enabled", func(t *testing.T) {
os.Setenv("HWS_GZIP", "true")
defer os.Unsetenv("HWS_GZIP")
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
assert.Equal(t, true, config.GZIP)
})
t.Run("Custom timeouts", func(t *testing.T) {
os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
os.Setenv("HWS_WRITE_TIMEOUT", "30")
os.Setenv("HWS_IDLE_TIMEOUT", "300")
defer os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
defer os.Unsetenv("HWS_WRITE_TIMEOUT")
defer os.Unsetenv("HWS_IDLE_TIMEOUT")
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
assert.Equal(t, 5*time.Second, config.ReadHeaderTimeout)
assert.Equal(t, 30*time.Second, config.WriteTimeout)
assert.Equal(t, 300*time.Second, config.IdleTimeout)
})
t.Run("All custom values", func(t *testing.T) {
os.Setenv("HWS_HOST", "0.0.0.0")
os.Setenv("HWS_PORT", "9000")
os.Setenv("HWS_TRUSTED_HOST", "myapp.com")
os.Setenv("HWS_GZIP", "true")
os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
os.Setenv("HWS_WRITE_TIMEOUT", "15")
os.Setenv("HWS_IDLE_TIMEOUT", "180")
defer func() {
os.Unsetenv("HWS_HOST")
os.Unsetenv("HWS_PORT")
os.Unsetenv("HWS_TRUSTED_HOST")
os.Unsetenv("HWS_GZIP")
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
os.Unsetenv("HWS_WRITE_TIMEOUT")
os.Unsetenv("HWS_IDLE_TIMEOUT")
}()
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
assert.Equal(t, "0.0.0.0", config.Host)
assert.Equal(t, uint64(9000), config.Port)
assert.Equal(t, "myapp.com", config.TrustedHost)
assert.Equal(t, true, config.GZIP)
assert.Equal(t, 3*time.Second, config.ReadHeaderTimeout)
assert.Equal(t, 15*time.Second, config.WriteTimeout)
assert.Equal(t, 180*time.Second, config.IdleTimeout)
})
}

View File

@@ -1,39 +1,108 @@
package hws
import "net/http"
import (
"context"
"io"
"net/http"
"net/http/httptest"
"github.com/pkg/errors"
)
// Error to use with Server.ThrowError
type HWSError struct {
statusCode int // HTTP Status code
message string // Error message
error error // Error
StatusCode int // HTTP Status code
Message string // Error message
Error error // Error
Level ErrorLevel // Error level to use for logging. Defaults to Error
RenderErrorPage bool // If true, the servers ErrorPage will be rendered
}
type ErrorPage func(statusCode int, w http.ResponseWriter, r *http.Request) error
type ErrorLevel string
func NewError(statusCode int, msg string, err error) *HWSError {
return &HWSError{
statusCode: statusCode,
message: msg,
error: err,
const (
ErrorDEBUG ErrorLevel = "Debug"
ErrorINFO ErrorLevel = "Info"
ErrorWARN ErrorLevel = "Warn"
ErrorERROR ErrorLevel = "Error"
ErrorFATAL ErrorLevel = "Fatal"
ErrorPANIC ErrorLevel = "Panic"
)
// ErrorPageFunc is a function that returns an ErrorPage with the specified HTTP Status code
// This will be called by the server when it needs to render an error page
type ErrorPageFunc func(errorCode int) (ErrorPage, error)
// ErrorPage must implement a Render() function that takes in a context and ResponseWriter,
// and should write a reponse as output to the ResponseWriter.
// Server.ThrowError will call the Render() function on the current request
type ErrorPage interface {
Render(ctx context.Context, w io.Writer) error
}
// TODO: add test for ErrorPageFunc that returns an error
func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
page, err := pageFunc(http.StatusInternalServerError)
if err != nil {
return errors.Wrap(err, "An error occured when trying to get the error page")
}
err = page.Render(req.Context(), rr)
if err != nil {
return errors.Wrap(err, "An error occured when trying to render the error page")
}
if len(rr.Header()) == 0 && rr.Body.String() == "" {
return errors.New("Render method of the error page did not write anything to the response writer")
}
server.errorPage = pageFunc
return nil
}
func (server *Server) AddErrorPage(page ErrorPage) {
server.errorPage = page
}
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error *HWSError) {
w.WriteHeader(error.statusCode)
server.logger.logger.Error().Err(error.error).Msg(error.message)
if server.errorPage != nil {
err := server.errorPage(error.statusCode, w, r)
// ThrowError will write the HTTP status code to the response headers, and log
// the error with the level specified by the HWSError.
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
// and the request chain should be terminated.
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) error {
if error.StatusCode <= 0 {
return errors.New("HWSError.StatusCode cannot be 0.")
}
if error.Message == "" {
return errors.New("HWSError.Message cannot be empty")
}
if error.Error == nil {
return errors.New("HWSError.Error cannot be nil")
}
if r == nil {
return errors.New("Request cannot be nil")
}
if !server.IsReady() {
return errors.New("ThrowError called before server started")
}
w.WriteHeader(error.StatusCode)
server.LogError(error)
if server.errorPage == nil {
server.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
return nil
}
if error.RenderErrorPage {
server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
errPage, err := server.errorPage(error.StatusCode)
if err != nil {
server.logger.logger.Error().Err(err).Msg("Failed to render error page")
server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
}
err = errPage.Render(r.Context(), w)
if err != nil {
server.LogError(HWSError{Message: "Failed to render error page", Error: err})
}
} else {
server.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
}
return nil
}
func (server *Server) ThrowWarn(w http.ResponseWriter, error *HWSError) {
w.WriteHeader(error.statusCode)
server.logger.logger.Warn().Err(error.error).Msg(error.message)
func (server *Server) ThrowFatal(w http.ResponseWriter, err error) {
w.WriteHeader(http.StatusInternalServerError)
server.LogFatal(err)
}

273
hws/errors_test.go Normal file
View File

@@ -0,0 +1,273 @@
package hws_test
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type goodPage struct{}
type badPage struct{}
func goodRender(code int) (hws.ErrorPage, error) {
return goodPage{}, nil
}
func badRender1(code int) (hws.ErrorPage, error) {
return badPage{}, nil
}
func badRender2(code int) (hws.ErrorPage, error) {
return nil, errors.New("I'm an error")
}
func (g goodPage) Render(ctx context.Context, w io.Writer) error {
w.Write([]byte("Test write to ResponseWriter"))
return nil
}
func (b badPage) Render(ctx context.Context, w io.Writer) error {
return nil
}
func Test_AddErrorPage(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
goodRender := goodRender
badRender1 := badRender1
badRender2 := badRender2
tests := []struct {
name string
renderer hws.ErrorPageFunc
valid bool
}{
{
name: "Valid Renderer",
renderer: goodRender,
valid: true,
},
{
name: "Invalid Renderer 1",
renderer: badRender1,
valid: false,
},
{
name: "Invalid Renderer 2",
renderer: badRender2,
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := server.AddErrorPage(tt.renderer)
if tt.valid {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func Test_ThrowError(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
t.Run("Server not started", func(t *testing.T) {
err := server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "Error",
Error: errors.New("Error"),
})
assert.Error(t, err)
})
startTestServer(t, server)
defer server.Shutdown(t.Context())
tests := []struct {
name string
request *http.Request
error hws.HWSError
valid bool
}{
{
name: "No HWSError.Status code",
request: nil,
error: hws.HWSError{},
valid: false,
},
{
name: "Negative HWSError.Status code",
request: nil,
error: hws.HWSError{StatusCode: -1},
valid: false,
},
{
name: "No HWSError.Message",
request: nil,
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
valid: false,
},
{
name: "No HWSError.Error",
request: nil,
error: hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
},
valid: false,
},
{
name: "No request provided",
request: nil,
error: hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
},
valid: false,
},
{
name: "Valid",
request: httptest.NewRequest("GET", "/", nil),
error: hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
},
valid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rr := httptest.NewRecorder()
err := server.ThrowError(rr, tt.request, tt.error)
if tt.valid {
assert.NoError(t, err)
} else {
t.Log(err)
assert.Error(t, err)
}
})
}
t.Run("Log level set correctly", func(t *testing.T) {
buf.Reset()
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
Level: hws.ErrorWARN,
})
assert.NoError(t, err)
_, err = buf.ReadString([]byte(" ")[0])
loglvl, err := buf.ReadString([]byte(" ")[0])
assert.NoError(t, err)
if loglvl != "\x1b[33mWRN\x1b[0m " {
err = errors.New("Log level not set correctly")
}
assert.NoError(t, err)
buf.Reset()
err = server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
})
assert.NoError(t, err)
_, err = buf.ReadString([]byte(" ")[0])
loglvl, err = buf.ReadString([]byte(" ")[0])
assert.NoError(t, err)
if loglvl != "\x1b[31mERR\x1b[0m " {
err = errors.New("Log level not set correctly")
}
assert.NoError(t, err)
})
t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
// Must be run before adding the error page to the test server
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
RenderErrorPage: true,
})
assert.NoError(t, err)
body := rr.Body.String()
if body != "" {
assert.Error(t, nil)
}
})
t.Run("Error page renders", func(t *testing.T) {
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
// Adding the error page will carry over to all future tests and cant be undone
server.AddErrorPage(goodRender)
err := server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
RenderErrorPage: true,
})
assert.NoError(t, err)
body := rr.Body.String()
if body == "" {
assert.Error(t, nil)
}
})
t.Run("Error page doesnt render if no told to render", func(t *testing.T) {
// Error page already added to server
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
})
assert.NoError(t, err)
body := rr.Body.String()
if body != "" {
assert.Error(t, nil)
}
})
server.Shutdown(t.Context())
t.Run("Doesn't error if no logger added to server", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
})
require.NoError(t, err)
err = server.AddRoutes(hws.Route{
Path: "/",
Method: hws.MethodGET,
Handler: testHandler,
})
require.NoError(t, err)
err = server.Start(t.Context())
require.NoError(t, err)
<-server.Ready()
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err = server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
})
assert.NoError(t, err)
})
}

View File

@@ -3,12 +3,22 @@ module git.haelnorr.com/h/golib/hws
go 1.25.5
require (
git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/hlog v0.9.0
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.11.1
k8s.io/apimachinery v0.35.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rs/zerolog v1.34.0 // indirect
golang.org/x/sys v0.12.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect
)

View File

@@ -1,4 +1,12 @@
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/hlog v0.9.0 h1:ib8n2MdmiRK2TF067p220kXmhDe9aAnlcsgpuv+QpvE=
git.haelnorr.com/h/golib/hlog v0.9.0/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
@@ -7,10 +15,24 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8=
k8s.io/apimachinery v0.35.0/go.mod h1:jQCgFZFR1F4Ik7hvr2g84RTJSZegBc8yHgFWKn//hns=
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 h1:SjGebBtkBqHFOli+05xYbK8YF1Dzkbzn+gDM4X9T4Ck=
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=

223
hws/gzip_test.go Normal file
View File

@@ -0,0 +1,223 @@
package hws_test
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"testing"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/hws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_GZIP_Compression(t *testing.T) {
var buf bytes.Buffer
t.Run("GZIP enabled compresses response", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
GZIP: true,
})
require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "")
require.NoError(t, err)
err = server.AddLogger(logger)
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("This is a test response that should be compressed"))
})
err = server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
err = server.Start(t.Context())
require.NoError(t, err)
defer server.Shutdown(t.Context())
<-server.Ready()
// Make request with Accept-Encoding: gzip
client := &http.Client{}
req, err := http.NewRequest("GET", "http://"+server.Addr()+"/test", nil)
require.NoError(t, err)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Verify the response is gzip compressed
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
// Decompress and verify content
gzReader, err := gzip.NewReader(resp.Body)
require.NoError(t, err)
defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader)
require.NoError(t, err)
assert.Equal(t, "This is a test response that should be compressed", string(decompressed))
})
t.Run("GZIP disabled does not compress", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
GZIP: false,
})
require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "")
require.NoError(t, err)
err = server.AddLogger(logger)
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("This response should not be compressed"))
})
err = server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
err = server.Start(t.Context())
require.NoError(t, err)
defer server.Shutdown(t.Context())
<-server.Ready()
// Make request with Accept-Encoding: gzip
client := &http.Client{}
req, err := http.NewRequest("GET", "http://"+server.Addr()+"/test", nil)
require.NoError(t, err)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Verify the response is NOT gzip compressed
assert.Empty(t, resp.Header.Get("Content-Encoding"))
// Read plain content
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "This response should not be compressed", string(body))
})
t.Run("GZIP not used when client doesn't accept it", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
GZIP: true,
})
require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "")
require.NoError(t, err)
err = server.AddLogger(logger)
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("plain text"))
})
err = server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
err = server.Start(t.Context())
require.NoError(t, err)
defer server.Shutdown(t.Context())
<-server.Ready()
// Request without Accept-Encoding header should not be compressed
client := &http.Client{}
req, err := http.NewRequest("GET", "http://"+server.Addr()+"/test", nil)
require.NoError(t, err)
// Explicitly NOT setting Accept-Encoding header
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Verify the response is NOT gzip compressed even though server has GZIP enabled
assert.Empty(t, resp.Header.Get("Content-Encoding"))
// Read plain content
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "plain text", string(body))
})
}
func Test_GzipResponseWriter(t *testing.T) {
t.Run("Can write through gzip writer", func(t *testing.T) {
var buf bytes.Buffer
gzWriter := gzip.NewWriter(&buf)
testData := []byte("Test data to compress")
n, err := gzWriter.Write(testData)
require.NoError(t, err)
assert.Equal(t, len(testData), n)
err = gzWriter.Close()
require.NoError(t, err)
// Decompress and verify
gzReader, err := gzip.NewReader(&buf)
require.NoError(t, err)
defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader)
require.NoError(t, err)
assert.Equal(t, testData, decompressed)
})
t.Run("Headers are set correctly", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
})
// Create a simple middleware to test gzip behavior
testMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Header.Set("Accept-Encoding", "gzip")
next.ServeHTTP(w, r)
})
}
wrapped := testMiddleware(handler)
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Accept-Encoding", "gzip")
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
// Note: This is a simplified test
})
}

View File

@@ -5,21 +5,61 @@ import (
"fmt"
"net/url"
"github.com/rs/zerolog"
"git.haelnorr.com/h/golib/hlog"
)
type logger struct {
logger *zerolog.Logger
logger *hlog.Logger
ignoredPaths []string
}
// TODO: add tests to make sure all the fields are correctly set
func (s *Server) LogError(err HWSError) {
if s.logger == nil {
return
}
switch err.Level {
case ErrorDEBUG:
s.logger.logger.Debug().Err(err.Error).Msg(err.Message)
return
case ErrorINFO:
s.logger.logger.Info().Err(err.Error).Msg(err.Message)
return
case ErrorWARN:
s.logger.logger.Warn().Err(err.Error).Msg(err.Message)
return
case ErrorERROR:
s.logger.logger.Error().Err(err.Error).Msg(err.Message)
return
case ErrorFATAL:
s.logger.logger.Fatal().Err(err.Error).Msg(err.Message)
return
case ErrorPANIC:
s.logger.logger.Panic().Err(err.Error).Msg(err.Message)
return
default:
s.logger.logger.Error().Err(err.Error).Msg(err.Message)
}
}
func (server *Server) LogFatal(err error) {
if err == nil {
err = errors.New("LogFatal was called with a nil error")
}
if server.logger == nil {
fmt.Printf("FATAL - %s: %s", "A fatal error has occured", err.Error())
return
}
server.logger.logger.Fatal().Err(err).Msg("A fatal error has occured")
}
// Server.AddLogger adds a logger to the server to use for request logging.
func (server *Server) AddLogger(zlogger *zerolog.Logger) error {
if zlogger == nil {
func (server *Server) AddLogger(hlogger *hlog.Logger) error {
if hlogger == nil {
return errors.New("Unable to add logger, no logger provided")
}
server.logger = &logger{
logger: zlogger,
logger: hlogger,
}
return nil
}

239
hws/logger_test.go Normal file
View File

@@ -0,0 +1,239 @@
package hws_test
import (
"bytes"
"testing"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AddLogger(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
})
require.NoError(t, err)
t.Run("No logger provided", func(t *testing.T) {
err = server.AddLogger(nil)
assert.Error(t, err)
})
}
func Test_LogError_AllLevels(t *testing.T) {
t.Run("DEBUG level", func(t *testing.T) {
var buf bytes.Buffer
// Create server with logger explicitly set to Debug level
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
})
require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("debug"), &buf, nil, "")
require.NoError(t, err)
err = server.AddLogger(logger)
require.NoError(t, err)
testErr := hws.HWSError{
StatusCode: 500,
Message: "test message",
Error: errors.New("test error"),
Level: hws.ErrorDEBUG,
}
server.LogError(testErr)
output := buf.String()
// If output is empty, skip the test - debug logging might be disabled
if output == "" {
t.Skip("Debug logging appears to be disabled")
}
assert.Contains(t, output, "DBG", "Log output should contain the expected log level indicator")
assert.Contains(t, output, "test message", "Log output should contain the message")
assert.Contains(t, output, "test error", "Log output should contain the error")
})
tests := []struct {
name string
level hws.ErrorLevel
expected string
}{
{
name: "INFO level",
level: hws.ErrorINFO,
expected: "INF",
},
{
name: "WARN level",
level: hws.ErrorWARN,
expected: "WRN",
},
{
name: "ERROR level",
level: hws.ErrorERROR,
expected: "ERR",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
// Create an error with the specific level
testErr := hws.HWSError{
StatusCode: 500,
Message: "test message",
Error: errors.New("test error"),
Level: tt.level,
}
server.LogError(testErr)
output := buf.String()
assert.Contains(t, output, tt.expected, "Log output should contain the expected log level indicator")
assert.Contains(t, output, "test message", "Log output should contain the message")
assert.Contains(t, output, "test error", "Log output should contain the error")
})
}
t.Run("Default level when invalid level provided", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
testErr := hws.HWSError{
StatusCode: 500,
Message: "test message",
Error: errors.New("test error"),
Level: hws.ErrorLevel("InvalidLevel"),
}
server.LogError(testErr)
output := buf.String()
// Should default to ERROR level
assert.Contains(t, output, "ERR", "Invalid level should default to ERROR")
})
t.Run("LogError with nil logger does nothing", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
})
require.NoError(t, err)
// No logger added
testErr := hws.HWSError{
StatusCode: 500,
Message: "test message",
Error: errors.New("test error"),
Level: hws.ErrorERROR,
}
// Should not panic
server.LogError(testErr)
})
}
func Test_LogError_PANIC(t *testing.T) {
t.Run("PANIC level causes panic", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
testErr := hws.HWSError{
StatusCode: 500,
Message: "test panic message",
Error: errors.New("test panic error"),
Level: hws.ErrorPANIC,
}
// Should panic
assert.Panics(t, func() {
server.LogError(testErr)
}, "LogError with PANIC level should cause a panic")
// Check that the log was written before panic
output := buf.String()
assert.Contains(t, output, "test panic message")
assert.Contains(t, output, "test panic error")
})
}
func Test_LogFatal(t *testing.T) {
// Note: We cannot actually test Fatal() as it calls os.Exit()
// Testing this would require subprocess testing which is overly complex
// These tests document the expected behavior and verify the function signatures exist
t.Run("LogFatal with nil logger prints to stdout", func(t *testing.T) {
_, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
})
require.NoError(t, err)
// No logger added
// In production, LogFatal would print to stdout and exit
})
t.Run("LogFatal with nil error", func(t *testing.T) {
_, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
})
require.NoError(t, err)
// In production, nil errors are converted to a default error message
})
}
func Test_LoggerIgnorePaths(t *testing.T) {
t.Run("Invalid path with scheme", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
err := server.LoggerIgnorePaths("http://example.com/path")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path")
})
t.Run("Invalid path with host", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
err := server.LoggerIgnorePaths("//example.com/path")
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), "Invalid path")
}
})
t.Run("Invalid path with query", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
err := server.LoggerIgnorePaths("/path?query=value")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path")
})
t.Run("Invalid path with fragment", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
err := server.LoggerIgnorePaths("/path#fragment")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path")
})
t.Run("Valid paths", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
err := server.LoggerIgnorePaths("/static/css", "/favicon.ico", "/api/health")
assert.NoError(t, err)
})
}

View File

@@ -24,7 +24,7 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error {
}
// RUN GZIP
if server.gzip {
if server.GZIP {
server.server.Handler = addgzip(server.server.Handler)
}
// RUN TIMER MIDDLEWARE LAST
@@ -35,6 +35,11 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error {
return nil
}
// NewMiddleware returns a new Middleware for the server.
// A MiddlewareFunc is a function that takes in a http.ResponseWriter and http.Request,
// and returns a new request and optional HWSError.
// If a HWSError is returned, server.ThrowError will be called.
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
func (server *Server) NewMiddleware(
middlewareFunc MiddlewareFunc,
) Middleware {
@@ -42,8 +47,10 @@ func (server *Server) NewMiddleware(
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
newReq, herr := middlewareFunc(w, r)
if herr != nil {
server.ThrowError(w, r, herr)
return
server.ThrowError(w, r, *herr)
if herr.RenderErrorPage {
return
}
}
next.ServeHTTP(w, newReq)
})

249
hws/middleware_test.go Normal file
View File

@@ -0,0 +1,249 @@
package hws_test
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"git.haelnorr.com/h/golib/hws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AddMiddleware(t *testing.T) {
var buf bytes.Buffer
t.Run("Cannot add middleware before routes", func(t *testing.T) {
server := createTestServer(t, &buf)
err := server.AddMiddleware()
assert.Error(t, err)
assert.Contains(t, err.Error(), "Server.AddRoutes must be called before")
})
t.Run("Can add middleware after routes", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
err = server.AddMiddleware()
assert.NoError(t, err)
})
t.Run("Can add custom middleware", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
customMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom", "test")
next.ServeHTTP(w, r)
})
}
err = server.AddMiddleware(customMiddleware)
assert.NoError(t, err)
})
t.Run("Can add multiple middlewares", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
middleware1 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}
middleware2 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}
err = server.AddMiddleware(middleware1, middleware2)
assert.NoError(t, err)
})
}
func Test_NewMiddleware(t *testing.T) {
var buf bytes.Buffer
t.Run("NewMiddleware without error", func(t *testing.T) {
server := createTestServer(t, &buf)
middlewareFunc := func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
// Modify request or do something
return r, nil
}
middleware := server.NewMiddleware(middlewareFunc)
assert.NotNil(t, middleware)
// Test the middleware
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
})
wrappedHandler := middleware(handler)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
})
t.Run("NewMiddleware with error but no render", func(t *testing.T) {
server := createTestServer(t, &buf)
// Add routes and logger first
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
middlewareFunc := func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
return r, &hws.HWSError{
StatusCode: http.StatusBadRequest,
Message: "Test error",
Error: assert.AnError,
RenderErrorPage: false,
}
}
middleware := server.NewMiddleware(middlewareFunc)
wrappedHandler := middleware(handler)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
// Handler should still be called
assert.Equal(t, http.StatusOK, rr.Code)
})
t.Run("NewMiddleware with error and render", func(t *testing.T) {
server := createTestServer(t, &buf)
// Add routes and logger first
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("should not reach"))
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
middlewareFunc := func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
return r, &hws.HWSError{
StatusCode: http.StatusForbidden,
Message: "Access denied",
Error: assert.AnError,
RenderErrorPage: true,
}
}
middleware := server.NewMiddleware(middlewareFunc)
wrappedHandler := middleware(handler)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
// Handler should NOT be called, response should be empty or error page
body := rr.Body.String()
assert.NotContains(t, body, "should not reach")
})
t.Run("NewMiddleware can modify request", func(t *testing.T) {
server := createTestServer(t, &buf)
middlewareFunc := func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
// Add a header to the request
r.Header.Set("X-Modified", "true")
return r, nil
}
middleware := server.NewMiddleware(middlewareFunc)
var capturedHeader string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeader = r.Header.Get("X-Modified")
w.WriteHeader(http.StatusOK)
})
wrappedHandler := middleware(handler)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
assert.Equal(t, "true", capturedHeader)
})
}
func Test_Middleware_Ordering(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
var order []string
middleware1 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "middleware1")
next.ServeHTTP(w, r)
})
}
middleware2 := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "middleware2")
next.ServeHTTP(w, r)
})
}
err = server.AddMiddleware(middleware1, middleware2)
require.NoError(t, err)
// The middleware should execute in the order provided
// Note: This test is simplified and may need adjustment based on actual execution
}

160
hws/routes_test.go Normal file
View File

@@ -0,0 +1,160 @@
package hws_test
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"git.haelnorr.com/h/golib/hws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_AddRoutes(t *testing.T) {
var buf bytes.Buffer
t.Run("No routes provided", func(t *testing.T) {
server := createTestServer(t, &buf)
err := server.AddRoutes()
assert.Error(t, err)
assert.Contains(t, err.Error(), "No routes provided")
})
t.Run("Single valid route", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
assert.NoError(t, err)
})
t.Run("Multiple valid routes", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(
hws.Route{Path: "/test1", Method: hws.MethodGET, Handler: handler},
hws.Route{Path: "/test2", Method: hws.MethodPOST, Handler: handler},
hws.Route{Path: "/test3", Method: hws.MethodPUT, Handler: handler},
)
assert.NoError(t, err)
})
t.Run("Invalid method", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.Method("INVALID"),
Handler: handler,
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid method")
})
t.Run("No handler provided", func(t *testing.T) {
server := createTestServer(t, &buf)
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: nil,
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "No handler provided")
})
t.Run("All HTTP methods are valid", func(t *testing.T) {
methods := []hws.Method{
hws.MethodGET,
hws.MethodPOST,
hws.MethodPUT,
hws.MethodHEAD,
hws.MethodDELETE,
hws.MethodCONNECT,
hws.MethodOPTIONS,
hws.MethodTRACE,
hws.MethodPATCH,
}
for _, method := range methods {
t.Run(string(method), func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: method,
Handler: handler,
})
assert.NoError(t, err)
})
}
})
t.Run("Healthz endpoint is automatically added", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
// Test using httptest instead of starting the server
req := httptest.NewRequest("GET", "/healthz", nil)
rr := httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
})
}
func Test_Routes_EndToEnd(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
// Add multiple routes with different methods
getHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("GET response"))
})
postHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
w.Write([]byte("POST response"))
})
err := server.AddRoutes(
hws.Route{Path: "/get", Method: hws.MethodGET, Handler: getHandler},
hws.Route{Path: "/post", Method: hws.MethodPOST, Handler: postHandler},
)
require.NoError(t, err)
// Test GET request using httptest
req := httptest.NewRequest("GET", "/get", nil)
rr := httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "GET response", rr.Body.String())
// Test POST request using httptest
req = httptest.NewRequest("POST", "/post", nil)
rr = httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
assert.Equal(t, "POST response", rr.Body.String())
}

213
hws/safefileserver_test.go Normal file
View File

@@ -0,0 +1,213 @@
package hws_test
import (
"bytes"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"git.haelnorr.com/h/golib/hws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_SafeFileServer(t *testing.T) {
t.Run("Nil filesystem returns error", func(t *testing.T) {
handler, err := hws.SafeFileServer(nil)
assert.Error(t, err)
assert.Nil(t, handler)
assert.Contains(t, err.Error(), "No file system provided")
})
t.Run("Valid filesystem returns handler", func(t *testing.T) {
fs := http.Dir(".")
httpFS := http.FileSystem(fs)
handler, err := hws.SafeFileServer(&httpFS)
assert.NoError(t, err)
assert.NotNil(t, handler)
})
t.Run("Directory listing is blocked", func(t *testing.T) {
// Create a temporary directory
tmpDir := t.TempDir()
// Create some test files
testFile := filepath.Join(tmpDir, "test.txt")
err := os.WriteFile(testFile, []byte("test content"), 0644)
require.NoError(t, err)
fs := http.Dir(tmpDir)
httpFS := http.FileSystem(fs)
handler, err := hws.SafeFileServer(&httpFS)
require.NoError(t, err)
// Try to access the directory
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Should return 404 for directory listing
assert.Equal(t, http.StatusNotFound, rr.Code)
})
t.Run("Individual files are accessible", func(t *testing.T) {
// Create a temporary directory
tmpDir := t.TempDir()
// Create a test file
testFile := filepath.Join(tmpDir, "test.txt")
testContent := []byte("test content")
err := os.WriteFile(testFile, testContent, 0644)
require.NoError(t, err)
fs := http.Dir(tmpDir)
httpFS := http.FileSystem(fs)
handler, err := hws.SafeFileServer(&httpFS)
require.NoError(t, err)
// Try to access the file
req := httptest.NewRequest("GET", "/test.txt", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Should return 200 for file access
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, string(testContent), rr.Body.String())
})
t.Run("Non-existent file returns 404", func(t *testing.T) {
tmpDir := t.TempDir()
fs := http.Dir(tmpDir)
httpFS := http.FileSystem(fs)
handler, err := hws.SafeFileServer(&httpFS)
require.NoError(t, err)
req := httptest.NewRequest("GET", "/nonexistent.txt", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusNotFound, rr.Code)
})
t.Run("Subdirectory listing is blocked", func(t *testing.T) {
tmpDir := t.TempDir()
// Create a subdirectory
subDir := filepath.Join(tmpDir, "subdir")
err := os.Mkdir(subDir, 0755)
require.NoError(t, err)
// Create a file in the subdirectory
testFile := filepath.Join(subDir, "test.txt")
err = os.WriteFile(testFile, []byte("content"), 0644)
require.NoError(t, err)
fs := http.Dir(tmpDir)
httpFS := http.FileSystem(fs)
handler, err := hws.SafeFileServer(&httpFS)
require.NoError(t, err)
// Try to list the subdirectory
req := httptest.NewRequest("GET", "/subdir/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Should return 404 for subdirectory listing
assert.Equal(t, http.StatusNotFound, rr.Code)
})
t.Run("Files in subdirectories are accessible", func(t *testing.T) {
tmpDir := t.TempDir()
// Create a subdirectory
subDir := filepath.Join(tmpDir, "subdir")
err := os.Mkdir(subDir, 0755)
require.NoError(t, err)
// Create a file in the subdirectory
testFile := filepath.Join(subDir, "test.txt")
testContent := []byte("subdirectory content")
err = os.WriteFile(testFile, testContent, 0644)
require.NoError(t, err)
fs := http.Dir(tmpDir)
httpFS := http.FileSystem(fs)
handler, err := hws.SafeFileServer(&httpFS)
require.NoError(t, err)
// Try to access the file in the subdirectory
req := httptest.NewRequest("GET", "/subdir/test.txt", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, string(testContent), rr.Body.String())
})
t.Run("Hidden files are accessible", func(t *testing.T) {
tmpDir := t.TempDir()
// Create a hidden file (starting with .)
testFile := filepath.Join(tmpDir, ".hidden")
testContent := []byte("hidden content")
err := os.WriteFile(testFile, testContent, 0644)
require.NoError(t, err)
fs := http.Dir(tmpDir)
httpFS := http.FileSystem(fs)
handler, err := hws.SafeFileServer(&httpFS)
require.NoError(t, err)
req := httptest.NewRequest("GET", "/.hidden", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Hidden files should still be accessible
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, string(testContent), rr.Body.String())
})
}
func Test_SafeFileServer_Integration(t *testing.T) {
var buf bytes.Buffer
tmpDir := t.TempDir()
// Create test files
indexFile := filepath.Join(tmpDir, "index.html")
err := os.WriteFile(indexFile, []byte("<html>Test</html>"), 0644)
require.NoError(t, err)
cssFile := filepath.Join(tmpDir, "style.css")
err = os.WriteFile(cssFile, []byte("body { color: red; }"), 0644)
require.NoError(t, err)
// Create server with SafeFileServer
server := createTestServer(t, &buf)
fs := http.Dir(tmpDir)
httpFS := http.FileSystem(fs)
handler, err := hws.SafeFileServer(&httpFS)
require.NoError(t, err)
err = server.AddRoutes(hws.Route{
Path: "/static/",
Method: hws.MethodGET,
Handler: http.StripPrefix("/static", handler),
})
require.NoError(t, err)
err = server.Start(t.Context())
require.NoError(t, err)
defer server.Shutdown(t.Context())
<-server.Ready()
t.Run("Can serve static files through server", func(t *testing.T) {
// This would need actual HTTP requests to the running server
// Simplified for now
})
}

View File

@@ -3,48 +3,98 @@ package hws
import (
"context"
"fmt"
"net"
"net/http"
"sync"
"time"
"k8s.io/apimachinery/pkg/util/validation"
"github.com/pkg/errors"
)
type Server struct {
GZIP bool
server *http.Server
logger *logger
routes bool
middleware bool
gzip bool
errorPage ErrorPage
errorPage ErrorPageFunc
ready chan struct{}
}
// NewServer returns a new hws.Server with the specified parameters.
// The timeout options are specified in seconds
func NewServer(
host string,
port string,
readHeaderTimeout time.Duration,
writeTimeout time.Duration,
idleTimeout time.Duration,
gzip bool,
) (*Server, error) {
// TODO: test that host and port are valid values
httpServer := &http.Server{
Addr: net.JoinHostPort(host, port),
ReadHeaderTimeout: readHeaderTimeout * time.Second,
WriteTimeout: writeTimeout * time.Second,
IdleTimeout: idleTimeout * time.Second,
// Ready returns a channel that is closed when the server is started
func (server *Server) Ready() <-chan struct{} {
return server.ready
}
// IsReady checks if the server is running
func (server *Server) IsReady() bool {
select {
case <-server.ready:
return true
default:
return false
}
}
// Addr returns the server's network address
func (server *Server) Addr() string {
return server.server.Addr
}
// Handler returns the server's HTTP handler for testing purposes
func (server *Server) Handler() http.Handler {
return server.server.Handler
}
// NewServer returns a new hws.Server with the specified configuration.
func NewServer(config *Config) (*Server, error) {
if config == nil {
return nil, errors.New("Config cannot be nil")
}
// Apply defaults for undefined fields
if config.Host == "" {
config.Host = "127.0.0.1"
}
if config.Port == 0 {
config.Port = 3000
}
if config.ReadHeaderTimeout == 0 {
config.ReadHeaderTimeout = 2 * time.Second
}
if config.WriteTimeout == 0 {
config.WriteTimeout = 10 * time.Second
}
if config.IdleTimeout == 0 {
config.IdleTimeout = 120 * time.Second
}
valid := isValidHostname(config.Host)
if !valid {
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host)
}
httpServer := &http.Server{
Addr: fmt.Sprintf("%s:%v", config.Host, config.Port),
ReadHeaderTimeout: config.ReadHeaderTimeout,
WriteTimeout: config.WriteTimeout,
IdleTimeout: config.IdleTimeout,
}
server := &Server{
server: httpServer,
routes: false,
gzip: gzip,
GZIP: config.GZIP,
ready: make(chan struct{}),
}
return server, nil
}
func (server *Server) Start() error {
func (server *Server) Start(ctx context.Context) error {
if ctx == nil {
return errors.New("Context cannot be nil")
}
if !server.routes {
return errors.New("Server.AddRoutes must be run before starting the server")
}
@@ -65,20 +115,67 @@ func (server *Server) Start() error {
if server.logger == nil {
fmt.Printf("Server encountered a fatal error: %s", err.Error())
} else {
server.logger.logger.Error().Err(err).Msg("Server encountered a fatal error")
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
}
}
}()
server.waitUntilReady(ctx)
return nil
}
func (server *Server) Shutdown(ctx context.Context) {
if err := server.server.Shutdown(ctx); err != nil {
if server.logger == nil {
fmt.Printf("Failed to gracefully shutdown the server: %s", err.Error())
} else {
server.logger.logger.Error().Err(err).Msg("Failed to gracefully shutdown the server")
func (server *Server) Shutdown(ctx context.Context) error {
if !server.IsReady() {
return errors.New("Server isn't running")
}
if ctx == nil {
return errors.New("Context cannot be nil")
}
err := server.server.Shutdown(ctx)
if err != nil {
return errors.Wrap(err, "Failed to shutdown the server gracefully")
}
server.ready = make(chan struct{})
return nil
}
func isValidHostname(host string) bool {
// Validate as IP or hostname
if errs := validation.IsDNS1123Subdomain(host); len(errs) == 0 {
return true
}
// Check IPv4 / IPv6
if errs := validation.IsValidIP(nil, host); len(errs) == 0 {
return true
}
return false
}
func (server *Server) waitUntilReady(ctx context.Context) error {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
closeOnce := sync.Once{}
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
resp, err := http.Get("http://" + server.server.Addr + "/healthz")
if err != nil {
continue // not accepting yet
}
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
closeOnce.Do(func() { close(server.ready) })
return nil
}
}
}
}

209
hws/server_methods_test.go Normal file
View File

@@ -0,0 +1,209 @@
package hws_test
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"testing"
"git.haelnorr.com/h/golib/hws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_Server_Addr(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: "192.168.1.1",
Port: 8080,
})
require.NoError(t, err)
addr := server.Addr()
assert.Equal(t, "192.168.1.1:8080", addr)
}
func Test_Server_Handler(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
// Add routes first
handler := testHandler
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: handler,
})
require.NoError(t, err)
// Get the handler
h := server.Handler()
require.NotNil(t, h)
// Test the handler directly with httptest
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
assert.Equal(t, 200, rr.Code)
assert.Equal(t, "hello world", rr.Body.String())
}
func Test_LoggerIgnorePaths_Integration(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
// Add routes
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: testHandler,
}, hws.Route{
Path: "/ignore",
Method: hws.MethodGET,
Handler: testHandler,
})
require.NoError(t, err)
// Set paths to ignore
server.LoggerIgnorePaths("/ignore", "/healthz")
err = server.AddMiddleware()
require.NoError(t, err)
// Test that ignored path doesn't generate logs
buf.Reset()
req := httptest.NewRequest("GET", "/ignore", nil)
rr := httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
// Buffer should be empty for ignored path
assert.Empty(t, buf.String())
// Test that non-ignored path generates logs
buf.Reset()
req = httptest.NewRequest("GET", "/test", nil)
rr = httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
// Buffer should have logs for non-ignored path
assert.NotEmpty(t, buf.String())
}
func Test_WrappedWriter(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
// Add routes with different status codes
err := server.AddRoutes(
hws.Route{
Path: "/ok",
Method: hws.MethodGET,
Handler: testHandler,
},
hws.Route{
Path: "/created",
Method: hws.MethodPOST,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(201)
w.Write([]byte("created"))
}),
},
)
require.NoError(t, err)
err = server.AddMiddleware()
require.NoError(t, err)
// Test OK status
req := httptest.NewRequest("GET", "/ok", nil)
rr := httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, 200, rr.Code)
// Test Created status
req = httptest.NewRequest("POST", "/created", nil)
rr = httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, 201, rr.Code)
}
func Test_Start_Errors(t *testing.T) {
t.Run("Start fails when AddRoutes not called", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
err := server.Start(t.Context())
assert.Error(t, err)
assert.Contains(t, err.Error(), "Server.AddRoutes must be run before starting the server")
})
t.Run("Start fails with nil context", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: testHandler,
})
require.NoError(t, err)
err = server.Start(nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "Context cannot be nil")
})
}
func Test_Shutdown_Errors(t *testing.T) {
t.Run("Shutdown fails with nil context", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
startTestServer(t, server)
<-server.Ready()
err := server.Shutdown(nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "Context cannot be nil")
// Clean up
server.Shutdown(t.Context())
})
t.Run("Shutdown fails when server not running", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
err := server.Shutdown(t.Context())
assert.Error(t, err)
assert.Contains(t, err.Error(), "Server isn't running")
})
}
func Test_WaitUntilReady_ContextCancelled(t *testing.T) {
t.Run("Context cancelled before server ready", func(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Handler: testHandler,
})
require.NoError(t, err)
// Create a context with a very short timeout
ctx, cancel := context.WithTimeout(t.Context(), 1)
defer cancel()
// Start should return with context error since timeout is so short
err = server.Start(ctx)
// The error could be nil if server started very quickly, or context.DeadlineExceeded
// This tests the ctx.Err() path in waitUntilReady
if err != nil {
assert.Equal(t, context.DeadlineExceeded, err)
}
})
}

231
hws/server_test.go Normal file
View File

@@ -0,0 +1,231 @@
package hws_test
import (
"io"
"math/rand/v2"
"net/http"
"slices"
"testing"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/hws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var ports []uint64
func randomPort() uint64 {
port := uint64(3000 + rand.IntN(1001))
for slices.Contains(ports, port) {
port = uint64(3000 + rand.IntN(1001))
}
ports = append(ports, port)
return port
}
func createTestServer(t *testing.T, w io.Writer) *hws.Server {
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
})
require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), w, nil, "")
require.NoError(t, err)
err = server.AddLogger(logger)
require.NoError(t, err)
return server
}
var testHandler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello world"))
})
func startTestServer(t *testing.T, server *hws.Server) {
err := server.AddRoutes(hws.Route{
Path: "/",
Method: hws.MethodGET,
Handler: testHandler,
})
require.NoError(t, err)
err = server.Start(t.Context())
require.NoError(t, err)
t.Log("Test server started")
}
func Test_NewServer(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: "localhost",
Port: randomPort(),
})
require.NoError(t, err)
require.NotNil(t, server)
t.Run("Nil config returns error", func(t *testing.T) {
server, err := hws.NewServer(nil)
assert.Error(t, err)
assert.Nil(t, server)
assert.Contains(t, err.Error(), "Config cannot be nil")
})
tests := []struct {
name string
host string
port uint64
valid bool
}{
{
name: "Valid localhost on http",
host: "127.0.0.1",
port: 80,
valid: true,
},
{
name: "Valid IP on https",
host: "192.168.1.1",
port: 443,
valid: true,
},
{
name: "Valid IP on port 65535",
host: "10.0.0.5",
port: 65535,
valid: true,
},
{
name: "0.0.0.0 on port 8080",
host: "0.0.0.0",
port: 8080,
valid: true,
},
{
name: "Broadcast IP on port 1",
host: "255.255.255.255",
port: 1,
valid: true,
},
{
name: "Port 0 gets default",
host: "127.0.0.1",
port: 0,
valid: true, // port 0 now gets default value of 3000
},
{
name: "Invalid port 65536",
host: "127.0.0.1",
port: 65536,
valid: true, // port is accepted (validated at OS level)
},
{
name: "No hostname provided gets default",
host: "",
port: 80,
valid: true, // empty hostname gets default 127.0.0.1
},
{
name: "Spaces provided for host",
host: " ",
port: 80,
valid: false,
},
{
name: "Localhost as string",
host: "localhost",
port: 8080,
valid: true,
},
{
name: "Number only host",
host: "1234",
port: 80,
valid: true,
},
{
name: "Valid domain on http",
host: "example.com",
port: 80,
valid: true,
},
{
name: "Valid domain on https",
host: "a-b-c.example123.co",
port: 443,
valid: true,
},
{
name: "Valid domain starting with a digit",
host: "1example.com",
port: 8080,
valid: true, // labels may start with digits (RFC 1123)
},
{
name: "Single character hostname",
host: "a",
port: 1,
valid: true, // single-label hostname, min length
},
{
name: "Hostname starts with a hyphen",
host: "-example.com",
port: 80,
valid: false, // label starts with hyphen
},
{
name: "Hostname ends with a hyphen",
host: "example-.com",
port: 80,
valid: false, // label ends with hyphen
},
{
name: "Empty label in hostname",
host: "ex..ample.com",
port: 80,
valid: false, // empty label
},
{
name: "Invalid character: '_'",
host: "exa_mple.com",
port: 80,
valid: false, // invalid character (_)
},
{
name: "Trailing dot",
host: "example.com.",
port: 80,
valid: false, // trailing dot not allowed per spec
},
{
name: "Valid IPv6 localhost",
host: "::1",
port: 8080,
valid: true, // IPv6 localhost
},
{
name: "Valid IPv6 shortened",
host: "2001:db8::1",
port: 80,
valid: true, // shortened IPv6
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: tt.host,
Port: tt.port,
})
if tt.valid {
assert.NoError(t, err)
assert.NotNil(t, server)
} else {
assert.Error(t, err)
}
})
}
}

View File

@@ -1,7 +1,6 @@
package hwsauth
import (
"database/sql"
"net/http"
"time"
@@ -11,14 +10,14 @@ import (
// Check the cookies for token strings and attempt to authenticate them
func (auth *Authenticator[T]) getAuthenticatedUser(
tx *sql.Tx,
tx DBTransaction,
w http.ResponseWriter,
r *http.Request,
) (*authenticatedModel[T], error) {
) (authenticatedModel[T], error) {
// Get token strings from cookies
atStr, rtStr := jwt.GetTokenCookies(r)
if atStr == "" && rtStr == "" {
return nil, errors.New("No token strings provided")
return authenticatedModel[T]{}, errors.New("No token strings provided")
}
// Attempt to parse the access token
aT, err := auth.tokenGenerator.ValidateAccess(tx, atStr)
@@ -26,29 +25,29 @@ func (auth *Authenticator[T]) getAuthenticatedUser(
// Access token invalid, attempt to parse refresh token
rT, err := auth.tokenGenerator.ValidateRefresh(tx, rtStr)
if err != nil {
return nil, errors.Wrap(err, "auth.tokenGenerator.ValidateRefresh")
return authenticatedModel[T]{}, errors.Wrap(err, "auth.tokenGenerator.ValidateRefresh")
}
// Refresh token valid, attempt to get a new token pair
model, err := auth.refreshAuthTokens(tx, w, r, rT)
if err != nil {
return nil, errors.Wrap(err, "auth.refreshAuthTokens")
return authenticatedModel[T]{}, errors.Wrap(err, "auth.refreshAuthTokens")
}
// New token pair sent, return the authorized user
authUser := authenticatedModel[T]{
model: model,
fresh: time.Now().Unix(),
}
return &authUser, nil
return authUser, nil
}
// Access token valid
model, err := auth.load(tx, aT.SUB)
if err != nil {
return nil, errors.Wrap(err, "auth.load")
return authenticatedModel[T]{}, errors.Wrap(err, "auth.load")
}
authUser := authenticatedModel[T]{
model: model,
fresh: aT.Fresh,
}
return &authUser, nil
return authUser, nil
}

View File

@@ -10,31 +10,28 @@ import (
)
type Authenticator[T Model] struct {
tokenGenerator *jwt.TokenGenerator
load LoadFunc[T]
conn *sql.DB
ignoredPaths []string
logger *zerolog.Logger
server *hws.Server
errorPage hws.ErrorPage
SSL bool // Use SSL for JWT tokens. Default true
TrustedHost string // TrustedHost to use for SSL verification
SecretKey string // Secret key to use for JWT tokens
AccessTokenExpiry int64 // Expiry time for Access tokens in minutes. Default 5
RefreshTokenExpiry int64 // Expiry time for Refresh tokens in minutes. Default 1440 (1 day)
TokenFreshTime int64 // Expiry time of token freshness. Default 5 minutes
LandingPage string // Path of the desired landing page for logged in users
tokenGenerator *jwt.TokenGenerator
load LoadFunc[T]
conn DBConnection
ignoredPaths []string
logger *zerolog.Logger
server *hws.Server
errorPage hws.ErrorPageFunc
SSL bool // Use SSL for JWT tokens. Default true
LandingPage string // Path of the desired landing page for logged in users
}
// NewAuthenticator creates and returns a new Authenticator using the provided configuration.
// All expiry times should be provided in minutes.
// trustedHost and secretKey strings must be provided.
// If cfg is nil or any required fields are not set, default values will be used or an error returned.
// Required fields: SecretKey (no default)
// If SSL is true, TrustedHost is also required.
func NewAuthenticator[T Model](
cfg *Config,
load LoadFunc[T],
server *hws.Server,
conn *sql.DB,
conn DBConnection,
logger *zerolog.Logger,
errorPage hws.ErrorPage,
errorPage hws.ErrorPageFunc,
) (*Authenticator[T], error) {
if load == nil {
return nil, errors.New("No function to load model supplied")
@@ -51,43 +48,70 @@ func NewAuthenticator[T Model](
if errorPage == nil {
return nil, errors.New("No ErrorPage provided")
}
// Validate config
if cfg == nil {
return nil, errors.New("Config is required")
}
if cfg.SecretKey == "" {
return nil, errors.New("SecretKey is required")
}
if cfg.SSL && cfg.TrustedHost == "" {
return nil, errors.New("TrustedHost is required when SSL is enabled")
}
if cfg.AccessTokenExpiry == 0 {
cfg.AccessTokenExpiry = 5
}
if cfg.RefreshTokenExpiry == 0 {
cfg.RefreshTokenExpiry = 1440
}
if cfg.TokenFreshTime == 0 {
cfg.TokenFreshTime = 5
}
if cfg.LandingPage == "" {
cfg.LandingPage = "/profile"
}
// Cast DBConnection to *sql.DB
// DBConnection is satisfied by *sql.DB, so this cast should be safe for standard usage
sqlDB, ok := conn.(*sql.DB)
if !ok {
return nil, errors.New("DBConnection must be *sql.DB for JWT token generation")
}
// Configure JWT table
tableConfig := jwt.DefaultTableConfig()
if cfg.JWTTableName != "" {
tableConfig.TableName = cfg.JWTTableName
}
// Create token generator
tokenGen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
AccessExpireAfter: cfg.AccessTokenExpiry,
RefreshExpireAfter: cfg.RefreshTokenExpiry,
FreshExpireAfter: cfg.TokenFreshTime,
TrustedHost: cfg.TrustedHost,
SecretKey: cfg.SecretKey,
DBConn: sqlDB,
DBType: jwt.DatabaseType{
Type: cfg.DatabaseType,
Version: cfg.DatabaseVersion,
},
TableConfig: tableConfig,
})
if err != nil {
return nil, errors.Wrap(err, "jwt.CreateGenerator")
}
auth := Authenticator[T]{
load: load,
server: server,
conn: conn,
logger: logger,
errorPage: errorPage,
AccessTokenExpiry: 5,
RefreshTokenExpiry: 1440,
TokenFreshTime: 5,
SSL: true,
tokenGenerator: tokenGen,
load: load,
server: server,
conn: conn,
logger: logger,
errorPage: errorPage,
SSL: cfg.SSL,
LandingPage: cfg.LandingPage,
}
return &auth, nil
}
// Initialise finishes the setup and prepares the Authenticator for use.
// Any custom configuration must be set before Initialise is called
func (auth *Authenticator[T]) Initialise() error {
if auth.TrustedHost == "" {
return errors.New("Trusted host must be provided")
}
if auth.SecretKey == "" {
return errors.New("Secret key cannot be blank")
}
if auth.LandingPage == "" {
return errors.New("No landing page specified")
}
tokenGen, err := jwt.CreateGenerator(
auth.AccessTokenExpiry,
auth.RefreshTokenExpiry,
auth.TokenFreshTime,
auth.TrustedHost,
auth.SecretKey,
auth.conn,
)
if err != nil {
return errors.Wrap(err, "jwt.CreateGenerator")
}
auth.tokenGenerator = tokenGen
return nil
}

46
hwsauth/config.go Normal file
View File

@@ -0,0 +1,46 @@
package hwsauth
import (
"git.haelnorr.com/h/golib/env"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
type Config struct {
SSL bool // ENV HWSAUTH_SSL: Flag for SSL Mode (default: false)
TrustedHost string // ENV HWSAUTH_TRUSTED_HOST: Full server address to accept as trusted SSL host (required if SSL is true)
SecretKey string // ENV HWSAUTH_SECRET_KEY: Secret key for signing tokens (required)
AccessTokenExpiry int64 // ENV HWSAUTH_ACCESS_TOKEN_EXPIRY: Access token expiry in minutes (default: 5)
RefreshTokenExpiry int64 // ENV HWSAUTH_REFRESH_TOKEN_EXPIRY: Refresh token expiry in minutes (default: 1440)
TokenFreshTime int64 // ENV HWSAUTH_TOKEN_FRESH_TIME: Time for tokens to stay fresh in minutes (default: 5)
LandingPage string // ENV HWSAUTH_LANDING_PAGE: Path of the desired landing page for logged in users (default: "/profile")
DatabaseType string // ENV HWSAUTH_DATABASE_TYPE: Database type (postgres, mysql, sqlite, mariadb) (default: "postgres")
DatabaseVersion string // ENV HWSAUTH_DATABASE_VERSION: Database version (default: "15")
JWTTableName string // ENV HWSAUTH_JWT_TABLE_NAME: JWT blacklist table name (default: "jwtblacklist")
}
func ConfigFromEnv() (*Config, error) {
ssl := env.Bool("HWSAUTH_SSL", false)
trustedHost := env.String("HWS_TRUSTED_HOST", "")
if ssl && trustedHost == "" {
return nil, errors.New("SSL is enabled and no HWS_TRUSTED_HOST set")
}
cfg := &Config{
SSL: ssl,
TrustedHost: trustedHost,
SecretKey: env.String("HWSAUTH_SECRET_KEY", ""),
AccessTokenExpiry: env.Int64("HWSAUTH_ACCESS_TOKEN_EXPIRY", 5),
RefreshTokenExpiry: env.Int64("HWSAUTH_REFRESH_TOKEN_EXPIRY", 1440),
TokenFreshTime: env.Int64("HWSAUTH_TOKEN_FRESH_TIME", 5),
LandingPage: env.String("HWSAUTH_LANDING_PAGE", "/profile"),
DatabaseType: env.String("HWSAUTH_DATABASE_TYPE", jwt.DatabasePostgreSQL),
DatabaseVersion: env.String("HWSAUTH_DATABASE_VERSION", "15"),
JWTTableName: env.String("HWSAUTH_JWT_TABLE_NAME", "jwtblacklist"),
}
if cfg.SecretKey == "" {
return nil, errors.New("Envar not set: HWSAUTH_SECRET_KEY")
}
return cfg, nil
}

27
hwsauth/db.go Normal file
View File

@@ -0,0 +1,27 @@
package hwsauth
import (
"context"
"database/sql"
)
// DBTransaction represents a database transaction that can be committed or rolled back.
// This interface can be implemented by standard library sql.Tx, or by ORM transactions
// from libraries like bun, gorm, sqlx, etc.
type DBTransaction interface {
Commit() error
Rollback() error
}
// DBConnection represents a database connection that can begin transactions.
// This interface can be implemented by standard library sql.DB, or by ORM connections
// from libraries like bun, gorm, sqlx, etc.
type DBConnection interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error)
}
// Ensure *sql.Tx implements DBTransaction
var _ DBTransaction = (*sql.Tx)(nil)
// Ensure *sql.DB implements DBConnection
var _ DBConnection = (*sql.DB)(nil)

View File

@@ -4,16 +4,24 @@ go 1.25.5
require (
git.haelnorr.com/h/golib/cookies v0.9.0
git.haelnorr.com/h/golib/jwt v0.9.2
git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/hws v0.1.0
git.haelnorr.com/h/golib/jwt v0.9.2
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.34.0
)
replace git.haelnorr.com/h/golib/hws => ../hws
require (
git.haelnorr.com/h/golib/hlog v0.9.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
golang.org/x/sys v0.12.0 // indirect
k8s.io/apimachinery v0.35.0 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect
)

View File

@@ -1,7 +1,9 @@
git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDOV/AuWs=
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
git.haelnorr.com/h/golib/hws v0.1.0 h1:+0eNq1uGWrGfbS5AgHeGoGDjVfCWuaVu+1wBxgPqyOY=
git.haelnorr.com/h/golib/hws v0.1.0/go.mod h1:b2pbkMaebzmck9TxqGBGzTJPEcB5TWcEHwFknLE7dqM=
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/hlog v0.9.0 h1:ib8n2MdmiRK2TF067p220kXmhDe9aAnlcsgpuv+QpvE=
git.haelnorr.com/h/golib/hlog v0.9.0/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk=
git.haelnorr.com/h/golib/jwt v0.9.2 h1:l1Ow7DPGACAU54CnMP/NlZjdc4nRD1wr3xZ8a7taRvU=
git.haelnorr.com/h/golib/jwt v0.9.2/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
@@ -9,6 +11,8 @@ github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
@@ -34,3 +38,9 @@ golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8=
k8s.io/apimachinery v0.35.0/go.mod h1:jQCgFZFR1F4Ik7hvr2g84RTJSZegBc8yHgFWKn//hns=
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 h1:SjGebBtkBqHFOli+05xYbK8YF1Dzkbzn+gDM4X9T4Ck=
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=

View File

@@ -1,14 +1,13 @@
package hwsauth
import (
"database/sql"
"net/http"
"git.haelnorr.com/h/golib/cookies"
"github.com/pkg/errors"
)
func (auth *Authenticator[T]) Logout(tx *sql.Tx, w http.ResponseWriter, r *http.Request) error {
func (auth *Authenticator[T]) Logout(tx DBTransaction, w http.ResponseWriter, r *http.Request) error {
aT, rT, err := auth.getTokens(tx, r)
if err != nil {
return errors.Wrap(err, "auth.getTokens")

View File

@@ -23,7 +23,7 @@ func (auth *Authenticator[T]) authenticate() hws.MiddlewareFunc {
// Start the transaction
tx, err := auth.conn.BeginTx(ctx, nil)
if err != nil {
return nil, hws.NewError(http.StatusServiceUnavailable, "Unable to start transaction", err)
return nil, &hws.HWSError{Message: "Unable to start transaction", StatusCode: http.StatusServiceUnavailable, Error: err}
}
model, err := auth.getAuthenticatedUser(tx, w, r)
if err != nil {

View File

@@ -2,7 +2,6 @@ package hwsauth
import (
"context"
"database/sql"
)
type authenticatedModel[T Model] struct {
@@ -21,26 +20,39 @@ type Model interface {
type ContextLoader[T Model] func(ctx context.Context) T
type LoadFunc[T Model] func(tx *sql.Tx, id int) (T, error)
type LoadFunc[T Model] func(tx DBTransaction, id int) (T, error)
// Return a new context with the user added in
func setAuthenticatedModel[T Model](ctx context.Context, m *authenticatedModel[T]) context.Context {
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
return context.WithValue(ctx, "hwsauth context key authenticated-model", m)
}
// Retrieve a user from the given context. Returns nil if not set
func getAuthorizedModel[T Model](ctx context.Context) *authenticatedModel[T] {
model, ok := ctx.Value("hwsauth context key authenticated-model").(*authenticatedModel[T])
if !ok {
return nil
func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[T], ok bool) {
defer func() {
if r := recover(); r != nil {
// panic happened, return ok = false
ok = false
model = authenticatedModel[T]{}
}
}()
model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T])
if !cok {
return authenticatedModel[T]{}, false
}
return model
return model, true
}
func (auth *Authenticator[T]) CurrentModel(ctx context.Context) T {
model := getAuthorizedModel[T](ctx)
if model == nil {
auth.logger.Debug().Any("context", ctx).Msg("")
if ctx == nil {
return getNil[T]()
}
model, ok := getAuthorizedModel[T](ctx)
if !ok {
result := getNil[T]()
auth.logger.Debug().Any("model", result).Msg("")
return result
}
return model.model
}

View File

@@ -3,14 +3,33 @@ package hwsauth
import (
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
)
// Checks if the model is set in the context and shows 401 page if not logged in
func (auth *Authenticator[T]) LoginReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model := getAuthorizedModel[T](r.Context())
if model == nil {
auth.errorPage(http.StatusUnauthorized, w, r)
_, ok := getAuthorizedModel[T](r.Context())
if !ok {
page, err := auth.errorPage(http.StatusUnauthorized)
if err != nil {
auth.server.ThrowError(w, r, hws.HWSError{
Error: err,
Message: "Failed to get valid error page",
StatusCode: http.StatusInternalServerError,
RenderErrorPage: true,
})
}
err = page.Render(r.Context(), w)
if err != nil {
auth.server.ThrowError(w, r, hws.HWSError{
Error: err,
Message: "Failed to render error page",
StatusCode: http.StatusInternalServerError,
RenderErrorPage: true,
})
}
return
}
next.ServeHTTP(w, r)
@@ -21,8 +40,8 @@ func (auth *Authenticator[T]) LoginReq(next http.Handler) http.Handler {
// they are logged in
func (auth *Authenticator[T]) LogoutReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model := getAuthorizedModel[T](r.Context())
if model != nil {
_, ok := getAuthorizedModel[T](r.Context())
if ok {
http.Redirect(w, r, auth.LandingPage, http.StatusFound)
return
}
@@ -30,9 +49,33 @@ func (auth *Authenticator[T]) LogoutReq(next http.Handler) http.Handler {
})
}
// FreshReq protects a route from access if the auth token is not fresh.
// A status code of 444 will be written to the header and the request will be terminated.
// As an example, this can be used on the client to show a confirm password dialog to refresh their login
func (auth *Authenticator[T]) FreshReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model := getAuthorizedModel[T](r.Context())
model, ok := getAuthorizedModel[T](r.Context())
if !ok {
page, err := auth.errorPage(http.StatusUnauthorized)
if err != nil {
auth.server.ThrowError(w, r, hws.HWSError{
Error: err,
Message: "Failed to get valid error page",
StatusCode: http.StatusInternalServerError,
RenderErrorPage: true,
})
}
err = page.Render(r.Context(), w)
if err != nil {
auth.server.ThrowError(w, r, hws.HWSError{
Error: err,
Message: "Failed to render error page",
StatusCode: http.StatusInternalServerError,
RenderErrorPage: true,
})
}
return
}
isFresh := time.Now().Before(time.Unix(model.fresh, 0))
if !isFresh {
w.WriteHeader(444)

View File

@@ -1,14 +1,13 @@
package hwsauth
import (
"database/sql"
"net/http"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
func (auth *Authenticator[T]) RefreshAuthTokens(tx *sql.Tx, w http.ResponseWriter, r *http.Request) error {
func (auth *Authenticator[T]) RefreshAuthTokens(tx DBTransaction, w http.ResponseWriter, r *http.Request) error {
aT, rT, err := auth.getTokens(tx, r)
if err != nil {
return errors.Wrap(err, "getTokens")
@@ -32,7 +31,7 @@ func (auth *Authenticator[T]) RefreshAuthTokens(tx *sql.Tx, w http.ResponseWrite
// Get the tokens from the request
func (auth *Authenticator[T]) getTokens(
tx *sql.Tx,
tx DBTransaction,
r *http.Request,
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
// get the existing tokens from the cookies
@@ -50,7 +49,7 @@ func (auth *Authenticator[T]) getTokens(
// Revoke the given token pair
func revokeTokenPair(
tx *sql.Tx,
tx DBTransaction,
aT *jwt.AccessToken,
rT *jwt.RefreshToken,
) error {

View File

@@ -1,7 +1,6 @@
package hwsauth
import (
"database/sql"
"net/http"
"git.haelnorr.com/h/golib/jwt"
@@ -10,7 +9,7 @@ import (
// Attempt to use a valid refresh token to generate a new token pair
func (auth *Authenticator[T]) refreshAuthTokens(
tx *sql.Tx,
tx DBTransaction,
w http.ResponseWriter,
r *http.Request,
rT *jwt.RefreshToken,

1
jwt/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
.claude/

21
jwt/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 haelnorr
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

105
jwt/README.md Normal file
View File

@@ -0,0 +1,105 @@
# JWT Package
[![Go Reference](https://pkg.go.dev/badge/git.haelnorr.com/h/golib/jwt.svg)](https://pkg.go.dev/git.haelnorr.com/h/golib/jwt)
JWT (JSON Web Token) generation and validation with database-backed token revocation support.
## Features
- 🔐 Access and refresh token generation
- ✅ Token validation with expiration checking
- 🚫 Token revocation via database blacklist
- 🗄️ Multi-database support (PostgreSQL, MySQL, SQLite, MariaDB)
- 🔧 Compatible with database/sql, GORM, and Bun
- 🤖 Automatic table creation and management
- 🧹 Database-native automatic cleanup
- 🔄 Token freshness tracking
- 💾 "Remember me" functionality
## Installation
```bash
go get git.haelnorr.com/h/golib/jwt
```
## Quick Start
```go
package main
import (
"context"
"database/sql"
"git.haelnorr.com/h/golib/jwt"
_ "github.com/lib/pq"
)
func main() {
// Open database
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
defer db.Close()
// Create a transaction getter function
txGetter := func(ctx context.Context) (jwt.DBTransaction, error) {
return db.Begin()
}
// Create token generator
gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
AccessExpireAfter: 15, // 15 minutes
RefreshExpireAfter: 1440, // 24 hours
FreshExpireAfter: 5, // 5 minutes
TrustedHost: "example.com",
SecretKey: "your-secret-key",
DB: db,
DBType: jwt.DatabaseType{
Type: jwt.DatabasePostgreSQL,
Version: "15",
},
TableConfig: jwt.DefaultTableConfig(),
}, txGetter)
if err != nil {
panic(err)
}
// Generate tokens
accessToken, _, _ := gen.NewAccess(42, true, false)
refreshToken, _, _ := gen.NewRefresh(42, false)
// Validate token
tx, _ := db.Begin()
token, _ := gen.ValidateAccess(tx, accessToken)
// Revoke token
token.Revoke(tx)
tx.Commit()
}
```
## Documentation
Comprehensive documentation is available in the [Wiki](https://git.haelnorr.com/h/golib/wiki/JWT).
### Key Topics
- [Configuration](https://git.haelnorr.com/h/golib/wiki/JWT#configuration)
- [Token Generation](https://git.haelnorr.com/h/golib/wiki/JWT#token-generation)
- [Token Validation](https://git.haelnorr.com/h/golib/wiki/JWT#token-validation)
- [Token Revocation](https://git.haelnorr.com/h/golib/wiki/JWT#token-revocation)
- [Cleanup](https://git.haelnorr.com/h/golib/wiki/JWT#cleanup)
- [Using with ORMs](https://git.haelnorr.com/h/golib/wiki/JWT#using-with-orms)
## Supported Databases
- PostgreSQL
- MySQL
- MariaDB
- SQLite
## License
See LICENSE file in the repository root.
## Contributing
Contributions are welcome! Please open an issue or submit a pull request.

View File

@@ -6,7 +6,12 @@ import (
"time"
)
// Get the value of the access and refresh tokens
// GetTokenCookies extracts access and refresh tokens from HTTP request cookies.
// Returns empty strings for any cookies that don't exist.
//
// Returns:
// - acc: The access token value from the "access" cookie (empty if not found)
// - ref: The refresh token value from the "refresh" cookie (empty if not found)
func GetTokenCookies(
r *http.Request,
) (acc string, ref string) {
@@ -25,7 +30,16 @@ func GetTokenCookies(
return accStr, refStr
}
// Set a token with the provided details
// setToken is an internal helper that sets a token cookie with the specified parameters.
// The cookie is HttpOnly for security and uses SameSite=Lax mode.
//
// Parameters:
// - w: HTTP response writer to set the cookie on
// - token: The token value to store in the cookie
// - scope: The cookie name ("access" or "refresh")
// - exp: Unix timestamp when the token expires
// - rememberme: If true, sets cookie expiration; if false, cookie is session-only
// - useSSL: If true, marks cookie as Secure (HTTPS only)
func setToken(
w http.ResponseWriter,
token string,
@@ -48,7 +62,21 @@ func setToken(
http.SetCookie(w, tokenCookie)
}
// Generate new tokens for the subject and set them as cookies
// SetTokenCookies generates new access and refresh tokens for a user and sets them as HTTP cookies.
// This is a convenience function that combines token generation with cookie setting.
// Cookies are HttpOnly and use SameSite=Lax for security.
//
// Parameters:
// - w: HTTP response writer to set cookies on
// - r: HTTP request (unused but kept for API consistency)
// - tokenGen: The TokenGenerator to use for creating tokens
// - subject: The user ID to generate tokens for
// - fresh: If true, marks the access token as fresh for sensitive operations
// - rememberMe: If true, tokens persist beyond browser session
// - useSSL: If true, marks cookies as Secure (HTTPS only)
//
// Returns an error if token generation fails. Cookies are only set if both tokens
// are generated successfully.
func SetTokenCookies(
w http.ResponseWriter,
r *http.Request,

66
jwt/database.go Normal file
View File

@@ -0,0 +1,66 @@
package jwt
import (
"context"
"database/sql"
)
// DBTransaction represents a database transaction that can execute queries.
// This interface is compatible with *sql.Tx and can be implemented by ORM transactions
// from libraries like GORM (gormDB.Begin()), Bun (bunDB.Begin()), etc.
type DBTransaction interface {
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
Commit() error
Rollback() error
}
// BeginTX represents a wrapper function that is used to start a transaction with any dependencies injected
type BeginTX func(ctx context.Context) (DBTransaction, error)
// DatabaseType specifies the database system and version being used.
type DatabaseType struct {
Type string // Database type: "postgres", "mysql", "sqlite", "mariadb"
Version string // Version string, e.g., "15.3", "8.0.32", "3.42.0"
}
// Predefined database type constants for easy configuration and validation.
const (
DatabasePostgreSQL = "postgres"
DatabaseMySQL = "mysql"
DatabaseSQLite = "sqlite"
DatabaseMariaDB = "mariadb"
)
// TableConfig configures the JWT blacklist table.
type TableConfig struct {
// TableName is the name of the blacklist table.
// Default: "jwtblacklist"
TableName string
// AutoCreate determines whether to automatically create the table if it doesn't exist.
// Default: true
AutoCreate bool
// EnableAutoCleanup configures database-native automatic cleanup of expired tokens.
// For PostgreSQL: Creates a cleanup function (requires external scheduler or pg_cron)
// For MySQL/MariaDB: Creates a database event
// For SQLite: No automatic cleanup (manual only)
// Default: true
EnableAutoCleanup bool
// CleanupInterval specifies how often automatic cleanup should run (in hours).
// Only used if EnableAutoCleanup is true.
// Default: 24 (daily cleanup)
CleanupInterval int
}
// DefaultTableConfig returns a TableConfig with sensible defaults.
func DefaultTableConfig() TableConfig {
return TableConfig{
TableName: "jwtblacklist",
AutoCreate: true,
EnableAutoCleanup: true,
CleanupInterval: 24,
}
}

150
jwt/doc.go Normal file
View File

@@ -0,0 +1,150 @@
// Package jwt provides JWT (JSON Web Token) generation and validation with token revocation support.
//
// This package implements JWT access and refresh tokens with the ability to revoke tokens
// using a database-backed blacklist. It supports multiple database backends including
// PostgreSQL, MySQL, SQLite, and MariaDB, and works with both standard library database/sql
// and popular ORMs like GORM and Bun.
//
// # Features
//
// - Access and refresh token generation
// - Token validation with expiration checking
// - Token revocation via database blacklist
// - Support for multiple database types (PostgreSQL, MySQL, SQLite, MariaDB)
// - Compatible with database/sql, GORM, and Bun ORMs
// - Automatic table creation and management
// - Database-native automatic cleanup (PostgreSQL functions, MySQL events)
// - Manual cleanup method for on-demand token cleanup
// - Token freshness tracking for sensitive operations
// - "Remember me" functionality with session vs persistent tokens
//
// # Basic Usage
//
// Create a token generator with database support:
//
// db, _ := sql.Open("postgres", "connection_string")
// gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
// AccessExpireAfter: 15, // 15 minutes
// RefreshExpireAfter: 1440, // 24 hours
// FreshExpireAfter: 5, // 5 minutes
// TrustedHost: "example.com",
// SecretKey: "your-secret-key",
// DB: db,
// DBType: jwt.DatabaseType{Type: jwt.DatabasePostgreSQL, Version: "15"},
// TableConfig: jwt.DefaultTableConfig(),
// })
//
// Generate tokens:
//
// accessToken, accessExp, err := gen.NewAccess(userID, true, false)
// refreshToken, refreshExp, err := gen.NewRefresh(userID, false)
//
// Validate tokens (using standard library):
//
// tx, _ := db.Begin()
// token, err := gen.ValidateAccess(tx, accessToken)
// if err != nil {
// // Token is invalid or revoked
// }
// tx.Commit()
//
// Validate tokens (using ORM like GORM):
//
// tx := gormDB.Begin()
// token, err := gen.ValidateAccess(tx.Statement.ConnPool, accessToken)
// // or with Bun: gen.ValidateAccess(bunDB.BeginTx(ctx, nil), accessToken)
// tx.Commit()
//
// Revoke tokens:
//
// tx, _ := db.Begin()
// err := token.Revoke(tx)
// tx.Commit()
//
// # Database Configuration
//
// The package automatically creates a blacklist table with the following schema:
//
// CREATE TABLE jwtblacklist (
// jti UUID PRIMARY KEY, -- Token unique identifier
// exp BIGINT NOT NULL, -- Expiration timestamp
// sub INT NOT NULL, -- Subject (user) ID
// created_at TIMESTAMP -- When token was blacklisted
// );
//
// # Cleanup
//
// For PostgreSQL, the package creates a cleanup function that can be called manually
// or scheduled with pg_cron:
//
// SELECT cleanup_jwtblacklist();
//
// For MySQL/MariaDB, the package creates a database event that runs automatically
// (requires event_scheduler to be enabled).
//
// Manual cleanup can be performed at any time:
//
// err := gen.Cleanup(context.Background())
//
// # Using with ORMs
//
// The package works with popular ORMs by using raw SQL queries. For GORM and Bun,
// wrap the underlying *sql.DB with NewDBConnection() when creating the generator:
//
// // GORM example - can use GORM transactions directly
// gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{})
// sqlDB, _ := gormDB.DB()
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
// // ... config ...
// DB: sqlDB,
// })
// // Use GORM transaction
// tx := gormDB.Begin()
// token, _ := gen.ValidateAccess(tx.Statement.ConnPool, tokenString)
// tx.Commit()
//
// // Bun example - can use Bun transactions directly
// sqlDB, _ := sql.Open("postgres", dsn)
// bunDB := bun.NewDB(sqlDB, pgdialect.New())
// gen, _ := jwt.CreateGenerator(jwt.GeneratorConfig{
// // ... config ...
// DB: sqlDB,
// })
// // Use Bun transaction
// tx, _ := bunDB.BeginTx(context.Background(), nil)
// token, _ := gen.ValidateAccess(tx, tokenString)
// tx.Commit()
//
// # Token Freshness
//
// Tokens can be marked as "fresh" for sensitive operations. Fresh tokens are typically
// required for actions like changing passwords or email addresses:
//
// token, err := gen.ValidateAccess(exec, tokenString)
// if time.Now().Unix() > token.Fresh {
// // Token is not fresh, require re-authentication
// }
//
// # Custom Table Names
//
// You can customize the blacklist table name:
//
// config := jwt.DefaultTableConfig()
// config.TableName = "my_token_blacklist"
//
// # Disabling Database Features
//
// To use JWT without revocation support (no database):
//
// gen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
// AccessExpireAfter: 15,
// RefreshExpireAfter: 1440,
// FreshExpireAfter: 5,
// TrustedHost: "example.com",
// SecretKey: "your-secret-key",
// DB: nil, // No database
// })
//
// When DB is nil, revocation features are disabled and token validation
// will not check the blacklist.
package jwt

View File

@@ -1,62 +1,135 @@
package jwt
import (
"context"
"database/sql"
"errors"
"time"
pkgerrors "github.com/pkg/errors"
)
type TokenGenerator struct {
accessExpireAfter int64 // Access Token expiry time in minutes
refreshExpireAfter int64 // Refresh Token expiry time in minutes
freshExpireAfter int64 // Token freshness expiry time in minutes
trustedHost string // Trusted hostname to use for the tokens
secretKey string // Secret key to use for token hashing
dbConn *sql.DB // Database handle for token blacklisting
accessExpireAfter int64 // Access Token expiry time in minutes
refreshExpireAfter int64 // Refresh Token expiry time in minutes
freshExpireAfter int64 // Token freshness expiry time in minutes
trustedHost string // Trusted hostname to use for the tokens
secretKey string // Secret key to use for token hashing
beginTx BeginTX // Database transaction getter for token blacklisting
tableConfig TableConfig // Table configuration
tableManager *TableManager // Table lifecycle manager
}
// GeneratorConfig holds configuration for creating a TokenGenerator.
type GeneratorConfig struct {
// AccessExpireAfter is the access token expiry time in minutes.
AccessExpireAfter int64
// RefreshExpireAfter is the refresh token expiry time in minutes.
RefreshExpireAfter int64
// FreshExpireAfter is the token freshness expiry time in minutes.
FreshExpireAfter int64
// TrustedHost is the trusted hostname to use for the tokens.
TrustedHost string
// SecretKey is the secret key to use for token hashing.
SecretKey string
// DB is the database connection. Can be nil to disable token revocation.
// When using ORMs like GORM or Bun, pass the underlying *sql.DB.
DB *sql.DB
// DBType specifies the database type and version for proper table management.
// Only required if DB is not nil.
DBType DatabaseType
// TableConfig configures the blacklist table name and behavior.
// Only required if DB is not nil.
TableConfig TableConfig
}
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
// All expiry times should be provided in minutes.
// trustedHost and secretKey strings must be provided.
// dbConn can be nil, but doing this will disable token revocation
func CreateGenerator(
accessExpireAfter int64,
refreshExpireAfter int64,
freshExpireAfter int64,
trustedHost string,
secretKey string,
dbConn *sql.DB,
) (gen *TokenGenerator, err error) {
if accessExpireAfter <= 0 {
func CreateGenerator(config GeneratorConfig, txGetter BeginTX) (gen *TokenGenerator, err error) {
if config.AccessExpireAfter <= 0 {
return nil, errors.New("accessExpireAfter must be greater than 0")
}
if refreshExpireAfter <= 0 {
if config.RefreshExpireAfter <= 0 {
return nil, errors.New("refreshExpireAfter must be greater than 0")
}
if freshExpireAfter <= 0 {
if config.FreshExpireAfter <= 0 {
return nil, errors.New("freshExpireAfter must be greater than 0")
}
if trustedHost == "" {
if config.TrustedHost == "" {
return nil, errors.New("trustedHost cannot be an empty string")
}
if secretKey == "" {
if config.SecretKey == "" {
return nil, errors.New("secretKey cannot be an empty string")
}
if dbConn != nil {
err := dbConn.Ping()
if err != nil {
return nil, errors.New("Failed to ping database")
var tableManager *TableManager
if config.DB != nil {
// Create table manager
tableManager = NewTableManager(config.DB, config.DBType, config.TableConfig)
// Create table if AutoCreate is enabled
if config.TableConfig.AutoCreate {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = tableManager.CreateTable(ctx)
if err != nil {
return nil, pkgerrors.Wrap(err, "failed to create blacklist table")
}
}
// Setup automatic cleanup if enabled
if config.TableConfig.EnableAutoCleanup {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = tableManager.SetupAutoCleanup(ctx)
if err != nil {
return nil, pkgerrors.Wrap(err, "failed to setup automatic cleanup")
}
}
// TODO: check if jwtblacklist table exists
// TODO: create jwtblacklist table if not existing
}
return &TokenGenerator{
accessExpireAfter: accessExpireAfter,
refreshExpireAfter: refreshExpireAfter,
freshExpireAfter: freshExpireAfter,
trustedHost: trustedHost,
secretKey: secretKey,
dbConn: dbConn,
accessExpireAfter: config.AccessExpireAfter,
refreshExpireAfter: config.RefreshExpireAfter,
freshExpireAfter: config.FreshExpireAfter,
trustedHost: config.TrustedHost,
secretKey: config.SecretKey,
beginTx: txGetter,
tableConfig: config.TableConfig,
tableManager: tableManager,
}, nil
}
// Cleanup manually removes expired tokens from the blacklist table.
// This method should be called periodically if automatic cleanup is not enabled,
// or can be called on-demand regardless of automatic cleanup settings.
func (gen *TokenGenerator) Cleanup(ctx context.Context) error {
if gen.beginTx == nil {
return errors.New("No DB provided, unable to use this function")
}
tx, err := gen.beginTx(ctx)
if err != nil {
return pkgerrors.Wrap(err, "failed to begin transaction")
}
tableName := gen.tableConfig.TableName
currentTime := time.Now().Unix()
query := "DELETE FROM " + tableName + " WHERE exp < ?"
_, err = tx.Exec(query, currentTime)
if err != nil {
return pkgerrors.Wrap(err, "failed to cleanup expired tokens")
}
return nil
}

View File

@@ -1,6 +1,7 @@
package jwt
import (
"context"
"testing"
"github.com/DATA-DOG/go-sqlmock"
@@ -8,14 +9,16 @@ import (
)
func TestCreateGenerator_Success_NoDB(t *testing.T) {
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"secret",
nil,
)
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
FreshExpireAfter: 5,
TrustedHost: "example.com",
SecretKey: "secret",
DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
}, nil)
require.NoError(t, err)
require.NotNil(t, gen)
@@ -26,14 +29,62 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
require.NoError(t, err)
defer db.Close()
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"secret",
db,
)
config := DefaultTableConfig()
config.AutoCreate = false
config.EnableAutoCleanup = false
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
FreshExpireAfter: 5,
TrustedHost: "example.com",
SecretKey: "secret",
DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: config,
}, txGetter)
require.NoError(t, err)
require.NotNil(t, gen)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestCreateGenerator_WithDB_AutoCreate(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
// Mock table doesn't exist
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
WithArgs("jwtblacklist").
WillReturnRows(sqlmock.NewRows([]string{"1"}))
// Mock CREATE TABLE
mock.ExpectExec("CREATE TABLE IF NOT EXISTS jwtblacklist").
WillReturnResult(sqlmock.NewResult(0, 0))
// Mock cleanup function creation
mock.ExpectExec("CREATE OR REPLACE FUNCTION cleanup_jwtblacklist").
WillReturnResult(sqlmock.NewResult(0, 0))
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
FreshExpireAfter: 5,
TrustedHost: "example.com",
SecretKey: "secret",
DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
}, txGetter)
require.NoError(t, err)
require.NotNil(t, gen)
@@ -42,49 +93,118 @@ func TestCreateGenerator_Success_WithDB(t *testing.T) {
func TestCreateGenerator_InvalidInputs(t *testing.T) {
tests := []struct {
name string
fn func() error
name string
config GeneratorConfig
}{
{
"access expiry <= 0",
func() error {
_, err := CreateGenerator(0, 1, 1, "h", "s", nil)
return err
GeneratorConfig{
AccessExpireAfter: 0,
RefreshExpireAfter: 1,
FreshExpireAfter: 1,
TrustedHost: "h",
SecretKey: "s",
},
},
{
"refresh expiry <= 0",
func() error {
_, err := CreateGenerator(1, 0, 1, "h", "s", nil)
return err
GeneratorConfig{
AccessExpireAfter: 1,
RefreshExpireAfter: 0,
FreshExpireAfter: 1,
TrustedHost: "h",
SecretKey: "s",
},
},
{
"fresh expiry <= 0",
func() error {
_, err := CreateGenerator(1, 1, 0, "h", "s", nil)
return err
GeneratorConfig{
AccessExpireAfter: 1,
RefreshExpireAfter: 1,
FreshExpireAfter: 0,
TrustedHost: "h",
SecretKey: "s",
},
},
{
"empty trustedHost",
func() error {
_, err := CreateGenerator(1, 1, 1, "", "s", nil)
return err
GeneratorConfig{
AccessExpireAfter: 1,
RefreshExpireAfter: 1,
FreshExpireAfter: 1,
TrustedHost: "",
SecretKey: "s",
},
},
{
"empty secretKey",
func() error {
_, err := CreateGenerator(1, 1, 1, "h", "", nil)
return err
GeneratorConfig{
AccessExpireAfter: 1,
RefreshExpireAfter: 1,
FreshExpireAfter: 1,
TrustedHost: "h",
SecretKey: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Error(t, tt.fn())
_, err := CreateGenerator(tt.config, nil)
require.Error(t, err)
})
}
}
func TestCleanup_NoDB(t *testing.T) {
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
FreshExpireAfter: 5,
TrustedHost: "example.com",
SecretKey: "secret",
DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
}, nil)
require.NoError(t, err)
err = gen.Cleanup(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "No DB provided")
}
func TestCleanup_Success(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
config := DefaultTableConfig()
config.AutoCreate = false
config.EnableAutoCleanup = false
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
FreshExpireAfter: 5,
TrustedHost: "example.com",
SecretKey: "secret",
DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: config,
}, txGetter)
require.NoError(t, err)
// Mock transaction begin and DELETE query
mock.ExpectBegin()
mock.ExpectExec("DELETE FROM jwtblacklist WHERE exp").
WillReturnResult(sqlmock.NewResult(0, 5))
err = gen.Cleanup(context.Background())
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}

View File

@@ -1,38 +1,54 @@
package jwt
import (
"database/sql"
"fmt"
"github.com/pkg/errors"
)
// Revoke a token by adding it to the database
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
if gen.dbConn == nil {
// revoke is an internal method that adds a token to the blacklist database.
// Once revoked, the token will fail validation checks even if it hasn't expired.
// This operation must be performed within a database transaction.
func (gen *TokenGenerator) revoke(tx DBTransaction, t Token) error {
if gen.beginTx == nil {
return errors.New("No DB provided, unable to use this function")
}
tableName := gen.tableConfig.TableName
jti := t.GetJTI()
exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err := tx.Exec(query, jti, exp)
sub := t.GetSUB()
query := fmt.Sprintf("INSERT INTO %s (jti, exp, sub) VALUES (?, ?, ?)", tableName)
_, err := tx.Exec(query, jti.String(), exp, sub)
if err != nil {
return errors.Wrap(err, "tx.Exec")
return errors.Wrap(err, "tx.ExecContext")
}
return nil
}
// Check if a token has been revoked. Returns true if not revoked.
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
if gen.dbConn == nil {
// checkNotRevoked is an internal method that queries the blacklist to verify
// a token hasn't been revoked. Returns true if the token is valid (not blacklisted),
// false if it has been revoked. This operation must be performed within a database transaction.
func (gen *TokenGenerator) checkNotRevoked(tx DBTransaction, t Token) (bool, error) {
if gen.beginTx == nil {
return false, errors.New("No DB provided, unable to use this function")
}
tableName := gen.tableConfig.TableName
jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := tx.Query(query, jti)
query := fmt.Sprintf("SELECT 1 FROM %s WHERE jti = ? LIMIT 1", tableName)
rows, err := tx.Query(query, jti.String())
if err != nil {
return false, errors.Wrap(err, "tx.Query")
return false, errors.Wrap(err, "tx.QueryContext")
}
defer rows.Close()
revoked := rows.Next()
return !revoked, nil
exists := rows.Next()
if err := rows.Err(); err != nil {
return false, errors.Wrap(err, "rows iteration")
}
return !exists, nil
}

View File

@@ -12,19 +12,48 @@ import (
)
func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"supersecret",
nil,
)
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
FreshExpireAfter: 5,
TrustedHost: "example.com",
SecretKey: "supersecret",
DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
}, nil)
require.NoError(t, err)
return gen
}
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, *sql.DB, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
config := DefaultTableConfig()
config.AutoCreate = false
config.EnableAutoCleanup = false
txGetter := func(ctx context.Context) (DBTransaction, error) {
return db.Begin()
}
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
FreshExpireAfter: 5,
TrustedHost: "example.com",
SecretKey: "supersecret",
DB: db,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: config,
}, txGetter)
require.NoError(t, err)
return gen, db, mock, func() { db.Close() }
}
func TestNoDBFail(t *testing.T) {
jti := uuid.New()
exp := time.Now().Add(time.Hour).Unix()
@@ -32,42 +61,48 @@ func TestNoDBFail(t *testing.T) {
token := AccessToken{
JTI: jti,
EXP: exp,
SUB: 42,
gen: &TokenGenerator{},
}
// Create a nil transaction (can't revoke without DB)
var tx *sql.Tx = nil
// Revoke should fail due to no DB
err := token.Revoke(&sql.Tx{})
err := token.Revoke(tx)
require.Error(t, err)
// CheckNotRevoked should fail
_, err = token.CheckNotRevoked(&sql.Tx{})
_, err = token.CheckNotRevoked(tx)
require.Error(t, err)
}
func TestRevokeAndCheckNotRevoked(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t)
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup()
jti := uuid.New()
exp := time.Now().Add(time.Hour).Unix()
sub := 42
token := AccessToken{
JTI: jti,
EXP: exp,
SUB: sub,
gen: gen,
}
// Revoke expectations
mock.ExpectBegin()
mock.ExpectExec(`INSERT INTO jwtblacklist`).
WithArgs(jti, exp).
WithArgs(jti.String(), exp, sub).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
WithArgs(jti).
WithArgs(jti.String()).
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectCommit()
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
tx, err := db.Begin()
defer tx.Rollback()
require.NoError(t, err)

212
jwt/tablemanager.go Normal file
View File

@@ -0,0 +1,212 @@
package jwt
import (
"context"
"database/sql"
"fmt"
"github.com/pkg/errors"
)
// TableManager handles table creation, existence checks, and cleanup configuration.
type TableManager struct {
dbType DatabaseType
tableConfig TableConfig
db *sql.DB
}
// NewTableManager creates a new TableManager instance.
func NewTableManager(db *sql.DB, dbType DatabaseType, config TableConfig) *TableManager {
return &TableManager{
dbType: dbType,
tableConfig: config,
db: db,
}
}
// CreateTable creates the blacklist table if it doesn't exist.
func (tm *TableManager) CreateTable(ctx context.Context) error {
exists, err := tm.tableExists(ctx)
if err != nil {
return errors.Wrap(err, "failed to check if table exists")
}
if exists {
return nil // Table already exists
}
createSQL, err := tm.getCreateTableSQL()
if err != nil {
return err
}
_, err = tm.db.ExecContext(ctx, createSQL)
if err != nil {
return errors.Wrapf(err, "failed to create table %s", tm.tableConfig.TableName)
}
return nil
}
// tableExists checks if the blacklist table exists in the database.
func (tm *TableManager) tableExists(ctx context.Context) (bool, error) {
tableName := tm.tableConfig.TableName
var query string
var args []interface{}
switch tm.dbType.Type {
case DatabasePostgreSQL:
query = `
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = $1
`
args = []interface{}{tableName}
case DatabaseMySQL, DatabaseMariaDB:
query = `
SELECT 1 FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = ?
`
args = []interface{}{tableName}
case DatabaseSQLite:
query = `
SELECT 1 FROM sqlite_master
WHERE type = 'table'
AND name = ?
`
args = []interface{}{tableName}
default:
return false, errors.Errorf("unsupported database type: %s", tm.dbType.Type)
}
rows, err := tm.db.QueryContext(ctx, query, args...)
if err != nil {
return false, errors.Wrap(err, "failed to check table existence")
}
defer rows.Close()
return rows.Next(), nil
}
// getCreateTableSQL returns the CREATE TABLE statement for the given database type.
func (tm *TableManager) getCreateTableSQL() (string, error) {
tableName := tm.tableConfig.TableName
switch tm.dbType.Type {
case DatabasePostgreSQL:
return fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
jti UUID PRIMARY KEY,
exp BIGINT NOT NULL,
sub INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
`, tableName, tableName, tableName, tableName, tableName), nil
case DatabaseMySQL, DatabaseMariaDB:
return fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
jti CHAR(36) PRIMARY KEY,
exp BIGINT NOT NULL,
sub INT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_exp (exp),
INDEX idx_sub (sub)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
`, tableName), nil
case DatabaseSQLite:
return fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
jti TEXT PRIMARY KEY,
exp INTEGER NOT NULL,
sub INTEGER NOT NULL,
created_at INTEGER DEFAULT (strftime('%%s', 'now'))
);
CREATE INDEX IF NOT EXISTS idx_%s_exp ON %s(exp);
CREATE INDEX IF NOT EXISTS idx_%s_sub ON %s(sub);
`, tableName, tableName, tableName, tableName, tableName), nil
default:
return "", errors.Errorf("unsupported database type: %s", tm.dbType.Type)
}
}
// SetupAutoCleanup configures database-native automatic cleanup of expired tokens.
func (tm *TableManager) SetupAutoCleanup(ctx context.Context) error {
if !tm.tableConfig.EnableAutoCleanup {
return nil
}
switch tm.dbType.Type {
case DatabasePostgreSQL:
return tm.setupPostgreSQLCleanup(ctx)
case DatabaseMySQL, DatabaseMariaDB:
return tm.setupMySQLCleanup(ctx)
case DatabaseSQLite:
// SQLite doesn't support automatic cleanup
return nil
default:
return errors.Errorf("unsupported database type: %s", tm.dbType.Type)
}
}
// setupPostgreSQLCleanup creates a cleanup function for PostgreSQL.
// Note: This creates a function but does not schedule it. You need to use pg_cron
// or an external scheduler to call this function periodically.
func (tm *TableManager) setupPostgreSQLCleanup(ctx context.Context) error {
tableName := tm.tableConfig.TableName
functionName := fmt.Sprintf("cleanup_%s", tableName)
createFunctionSQL := fmt.Sprintf(`
CREATE OR REPLACE FUNCTION %s()
RETURNS void AS $$
BEGIN
DELETE FROM %s WHERE exp < EXTRACT(EPOCH FROM NOW());
END;
$$ LANGUAGE plpgsql;
`, functionName, tableName)
_, err := tm.db.ExecContext(ctx, createFunctionSQL)
if err != nil {
return errors.Wrap(err, "failed to create cleanup function")
}
// Note: Actual scheduling requires pg_cron extension or external tools
// Users should call this function periodically using:
// SELECT cleanup_jwtblacklist();
return nil
}
// setupMySQLCleanup creates a MySQL event for automatic cleanup.
// Note: Requires event_scheduler to be enabled in MySQL/MariaDB configuration.
func (tm *TableManager) setupMySQLCleanup(ctx context.Context) error {
tableName := tm.tableConfig.TableName
eventName := fmt.Sprintf("cleanup_%s_event", tableName)
interval := tm.tableConfig.CleanupInterval
// Drop existing event if it exists
dropEventSQL := fmt.Sprintf("DROP EVENT IF EXISTS %s", eventName)
_, err := tm.db.ExecContext(ctx, dropEventSQL)
if err != nil {
return errors.Wrap(err, "failed to drop existing event")
}
// Create new event
createEventSQL := fmt.Sprintf(`
CREATE EVENT %s
ON SCHEDULE EVERY %d HOUR
DO
DELETE FROM %s WHERE exp < UNIX_TIMESTAMP()
`, eventName, interval, tableName)
_, err = tm.db.ExecContext(ctx, createEventSQL)
if err != nil {
return errors.Wrapf(err, "failed to create cleanup event (ensure event_scheduler is enabled)")
}
return nil
}

221
jwt/tablemanager_test.go Normal file
View File

@@ -0,0 +1,221 @@
package jwt
import (
"context"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestNewTableManager(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
config := DefaultTableConfig()
tm := NewTableManager(db, dbType, config)
require.NotNil(t, tm)
}
func TestGetCreateTableSQL_PostgreSQL(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
config := DefaultTableConfig()
tm := NewTableManager(db, dbType, config)
sql, err := tm.getCreateTableSQL()
require.NoError(t, err)
require.Contains(t, sql, "CREATE TABLE IF NOT EXISTS jwtblacklist")
require.Contains(t, sql, "jti UUID PRIMARY KEY")
require.Contains(t, sql, "exp BIGINT NOT NULL")
require.Contains(t, sql, "sub INTEGER NOT NULL")
require.Contains(t, sql, "CREATE INDEX IF NOT EXISTS idx_jwtblacklist_exp")
require.Contains(t, sql, "CREATE INDEX IF NOT EXISTS idx_jwtblacklist_sub")
}
func TestGetCreateTableSQL_MySQL(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabaseMySQL, Version: "8.0"}
config := DefaultTableConfig()
tm := NewTableManager(db, dbType, config)
sql, err := tm.getCreateTableSQL()
require.NoError(t, err)
require.Contains(t, sql, "CREATE TABLE IF NOT EXISTS jwtblacklist")
require.Contains(t, sql, "jti CHAR(36) PRIMARY KEY")
require.Contains(t, sql, "exp BIGINT NOT NULL")
require.Contains(t, sql, "sub INT NOT NULL")
require.Contains(t, sql, "INDEX idx_exp")
require.Contains(t, sql, "ENGINE=InnoDB")
}
func TestGetCreateTableSQL_SQLite(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabaseSQLite, Version: "3.42"}
config := DefaultTableConfig()
tm := NewTableManager(db, dbType, config)
sql, err := tm.getCreateTableSQL()
require.NoError(t, err)
require.Contains(t, sql, "CREATE TABLE IF NOT EXISTS jwtblacklist")
require.Contains(t, sql, "jti TEXT PRIMARY KEY")
require.Contains(t, sql, "exp INTEGER NOT NULL")
require.Contains(t, sql, "sub INTEGER NOT NULL")
}
func TestGetCreateTableSQL_CustomTableName(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
config := TableConfig{
TableName: "custom_blacklist",
AutoCreate: true,
EnableAutoCleanup: false,
CleanupInterval: 24,
}
tm := NewTableManager(db, dbType, config)
sql, err := tm.getCreateTableSQL()
require.NoError(t, err)
require.Contains(t, sql, "CREATE TABLE IF NOT EXISTS custom_blacklist")
require.Contains(t, sql, "CREATE INDEX IF NOT EXISTS idx_custom_blacklist_exp")
}
func TestGetCreateTableSQL_UnsupportedDB(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: "unsupported", Version: "1.0"}
config := DefaultTableConfig()
tm := NewTableManager(db, dbType, config)
sql, err := tm.getCreateTableSQL()
require.Error(t, err)
require.Empty(t, sql)
require.Contains(t, err.Error(), "unsupported database type")
}
func TestTableExists_PostgreSQL(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
config := DefaultTableConfig()
tm := NewTableManager(db, dbType, config)
// Test table exists
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
WithArgs("jwtblacklist").
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
exists, err := tm.tableExists(context.Background())
require.NoError(t, err)
require.True(t, exists)
// Test table doesn't exist
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
WithArgs("jwtblacklist").
WillReturnRows(sqlmock.NewRows([]string{"1"}))
exists, err = tm.tableExists(context.Background())
require.NoError(t, err)
require.False(t, exists)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestCreateTable_AlreadyExists(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
config := DefaultTableConfig()
tm := NewTableManager(db, dbType, config)
// Mock table exists check
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
WithArgs("jwtblacklist").
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
err = tm.CreateTable(context.Background())
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestCreateTable_Success(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
config := DefaultTableConfig()
tm := NewTableManager(db, dbType, config)
// Mock table doesn't exist
mock.ExpectQuery("SELECT 1 FROM information_schema.tables").
WithArgs("jwtblacklist").
WillReturnRows(sqlmock.NewRows([]string{"1"}))
// Mock CREATE TABLE
mock.ExpectExec("CREATE TABLE IF NOT EXISTS jwtblacklist").
WillReturnResult(sqlmock.NewResult(0, 0))
err = tm.CreateTable(context.Background())
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestSetupAutoCleanup_Disabled(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabasePostgreSQL, Version: "15"}
config := TableConfig{
TableName: "jwtblacklist",
AutoCreate: true,
EnableAutoCleanup: false,
CleanupInterval: 24,
}
tm := NewTableManager(db, dbType, config)
err = tm.SetupAutoCleanup(context.Background())
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestSetupAutoCleanup_SQLite(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
dbType := DatabaseType{Type: DatabaseSQLite, Version: "3.42"}
config := DefaultTableConfig()
tm := NewTableManager(db, dbType, config)
// SQLite doesn't support auto-cleanup, should return nil
err = tm.SetupAutoCleanup(context.Background())
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}

View File

@@ -8,7 +8,21 @@ import (
"github.com/pkg/errors"
)
// Generates an access token for the provided subject
// NewAccess generates a new JWT access token for the specified subject (user).
//
// Parameters:
// - subjectID: The user ID or subject identifier to associate with the token
// - fresh: If true, marks the token as "fresh" for sensitive operations.
// Fresh tokens are typically required for actions like changing passwords
// or email addresses. The token remains fresh until FreshExpireAfter minutes.
// - rememberMe: If true, the token is persistent (TTL="exp") and will be stored
// with an expiration date. If false, it's session-only (TTL="session") and
// expires when the browser closes.
//
// Returns:
// - tokenString: The signed JWT token string
// - expiresIn: Unix timestamp when the token expires
// - err: Any error encountered during token generation
func (gen *TokenGenerator) NewAccess(
subjectID int,
fresh bool,
@@ -47,7 +61,19 @@ func (gen *TokenGenerator) NewAccess(
return signedToken, expiresAt, nil
}
// Generates a refresh token for the provided user
// NewRefresh generates a new JWT refresh token for the specified subject (user).
// Refresh tokens are used to obtain new access tokens without re-authentication.
//
// Parameters:
// - subjectID: The user ID or subject identifier to associate with the token
// - rememberMe: If true, the token is persistent (TTL="exp") and will be stored
// with an expiration date. If false, it's session-only (TTL="session") and
// expires when the browser closes.
//
// Returns:
// - tokenStr: The signed JWT token string
// - exp: Unix timestamp when the token expires
// - err: Any error encountered during token generation
func (gen *TokenGenerator) NewRefresh(
subjectID int,
rememberMe bool,

View File

@@ -7,14 +7,16 @@ import (
)
func newTestGenerator(t *testing.T) *TokenGenerator {
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"supersecret",
nil,
)
gen, err := CreateGenerator(GeneratorConfig{
AccessExpireAfter: 15,
RefreshExpireAfter: 60,
FreshExpireAfter: 5,
TrustedHost: "example.com",
SecretKey: "supersecret",
DB: nil,
DBType: DatabaseType{Type: DatabasePostgreSQL, Version: "15"},
TableConfig: DefaultTableConfig(),
}, nil)
require.NoError(t, err)
return gen
}

View File

@@ -1,20 +1,39 @@
package jwt
import (
"database/sql"
"github.com/google/uuid"
)
// Token is the common interface implemented by both AccessToken and RefreshToken.
// It provides methods to access token claims and manage token revocation.
type Token interface {
// GetJTI returns the unique token identifier (JTI claim)
GetJTI() uuid.UUID
// GetEXP returns the expiration timestamp (EXP claim)
GetEXP() int64
// GetSUB returns the subject/user ID (SUB claim)
GetSUB() int
// GetScope returns the token scope ("access" or "refresh")
GetScope() string
Revoke(*sql.Tx) error
CheckNotRevoked(*sql.Tx) (bool, error)
// Revoke adds this token to the blacklist, preventing future use.
// Must be called within a database transaction context.
// Accepts any transaction type that implements DBTransaction interface.
Revoke(DBTransaction) error
// CheckNotRevoked verifies that this token has not been blacklisted.
// Returns true if the token is valid, false if revoked.
// Must be called within a database transaction context.
// Accepts any transaction type that implements DBTransaction interface.
CheckNotRevoked(DBTransaction) (bool, error)
}
// Access token
// AccessToken represents a JWT access token with all its claims.
// Access tokens are short-lived and used for authenticating API requests.
// They can be marked as "fresh" for sensitive operations like password changes.
type AccessToken struct {
ISS string // Issuer, generally TrustedHost
IAT int64 // Time issued at
@@ -27,7 +46,9 @@ type AccessToken struct {
gen *TokenGenerator
}
// Refresh token
// RefreshToken represents a JWT refresh token with all its claims.
// Refresh tokens are longer-lived and used to obtain new access tokens
// without requiring the user to re-authenticate.
type RefreshToken struct {
ISS string // Issuer, generally TrustedHost
IAT int64 // Time issued at
@@ -51,21 +72,27 @@ func (a AccessToken) GetEXP() int64 {
func (r RefreshToken) GetEXP() int64 {
return r.EXP
}
func (a AccessToken) GetSUB() int {
return a.SUB
}
func (r RefreshToken) GetSUB() int {
return r.SUB
}
func (a AccessToken) GetScope() string {
return a.Scope
}
func (r RefreshToken) GetScope() string {
return r.Scope
}
func (a AccessToken) Revoke(tx *sql.Tx) error {
func (a AccessToken) Revoke(tx DBTransaction) error {
return a.gen.revoke(tx, a)
}
func (r RefreshToken) Revoke(tx *sql.Tx) error {
func (r RefreshToken) Revoke(tx DBTransaction) error {
return r.gen.revoke(tx, r)
}
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
func (a AccessToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
return a.gen.checkNotRevoked(tx, a)
}
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
func (r RefreshToken) CheckNotRevoked(tx DBTransaction) (bool, error) {
return r.gen.checkNotRevoked(tx, r)
}

View File

@@ -1,16 +1,32 @@
package jwt
import (
"database/sql"
"github.com/pkg/errors"
)
// Parse an access token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
// ValidateAccess parses and validates a JWT access token string.
//
// This method performs comprehensive validation including:
// - Signature verification using the secret key
// - Expiration time checking (token must not be expired)
// - Issuer verification (must match trusted host)
// - Scope verification (must be "access" token)
// - Revocation status check (if database is configured)
//
// The validation must be performed within a database transaction context to ensure
// consistency when checking the blacklist. If no database is configured, the
// revocation check is skipped.
//
// Parameters:
// - tx: Database transaction for checking token revocation status.
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
// - tokenString: The JWT token string to validate
//
// Returns:
// - *AccessToken: The validated token with all claims, or nil if validation fails
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
func (gen *TokenGenerator) ValidateAccess(
tx *sql.Tx,
tx DBTransaction,
tokenString string,
) (*AccessToken, error) {
if tokenString == "" {
@@ -69,20 +85,38 @@ func (gen *TokenGenerator) ValidateAccess(
}
valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.dbConn != nil {
if err != nil && gen.beginTx != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked")
}
if !valid && gen.dbConn != nil {
if !valid && gen.beginTx != nil {
return nil, errors.New("Token has been revoked")
}
return token, nil
}
// Parse a refresh token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
// ValidateRefresh parses and validates a JWT refresh token string.
//
// This method performs comprehensive validation including:
// - Signature verification using the secret key
// - Expiration time checking (token must not be expired)
// - Issuer verification (must match trusted host)
// - Scope verification (must be "refresh" token)
// - Revocation status check (if database is configured)
//
// The validation must be performed within a database transaction context to ensure
// consistency when checking the blacklist. If no database is configured, the
// revocation check is skipped.
//
// Parameters:
// - tx: Database transaction for checking token revocation status.
// Accepts *sql.Tx or any ORM transaction implementing DBTransaction interface.
// - tokenString: The JWT token string to validate
//
// Returns:
// - *RefreshToken: The validated token with all claims, or nil if validation fails
// - error: Detailed error if validation fails (expired, revoked, invalid signature, etc.)
func (gen *TokenGenerator) ValidateRefresh(
tx *sql.Tx,
tx DBTransaction,
tokenString string,
) (*RefreshToken, error) {
if tokenString == "" {
@@ -136,10 +170,10 @@ func (gen *TokenGenerator) ValidateRefresh(
}
valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.dbConn != nil {
if err != nil && gen.beginTx != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked")
}
if !valid && gen.dbConn != nil {
if !valid && gen.beginTx != nil {
return nil, errors.New("Token has been revoked")
}
return token, nil

View File

@@ -1,7 +1,6 @@
package jwt
import (
"context"
"database/sql"
"testing"
@@ -9,23 +8,6 @@ import (
"github.com/stretchr/testify/require"
)
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
gen, err := CreateGenerator(
15,
60,
5,
"example.com",
"supersecret",
db,
)
require.NoError(t, err)
return gen, mock, func() { db.Close() }
}
func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
mock.ExpectBegin()
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
@@ -35,7 +17,7 @@ func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
}
func TestValidateAccess_Success(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t)
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup()
tokenStr, _, err := gen.NewAccess(42, true, false)
@@ -44,7 +26,7 @@ func TestValidateAccess_Success(t *testing.T) {
// We don't know the JTI beforehand; match any arg
expectNotRevoked(mock, sqlmock.AnyArg())
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
@@ -61,14 +43,17 @@ func TestValidateAccess_NoDB(t *testing.T) {
tokenStr, _, err := gen.NewAccess(42, true, false)
require.NoError(t, err)
token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
// Use nil transaction for no-db case
var tx *sql.Tx = nil
token, err := gen.ValidateAccess(tx, tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "access", token.Scope)
}
func TestValidateRefresh_Success(t *testing.T) {
gen, mock, cleanup := newGeneratorWithMockDB(t)
gen, db, mock, cleanup := newGeneratorWithMockDB(t)
defer cleanup()
tokenStr, _, err := gen.NewRefresh(42, false)
@@ -76,7 +61,7 @@ func TestValidateRefresh_Success(t *testing.T) {
expectNotRevoked(mock, sqlmock.AnyArg())
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
@@ -93,7 +78,10 @@ func TestValidateRefresh_NoDB(t *testing.T) {
tokenStr, _, err := gen.NewRefresh(42, false)
require.NoError(t, err)
token, err := gen.ValidateRefresh(nil, tokenStr)
// Use nil transaction for no-db case
var tx *sql.Tx = nil
token, err := gen.ValidateRefresh(tx, tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "refresh", token.Scope)
@@ -102,7 +90,10 @@ func TestValidateRefresh_NoDB(t *testing.T) {
func TestValidateAccess_EmptyToken(t *testing.T) {
gen := newTestGenerator(t)
_, err := gen.ValidateAccess(nil, "")
// Use nil transaction
var tx *sql.Tx = nil
_, err := gen.ValidateAccess(tx, "")
require.Error(t, err)
}
@@ -113,6 +104,9 @@ func TestValidateRefresh_WrongScope(t *testing.T) {
tokenStr, _, err := gen.NewAccess(1, false, false)
require.NoError(t, err)
_, err = gen.ValidateRefresh(nil, tokenStr)
// Use nil transaction
var tx *sql.Tx = nil
_, err = gen.ValidateRefresh(tx, tokenStr)
require.Error(t, err)
}