Compare commits

...

3 Commits

10 changed files with 258 additions and 30 deletions

View File

@@ -51,6 +51,12 @@ func main() {
Method: hws.MethodGET,
Handler: http.HandlerFunc(getUserHandler),
},
{
// Single route handling multiple HTTP methods
Path: "/api/resource",
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST, hws.MethodPUT},
Handler: http.HandlerFunc(resourceHandler),
},
}
// Add routes and middleware
@@ -73,6 +79,18 @@ func getUserHandler(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
w.Write([]byte("User ID: " + id))
}
func resourceHandler(w http.ResponseWriter, r *http.Request) {
// Handle GET, POST, and PUT for the same path
switch r.Method {
case "GET":
w.Write([]byte("Getting resource"))
case "POST":
w.Write([]byte("Creating resource"))
case "PUT":
w.Write([]byte("Updating resource"))
}
}
```
## Documentation

View File

@@ -74,6 +74,18 @@
// },
// }
//
// A single route can handle multiple HTTP methods using the Methods field:
//
// routes := []hws.Route{
// {
// Path: "/api/resource",
// Methods: []hws.Method{hws.MethodGET, hws.MethodPOST, hws.MethodPUT},
// Handler: http.HandlerFunc(resourceHandler),
// },
// }
//
// Note: The Methods field takes precedence over Method if both are provided.
//
// Path parameters can be accessed using r.PathValue():
//
// func getUser(w http.ResponseWriter, r *http.Request) {

View File

@@ -4,11 +4,15 @@ import (
"errors"
"fmt"
"net/http"
"slices"
)
type Route struct {
Path string // Absolute path to the requested resource
Method Method // HTTP Method
Path string // Absolute path to the requested resource
Method Method // HTTP Method
// Methods is an optional slice of Methods to use, if more than one can use the same handler.
// Will take precedence over the Method field if provided
Methods []Method
Handler http.Handler // Handler to use for the request
}
@@ -28,21 +32,33 @@ const (
// Server.AddRoutes registers the page handlers for the server.
// At least one route must be provided.
// If any route patterns (path + method) are defined multiple times, the first
// instance will be added and any additional conflicts will be discarded.
func (server *Server) AddRoutes(routes ...Route) error {
if len(routes) == 0 {
return errors.New("No routes provided")
}
patterns := []string{}
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {})
for _, route := range routes {
if !validMethod(route.Method) {
return fmt.Errorf("Invalid method %s for path %s", route.Method, route.Path)
if len(route.Methods) == 0 {
route.Methods = []Method{route.Method}
}
if route.Handler == nil {
return fmt.Errorf("No handler provided for %s %s", route.Method, route.Path)
for _, method := range route.Methods {
if !validMethod(method) {
return fmt.Errorf("Invalid method %s for path %s", method, route.Path)
}
if route.Handler == nil {
return fmt.Errorf("No handler provided for %s %s", method, route.Path)
}
pattern := fmt.Sprintf("%s %s", method, route.Path)
if slices.Contains(patterns, pattern) {
continue
}
patterns = append(patterns, pattern)
mux.Handle(pattern, route.Handler)
}
pattern := fmt.Sprintf("%s %s", route.Method, route.Path)
mux.Handle(pattern, route.Handler)
}
server.server.Handler = mux

View File

@@ -122,6 +122,111 @@ func Test_AddRoutes(t *testing.T) {
})
}
func Test_AddRoutes_MultipleMethods(t *testing.T) {
var buf bytes.Buffer
t.Run("Single route with multiple methods", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.Method + " response"))
})
err := server.AddRoutes(hws.Route{
Path: "/api/resource",
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST, hws.MethodPUT},
Handler: handler,
})
require.NoError(t, err)
// Test GET request
req := httptest.NewRequest("GET", "/api/resource", nil)
rr := httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "GET response", rr.Body.String())
// Test POST request
req = httptest.NewRequest("POST", "/api/resource", nil)
rr = httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "POST response", rr.Body.String())
// Test PUT request
req = httptest.NewRequest("PUT", "/api/resource", nil)
rr = httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "PUT response", rr.Body.String())
})
t.Run("Methods field takes precedence over Method field", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET, // This should be ignored
Methods: []hws.Method{hws.MethodPOST, hws.MethodPUT},
Handler: handler,
})
require.NoError(t, err)
// GET should not work (Method field ignored)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
// POST should work (from Methods field)
req = httptest.NewRequest("POST", "/test", nil)
rr = httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
// PUT should work (from Methods field)
req = httptest.NewRequest("PUT", "/test", nil)
rr = httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
})
t.Run("Invalid method in Methods slice", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Methods: []hws.Method{hws.MethodGET, hws.Method("INVALID")},
Handler: handler,
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid method")
})
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {
server := createTestServer(t, &buf)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
err := server.AddRoutes(hws.Route{
Path: "/test",
Method: hws.MethodGET,
Methods: []hws.Method{}, // Empty slice
Handler: handler,
})
require.NoError(t, err)
// GET should work (from Method field)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
server.Handler().ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
})
}
func Test_Routes_EndToEnd(t *testing.T) {
var buf bytes.Buffer
server := createTestServer(t, &buf)

View File

@@ -2,6 +2,7 @@ package hwsauth
import (
"net/http"
"reflect"
"time"
"git.haelnorr.com/h/golib/jwt"
@@ -45,6 +46,9 @@ func (auth *Authenticator[T, TX]) getAuthenticatedUser(
if err != nil {
return authenticatedModel[T]{}, errors.Wrap(err, "auth.load")
}
if reflect.ValueOf(model).IsNil() {
return authenticatedModel[T]{}, errors.New("no user matching JWT in database")
}
authUser := authenticatedModel[T]{
model: model,
fresh: aT.Fresh,

View File

@@ -1,6 +1,11 @@
package hwsauth
import (
"context"
"database/sql"
"os"
"time"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/jwt"
@@ -30,6 +35,7 @@ func NewAuthenticator[T Model, TX DBTransaction](
beginTx BeginTX,
logger *hlog.Logger,
errorPage hws.ErrorPageFunc,
db *sql.DB,
) (*Authenticator[T, TX], error) {
if load == nil {
return nil, errors.New("No function to load model supplied")
@@ -55,7 +61,10 @@ func NewAuthenticator[T Model, TX DBTransaction](
return nil, errors.New("SecretKey is required")
}
if cfg.SSL && cfg.TrustedHost == "" {
return nil, errors.New("TrustedHost is required when SSL is enabled")
cfg.SSL = false // Disable SSL if TrustedHost is not configured
}
if cfg.TrustedHost == "" {
cfg.TrustedHost = "localhost" // Default TrustedHost for JWT
}
if cfg.AccessTokenExpiry == 0 {
cfg.AccessTokenExpiry = 5
@@ -69,12 +78,35 @@ func NewAuthenticator[T Model, TX DBTransaction](
if cfg.LandingPage == "" {
cfg.LandingPage = "/profile"
}
if cfg.DatabaseType == "" {
cfg.DatabaseType = "postgres"
}
if cfg.DatabaseVersion == "" {
cfg.DatabaseVersion = "15"
}
if db == nil {
return nil, errors.New("No Database provided")
}
// Test database connectivity
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, errors.Wrap(err, "database connection test failed")
}
// Configure JWT table
tableConfig := jwt.DefaultTableConfig()
if cfg.JWTTableName != "" {
tableConfig.TableName = cfg.JWTTableName
}
// Disable auto-creation for tests
// Check for test environment or mock database
if os.Getenv("GO_TEST") == "1" {
tableConfig.AutoCreate = false
tableConfig.EnableAutoCleanup = false
}
// Create token generator
tokenGen, err := jwt.CreateGenerator(jwt.GeneratorConfig{
@@ -87,6 +119,7 @@ func NewAuthenticator[T Model, TX DBTransaction](
Type: cfg.DatabaseType,
Version: cfg.DatabaseVersion,
},
DB: db,
TableConfig: tableConfig,
}, beginTx)
if err != nil {

View File

@@ -8,6 +8,7 @@ require (
git.haelnorr.com/h/golib/hlog v0.10.4
git.haelnorr.com/h/golib/hws v0.3.0
git.haelnorr.com/h/golib/jwt v0.10.1
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.11.1
)

View File

@@ -2,16 +2,10 @@ git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDO
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/hlog v0.9.1 h1:9VmE/IQTfD8LAEyTbUCZLy/+8PbcHA1Kob/WQHRHKzc=
git.haelnorr.com/h/golib/hlog v0.9.1/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk=
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
git.haelnorr.com/h/golib/hws v0.2.0 h1:MR2Tu2qPaW+/oK8aXFJLRFaYZIHgKiex3t3zE41cu1U=
git.haelnorr.com/h/golib/hws v0.2.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo=
git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
git.haelnorr.com/h/golib/jwt v0.10.0 h1:8cI8mSnb8X+EmJtrBO/5UZwuBMtib0IE9dv85gkm94E=
git.haelnorr.com/h/golib/jwt v0.10.0/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
@@ -26,6 +20,7 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=

View File

@@ -10,6 +10,7 @@ import (
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/hws"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -47,6 +48,28 @@ func (tep TestErrorPage) Render(ctx context.Context, w io.Writer) error {
return nil
}
// createMockDB creates a mock SQL database for testing
func createMockDB() (*sql.DB, sqlmock.Sqlmock, error) {
db, mock, err := sqlmock.New()
if err != nil {
return nil, nil, err
}
// Expect a ping to succeed for database connectivity test
mock.ExpectPing()
// Expect table existence check (returns a row = table exists)
mock.ExpectQuery(`SELECT 1 FROM information_schema\.tables WHERE table_schema = 'public' AND table_name = \$1`).
WithArgs("jwtblacklist").
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
// Expect cleanup function creation
mock.ExpectExec(`CREATE OR REPLACE FUNCTION cleanup_jwtblacklist\(\) RETURNS void AS \$\$ BEGIN DELETE FROM jwtblacklist WHERE exp < EXTRACT\(EPOCH FROM NOW\(\)\); END; \$\$ LANGUAGE plpgsql;`).
WillReturnResult(sqlmock.NewResult(0, 0))
return db, mock, nil
}
func TestGetNil(t *testing.T) {
var zero TestModel
result := getNil[TestModel]()
@@ -209,12 +232,13 @@ func TestNewAuthenticator_NilConfig(t *testing.T) {
}
auth, err := NewAuthenticator(
nil,
nil, // cfg
load,
server,
beginTx,
logger,
errorPage,
nil, // db
)
assert.Error(t, err)
@@ -246,6 +270,7 @@ func TestNewAuthenticator_MissingSecretKey(t *testing.T) {
beginTx,
logger,
errorPage,
nil, // db - will fail before db check since SecretKey is missing
)
assert.Error(t, err)
@@ -274,6 +299,7 @@ func TestNewAuthenticator_NilLoadFunction(t *testing.T) {
beginTx,
logger,
errorPage,
nil, // db
)
assert.Error(t, err)
@@ -299,6 +325,10 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
return TestErrorPage{}, nil
}
db, _, err := createMockDB()
require.NoError(t, err)
defer db.Close()
auth, err := NewAuthenticator(
cfg,
load,
@@ -306,17 +336,19 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
beginTx,
logger,
errorPage,
db,
)
assert.Error(t, err)
assert.Nil(t, auth)
assert.Contains(t, err.Error(), "TrustedHost is required when SSL is enabled")
require.NoError(t, err)
require.NotNil(t, auth)
assert.Equal(t, false, auth.SSL)
assert.Equal(t, "/profile", auth.LandingPage)
}
func TestNewAuthenticator_ValidMinimalConfig(t *testing.T) {
func TestNewAuthenticator_NilDatabase(t *testing.T) {
cfg := &Config{
SecretKey: "test-secret",
TrustedHost: "example.com",
SecretKey: "test-secret",
}
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
@@ -338,13 +370,12 @@ func TestNewAuthenticator_ValidMinimalConfig(t *testing.T) {
beginTx,
logger,
errorPage,
nil, // db
)
require.NoError(t, err)
require.NotNil(t, auth)
assert.Equal(t, false, auth.SSL)
assert.Equal(t, "/profile", auth.LandingPage)
assert.Error(t, err)
assert.Nil(t, auth)
assert.Contains(t, err.Error(), "No Database provided")
}
func TestModelInterface(t *testing.T) {
@@ -376,6 +407,10 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
return TestErrorPage{}, nil
}
db, _, err := createMockDB()
require.NoError(t, err)
defer db.Close()
auth, err := NewAuthenticator(
cfg,
load,
@@ -383,6 +418,7 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
beginTx,
logger,
errorPage,
db,
)
require.NoError(t, err)
@@ -416,6 +452,10 @@ func TestLogin_BasicFunctionality(t *testing.T) {
return TestErrorPage{}, nil
}
db, _, err := createMockDB()
require.NoError(t, err)
defer db.Close()
auth, err := NewAuthenticator(
cfg,
load,
@@ -423,6 +463,7 @@ func TestLogin_BasicFunctionality(t *testing.T) {
beginTx,
logger,
errorPage,
db,
)
require.NoError(t, err)

View File

@@ -2,6 +2,7 @@ package hwsauth
import (
"net/http"
"reflect"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
@@ -18,7 +19,9 @@ func (auth *Authenticator[T, TX]) refreshAuthTokens(
if err != nil {
return getNil[T](), errors.Wrap(err, "auth.load")
}
if reflect.ValueOf(model).IsNil() {
return getNil[T](), errors.New("no user matching JWT in database")
}
rememberMe := map[string]bool{
"session": false,
"exp": true,