updated hws.ThrowError to not return an error and log it to console instead
fixed errors_test fixed tests
This commit is contained in:
@@ -13,12 +13,12 @@ import (
|
|||||||
func Test_ConfigFromEnv(t *testing.T) {
|
func Test_ConfigFromEnv(t *testing.T) {
|
||||||
t.Run("Default values when no env vars set", func(t *testing.T) {
|
t.Run("Default values when no env vars set", func(t *testing.T) {
|
||||||
// Clear any existing env vars
|
// Clear any existing env vars
|
||||||
os.Unsetenv("HWS_HOST")
|
_ = os.Unsetenv("HWS_HOST")
|
||||||
os.Unsetenv("HWS_PORT")
|
_ = os.Unsetenv("HWS_PORT")
|
||||||
os.Unsetenv("HWS_GZIP")
|
_ = os.Unsetenv("HWS_GZIP")
|
||||||
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||||
os.Unsetenv("HWS_WRITE_TIMEOUT")
|
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||||
os.Unsetenv("HWS_IDLE_TIMEOUT")
|
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -33,8 +33,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Custom host", func(t *testing.T) {
|
t.Run("Custom host", func(t *testing.T) {
|
||||||
os.Setenv("HWS_HOST", "192.168.1.1")
|
_ = os.Setenv("HWS_HOST", "192.168.1.1")
|
||||||
defer os.Unsetenv("HWS_HOST")
|
defer func() {
|
||||||
|
_ = os.Unsetenv("HWS_HOST")
|
||||||
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -42,8 +44,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Custom port", func(t *testing.T) {
|
t.Run("Custom port", func(t *testing.T) {
|
||||||
os.Setenv("HWS_PORT", "8080")
|
_ = os.Setenv("HWS_PORT", "8080")
|
||||||
defer os.Unsetenv("HWS_PORT")
|
defer func() {
|
||||||
|
_ = os.Unsetenv("HWS_PORT")
|
||||||
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -51,8 +55,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("GZIP enabled", func(t *testing.T) {
|
t.Run("GZIP enabled", func(t *testing.T) {
|
||||||
os.Setenv("HWS_GZIP", "true")
|
_ = os.Setenv("HWS_GZIP", "true")
|
||||||
defer os.Unsetenv("HWS_GZIP")
|
defer func() {
|
||||||
|
_ = os.Unsetenv("HWS_GZIP")
|
||||||
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -60,12 +66,14 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Custom timeouts", func(t *testing.T) {
|
t.Run("Custom timeouts", func(t *testing.T) {
|
||||||
os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
|
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
|
||||||
os.Setenv("HWS_WRITE_TIMEOUT", "30")
|
_ = os.Setenv("HWS_WRITE_TIMEOUT", "30")
|
||||||
os.Setenv("HWS_IDLE_TIMEOUT", "300")
|
_ = os.Setenv("HWS_IDLE_TIMEOUT", "300")
|
||||||
defer os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
defer func() {
|
||||||
defer os.Unsetenv("HWS_WRITE_TIMEOUT")
|
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||||
defer os.Unsetenv("HWS_IDLE_TIMEOUT")
|
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||||
|
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||||
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -75,19 +83,19 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("All custom values", func(t *testing.T) {
|
t.Run("All custom values", func(t *testing.T) {
|
||||||
os.Setenv("HWS_HOST", "0.0.0.0")
|
_ = os.Setenv("HWS_HOST", "0.0.0.0")
|
||||||
os.Setenv("HWS_PORT", "9000")
|
_ = os.Setenv("HWS_PORT", "9000")
|
||||||
os.Setenv("HWS_GZIP", "true")
|
_ = os.Setenv("HWS_GZIP", "true")
|
||||||
os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
|
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
|
||||||
os.Setenv("HWS_WRITE_TIMEOUT", "15")
|
_ = os.Setenv("HWS_WRITE_TIMEOUT", "15")
|
||||||
os.Setenv("HWS_IDLE_TIMEOUT", "180")
|
_ = os.Setenv("HWS_IDLE_TIMEOUT", "180")
|
||||||
defer func() {
|
defer func() {
|
||||||
os.Unsetenv("HWS_HOST")
|
_ = os.Unsetenv("HWS_HOST")
|
||||||
os.Unsetenv("HWS_PORT")
|
_ = os.Unsetenv("HWS_PORT")
|
||||||
os.Unsetenv("HWS_GZIP")
|
_ = os.Unsetenv("HWS_GZIP")
|
||||||
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||||
os.Unsetenv("HWS_WRITE_TIMEOUT")
|
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||||
os.Unsetenv("HWS_IDLE_TIMEOUT")
|
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Error to use with Server.ThrowError
|
// HWSError wraps an error with other information for use with HWS features
|
||||||
type HWSError struct {
|
type HWSError struct {
|
||||||
StatusCode int // HTTP Status code
|
StatusCode int // HTTP Status code
|
||||||
Message string // Error message
|
Message string // Error message
|
||||||
@@ -41,7 +41,7 @@ type ErrorPage interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddErrorPage registers a handler that returns an ErrorPage
|
// AddErrorPage registers a handler that returns an ErrorPage
|
||||||
func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
func (s *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError})
|
page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError})
|
||||||
@@ -56,7 +56,7 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
|||||||
return errors.New("Render method of the error page did not write anything to the response writer")
|
return errors.New("Render method of the error page did not write anything to the response writer")
|
||||||
}
|
}
|
||||||
|
|
||||||
server.errorPage = pageFunc
|
s.errorPage = pageFunc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +64,19 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
|||||||
// the error with the level specified by the HWSError.
|
// the error with the level specified by the HWSError.
|
||||||
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
|
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
|
||||||
// and the request chain should be terminated.
|
// and the request chain should be terminated.
|
||||||
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) error {
|
func (s *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) {
|
||||||
|
err := s.throwError(w, r, error)
|
||||||
|
if err != nil {
|
||||||
|
s.LogError(error)
|
||||||
|
s.LogError(HWSError{
|
||||||
|
Message: "Error occured during throwError",
|
||||||
|
Error: errors.Wrap(err, "s.throwError"),
|
||||||
|
Level: ErrorERROR,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) throwError(w http.ResponseWriter, r *http.Request, error HWSError) error {
|
||||||
if error.StatusCode <= 0 {
|
if error.StatusCode <= 0 {
|
||||||
return errors.New("HWSError.StatusCode cannot be 0.")
|
return errors.New("HWSError.StatusCode cannot be 0.")
|
||||||
}
|
}
|
||||||
@@ -77,32 +89,27 @@ func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error H
|
|||||||
if r == nil {
|
if r == nil {
|
||||||
return errors.New("Request cannot be nil")
|
return errors.New("Request cannot be nil")
|
||||||
}
|
}
|
||||||
if !server.IsReady() {
|
if !s.IsReady() {
|
||||||
return errors.New("ThrowError called before server started")
|
return errors.New("ThrowError called before server started")
|
||||||
}
|
}
|
||||||
w.WriteHeader(error.StatusCode)
|
w.WriteHeader(error.StatusCode)
|
||||||
server.LogError(error)
|
s.LogError(error)
|
||||||
if server.errorPage == nil {
|
if s.errorPage == nil {
|
||||||
server.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
|
s.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if error.RenderErrorPage {
|
if error.RenderErrorPage {
|
||||||
server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
|
s.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
|
||||||
errPage, err := server.errorPage(error)
|
errPage, err := s.errorPage(error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
|
s.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
|
||||||
}
|
}
|
||||||
err = errPage.Render(r.Context(), w)
|
err = errPage.Render(r.Context(), w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
server.LogError(HWSError{Message: "Failed to render error page", Error: err})
|
s.LogError(HWSError{Message: "Failed to render error page", Error: err})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
server.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
|
s.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) ThrowFatal(w http.ResponseWriter, err error) {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
server.LogFatal(err)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -14,22 +14,26 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
type goodPage struct{}
|
type (
|
||||||
type badPage struct{}
|
goodPage struct{}
|
||||||
|
badPage struct{}
|
||||||
|
)
|
||||||
|
|
||||||
func goodRender(error hws.HWSError) (hws.ErrorPage, error) {
|
func goodRender(error hws.HWSError) (hws.ErrorPage, error) {
|
||||||
return goodPage{}, nil
|
return goodPage{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func badRender1(error hws.HWSError) (hws.ErrorPage, error) {
|
func badRender1(error hws.HWSError) (hws.ErrorPage, error) {
|
||||||
return badPage{}, nil
|
return badPage{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func badRender2(error hws.HWSError) (hws.ErrorPage, error) {
|
func badRender2(error hws.HWSError) (hws.ErrorPage, error) {
|
||||||
return nil, errors.New("I'm an error")
|
return nil, errors.New("I'm an error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g goodPage) Render(ctx context.Context, w io.Writer) error {
|
func (g goodPage) Render(ctx context.Context, w io.Writer) error {
|
||||||
w.Write([]byte("Test write to ResponseWriter"))
|
_, err := w.Write([]byte("Test write to ResponseWriter"))
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b badPage) Render(ctx context.Context, w io.Writer) error {
|
func (b badPage) Render(ctx context.Context, w io.Writer) error {
|
||||||
@@ -85,40 +89,42 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
t.Run("Server not started", func(t *testing.T) {
|
t.Run("Server not started", func(t *testing.T) {
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
buf.Reset()
|
||||||
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "Error",
|
Message: "Error",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
// ThrowError logs errors internally when validation fails
|
||||||
|
output := buf.String()
|
||||||
|
assert.Contains(t, output, "ThrowError called before server started")
|
||||||
})
|
})
|
||||||
|
|
||||||
startTestServer(t, server)
|
startTestServer(t, server)
|
||||||
defer server.Shutdown(t.Context())
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
request *http.Request
|
request *http.Request
|
||||||
error hws.HWSError
|
error hws.HWSError
|
||||||
valid bool
|
expectLogItem string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "No HWSError.Status code",
|
name: "No HWSError.Status code",
|
||||||
request: nil,
|
request: nil,
|
||||||
error: hws.HWSError{},
|
error: hws.HWSError{},
|
||||||
valid: false,
|
expectLogItem: "HWSError.StatusCode cannot be 0",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Negative HWSError.Status code",
|
name: "Negative HWSError.Status code",
|
||||||
request: nil,
|
request: nil,
|
||||||
error: hws.HWSError{StatusCode: -1},
|
error: hws.HWSError{StatusCode: -1},
|
||||||
valid: false,
|
expectLogItem: "HWSError.StatusCode cannot be 0",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "No HWSError.Message",
|
name: "No HWSError.Message",
|
||||||
request: nil,
|
request: nil,
|
||||||
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
|
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
|
||||||
valid: false,
|
expectLogItem: "HWSError.Message cannot be empty",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "No HWSError.Error",
|
name: "No HWSError.Error",
|
||||||
@@ -127,7 +133,7 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
},
|
},
|
||||||
valid: false,
|
expectLogItem: "HWSError.Error cannot be nil",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "No request provided",
|
name: "No request provided",
|
||||||
@@ -137,7 +143,7 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
},
|
},
|
||||||
valid: false,
|
expectLogItem: "Request cannot be nil",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Valid",
|
name: "Valid",
|
||||||
@@ -147,106 +153,92 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
},
|
},
|
||||||
valid: true,
|
expectLogItem: "An error occured",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
buf.Reset()
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
err := server.ThrowError(rr, tt.request, tt.error)
|
server.ThrowError(rr, tt.request, tt.error)
|
||||||
if tt.valid {
|
// ThrowError no longer returns errors; check logs instead
|
||||||
assert.NoError(t, err)
|
output := buf.String()
|
||||||
} else {
|
assert.Contains(t, output, tt.expectLogItem)
|
||||||
t.Log(err)
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
t.Run("Log level set correctly", func(t *testing.T) {
|
t.Run("Log level set correctly", func(t *testing.T) {
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
Level: hws.ErrorWARN,
|
Level: hws.ErrorWARN,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
_, err := buf.ReadString([]byte(" ")[0])
|
||||||
_, err = buf.ReadString([]byte(" ")[0])
|
require.NoError(t, err)
|
||||||
loglvl, err := buf.ReadString([]byte(" ")[0])
|
loglvl, err := buf.ReadString([]byte(" ")[0])
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if loglvl != "\x1b[33mWRN\x1b[0m " {
|
assert.Equal(t, "\x1b[33mWRN\x1b[0m ", loglvl, "Log level should be WRN for ErrorWARN")
|
||||||
err = errors.New("Log level not set correctly")
|
|
||||||
}
|
|
||||||
assert.NoError(t, err)
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
err = server.ThrowError(rr, req, hws.HWSError{
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
|
||||||
_, err = buf.ReadString([]byte(" ")[0])
|
_, err = buf.ReadString([]byte(" ")[0])
|
||||||
|
require.NoError(t, err)
|
||||||
loglvl, err = buf.ReadString([]byte(" ")[0])
|
loglvl, err = buf.ReadString([]byte(" ")[0])
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if loglvl != "\x1b[31mERR\x1b[0m " {
|
assert.Equal(t, "\x1b[31mERR\x1b[0m ", loglvl, "Log level should be ERR when no level specified")
|
||||||
err = errors.New("Log level not set correctly")
|
|
||||||
}
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
|
t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
|
||||||
// Must be run before adding the error page to the test server
|
// Must be run before adding the error page to the test server
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
RenderErrorPage: true,
|
RenderErrorPage: true,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
if body != "" {
|
assert.Empty(t, body, "Error page should not render when no error page is set")
|
||||||
assert.Error(t, nil)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
t.Run("Error page renders", func(t *testing.T) {
|
t.Run("Error page renders", func(t *testing.T) {
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
// Adding the error page will carry over to all future tests and cant be undone
|
// Adding the error page will carry over to all future tests and cant be undone
|
||||||
server.AddErrorPage(goodRender)
|
err := server.AddErrorPage(goodRender)
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
require.NoError(t, err)
|
||||||
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
RenderErrorPage: true,
|
RenderErrorPage: true,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
if body == "" {
|
assert.NotEmpty(t, body, "Error page should render when RenderErrorPage is true")
|
||||||
assert.Error(t, nil)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
t.Run("Error page doesnt render if no told to render", func(t *testing.T) {
|
t.Run("Error page doesnt render if not told to render", func(t *testing.T) {
|
||||||
// Error page already added to server
|
// Error page already added to server
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
if body != "" {
|
assert.Empty(t, body, "Error page should not render when RenderErrorPage is false")
|
||||||
assert.Error(t, nil)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
server.Shutdown(t.Context())
|
err := server.Shutdown(t.Context())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("Doesn't error if no logger added to server", func(t *testing.T) {
|
t.Run("Doesn't panic if no logger added to server", func(t *testing.T) {
|
||||||
server, err := hws.NewServer(&hws.Config{
|
server, err := hws.NewServer(&hws.Config{
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: randomPort(),
|
Port: randomPort(),
|
||||||
@@ -261,13 +253,18 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
err = server.Start(t.Context())
|
err = server.Start(t.Context())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
<-server.Ready()
|
<-server.Ready()
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
err = server.ThrowError(rr, req, hws.HWSError{
|
// Should not panic when no logger is present
|
||||||
StatusCode: http.StatusInternalServerError,
|
assert.NotPanics(t, func() {
|
||||||
Message: "An error occured",
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
Error: errors.New("Error"),
|
StatusCode: http.StatusInternalServerError,
|
||||||
})
|
Message: "An error occured",
|
||||||
assert.NoError(t, err)
|
Error: errors.New("Error"),
|
||||||
|
})
|
||||||
|
}, "ThrowError should not panic when no logger is present")
|
||||||
|
err = server.Shutdown(t.Context())
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||||
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
|
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||||
return func() (interface{}, error) {
|
return func() (any, error) {
|
||||||
return ConfigFromEnv()
|
return ConfigFromEnv()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,23 +43,12 @@ func (s *Server) LogError(err HWSError) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) LogFatal(err error) {
|
|
||||||
if err == nil {
|
|
||||||
err = errors.New("LogFatal was called with a nil error")
|
|
||||||
}
|
|
||||||
if server.logger == nil {
|
|
||||||
fmt.Printf("FATAL - %s: %s", "A fatal error has occured", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
server.logger.logger.Fatal().Err(err).Msg("A fatal error has occured")
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddLogger adds a logger to the server to use for request logging.
|
// AddLogger adds a logger to the server to use for request logging.
|
||||||
func (server *Server) AddLogger(hlogger *hlog.Logger) error {
|
func (s *Server) AddLogger(hlogger *hlog.Logger) error {
|
||||||
if hlogger == nil {
|
if hlogger == nil {
|
||||||
return errors.New("unable to add logger, no logger provided")
|
return errors.New("unable to add logger, no logger provided")
|
||||||
}
|
}
|
||||||
server.logger = &logger{
|
s.logger = &logger{
|
||||||
logger: hlogger,
|
logger: hlogger,
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -68,7 +57,7 @@ func (server *Server) AddLogger(hlogger *hlog.Logger) error {
|
|||||||
// LoggerIgnorePaths sets a list of URL paths to ignore logging for.
|
// LoggerIgnorePaths sets a list of URL paths to ignore logging for.
|
||||||
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
|
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
|
||||||
// Useful for ignoring requests to CSS files or favicons
|
// Useful for ignoring requests to CSS files or favicons
|
||||||
func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
func (s *Server) LoggerIgnorePaths(paths ...string) error {
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
u, err := url.Parse(path)
|
u, err := url.Parse(path)
|
||||||
valid := err == nil &&
|
valid := err == nil &&
|
||||||
@@ -80,7 +69,7 @@ func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
|||||||
return fmt.Errorf("invalid path: '%s'", path)
|
return fmt.Errorf("invalid path: '%s'", path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
server.logger.ignoredPaths = prepareGlobs(paths)
|
s.logger.ignoredPaths = prepareGlobs(paths)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
|||||||
|
|
||||||
err := server.LoggerIgnorePaths("http://example.com/path")
|
err := server.LoggerIgnorePaths("http://example.com/path")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid path")
|
assert.Contains(t, err.Error(), "invalid path")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Invalid path with host", func(t *testing.T) {
|
t.Run("Invalid path with host", func(t *testing.T) {
|
||||||
@@ -207,7 +207,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
|||||||
err := server.LoggerIgnorePaths("//example.com/path")
|
err := server.LoggerIgnorePaths("//example.com/path")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
assert.Contains(t, err.Error(), "Invalid path")
|
assert.Contains(t, err.Error(), "invalid path")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -217,7 +217,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
|||||||
|
|
||||||
err := server.LoggerIgnorePaths("/path?query=value")
|
err := server.LoggerIgnorePaths("/path?query=value")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid path")
|
assert.Contains(t, err.Error(), "invalid path")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Invalid path with fragment", func(t *testing.T) {
|
t.Run("Invalid path with fragment", func(t *testing.T) {
|
||||||
@@ -226,7 +226,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
|||||||
|
|
||||||
err := server.LoggerIgnorePaths("/path#fragment")
|
err := server.LoggerIgnorePaths("/path#fragment")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid path")
|
assert.Contains(t, err.Error(), "invalid path")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Valid paths", func(t *testing.T) {
|
t.Run("Valid paths", func(t *testing.T) {
|
||||||
|
|||||||
@@ -5,35 +5,37 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Middleware func(h http.Handler) http.Handler
|
type (
|
||||||
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
Middleware func(h http.Handler) http.Handler
|
||||||
|
MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
||||||
|
)
|
||||||
|
|
||||||
// Server.AddMiddleware registers all the middleware.
|
// AddMiddleware registers all the middleware.
|
||||||
// Middleware will be run in the order that they are provided.
|
// Middleware will be run in the order that they are provided.
|
||||||
// Can only be called once
|
// Can only be called once
|
||||||
func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
func (s *Server) AddMiddleware(middleware ...Middleware) error {
|
||||||
if !server.routes {
|
if !s.routes {
|
||||||
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
|
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
|
||||||
}
|
}
|
||||||
if server.middleware {
|
if s.middleware {
|
||||||
return errors.New("Server.AddMiddleware already called")
|
return errors.New("Server.AddMiddleware already called")
|
||||||
}
|
}
|
||||||
// RUN LOGGING MIDDLEWARE FIRST
|
// RUN LOGGING MIDDLEWARE FIRST
|
||||||
server.server.Handler = logging(server.server.Handler, server.logger)
|
s.server.Handler = logging(s.server.Handler, s.logger)
|
||||||
|
|
||||||
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
|
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
|
||||||
for i := len(middleware); i > 0; i-- {
|
for i := len(middleware); i > 0; i-- {
|
||||||
server.server.Handler = middleware[i-1](server.server.Handler)
|
s.server.Handler = middleware[i-1](s.server.Handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RUN GZIP
|
// RUN GZIP
|
||||||
if server.GZIP {
|
if s.GZIP {
|
||||||
server.server.Handler = addgzip(server.server.Handler)
|
s.server.Handler = addgzip(s.server.Handler)
|
||||||
}
|
}
|
||||||
// RUN TIMER MIDDLEWARE LAST
|
// RUN TIMER MIDDLEWARE LAST
|
||||||
server.server.Handler = startTimer(server.server.Handler)
|
s.server.Handler = startTimer(s.server.Handler)
|
||||||
|
|
||||||
server.middleware = true
|
s.middleware = true
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -43,14 +45,14 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
|||||||
// and returns a new request and optional HWSError.
|
// and returns a new request and optional HWSError.
|
||||||
// If a HWSError is returned, server.ThrowError will be called.
|
// If a HWSError is returned, server.ThrowError will be called.
|
||||||
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
|
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
|
||||||
func (server *Server) NewMiddleware(
|
func (s *Server) NewMiddleware(
|
||||||
middlewareFunc MiddlewareFunc,
|
middlewareFunc MiddlewareFunc,
|
||||||
) Middleware {
|
) Middleware {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
newReq, herr := middlewareFunc(w, r)
|
newReq, herr := middlewareFunc(w, r)
|
||||||
if herr != nil {
|
if herr != nil {
|
||||||
server.ThrowError(w, r, *herr)
|
s.ThrowError(w, r, *herr)
|
||||||
if herr.RenderErrorPage {
|
if herr.RenderErrorPage {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,16 +18,24 @@ func startTimer(next http.Handler) http.Handler {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
func (c contextKey) String() string {
|
||||||
|
return "hws context key " + string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestTimerCtxKey = contextKey("request-timer")
|
||||||
|
|
||||||
// Set the start time of the request
|
// Set the start time of the request
|
||||||
func setStart(ctx context.Context, time time.Time) context.Context {
|
func setStart(ctx context.Context, time time.Time) context.Context {
|
||||||
return context.WithValue(ctx, "hws context key request-timer", time)
|
return context.WithValue(ctx, requestTimerCtxKey, time)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the start time of the request
|
// Get the start time of the request
|
||||||
func getStartTime(ctx context.Context) (time.Time, error) {
|
func getStartTime(ctx context.Context) (time.Time, error) {
|
||||||
start, ok := ctx.Value("hws context key request-timer").(time.Time)
|
start, ok := ctx.Value(requestTimerCtxKey).(time.Time)
|
||||||
if !ok {
|
if !ok {
|
||||||
return time.Time{}, errors.New("Failed to get start time of request")
|
return time.Time{}, errors.New("failed to get start time of request")
|
||||||
}
|
}
|
||||||
return start, nil
|
return start, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ func (s *Server) NotifySub(nt notify.Notification) {
|
|||||||
}
|
}
|
||||||
_, exists := s.notifier.clients.getClient(nt.Target)
|
_, exists := s.notifier.clients.getClient(nt.Target)
|
||||||
if !exists {
|
if !exists {
|
||||||
err := fmt.Errorf("Tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
|
err := fmt.Errorf("tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
|
||||||
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -119,7 +119,7 @@ func (s *Server) NotifyID(nt notify.Notification, altID string) {
|
|||||||
clients, exists := s.notifier.clients.clientsIDMap[altID]
|
clients, exists := s.notifier.clients.clientsIDMap[altID]
|
||||||
s.notifier.clients.lock.RUnlock()
|
s.notifier.clients.lock.RUnlock()
|
||||||
if !exists {
|
if !exists {
|
||||||
err := fmt.Errorf("Tried to notify client group that doesn't exist - altID: %s", altID)
|
err := fmt.Errorf("tried to notify client group that doesn't exist - altID: %s", altID)
|
||||||
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,8 +15,9 @@ func newTestServerWithNotifier(t *testing.T) *Server {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: 0,
|
Port: 0,
|
||||||
|
ShutdownDelay: 0, // No delay for tests
|
||||||
}
|
}
|
||||||
|
|
||||||
server, err := NewServer(cfg)
|
server, err := NewServer(cfg)
|
||||||
@@ -359,7 +360,7 @@ func Test_ActiveClientStaysAlive(t *testing.T) {
|
|||||||
|
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < 3; i++ {
|
for range 3 {
|
||||||
<-ticker.C
|
<-ticker.C
|
||||||
server.NotifySub(notify.Notification{
|
server.NotifySub(notify.Notification{
|
||||||
Target: client.sub.ID,
|
Target: client.sub.ID,
|
||||||
@@ -460,7 +461,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
|
|||||||
defer close(stop)
|
defer close(stop)
|
||||||
|
|
||||||
// Send 10 notifications quickly (buffer is 10)
|
// Send 10 notifications quickly (buffer is 10)
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
server.NotifySub(notify.Notification{
|
server.NotifySub(notify.Notification{
|
||||||
Target: client.sub.ID,
|
Target: client.sub.ID,
|
||||||
Message: "Burst message",
|
Message: "Burst message",
|
||||||
@@ -468,7 +469,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Client should receive all 10
|
// Client should receive all 10
|
||||||
for i := 0; i < 10; i++ {
|
for i := range 10 {
|
||||||
select {
|
select {
|
||||||
case <-notifications:
|
case <-notifications:
|
||||||
// Received
|
// Received
|
||||||
@@ -487,7 +488,7 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
|
|||||||
defer close(stop)
|
defer close(stop)
|
||||||
|
|
||||||
// Fill buffer completely (buffer is 10)
|
// Fill buffer completely (buffer is 10)
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
server.NotifySub(notify.Notification{
|
server.NotifySub(notify.Notification{
|
||||||
Target: client.sub.ID,
|
Target: client.sub.ID,
|
||||||
Message: "Fill buffer",
|
Message: "Fill buffer",
|
||||||
@@ -500,15 +501,15 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
|
|||||||
Message: "Timeout message",
|
Message: "Timeout message",
|
||||||
})
|
})
|
||||||
|
|
||||||
// Wait for timeout
|
// Wait for timeout (5s timeout + small buffer)
|
||||||
time.Sleep(6 * time.Second)
|
time.Sleep(5100 * time.Millisecond)
|
||||||
|
|
||||||
// Check failure count (should be 1)
|
// Check failure count (should be 1)
|
||||||
fails := atomic.LoadInt32(&client.consecutiveFails)
|
fails := atomic.LoadInt32(&client.consecutiveFails)
|
||||||
require.Equal(t, int32(1), fails, "Should have 1 timeout")
|
require.Equal(t, int32(1), fails, "Should have 1 timeout")
|
||||||
|
|
||||||
// Now read all buffered messages
|
// Now read all buffered messages
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
<-notifications
|
<-notifications
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -538,7 +539,7 @@ func Test_ConsecutiveFailureDisconnect(t *testing.T) {
|
|||||||
defer close(stop)
|
defer close(stop)
|
||||||
|
|
||||||
// Fill buffer and never read to cause 5 consecutive timeouts
|
// Fill buffer and never read to cause 5 consecutive timeouts
|
||||||
for i := 0; i < 20; i++ {
|
for range 20 {
|
||||||
server.NotifySub(notify.Notification{
|
server.NotifySub(notify.Notification{
|
||||||
Target: client.sub.ID,
|
Target: client.sub.ID,
|
||||||
Message: "Timeout message",
|
Message: "Timeout message",
|
||||||
@@ -684,7 +685,7 @@ func Test_ConcurrentSubscriptions(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
clients := make([]*Client, 100)
|
clients := make([]*Client, 100)
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
for i := range 100 {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(index int) {
|
go func(index int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@@ -716,7 +717,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
|
|||||||
messageCount := 50
|
messageCount := 50
|
||||||
|
|
||||||
// Send from multiple goroutines
|
// Send from multiple goroutines
|
||||||
for i := 0; i < messageCount; i++ {
|
for i := range messageCount {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(index int) {
|
go func(index int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@@ -733,7 +734,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
|
|||||||
// This is expected behavior - we're testing thread safety, not guaranteed delivery
|
// This is expected behavior - we're testing thread safety, not guaranteed delivery
|
||||||
// Just verify we receive at least some messages without panicking or deadlocking
|
// Just verify we receive at least some messages without panicking or deadlocking
|
||||||
received := 0
|
received := 0
|
||||||
timeout := time.After(2 * time.Second)
|
timeout := time.After(500 * time.Millisecond)
|
||||||
for received < messageCount {
|
for received < messageCount {
|
||||||
select {
|
select {
|
||||||
case <-notifications:
|
case <-notifications:
|
||||||
@@ -751,7 +752,7 @@ func Test_ConcurrentCleanup(t *testing.T) {
|
|||||||
server := newTestServerWithNotifier(t)
|
server := newTestServerWithNotifier(t)
|
||||||
|
|
||||||
// Create some clients
|
// Create some clients
|
||||||
for i := 0; i < 10; i++ {
|
for i := range 10 {
|
||||||
client, _ := server.GetClient("", "")
|
client, _ := server.GetClient("", "")
|
||||||
// Set some to be old
|
// Set some to be old
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
@@ -790,39 +791,34 @@ func Test_NoRaceConditions(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
// Create a few clients and read from them
|
// Create a few clients and read from them
|
||||||
for i := 0; i < 5; i++ {
|
for range 5 {
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
client, _ := server.GetClient("", "")
|
client, _ := server.GetClient("", "")
|
||||||
notifications, stop := client.Listen()
|
notifications, stop := client.Listen()
|
||||||
defer close(stop)
|
defer close(stop)
|
||||||
|
|
||||||
// Actively read messages
|
// Actively read messages
|
||||||
timeout := time.After(2 * time.Second)
|
timeout := time.After(200 * time.Millisecond)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-notifications:
|
case <-notifications:
|
||||||
// Keep reading
|
// Keep reading
|
||||||
case <-timeout:
|
case <-timeout:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a few notifications
|
// Send a few notifications
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
go func() {
|
for range 10 {
|
||||||
defer wg.Done()
|
|
||||||
for j := 0; j < 20; j++ {
|
|
||||||
server.NotifyAll(notify.Notification{
|
server.NotifyAll(notify.Notification{
|
||||||
Message: "Stress test",
|
Message: "Stress test",
|
||||||
})
|
})
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -948,7 +944,7 @@ func Test_ListenSignature(t *testing.T) {
|
|||||||
require.NotNil(t, stop)
|
require.NotNil(t, stop)
|
||||||
|
|
||||||
// notifications should be receive-only
|
// notifications should be receive-only
|
||||||
_, ok := interface{}(notifications).(<-chan notify.Notification)
|
_, ok := any(notifications).(<-chan notify.Notification)
|
||||||
require.True(t, ok, "notifications should be receive-only channel")
|
require.True(t, ok, "notifications should be receive-only channel")
|
||||||
|
|
||||||
// stop should be closeable
|
// stop should be closeable
|
||||||
@@ -964,7 +960,7 @@ func Test_BufferSize(t *testing.T) {
|
|||||||
defer close(stop)
|
defer close(stop)
|
||||||
|
|
||||||
// Send 10 messages without reading (buffer size is 10)
|
// Send 10 messages without reading (buffer size is 10)
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
server.NotifySub(notify.Notification{
|
server.NotifySub(notify.Notification{
|
||||||
Target: client.sub.ID,
|
Target: client.sub.ID,
|
||||||
Message: "Buffered",
|
Message: "Buffered",
|
||||||
@@ -975,7 +971,7 @@ func Test_BufferSize(t *testing.T) {
|
|||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
// Read all 10
|
// Read all 10
|
||||||
for i := 0; i < 10; i++ {
|
for i := range 10 {
|
||||||
select {
|
select {
|
||||||
case <-notifications:
|
case <-notifications:
|
||||||
// Success
|
// Success
|
||||||
|
|||||||
@@ -30,13 +30,13 @@ const (
|
|||||||
MethodPATCH Method = "PATCH"
|
MethodPATCH Method = "PATCH"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server.AddRoutes registers the page handlers for the server.
|
// AddRoutes registers the page handlers for the server.
|
||||||
// At least one route must be provided.
|
// At least one route must be provided.
|
||||||
// If any route patterns (path + method) are defined multiple times, the first
|
// If any route patterns (path + method) are defined multiple times, the first
|
||||||
// instance will be added and any additional conflicts will be discarded.
|
// instance will be added and any additional conflicts will be discarded.
|
||||||
func (server *Server) AddRoutes(routes ...Route) error {
|
func (s *Server) AddRoutes(routes ...Route) error {
|
||||||
if len(routes) == 0 {
|
if len(routes) == 0 {
|
||||||
return errors.New("No routes provided")
|
return errors.New("no routes provided")
|
||||||
}
|
}
|
||||||
patterns := []string{}
|
patterns := []string{}
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
@@ -47,10 +47,10 @@ func (server *Server) AddRoutes(routes ...Route) error {
|
|||||||
}
|
}
|
||||||
for _, method := range route.Methods {
|
for _, method := range route.Methods {
|
||||||
if !validMethod(method) {
|
if !validMethod(method) {
|
||||||
return fmt.Errorf("Invalid method %s for path %s", method, route.Path)
|
return fmt.Errorf("invalid method %s for path %s", method, route.Path)
|
||||||
}
|
}
|
||||||
if route.Handler == nil {
|
if route.Handler == nil {
|
||||||
return fmt.Errorf("No handler provided for %s %s", method, route.Path)
|
return fmt.Errorf("no handler provided for %s %s", method, route.Path)
|
||||||
}
|
}
|
||||||
pattern := fmt.Sprintf("%s %s", method, route.Path)
|
pattern := fmt.Sprintf("%s %s", method, route.Path)
|
||||||
if slices.Contains(patterns, pattern) {
|
if slices.Contains(patterns, pattern) {
|
||||||
@@ -61,8 +61,8 @@ func (server *Server) AddRoutes(routes ...Route) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server.server.Handler = mux
|
s.server.Handler = mux
|
||||||
server.routes = true
|
s.routes = true
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func Test_AddRoutes(t *testing.T) {
|
|||||||
server := createTestServer(t, &buf)
|
server := createTestServer(t, &buf)
|
||||||
err := server.AddRoutes()
|
err := server.AddRoutes()
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "No routes provided")
|
assert.Contains(t, err.Error(), "no routes provided")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Single valid route", func(t *testing.T) {
|
t.Run("Single valid route", func(t *testing.T) {
|
||||||
@@ -58,7 +58,7 @@ func Test_AddRoutes(t *testing.T) {
|
|||||||
Handler: handler,
|
Handler: handler,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid method")
|
assert.Contains(t, err.Error(), "invalid method")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("No handler provided", func(t *testing.T) {
|
t.Run("No handler provided", func(t *testing.T) {
|
||||||
@@ -69,7 +69,7 @@ func Test_AddRoutes(t *testing.T) {
|
|||||||
Handler: nil,
|
Handler: nil,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "No handler provided")
|
assert.Contains(t, err.Error(), "no handler provided")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("All HTTP methods are valid", func(t *testing.T) {
|
t.Run("All HTTP methods are valid", func(t *testing.T) {
|
||||||
@@ -203,7 +203,7 @@ func Test_AddRoutes_MultipleMethods(t *testing.T) {
|
|||||||
Handler: handler,
|
Handler: handler,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid method")
|
assert.Contains(t, err.Error(), "invalid method")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {
|
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {
|
||||||
|
|||||||
@@ -26,14 +26,14 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ready returns a channel that is closed when the server is started
|
// Ready returns a channel that is closed when the server is started
|
||||||
func (server *Server) Ready() <-chan struct{} {
|
func (s *Server) Ready() <-chan struct{} {
|
||||||
return server.ready
|
return s.ready
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsReady checks if the server is running
|
// IsReady checks if the server is running
|
||||||
func (server *Server) IsReady() bool {
|
func (s *Server) IsReady() bool {
|
||||||
select {
|
select {
|
||||||
case <-server.ready:
|
case <-s.ready:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
@@ -41,13 +41,13 @@ func (server *Server) IsReady() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Addr returns the server's network address
|
// Addr returns the server's network address
|
||||||
func (server *Server) Addr() string {
|
func (s *Server) Addr() string {
|
||||||
return server.server.Addr
|
return s.server.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler returns the server's HTTP handler for testing purposes
|
// Handler returns the server's HTTP handler for testing purposes
|
||||||
func (server *Server) Handler() http.Handler {
|
func (s *Server) Handler() http.Handler {
|
||||||
return server.server.Handler
|
return s.server.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer returns a new hws.Server with the specified configuration.
|
// NewServer returns a new hws.Server with the specified configuration.
|
||||||
@@ -75,7 +75,7 @@ func NewServer(config *Config) (*Server, error) {
|
|||||||
|
|
||||||
valid := isValidHostname(config.Host)
|
valid := isValidHostname(config.Host)
|
||||||
if !valid {
|
if !valid {
|
||||||
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host)
|
return nil, fmt.Errorf("hostname '%s' is not valid", config.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
@@ -95,62 +95,64 @@ func NewServer(config *Config) (*Server, error) {
|
|||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) Start(ctx context.Context) error {
|
func (s *Server) Start(ctx context.Context) error {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return errors.New("Context cannot be nil")
|
return errors.New("Context cannot be nil")
|
||||||
}
|
}
|
||||||
if !server.routes {
|
if !s.routes {
|
||||||
return errors.New("Server.AddRoutes must be run before starting the server")
|
return errors.New("Server.AddRoutes must be run before starting the server")
|
||||||
}
|
}
|
||||||
if !server.middleware {
|
if !s.middleware {
|
||||||
err := server.AddMiddleware()
|
err := s.AddMiddleware()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "server.AddMiddleware")
|
return errors.Wrap(err, "server.AddMiddleware")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server.startNotifier()
|
s.startNotifier()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if server.logger == nil {
|
if s.logger == nil {
|
||||||
fmt.Printf("Listening for requests on %s", server.server.Addr)
|
fmt.Printf("Listening for requests on %s", s.server.Addr)
|
||||||
} else {
|
} else {
|
||||||
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
|
s.logger.logger.Info().Str("address", s.server.Addr).Msg("Listening for requests")
|
||||||
}
|
}
|
||||||
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
if server.logger == nil {
|
if s.logger == nil {
|
||||||
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
||||||
} else {
|
} else {
|
||||||
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
s.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
server.waitUntilReady(ctx)
|
s.waitUntilReady(ctx)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) Shutdown(ctx context.Context) error {
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
server.logger.logger.Debug().Dur("shutdown_delay", server.shutdowndelay).Msg("HWS Server shutting down")
|
if s.logger != nil {
|
||||||
server.NotifyAll(notify.Notification{
|
s.logger.logger.Debug().Dur("shutdown_delay", s.shutdowndelay).Msg("HWS Server shutting down")
|
||||||
|
}
|
||||||
|
s.NotifyAll(notify.Notification{
|
||||||
Title: "Shutting down",
|
Title: "Shutting down",
|
||||||
Message: fmt.Sprintf("Server is shutting down in %v", server.shutdowndelay),
|
Message: fmt.Sprintf("Server is shutting down in %v", s.shutdowndelay),
|
||||||
Level: LevelShutdown,
|
Level: LevelShutdown,
|
||||||
})
|
})
|
||||||
<-time.NewTimer(server.shutdowndelay).C
|
<-time.NewTimer(s.shutdowndelay).C
|
||||||
if !server.IsReady() {
|
if !s.IsReady() {
|
||||||
return errors.New("Server isn't running")
|
return errors.New("Server isn't running")
|
||||||
}
|
}
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return errors.New("Context cannot be nil")
|
return errors.New("Context cannot be nil")
|
||||||
}
|
}
|
||||||
err := server.server.Shutdown(ctx)
|
err := s.server.Shutdown(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Failed to shutdown the server gracefully")
|
return errors.Wrap(err, "Failed to shutdown the server gracefully")
|
||||||
}
|
}
|
||||||
server.closeNotifier()
|
s.closeNotifier()
|
||||||
server.ready = make(chan struct{})
|
s.ready = make(chan struct{})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,7 +170,7 @@ func isValidHostname(host string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) waitUntilReady(ctx context.Context) error {
|
func (s *Server) waitUntilReady(ctx context.Context) error {
|
||||||
ticker := time.NewTicker(50 * time.Millisecond)
|
ticker := time.NewTicker(50 * time.Millisecond)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -180,14 +182,14 @@ func (server *Server) waitUntilReady(ctx context.Context) error {
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
|
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
resp, err := http.Get("http://" + server.server.Addr + "/healthz")
|
resp, err := http.Get("http://" + s.server.Addr + "/healthz")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue // not accepting yet
|
continue // not accepting yet
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
if resp.StatusCode == http.StatusOK {
|
||||||
closeOnce.Do(func() { close(server.ready) })
|
closeOnce.Do(func() { close(s.ready) })
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,8 +26,9 @@ func randomPort() uint64 {
|
|||||||
|
|
||||||
func createTestServer(t *testing.T, w io.Writer) *hws.Server {
|
func createTestServer(t *testing.T, w io.Writer) *hws.Server {
|
||||||
server, err := hws.NewServer(&hws.Config{
|
server, err := hws.NewServer(&hws.Config{
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: randomPort(),
|
Port: randomPort(),
|
||||||
|
ShutdownDelay: 0, // No delay for tests
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user