diff --git a/hws/config_test.go b/hws/config_test.go index 246825b..021c430 100644 --- a/hws/config_test.go +++ b/hws/config_test.go @@ -13,12 +13,12 @@ import ( func Test_ConfigFromEnv(t *testing.T) { t.Run("Default values when no env vars set", func(t *testing.T) { // Clear any existing env vars - os.Unsetenv("HWS_HOST") - os.Unsetenv("HWS_PORT") - os.Unsetenv("HWS_GZIP") - os.Unsetenv("HWS_READ_HEADER_TIMEOUT") - os.Unsetenv("HWS_WRITE_TIMEOUT") - os.Unsetenv("HWS_IDLE_TIMEOUT") + _ = os.Unsetenv("HWS_HOST") + _ = os.Unsetenv("HWS_PORT") + _ = os.Unsetenv("HWS_GZIP") + _ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT") + _ = os.Unsetenv("HWS_WRITE_TIMEOUT") + _ = os.Unsetenv("HWS_IDLE_TIMEOUT") config, err := hws.ConfigFromEnv() require.NoError(t, err) @@ -33,8 +33,10 @@ func Test_ConfigFromEnv(t *testing.T) { }) t.Run("Custom host", func(t *testing.T) { - os.Setenv("HWS_HOST", "192.168.1.1") - defer os.Unsetenv("HWS_HOST") + _ = os.Setenv("HWS_HOST", "192.168.1.1") + defer func() { + _ = os.Unsetenv("HWS_HOST") + }() config, err := hws.ConfigFromEnv() require.NoError(t, err) @@ -42,8 +44,10 @@ func Test_ConfigFromEnv(t *testing.T) { }) t.Run("Custom port", func(t *testing.T) { - os.Setenv("HWS_PORT", "8080") - defer os.Unsetenv("HWS_PORT") + _ = os.Setenv("HWS_PORT", "8080") + defer func() { + _ = os.Unsetenv("HWS_PORT") + }() config, err := hws.ConfigFromEnv() require.NoError(t, err) @@ -51,8 +55,10 @@ func Test_ConfigFromEnv(t *testing.T) { }) t.Run("GZIP enabled", func(t *testing.T) { - os.Setenv("HWS_GZIP", "true") - defer os.Unsetenv("HWS_GZIP") + _ = os.Setenv("HWS_GZIP", "true") + defer func() { + _ = os.Unsetenv("HWS_GZIP") + }() config, err := hws.ConfigFromEnv() require.NoError(t, err) @@ -60,12 +66,14 @@ func Test_ConfigFromEnv(t *testing.T) { }) t.Run("Custom timeouts", func(t *testing.T) { - os.Setenv("HWS_READ_HEADER_TIMEOUT", "5") - os.Setenv("HWS_WRITE_TIMEOUT", "30") - os.Setenv("HWS_IDLE_TIMEOUT", "300") - defer os.Unsetenv("HWS_READ_HEADER_TIMEOUT") - defer os.Unsetenv("HWS_WRITE_TIMEOUT") - defer os.Unsetenv("HWS_IDLE_TIMEOUT") + _ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "5") + _ = os.Setenv("HWS_WRITE_TIMEOUT", "30") + _ = os.Setenv("HWS_IDLE_TIMEOUT", "300") + defer func() { + _ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT") + _ = os.Unsetenv("HWS_WRITE_TIMEOUT") + _ = os.Unsetenv("HWS_IDLE_TIMEOUT") + }() config, err := hws.ConfigFromEnv() require.NoError(t, err) @@ -75,19 +83,19 @@ func Test_ConfigFromEnv(t *testing.T) { }) t.Run("All custom values", func(t *testing.T) { - os.Setenv("HWS_HOST", "0.0.0.0") - os.Setenv("HWS_PORT", "9000") - os.Setenv("HWS_GZIP", "true") - os.Setenv("HWS_READ_HEADER_TIMEOUT", "3") - os.Setenv("HWS_WRITE_TIMEOUT", "15") - os.Setenv("HWS_IDLE_TIMEOUT", "180") + _ = os.Setenv("HWS_HOST", "0.0.0.0") + _ = os.Setenv("HWS_PORT", "9000") + _ = os.Setenv("HWS_GZIP", "true") + _ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "3") + _ = os.Setenv("HWS_WRITE_TIMEOUT", "15") + _ = os.Setenv("HWS_IDLE_TIMEOUT", "180") defer func() { - os.Unsetenv("HWS_HOST") - os.Unsetenv("HWS_PORT") - os.Unsetenv("HWS_GZIP") - os.Unsetenv("HWS_READ_HEADER_TIMEOUT") - os.Unsetenv("HWS_WRITE_TIMEOUT") - os.Unsetenv("HWS_IDLE_TIMEOUT") + _ = os.Unsetenv("HWS_HOST") + _ = os.Unsetenv("HWS_PORT") + _ = os.Unsetenv("HWS_GZIP") + _ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT") + _ = os.Unsetenv("HWS_WRITE_TIMEOUT") + _ = os.Unsetenv("HWS_IDLE_TIMEOUT") }() config, err := hws.ConfigFromEnv() diff --git a/hws/errors.go b/hws/errors.go index d58a4bf..30afabb 100644 --- a/hws/errors.go +++ b/hws/errors.go @@ -9,7 +9,7 @@ import ( "github.com/pkg/errors" ) -// Error to use with Server.ThrowError +// HWSError wraps an error with other information for use with HWS features type HWSError struct { StatusCode int // HTTP Status code Message string // Error message @@ -41,7 +41,7 @@ type ErrorPage interface { } // AddErrorPage registers a handler that returns an ErrorPage -func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error { +func (s *Server) AddErrorPage(pageFunc ErrorPageFunc) error { rr := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError}) @@ -56,7 +56,7 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error { return errors.New("Render method of the error page did not write anything to the response writer") } - server.errorPage = pageFunc + s.errorPage = pageFunc return nil } @@ -64,7 +64,19 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error { // the error with the level specified by the HWSError. // If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter // and the request chain should be terminated. -func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) error { +func (s *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) { + err := s.throwError(w, r, error) + if err != nil { + s.LogError(error) + s.LogError(HWSError{ + Message: "Error occured during throwError", + Error: errors.Wrap(err, "s.throwError"), + Level: ErrorERROR, + }) + } +} + +func (s *Server) throwError(w http.ResponseWriter, r *http.Request, error HWSError) error { if error.StatusCode <= 0 { return errors.New("HWSError.StatusCode cannot be 0.") } @@ -77,32 +89,27 @@ func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error H if r == nil { return errors.New("Request cannot be nil") } - if !server.IsReady() { + if !s.IsReady() { return errors.New("ThrowError called before server started") } w.WriteHeader(error.StatusCode) - server.LogError(error) - if server.errorPage == nil { - server.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG}) + s.LogError(error) + if s.errorPage == nil { + s.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG}) return nil } if error.RenderErrorPage { - server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG}) - errPage, err := server.errorPage(error) + s.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG}) + errPage, err := s.errorPage(error) if err != nil { - server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err}) + s.LogError(HWSError{Message: "Failed to get a valid error page", Error: err}) } err = errPage.Render(r.Context(), w) if err != nil { - server.LogError(HWSError{Message: "Failed to render error page", Error: err}) + s.LogError(HWSError{Message: "Failed to render error page", Error: err}) } } else { - server.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG}) + s.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG}) } return nil } - -func (server *Server) ThrowFatal(w http.ResponseWriter, err error) { - w.WriteHeader(http.StatusInternalServerError) - server.LogFatal(err) -} diff --git a/hws/errors_test.go b/hws/errors_test.go index 1f6cc70..7c3b1cf 100644 --- a/hws/errors_test.go +++ b/hws/errors_test.go @@ -14,22 +14,26 @@ import ( "github.com/stretchr/testify/require" ) -type goodPage struct{} -type badPage struct{} +type ( + goodPage struct{} + badPage struct{} +) func goodRender(error hws.HWSError) (hws.ErrorPage, error) { return goodPage{}, nil } + func badRender1(error hws.HWSError) (hws.ErrorPage, error) { return badPage{}, nil } + func badRender2(error hws.HWSError) (hws.ErrorPage, error) { return nil, errors.New("I'm an error") } func (g goodPage) Render(ctx context.Context, w io.Writer) error { - w.Write([]byte("Test write to ResponseWriter")) - return nil + _, err := w.Write([]byte("Test write to ResponseWriter")) + return err } func (b badPage) Render(ctx context.Context, w io.Writer) error { @@ -85,40 +89,42 @@ func Test_ThrowError(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) t.Run("Server not started", func(t *testing.T) { - err := server.ThrowError(rr, req, hws.HWSError{ + buf.Reset() + server.ThrowError(rr, req, hws.HWSError{ StatusCode: http.StatusInternalServerError, Message: "Error", Error: errors.New("Error"), }) - assert.Error(t, err) + // ThrowError logs errors internally when validation fails + output := buf.String() + assert.Contains(t, output, "ThrowError called before server started") }) startTestServer(t, server) - defer server.Shutdown(t.Context()) tests := []struct { - name string - request *http.Request - error hws.HWSError - valid bool + name string + request *http.Request + error hws.HWSError + expectLogItem string }{ { - name: "No HWSError.Status code", - request: nil, - error: hws.HWSError{}, - valid: false, + name: "No HWSError.Status code", + request: nil, + error: hws.HWSError{}, + expectLogItem: "HWSError.StatusCode cannot be 0", }, { - name: "Negative HWSError.Status code", - request: nil, - error: hws.HWSError{StatusCode: -1}, - valid: false, + name: "Negative HWSError.Status code", + request: nil, + error: hws.HWSError{StatusCode: -1}, + expectLogItem: "HWSError.StatusCode cannot be 0", }, { - name: "No HWSError.Message", - request: nil, - error: hws.HWSError{StatusCode: http.StatusInternalServerError}, - valid: false, + name: "No HWSError.Message", + request: nil, + error: hws.HWSError{StatusCode: http.StatusInternalServerError}, + expectLogItem: "HWSError.Message cannot be empty", }, { name: "No HWSError.Error", @@ -127,7 +133,7 @@ func Test_ThrowError(t *testing.T) { StatusCode: http.StatusInternalServerError, Message: "An error occured", }, - valid: false, + expectLogItem: "HWSError.Error cannot be nil", }, { name: "No request provided", @@ -137,7 +143,7 @@ func Test_ThrowError(t *testing.T) { Message: "An error occured", Error: errors.New("Error"), }, - valid: false, + expectLogItem: "Request cannot be nil", }, { name: "Valid", @@ -147,106 +153,92 @@ func Test_ThrowError(t *testing.T) { Message: "An error occured", Error: errors.New("Error"), }, - valid: true, + expectLogItem: "An error occured", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + buf.Reset() rr := httptest.NewRecorder() - err := server.ThrowError(rr, tt.request, tt.error) - if tt.valid { - assert.NoError(t, err) - } else { - t.Log(err) - assert.Error(t, err) - } + server.ThrowError(rr, tt.request, tt.error) + // ThrowError no longer returns errors; check logs instead + output := buf.String() + assert.Contains(t, output, tt.expectLogItem) }) } t.Run("Log level set correctly", func(t *testing.T) { buf.Reset() rr := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) - err := server.ThrowError(rr, req, hws.HWSError{ + server.ThrowError(rr, req, hws.HWSError{ StatusCode: http.StatusInternalServerError, Message: "An error occured", Error: errors.New("Error"), Level: hws.ErrorWARN, }) - assert.NoError(t, err) - _, err = buf.ReadString([]byte(" ")[0]) + _, err := buf.ReadString([]byte(" ")[0]) + require.NoError(t, err) loglvl, err := buf.ReadString([]byte(" ")[0]) - assert.NoError(t, err) - if loglvl != "\x1b[33mWRN\x1b[0m " { - err = errors.New("Log level not set correctly") - } - assert.NoError(t, err) + require.NoError(t, err) + assert.Equal(t, "\x1b[33mWRN\x1b[0m ", loglvl, "Log level should be WRN for ErrorWARN") + buf.Reset() - err = server.ThrowError(rr, req, hws.HWSError{ + server.ThrowError(rr, req, hws.HWSError{ StatusCode: http.StatusInternalServerError, Message: "An error occured", Error: errors.New("Error"), }) - assert.NoError(t, err) _, err = buf.ReadString([]byte(" ")[0]) + require.NoError(t, err) loglvl, err = buf.ReadString([]byte(" ")[0]) - assert.NoError(t, err) - if loglvl != "\x1b[31mERR\x1b[0m " { - err = errors.New("Log level not set correctly") - } - assert.NoError(t, err) + require.NoError(t, err) + assert.Equal(t, "\x1b[31mERR\x1b[0m ", loglvl, "Log level should be ERR when no level specified") }) t.Run("Error page doesnt render if no error page set", func(t *testing.T) { // Must be run before adding the error page to the test server rr := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) - err := server.ThrowError(rr, req, hws.HWSError{ + server.ThrowError(rr, req, hws.HWSError{ StatusCode: http.StatusInternalServerError, Message: "An error occured", Error: errors.New("Error"), RenderErrorPage: true, }) - assert.NoError(t, err) body := rr.Body.String() - if body != "" { - assert.Error(t, nil) - } + assert.Empty(t, body, "Error page should not render when no error page is set") }) t.Run("Error page renders", func(t *testing.T) { rr := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) // Adding the error page will carry over to all future tests and cant be undone - server.AddErrorPage(goodRender) - err := server.ThrowError(rr, req, hws.HWSError{ + err := server.AddErrorPage(goodRender) + require.NoError(t, err) + server.ThrowError(rr, req, hws.HWSError{ StatusCode: http.StatusInternalServerError, Message: "An error occured", Error: errors.New("Error"), RenderErrorPage: true, }) - assert.NoError(t, err) body := rr.Body.String() - if body == "" { - assert.Error(t, nil) - } + assert.NotEmpty(t, body, "Error page should render when RenderErrorPage is true") }) - t.Run("Error page doesnt render if no told to render", func(t *testing.T) { + t.Run("Error page doesnt render if not told to render", func(t *testing.T) { // Error page already added to server rr := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) - err := server.ThrowError(rr, req, hws.HWSError{ + server.ThrowError(rr, req, hws.HWSError{ StatusCode: http.StatusInternalServerError, Message: "An error occured", Error: errors.New("Error"), }) - assert.NoError(t, err) body := rr.Body.String() - if body != "" { - assert.Error(t, nil) - } + assert.Empty(t, body, "Error page should not render when RenderErrorPage is false") }) - server.Shutdown(t.Context()) + err := server.Shutdown(t.Context()) + require.NoError(t, err) - t.Run("Doesn't error if no logger added to server", func(t *testing.T) { + t.Run("Doesn't panic if no logger added to server", func(t *testing.T) { server, err := hws.NewServer(&hws.Config{ Host: "127.0.0.1", Port: randomPort(), @@ -261,13 +253,18 @@ func Test_ThrowError(t *testing.T) { err = server.Start(t.Context()) require.NoError(t, err) <-server.Ready() + rr := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) - err = server.ThrowError(rr, req, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured", - Error: errors.New("Error"), - }) - assert.NoError(t, err) + // Should not panic when no logger is present + assert.NotPanics(t, func() { + server.ThrowError(rr, req, hws.HWSError{ + StatusCode: http.StatusInternalServerError, + Message: "An error occured", + Error: errors.New("Error"), + }) + }, "ThrowError should not panic when no logger is present") + err = server.Shutdown(t.Context()) + require.NoError(t, err) }) } diff --git a/hws/ezconf.go b/hws/ezconf.go index 4c25c16..1c90d7b 100644 --- a/hws/ezconf.go +++ b/hws/ezconf.go @@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string { } // ConfigFunc returns the ConfigFromEnv function for ezconf -func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) { - return func() (interface{}, error) { +func (e EZConfIntegration) ConfigFunc() func() (any, error) { + return func() (any, error) { return ConfigFromEnv() } } diff --git a/hws/logger.go b/hws/logger.go index cd3d15e..136558b 100644 --- a/hws/logger.go +++ b/hws/logger.go @@ -43,23 +43,12 @@ func (s *Server) LogError(err HWSError) { } } -func (server *Server) LogFatal(err error) { - if err == nil { - err = errors.New("LogFatal was called with a nil error") - } - if server.logger == nil { - fmt.Printf("FATAL - %s: %s", "A fatal error has occured", err.Error()) - return - } - server.logger.logger.Fatal().Err(err).Msg("A fatal error has occured") -} - // AddLogger adds a logger to the server to use for request logging. -func (server *Server) AddLogger(hlogger *hlog.Logger) error { +func (s *Server) AddLogger(hlogger *hlog.Logger) error { if hlogger == nil { return errors.New("unable to add logger, no logger provided") } - server.logger = &logger{ + s.logger = &logger{ logger: hlogger, } return nil @@ -68,7 +57,7 @@ func (server *Server) AddLogger(hlogger *hlog.Logger) error { // LoggerIgnorePaths sets a list of URL paths to ignore logging for. // Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL // Useful for ignoring requests to CSS files or favicons -func (server *Server) LoggerIgnorePaths(paths ...string) error { +func (s *Server) LoggerIgnorePaths(paths ...string) error { for _, path := range paths { u, err := url.Parse(path) valid := err == nil && @@ -80,7 +69,7 @@ func (server *Server) LoggerIgnorePaths(paths ...string) error { return fmt.Errorf("invalid path: '%s'", path) } } - server.logger.ignoredPaths = prepareGlobs(paths) + s.logger.ignoredPaths = prepareGlobs(paths) return nil } diff --git a/hws/logger_test.go b/hws/logger_test.go index ce0eb6e..45c36e1 100644 --- a/hws/logger_test.go +++ b/hws/logger_test.go @@ -169,7 +169,7 @@ func Test_LogFatal(t *testing.T) { // Note: We cannot actually test Fatal() as it calls os.Exit() // Testing this would require subprocess testing which is overly complex // These tests document the expected behavior and verify the function signatures exist - + t.Run("LogFatal with nil logger prints to stdout", func(t *testing.T) { _, err := hws.NewServer(&hws.Config{ Host: "127.0.0.1", @@ -197,7 +197,7 @@ func Test_LoggerIgnorePaths(t *testing.T) { err := server.LoggerIgnorePaths("http://example.com/path") assert.Error(t, err) - assert.Contains(t, err.Error(), "Invalid path") + assert.Contains(t, err.Error(), "invalid path") }) t.Run("Invalid path with host", func(t *testing.T) { @@ -207,7 +207,7 @@ func Test_LoggerIgnorePaths(t *testing.T) { err := server.LoggerIgnorePaths("//example.com/path") assert.Error(t, err) if err != nil { - assert.Contains(t, err.Error(), "Invalid path") + assert.Contains(t, err.Error(), "invalid path") } }) @@ -217,7 +217,7 @@ func Test_LoggerIgnorePaths(t *testing.T) { err := server.LoggerIgnorePaths("/path?query=value") assert.Error(t, err) - assert.Contains(t, err.Error(), "Invalid path") + assert.Contains(t, err.Error(), "invalid path") }) t.Run("Invalid path with fragment", func(t *testing.T) { @@ -226,7 +226,7 @@ func Test_LoggerIgnorePaths(t *testing.T) { err := server.LoggerIgnorePaths("/path#fragment") assert.Error(t, err) - assert.Contains(t, err.Error(), "Invalid path") + assert.Contains(t, err.Error(), "invalid path") }) t.Run("Valid paths", func(t *testing.T) { diff --git a/hws/middleware.go b/hws/middleware.go index 1ecb2ad..5df6c33 100644 --- a/hws/middleware.go +++ b/hws/middleware.go @@ -5,35 +5,37 @@ import ( "net/http" ) -type Middleware func(h http.Handler) http.Handler -type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError) +type ( + Middleware func(h http.Handler) http.Handler + MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError) +) -// Server.AddMiddleware registers all the middleware. +// AddMiddleware registers all the middleware. // Middleware will be run in the order that they are provided. // Can only be called once -func (server *Server) AddMiddleware(middleware ...Middleware) error { - if !server.routes { +func (s *Server) AddMiddleware(middleware ...Middleware) error { + if !s.routes { return errors.New("Server.AddRoutes must be called before Server.AddMiddleware") } - if server.middleware { + if s.middleware { return errors.New("Server.AddMiddleware already called") } // RUN LOGGING MIDDLEWARE FIRST - server.server.Handler = logging(server.server.Handler, server.logger) + s.server.Handler = logging(s.server.Handler, s.logger) // LOOP PROVIDED MIDDLEWARE IN REVERSE order for i := len(middleware); i > 0; i-- { - server.server.Handler = middleware[i-1](server.server.Handler) + s.server.Handler = middleware[i-1](s.server.Handler) } // RUN GZIP - if server.GZIP { - server.server.Handler = addgzip(server.server.Handler) + if s.GZIP { + s.server.Handler = addgzip(s.server.Handler) } // RUN TIMER MIDDLEWARE LAST - server.server.Handler = startTimer(server.server.Handler) + s.server.Handler = startTimer(s.server.Handler) - server.middleware = true + s.middleware = true return nil } @@ -43,14 +45,14 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error { // and returns a new request and optional HWSError. // If a HWSError is returned, server.ThrowError will be called. // If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered -func (server *Server) NewMiddleware( +func (s *Server) NewMiddleware( middlewareFunc MiddlewareFunc, ) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { newReq, herr := middlewareFunc(w, r) if herr != nil { - server.ThrowError(w, r, *herr) + s.ThrowError(w, r, *herr) if herr.RenderErrorPage { return } diff --git a/hws/middleware_timer.go b/hws/middleware_timer.go index df0e690..6965a4e 100644 --- a/hws/middleware_timer.go +++ b/hws/middleware_timer.go @@ -18,16 +18,24 @@ func startTimer(next http.Handler) http.Handler { ) } +type contextKey string + +func (c contextKey) String() string { + return "hws context key " + string(c) +} + +var requestTimerCtxKey = contextKey("request-timer") + // Set the start time of the request func setStart(ctx context.Context, time time.Time) context.Context { - return context.WithValue(ctx, "hws context key request-timer", time) + return context.WithValue(ctx, requestTimerCtxKey, time) } // Get the start time of the request func getStartTime(ctx context.Context) (time.Time, error) { - start, ok := ctx.Value("hws context key request-timer").(time.Time) + start, ok := ctx.Value(requestTimerCtxKey).(time.Time) if !ok { - return time.Time{}, errors.New("Failed to get start time of request") + return time.Time{}, errors.New("failed to get start time of request") } return start, nil } diff --git a/hws/notify.go b/hws/notify.go index b774a66..a560c9a 100644 --- a/hws/notify.go +++ b/hws/notify.go @@ -100,7 +100,7 @@ func (s *Server) NotifySub(nt notify.Notification) { } _, exists := s.notifier.clients.getClient(nt.Target) if !exists { - err := fmt.Errorf("Tried to notify subscriber that doesn't exist - subID: %s", nt.Target) + err := fmt.Errorf("tried to notify subscriber that doesn't exist - subID: %s", nt.Target) s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err}) return } @@ -119,7 +119,7 @@ func (s *Server) NotifyID(nt notify.Notification, altID string) { clients, exists := s.notifier.clients.clientsIDMap[altID] s.notifier.clients.lock.RUnlock() if !exists { - err := fmt.Errorf("Tried to notify client group that doesn't exist - altID: %s", altID) + err := fmt.Errorf("tried to notify client group that doesn't exist - altID: %s", altID) s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err}) return } diff --git a/hws/notify_test.go b/hws/notify_test.go index fc82446..f281a2b 100644 --- a/hws/notify_test.go +++ b/hws/notify_test.go @@ -15,8 +15,9 @@ func newTestServerWithNotifier(t *testing.T) *Server { t.Helper() cfg := &Config{ - Host: "127.0.0.1", - Port: 0, + Host: "127.0.0.1", + Port: 0, + ShutdownDelay: 0, // No delay for tests } server, err := NewServer(cfg) @@ -359,7 +360,7 @@ func Test_ActiveClientStaysAlive(t *testing.T) { done := make(chan bool) go func() { - for i := 0; i < 3; i++ { + for range 3 { <-ticker.C server.NotifySub(notify.Notification{ Target: client.sub.ID, @@ -460,7 +461,7 @@ func Test_SlowConsumerTolerance(t *testing.T) { defer close(stop) // Send 10 notifications quickly (buffer is 10) - for i := 0; i < 10; i++ { + for range 10 { server.NotifySub(notify.Notification{ Target: client.sub.ID, Message: "Burst message", @@ -468,7 +469,7 @@ func Test_SlowConsumerTolerance(t *testing.T) { } // Client should receive all 10 - for i := 0; i < 10; i++ { + for i := range 10 { select { case <-notifications: // Received @@ -487,7 +488,7 @@ func Test_SingleTimeoutRecovery(t *testing.T) { defer close(stop) // Fill buffer completely (buffer is 10) - for i := 0; i < 10; i++ { + for range 10 { server.NotifySub(notify.Notification{ Target: client.sub.ID, Message: "Fill buffer", @@ -500,15 +501,15 @@ func Test_SingleTimeoutRecovery(t *testing.T) { Message: "Timeout message", }) - // Wait for timeout - time.Sleep(6 * time.Second) + // Wait for timeout (5s timeout + small buffer) + time.Sleep(5100 * time.Millisecond) // Check failure count (should be 1) fails := atomic.LoadInt32(&client.consecutiveFails) require.Equal(t, int32(1), fails, "Should have 1 timeout") // Now read all buffered messages - for i := 0; i < 10; i++ { + for range 10 { <-notifications } @@ -538,7 +539,7 @@ func Test_ConsecutiveFailureDisconnect(t *testing.T) { defer close(stop) // Fill buffer and never read to cause 5 consecutive timeouts - for i := 0; i < 20; i++ { + for range 20 { server.NotifySub(notify.Notification{ Target: client.sub.ID, Message: "Timeout message", @@ -684,7 +685,7 @@ func Test_ConcurrentSubscriptions(t *testing.T) { var wg sync.WaitGroup clients := make([]*Client, 100) - for i := 0; i < 100; i++ { + for i := range 100 { wg.Add(1) go func(index int) { defer wg.Done() @@ -716,7 +717,7 @@ func Test_ConcurrentNotifications(t *testing.T) { messageCount := 50 // Send from multiple goroutines - for i := 0; i < messageCount; i++ { + for i := range messageCount { wg.Add(1) go func(index int) { defer wg.Done() @@ -733,7 +734,7 @@ func Test_ConcurrentNotifications(t *testing.T) { // This is expected behavior - we're testing thread safety, not guaranteed delivery // Just verify we receive at least some messages without panicking or deadlocking received := 0 - timeout := time.After(2 * time.Second) + timeout := time.After(500 * time.Millisecond) for received < messageCount { select { case <-notifications: @@ -751,7 +752,7 @@ func Test_ConcurrentCleanup(t *testing.T) { server := newTestServerWithNotifier(t) // Create some clients - for i := 0; i < 10; i++ { + for i := range 10 { client, _ := server.GetClient("", "") // Set some to be old if i%2 == 0 { @@ -790,39 +791,34 @@ func Test_NoRaceConditions(t *testing.T) { var wg sync.WaitGroup // Create a few clients and read from them - for i := 0; i < 5; i++ { - wg.Add(1) - go func() { - defer wg.Done() + for range 5 { + wg.Go(func() { client, _ := server.GetClient("", "") notifications, stop := client.Listen() defer close(stop) // Actively read messages - timeout := time.After(2 * time.Second) + timeout := time.After(200 * time.Millisecond) for { select { case <-notifications: - // Keep reading + // Keep reading case <-timeout: return } } - }() + }) } // Send a few notifications - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 20; j++ { + wg.Go(func() { + for range 10 { server.NotifyAll(notify.Notification{ Message: "Stress test", }) - time.Sleep(50 * time.Millisecond) + time.Sleep(10 * time.Millisecond) } - }() - + }) wg.Wait() } @@ -948,7 +944,7 @@ func Test_ListenSignature(t *testing.T) { require.NotNil(t, stop) // notifications should be receive-only - _, ok := interface{}(notifications).(<-chan notify.Notification) + _, ok := any(notifications).(<-chan notify.Notification) require.True(t, ok, "notifications should be receive-only channel") // stop should be closeable @@ -964,7 +960,7 @@ func Test_BufferSize(t *testing.T) { defer close(stop) // Send 10 messages without reading (buffer size is 10) - for i := 0; i < 10; i++ { + for range 10 { server.NotifySub(notify.Notification{ Target: client.sub.ID, Message: "Buffered", @@ -975,7 +971,7 @@ func Test_BufferSize(t *testing.T) { time.Sleep(100 * time.Millisecond) // Read all 10 - for i := 0; i < 10; i++ { + for i := range 10 { select { case <-notifications: // Success diff --git a/hws/routes.go b/hws/routes.go index 22e59f9..fe6b62e 100644 --- a/hws/routes.go +++ b/hws/routes.go @@ -30,13 +30,13 @@ const ( MethodPATCH Method = "PATCH" ) -// Server.AddRoutes registers the page handlers for the server. +// AddRoutes registers the page handlers for the server. // At least one route must be provided. // If any route patterns (path + method) are defined multiple times, the first // instance will be added and any additional conflicts will be discarded. -func (server *Server) AddRoutes(routes ...Route) error { +func (s *Server) AddRoutes(routes ...Route) error { if len(routes) == 0 { - return errors.New("No routes provided") + return errors.New("no routes provided") } patterns := []string{} mux := http.NewServeMux() @@ -47,10 +47,10 @@ func (server *Server) AddRoutes(routes ...Route) error { } for _, method := range route.Methods { if !validMethod(method) { - return fmt.Errorf("Invalid method %s for path %s", method, route.Path) + return fmt.Errorf("invalid method %s for path %s", method, route.Path) } if route.Handler == nil { - return fmt.Errorf("No handler provided for %s %s", method, route.Path) + return fmt.Errorf("no handler provided for %s %s", method, route.Path) } pattern := fmt.Sprintf("%s %s", method, route.Path) if slices.Contains(patterns, pattern) { @@ -61,8 +61,8 @@ func (server *Server) AddRoutes(routes ...Route) error { } } - server.server.Handler = mux - server.routes = true + s.server.Handler = mux + s.routes = true return nil } diff --git a/hws/routes_test.go b/hws/routes_test.go index f643920..161d07b 100644 --- a/hws/routes_test.go +++ b/hws/routes_test.go @@ -18,7 +18,7 @@ func Test_AddRoutes(t *testing.T) { server := createTestServer(t, &buf) err := server.AddRoutes() assert.Error(t, err) - assert.Contains(t, err.Error(), "No routes provided") + assert.Contains(t, err.Error(), "no routes provided") }) t.Run("Single valid route", func(t *testing.T) { @@ -58,7 +58,7 @@ func Test_AddRoutes(t *testing.T) { Handler: handler, }) assert.Error(t, err) - assert.Contains(t, err.Error(), "Invalid method") + assert.Contains(t, err.Error(), "invalid method") }) t.Run("No handler provided", func(t *testing.T) { @@ -69,7 +69,7 @@ func Test_AddRoutes(t *testing.T) { Handler: nil, }) assert.Error(t, err) - assert.Contains(t, err.Error(), "No handler provided") + assert.Contains(t, err.Error(), "no handler provided") }) t.Run("All HTTP methods are valid", func(t *testing.T) { @@ -203,7 +203,7 @@ func Test_AddRoutes_MultipleMethods(t *testing.T) { Handler: handler, }) assert.Error(t, err) - assert.Contains(t, err.Error(), "Invalid method") + assert.Contains(t, err.Error(), "invalid method") }) t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) { diff --git a/hws/server.go b/hws/server.go index c5acb78..967f5c8 100644 --- a/hws/server.go +++ b/hws/server.go @@ -26,14 +26,14 @@ type Server struct { } // Ready returns a channel that is closed when the server is started -func (server *Server) Ready() <-chan struct{} { - return server.ready +func (s *Server) Ready() <-chan struct{} { + return s.ready } // IsReady checks if the server is running -func (server *Server) IsReady() bool { +func (s *Server) IsReady() bool { select { - case <-server.ready: + case <-s.ready: return true default: return false @@ -41,13 +41,13 @@ func (server *Server) IsReady() bool { } // Addr returns the server's network address -func (server *Server) Addr() string { - return server.server.Addr +func (s *Server) Addr() string { + return s.server.Addr } // Handler returns the server's HTTP handler for testing purposes -func (server *Server) Handler() http.Handler { - return server.server.Handler +func (s *Server) Handler() http.Handler { + return s.server.Handler } // NewServer returns a new hws.Server with the specified configuration. @@ -75,7 +75,7 @@ func NewServer(config *Config) (*Server, error) { valid := isValidHostname(config.Host) if !valid { - return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host) + return nil, fmt.Errorf("hostname '%s' is not valid", config.Host) } httpServer := &http.Server{ @@ -95,62 +95,64 @@ func NewServer(config *Config) (*Server, error) { return server, nil } -func (server *Server) Start(ctx context.Context) error { +func (s *Server) Start(ctx context.Context) error { if ctx == nil { return errors.New("Context cannot be nil") } - if !server.routes { + if !s.routes { return errors.New("Server.AddRoutes must be run before starting the server") } - if !server.middleware { - err := server.AddMiddleware() + if !s.middleware { + err := s.AddMiddleware() if err != nil { return errors.Wrap(err, "server.AddMiddleware") } } - server.startNotifier() + s.startNotifier() go func() { - if server.logger == nil { - fmt.Printf("Listening for requests on %s", server.server.Addr) + if s.logger == nil { + fmt.Printf("Listening for requests on %s", s.server.Addr) } else { - server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests") + s.logger.logger.Info().Str("address", s.server.Addr).Msg("Listening for requests") } - if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - if server.logger == nil { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + if s.logger == nil { fmt.Printf("Server encountered a fatal error: %s", err.Error()) } else { - server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"}) + s.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"}) } } }() - server.waitUntilReady(ctx) + s.waitUntilReady(ctx) return nil } -func (server *Server) Shutdown(ctx context.Context) error { - server.logger.logger.Debug().Dur("shutdown_delay", server.shutdowndelay).Msg("HWS Server shutting down") - server.NotifyAll(notify.Notification{ +func (s *Server) Shutdown(ctx context.Context) error { + if s.logger != nil { + s.logger.logger.Debug().Dur("shutdown_delay", s.shutdowndelay).Msg("HWS Server shutting down") + } + s.NotifyAll(notify.Notification{ Title: "Shutting down", - Message: fmt.Sprintf("Server is shutting down in %v", server.shutdowndelay), + Message: fmt.Sprintf("Server is shutting down in %v", s.shutdowndelay), Level: LevelShutdown, }) - <-time.NewTimer(server.shutdowndelay).C - if !server.IsReady() { + <-time.NewTimer(s.shutdowndelay).C + if !s.IsReady() { return errors.New("Server isn't running") } if ctx == nil { return errors.New("Context cannot be nil") } - err := server.server.Shutdown(ctx) + err := s.server.Shutdown(ctx) if err != nil { return errors.Wrap(err, "Failed to shutdown the server gracefully") } - server.closeNotifier() - server.ready = make(chan struct{}) + s.closeNotifier() + s.ready = make(chan struct{}) return nil } @@ -168,7 +170,7 @@ func isValidHostname(host string) bool { return false } -func (server *Server) waitUntilReady(ctx context.Context) error { +func (s *Server) waitUntilReady(ctx context.Context) error { ticker := time.NewTicker(50 * time.Millisecond) defer ticker.Stop() @@ -180,14 +182,14 @@ func (server *Server) waitUntilReady(ctx context.Context) error { return ctx.Err() case <-ticker.C: - resp, err := http.Get("http://" + server.server.Addr + "/healthz") + resp, err := http.Get("http://" + s.server.Addr + "/healthz") if err != nil { continue // not accepting yet } resp.Body.Close() if resp.StatusCode == http.StatusOK { - closeOnce.Do(func() { close(server.ready) }) + closeOnce.Do(func() { close(s.ready) }) return nil } } diff --git a/hws/server_test.go b/hws/server_test.go index a2dbbc8..68078cd 100644 --- a/hws/server_test.go +++ b/hws/server_test.go @@ -26,8 +26,9 @@ func randomPort() uint64 { func createTestServer(t *testing.T, w io.Writer) *hws.Server { server, err := hws.NewServer(&hws.Config{ - Host: "127.0.0.1", - Port: randomPort(), + Host: "127.0.0.1", + Port: randomPort(), + ShutdownDelay: 0, // No delay for tests }) require.NoError(t, err)