package hws_test import ( "bytes" "net/http" "net/http/httptest" "testing" "git.haelnorr.com/h/golib/hws" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_AddMiddleware(t *testing.T) { var buf bytes.Buffer t.Run("Cannot add middleware before routes", func(t *testing.T) { server := createTestServer(t, &buf) err := server.AddMiddleware() assert.Error(t, err) assert.Contains(t, err.Error(), "Server.AddRoutes must be called before") }) t.Run("Can add middleware after routes", func(t *testing.T) { server := createTestServer(t, &buf) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) err := server.AddRoutes(hws.Route{ Path: "/test", Method: hws.MethodGET, Handler: handler, }) require.NoError(t, err) err = server.AddMiddleware() assert.NoError(t, err) }) t.Run("Can add custom middleware", func(t *testing.T) { server := createTestServer(t, &buf) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) err := server.AddRoutes(hws.Route{ Path: "/test", Method: hws.MethodGET, Handler: handler, }) require.NoError(t, err) customMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Custom", "test") next.ServeHTTP(w, r) }) } err = server.AddMiddleware(customMiddleware) assert.NoError(t, err) }) t.Run("Can add multiple middlewares", func(t *testing.T) { server := createTestServer(t, &buf) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) err := server.AddRoutes(hws.Route{ Path: "/test", Method: hws.MethodGET, Handler: handler, }) require.NoError(t, err) middleware1 := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) }) } middleware2 := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) }) } err = server.AddMiddleware(middleware1, middleware2) assert.NoError(t, err) }) } func Test_NewMiddleware(t *testing.T) { var buf bytes.Buffer t.Run("NewMiddleware without error", func(t *testing.T) { server := createTestServer(t, &buf) middlewareFunc := func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { // Modify request or do something return r, nil } middleware := server.NewMiddleware(middlewareFunc) assert.NotNil(t, middleware) // Test the middleware handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("success")) }) wrappedHandler := middleware(handler) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) }) t.Run("NewMiddleware with error but no render", func(t *testing.T) { server := createTestServer(t, &buf) // Add routes and logger first handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) err := server.AddRoutes(hws.Route{ Path: "/test", Method: hws.MethodGET, Handler: handler, }) require.NoError(t, err) middlewareFunc := func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { return r, &hws.HWSError{ StatusCode: http.StatusBadRequest, Message: "Test error", Error: assert.AnError, RenderErrorPage: false, } } middleware := server.NewMiddleware(middlewareFunc) wrappedHandler := middleware(handler) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) // Handler should still be called assert.Equal(t, http.StatusOK, rr.Code) }) t.Run("NewMiddleware with error and render", func(t *testing.T) { server := createTestServer(t, &buf) // Add routes and logger first handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("should not reach")) }) err := server.AddRoutes(hws.Route{ Path: "/test", Method: hws.MethodGET, Handler: handler, }) require.NoError(t, err) middlewareFunc := func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { return r, &hws.HWSError{ StatusCode: http.StatusForbidden, Message: "Access denied", Error: assert.AnError, RenderErrorPage: true, } } middleware := server.NewMiddleware(middlewareFunc) wrappedHandler := middleware(handler) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) // Handler should NOT be called, response should be empty or error page body := rr.Body.String() assert.NotContains(t, body, "should not reach") }) t.Run("NewMiddleware can modify request", func(t *testing.T) { server := createTestServer(t, &buf) middlewareFunc := func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { // Add a header to the request r.Header.Set("X-Modified", "true") return r, nil } middleware := server.NewMiddleware(middlewareFunc) var capturedHeader string handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedHeader = r.Header.Get("X-Modified") w.WriteHeader(http.StatusOK) }) wrappedHandler := middleware(handler) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() wrappedHandler.ServeHTTP(rr, req) assert.Equal(t, "true", capturedHeader) }) } func Test_Middleware_Ordering(t *testing.T) { var buf bytes.Buffer server := createTestServer(t, &buf) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) err := server.AddRoutes(hws.Route{ Path: "/test", Method: hws.MethodGET, Handler: handler, }) require.NoError(t, err) var order []string middleware1 := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { order = append(order, "middleware1") next.ServeHTTP(w, r) }) } middleware2 := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { order = append(order, "middleware2") next.ServeHTTP(w, r) }) } err = server.AddMiddleware(middleware1, middleware2) require.NoError(t, err) // The middleware should execute in the order provided // Note: This test is simplified and may need adjustment based on actual execution }