From 378bd8006dcf0ffa6c4969ad4a1847fcd247f121 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Fri, 23 Jan 2026 12:33:13 +1100 Subject: [PATCH] updated hws: expanded error page functionality --- hws/errors.go | 8 ++++---- hws/errors_test.go | 6 +++--- hws/server_methods_test.go | 6 ++++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/hws/errors.go b/hws/errors.go index f1b4681..d58a4bf 100644 --- a/hws/errors.go +++ b/hws/errors.go @@ -31,7 +31,7 @@ const ( // ErrorPageFunc is a function that returns an ErrorPage with the specified HTTP Status code // This will be called by the server when it needs to render an error page -type ErrorPageFunc func(errorCode int) (ErrorPage, error) +type ErrorPageFunc func(error HWSError) (ErrorPage, error) // ErrorPage must implement a Render() function that takes in a context and ResponseWriter, // and should write a reponse as output to the ResponseWriter. @@ -40,11 +40,11 @@ type ErrorPage interface { Render(ctx context.Context, w io.Writer) error } -// TODO: add test for ErrorPageFunc that returns an error +// AddErrorPage registers a handler that returns an ErrorPage func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error { rr := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) - page, err := pageFunc(http.StatusInternalServerError) + page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError}) if err != nil { return errors.Wrap(err, "An error occured when trying to get the error page") } @@ -88,7 +88,7 @@ func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error H } if error.RenderErrorPage { server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG}) - errPage, err := server.errorPage(error.StatusCode) + errPage, err := server.errorPage(error) if err != nil { server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err}) } diff --git a/hws/errors_test.go b/hws/errors_test.go index e4d6307..1f6cc70 100644 --- a/hws/errors_test.go +++ b/hws/errors_test.go @@ -17,13 +17,13 @@ import ( type goodPage struct{} type badPage struct{} -func goodRender(code int) (hws.ErrorPage, error) { +func goodRender(error hws.HWSError) (hws.ErrorPage, error) { return goodPage{}, nil } -func badRender1(code int) (hws.ErrorPage, error) { +func badRender1(error hws.HWSError) (hws.ErrorPage, error) { return badPage{}, nil } -func badRender2(code int) (hws.ErrorPage, error) { +func badRender2(error hws.HWSError) (hws.ErrorPage, error) { return nil, errors.New("I'm an error") } diff --git a/hws/server_methods_test.go b/hws/server_methods_test.go index 0eab81a..c8f0457 100644 --- a/hws/server_methods_test.go +++ b/hws/server_methods_test.go @@ -149,7 +149,8 @@ func Test_Start_Errors(t *testing.T) { }) require.NoError(t, err) - err = server.Start(t.Context()) + var nilCtx context.Context = nil + err = server.Start(nilCtx) assert.Error(t, err) assert.Contains(t, err.Error(), "Context cannot be nil") }) @@ -163,7 +164,8 @@ func Test_Shutdown_Errors(t *testing.T) { startTestServer(t, server) <-server.Ready() - err := server.Shutdown(t.Context()) + var nilCtx context.Context = nil + err := server.Shutdown(nilCtx) assert.Error(t, err) assert.Contains(t, err.Error(), "Context cannot be nil")