Compare commits
2 Commits
hwsauth/v0
...
hwsauth/v0
| Author | SHA1 | Date | |
|---|---|---|---|
| 525b3b1396 | |||
| 563908bbb4 |
@@ -13,12 +13,12 @@ import (
|
||||
func Test_ConfigFromEnv(t *testing.T) {
|
||||
t.Run("Default values when no env vars set", func(t *testing.T) {
|
||||
// Clear any existing env vars
|
||||
os.Unsetenv("HWS_HOST")
|
||||
os.Unsetenv("HWS_PORT")
|
||||
os.Unsetenv("HWS_GZIP")
|
||||
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_HOST")
|
||||
_ = os.Unsetenv("HWS_PORT")
|
||||
_ = os.Unsetenv("HWS_GZIP")
|
||||
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -33,8 +33,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Custom host", func(t *testing.T) {
|
||||
os.Setenv("HWS_HOST", "192.168.1.1")
|
||||
defer os.Unsetenv("HWS_HOST")
|
||||
_ = os.Setenv("HWS_HOST", "192.168.1.1")
|
||||
defer func() {
|
||||
_ = os.Unsetenv("HWS_HOST")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -42,8 +44,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Custom port", func(t *testing.T) {
|
||||
os.Setenv("HWS_PORT", "8080")
|
||||
defer os.Unsetenv("HWS_PORT")
|
||||
_ = os.Setenv("HWS_PORT", "8080")
|
||||
defer func() {
|
||||
_ = os.Unsetenv("HWS_PORT")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -51,8 +55,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("GZIP enabled", func(t *testing.T) {
|
||||
os.Setenv("HWS_GZIP", "true")
|
||||
defer os.Unsetenv("HWS_GZIP")
|
||||
_ = os.Setenv("HWS_GZIP", "true")
|
||||
defer func() {
|
||||
_ = os.Unsetenv("HWS_GZIP")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -60,12 +66,14 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Custom timeouts", func(t *testing.T) {
|
||||
os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
|
||||
os.Setenv("HWS_WRITE_TIMEOUT", "30")
|
||||
os.Setenv("HWS_IDLE_TIMEOUT", "300")
|
||||
defer os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
defer os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
defer os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
|
||||
_ = os.Setenv("HWS_WRITE_TIMEOUT", "30")
|
||||
_ = os.Setenv("HWS_IDLE_TIMEOUT", "300")
|
||||
defer func() {
|
||||
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -75,19 +83,19 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("All custom values", func(t *testing.T) {
|
||||
os.Setenv("HWS_HOST", "0.0.0.0")
|
||||
os.Setenv("HWS_PORT", "9000")
|
||||
os.Setenv("HWS_GZIP", "true")
|
||||
os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
|
||||
os.Setenv("HWS_WRITE_TIMEOUT", "15")
|
||||
os.Setenv("HWS_IDLE_TIMEOUT", "180")
|
||||
_ = os.Setenv("HWS_HOST", "0.0.0.0")
|
||||
_ = os.Setenv("HWS_PORT", "9000")
|
||||
_ = os.Setenv("HWS_GZIP", "true")
|
||||
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
|
||||
_ = os.Setenv("HWS_WRITE_TIMEOUT", "15")
|
||||
_ = os.Setenv("HWS_IDLE_TIMEOUT", "180")
|
||||
defer func() {
|
||||
os.Unsetenv("HWS_HOST")
|
||||
os.Unsetenv("HWS_PORT")
|
||||
os.Unsetenv("HWS_GZIP")
|
||||
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_HOST")
|
||||
_ = os.Unsetenv("HWS_PORT")
|
||||
_ = os.Unsetenv("HWS_GZIP")
|
||||
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Error to use with Server.ThrowError
|
||||
// HWSError wraps an error with other information for use with HWS features
|
||||
type HWSError struct {
|
||||
StatusCode int // HTTP Status code
|
||||
Message string // Error message
|
||||
@@ -41,7 +41,7 @@ type ErrorPage interface {
|
||||
}
|
||||
|
||||
// AddErrorPage registers a handler that returns an ErrorPage
|
||||
func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||
func (s *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError})
|
||||
@@ -56,7 +56,7 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||
return errors.New("Render method of the error page did not write anything to the response writer")
|
||||
}
|
||||
|
||||
server.errorPage = pageFunc
|
||||
s.errorPage = pageFunc
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -64,7 +64,19 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||
// the error with the level specified by the HWSError.
|
||||
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
|
||||
// and the request chain should be terminated.
|
||||
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) error {
|
||||
func (s *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) {
|
||||
err := s.throwError(w, r, error)
|
||||
if err != nil {
|
||||
s.LogError(error)
|
||||
s.LogError(HWSError{
|
||||
Message: "Error occured during throwError",
|
||||
Error: errors.Wrap(err, "s.throwError"),
|
||||
Level: ErrorERROR,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) throwError(w http.ResponseWriter, r *http.Request, error HWSError) error {
|
||||
if error.StatusCode <= 0 {
|
||||
return errors.New("HWSError.StatusCode cannot be 0.")
|
||||
}
|
||||
@@ -77,32 +89,27 @@ func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error H
|
||||
if r == nil {
|
||||
return errors.New("Request cannot be nil")
|
||||
}
|
||||
if !server.IsReady() {
|
||||
if !s.IsReady() {
|
||||
return errors.New("ThrowError called before server started")
|
||||
}
|
||||
w.WriteHeader(error.StatusCode)
|
||||
server.LogError(error)
|
||||
if server.errorPage == nil {
|
||||
server.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
|
||||
s.LogError(error)
|
||||
if s.errorPage == nil {
|
||||
s.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
|
||||
return nil
|
||||
}
|
||||
if error.RenderErrorPage {
|
||||
server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
|
||||
errPage, err := server.errorPage(error)
|
||||
s.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
|
||||
errPage, err := s.errorPage(error)
|
||||
if err != nil {
|
||||
server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
|
||||
s.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
|
||||
}
|
||||
err = errPage.Render(r.Context(), w)
|
||||
if err != nil {
|
||||
server.LogError(HWSError{Message: "Failed to render error page", Error: err})
|
||||
s.LogError(HWSError{Message: "Failed to render error page", Error: err})
|
||||
}
|
||||
} else {
|
||||
server.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
|
||||
s.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) ThrowFatal(w http.ResponseWriter, err error) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
server.LogFatal(err)
|
||||
}
|
||||
|
||||
@@ -14,22 +14,26 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type goodPage struct{}
|
||||
type badPage struct{}
|
||||
type (
|
||||
goodPage struct{}
|
||||
badPage struct{}
|
||||
)
|
||||
|
||||
func goodRender(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return goodPage{}, nil
|
||||
}
|
||||
|
||||
func badRender1(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return badPage{}, nil
|
||||
}
|
||||
|
||||
func badRender2(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return nil, errors.New("I'm an error")
|
||||
}
|
||||
|
||||
func (g goodPage) Render(ctx context.Context, w io.Writer) error {
|
||||
w.Write([]byte("Test write to ResponseWriter"))
|
||||
return nil
|
||||
_, err := w.Write([]byte("Test write to ResponseWriter"))
|
||||
return err
|
||||
}
|
||||
|
||||
func (b badPage) Render(ctx context.Context, w io.Writer) error {
|
||||
@@ -85,40 +89,42 @@ func Test_ThrowError(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
t.Run("Server not started", func(t *testing.T) {
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
buf.Reset()
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "Error",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
// ThrowError logs errors internally when validation fails
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "ThrowError called before server started")
|
||||
})
|
||||
|
||||
startTestServer(t, server)
|
||||
defer server.Shutdown(t.Context())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request *http.Request
|
||||
error hws.HWSError
|
||||
valid bool
|
||||
name string
|
||||
request *http.Request
|
||||
error hws.HWSError
|
||||
expectLogItem string
|
||||
}{
|
||||
{
|
||||
name: "No HWSError.Status code",
|
||||
request: nil,
|
||||
error: hws.HWSError{},
|
||||
valid: false,
|
||||
name: "No HWSError.Status code",
|
||||
request: nil,
|
||||
error: hws.HWSError{},
|
||||
expectLogItem: "HWSError.StatusCode cannot be 0",
|
||||
},
|
||||
{
|
||||
name: "Negative HWSError.Status code",
|
||||
request: nil,
|
||||
error: hws.HWSError{StatusCode: -1},
|
||||
valid: false,
|
||||
name: "Negative HWSError.Status code",
|
||||
request: nil,
|
||||
error: hws.HWSError{StatusCode: -1},
|
||||
expectLogItem: "HWSError.StatusCode cannot be 0",
|
||||
},
|
||||
{
|
||||
name: "No HWSError.Message",
|
||||
request: nil,
|
||||
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
|
||||
valid: false,
|
||||
name: "No HWSError.Message",
|
||||
request: nil,
|
||||
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
|
||||
expectLogItem: "HWSError.Message cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "No HWSError.Error",
|
||||
@@ -127,7 +133,7 @@ func Test_ThrowError(t *testing.T) {
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
},
|
||||
valid: false,
|
||||
expectLogItem: "HWSError.Error cannot be nil",
|
||||
},
|
||||
{
|
||||
name: "No request provided",
|
||||
@@ -137,7 +143,7 @@ func Test_ThrowError(t *testing.T) {
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
},
|
||||
valid: false,
|
||||
expectLogItem: "Request cannot be nil",
|
||||
},
|
||||
{
|
||||
name: "Valid",
|
||||
@@ -147,106 +153,92 @@ func Test_ThrowError(t *testing.T) {
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
},
|
||||
valid: true,
|
||||
expectLogItem: "An error occured",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf.Reset()
|
||||
rr := httptest.NewRecorder()
|
||||
err := server.ThrowError(rr, tt.request, tt.error)
|
||||
if tt.valid {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
t.Log(err)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
server.ThrowError(rr, tt.request, tt.error)
|
||||
// ThrowError no longer returns errors; check logs instead
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, tt.expectLogItem)
|
||||
})
|
||||
}
|
||||
t.Run("Log level set correctly", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
Level: hws.ErrorWARN,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
_, err = buf.ReadString([]byte(" ")[0])
|
||||
_, err := buf.ReadString([]byte(" ")[0])
|
||||
require.NoError(t, err)
|
||||
loglvl, err := buf.ReadString([]byte(" ")[0])
|
||||
assert.NoError(t, err)
|
||||
if loglvl != "\x1b[33mWRN\x1b[0m " {
|
||||
err = errors.New("Log level not set correctly")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "\x1b[33mWRN\x1b[0m ", loglvl, "Log level should be WRN for ErrorWARN")
|
||||
|
||||
buf.Reset()
|
||||
err = server.ThrowError(rr, req, hws.HWSError{
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
_, err = buf.ReadString([]byte(" ")[0])
|
||||
require.NoError(t, err)
|
||||
loglvl, err = buf.ReadString([]byte(" ")[0])
|
||||
assert.NoError(t, err)
|
||||
if loglvl != "\x1b[31mERR\x1b[0m " {
|
||||
err = errors.New("Log level not set correctly")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "\x1b[31mERR\x1b[0m ", loglvl, "Log level should be ERR when no level specified")
|
||||
})
|
||||
|
||||
t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
|
||||
// Must be run before adding the error page to the test server
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
body := rr.Body.String()
|
||||
if body != "" {
|
||||
assert.Error(t, nil)
|
||||
}
|
||||
assert.Empty(t, body, "Error page should not render when no error page is set")
|
||||
})
|
||||
t.Run("Error page renders", func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
// Adding the error page will carry over to all future tests and cant be undone
|
||||
server.AddErrorPage(goodRender)
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
err := server.AddErrorPage(goodRender)
|
||||
require.NoError(t, err)
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
body := rr.Body.String()
|
||||
if body == "" {
|
||||
assert.Error(t, nil)
|
||||
}
|
||||
assert.NotEmpty(t, body, "Error page should render when RenderErrorPage is true")
|
||||
})
|
||||
t.Run("Error page doesnt render if no told to render", func(t *testing.T) {
|
||||
t.Run("Error page doesnt render if not told to render", func(t *testing.T) {
|
||||
// Error page already added to server
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
body := rr.Body.String()
|
||||
if body != "" {
|
||||
assert.Error(t, nil)
|
||||
}
|
||||
assert.Empty(t, body, "Error page should not render when RenderErrorPage is false")
|
||||
})
|
||||
server.Shutdown(t.Context())
|
||||
err := server.Shutdown(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Doesn't error if no logger added to server", func(t *testing.T) {
|
||||
t.Run("Doesn't panic if no logger added to server", func(t *testing.T) {
|
||||
server, err := hws.NewServer(&hws.Config{
|
||||
Host: "127.0.0.1",
|
||||
Port: randomPort(),
|
||||
@@ -261,13 +253,18 @@ func Test_ThrowError(t *testing.T) {
|
||||
err = server.Start(t.Context())
|
||||
require.NoError(t, err)
|
||||
<-server.Ready()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
err = server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
// Should not panic when no logger is present
|
||||
assert.NotPanics(t, func() {
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
}, "ThrowError should not panic when no logger is present")
|
||||
err = server.Shutdown(t.Context())
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string {
|
||||
}
|
||||
|
||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||
return func() (any, error) {
|
||||
return ConfigFromEnv()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,23 +43,12 @@ func (s *Server) LogError(err HWSError) {
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) LogFatal(err error) {
|
||||
if err == nil {
|
||||
err = errors.New("LogFatal was called with a nil error")
|
||||
}
|
||||
if server.logger == nil {
|
||||
fmt.Printf("FATAL - %s: %s", "A fatal error has occured", err.Error())
|
||||
return
|
||||
}
|
||||
server.logger.logger.Fatal().Err(err).Msg("A fatal error has occured")
|
||||
}
|
||||
|
||||
// AddLogger adds a logger to the server to use for request logging.
|
||||
func (server *Server) AddLogger(hlogger *hlog.Logger) error {
|
||||
func (s *Server) AddLogger(hlogger *hlog.Logger) error {
|
||||
if hlogger == nil {
|
||||
return errors.New("unable to add logger, no logger provided")
|
||||
}
|
||||
server.logger = &logger{
|
||||
s.logger = &logger{
|
||||
logger: hlogger,
|
||||
}
|
||||
return nil
|
||||
@@ -68,7 +57,7 @@ func (server *Server) AddLogger(hlogger *hlog.Logger) error {
|
||||
// LoggerIgnorePaths sets a list of URL paths to ignore logging for.
|
||||
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
|
||||
// Useful for ignoring requests to CSS files or favicons
|
||||
func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
||||
func (s *Server) LoggerIgnorePaths(paths ...string) error {
|
||||
for _, path := range paths {
|
||||
u, err := url.Parse(path)
|
||||
valid := err == nil &&
|
||||
@@ -80,7 +69,7 @@ func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
||||
return fmt.Errorf("invalid path: '%s'", path)
|
||||
}
|
||||
}
|
||||
server.logger.ignoredPaths = prepareGlobs(paths)
|
||||
s.logger.ignoredPaths = prepareGlobs(paths)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -169,7 +169,7 @@ func Test_LogFatal(t *testing.T) {
|
||||
// Note: We cannot actually test Fatal() as it calls os.Exit()
|
||||
// Testing this would require subprocess testing which is overly complex
|
||||
// These tests document the expected behavior and verify the function signatures exist
|
||||
|
||||
|
||||
t.Run("LogFatal with nil logger prints to stdout", func(t *testing.T) {
|
||||
_, err := hws.NewServer(&hws.Config{
|
||||
Host: "127.0.0.1",
|
||||
@@ -197,7 +197,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
||||
|
||||
err := server.LoggerIgnorePaths("http://example.com/path")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid path")
|
||||
assert.Contains(t, err.Error(), "invalid path")
|
||||
})
|
||||
|
||||
t.Run("Invalid path with host", func(t *testing.T) {
|
||||
@@ -207,7 +207,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
||||
err := server.LoggerIgnorePaths("//example.com/path")
|
||||
assert.Error(t, err)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "Invalid path")
|
||||
assert.Contains(t, err.Error(), "invalid path")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -217,7 +217,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
||||
|
||||
err := server.LoggerIgnorePaths("/path?query=value")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid path")
|
||||
assert.Contains(t, err.Error(), "invalid path")
|
||||
})
|
||||
|
||||
t.Run("Invalid path with fragment", func(t *testing.T) {
|
||||
@@ -226,7 +226,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
||||
|
||||
err := server.LoggerIgnorePaths("/path#fragment")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid path")
|
||||
assert.Contains(t, err.Error(), "invalid path")
|
||||
})
|
||||
|
||||
t.Run("Valid paths", func(t *testing.T) {
|
||||
|
||||
@@ -5,35 +5,37 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Middleware func(h http.Handler) http.Handler
|
||||
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
||||
type (
|
||||
Middleware func(h http.Handler) http.Handler
|
||||
MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
||||
)
|
||||
|
||||
// Server.AddMiddleware registers all the middleware.
|
||||
// AddMiddleware registers all the middleware.
|
||||
// Middleware will be run in the order that they are provided.
|
||||
// Can only be called once
|
||||
func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
||||
if !server.routes {
|
||||
func (s *Server) AddMiddleware(middleware ...Middleware) error {
|
||||
if !s.routes {
|
||||
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
|
||||
}
|
||||
if server.middleware {
|
||||
if s.middleware {
|
||||
return errors.New("Server.AddMiddleware already called")
|
||||
}
|
||||
// RUN LOGGING MIDDLEWARE FIRST
|
||||
server.server.Handler = logging(server.server.Handler, server.logger)
|
||||
s.server.Handler = logging(s.server.Handler, s.logger)
|
||||
|
||||
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
|
||||
for i := len(middleware); i > 0; i-- {
|
||||
server.server.Handler = middleware[i-1](server.server.Handler)
|
||||
s.server.Handler = middleware[i-1](s.server.Handler)
|
||||
}
|
||||
|
||||
// RUN GZIP
|
||||
if server.GZIP {
|
||||
server.server.Handler = addgzip(server.server.Handler)
|
||||
if s.GZIP {
|
||||
s.server.Handler = addgzip(s.server.Handler)
|
||||
}
|
||||
// RUN TIMER MIDDLEWARE LAST
|
||||
server.server.Handler = startTimer(server.server.Handler)
|
||||
s.server.Handler = startTimer(s.server.Handler)
|
||||
|
||||
server.middleware = true
|
||||
s.middleware = true
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -43,14 +45,14 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
||||
// and returns a new request and optional HWSError.
|
||||
// If a HWSError is returned, server.ThrowError will be called.
|
||||
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
|
||||
func (server *Server) NewMiddleware(
|
||||
func (s *Server) NewMiddleware(
|
||||
middlewareFunc MiddlewareFunc,
|
||||
) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
newReq, herr := middlewareFunc(w, r)
|
||||
if herr != nil {
|
||||
server.ThrowError(w, r, *herr)
|
||||
s.ThrowError(w, r, *herr)
|
||||
if herr.RenderErrorPage {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -18,16 +18,24 @@ func startTimer(next http.Handler) http.Handler {
|
||||
)
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return "hws context key " + string(c)
|
||||
}
|
||||
|
||||
var requestTimerCtxKey = contextKey("request-timer")
|
||||
|
||||
// Set the start time of the request
|
||||
func setStart(ctx context.Context, time time.Time) context.Context {
|
||||
return context.WithValue(ctx, "hws context key request-timer", time)
|
||||
return context.WithValue(ctx, requestTimerCtxKey, time)
|
||||
}
|
||||
|
||||
// Get the start time of the request
|
||||
func getStartTime(ctx context.Context) (time.Time, error) {
|
||||
start, ok := ctx.Value("hws context key request-timer").(time.Time)
|
||||
start, ok := ctx.Value(requestTimerCtxKey).(time.Time)
|
||||
if !ok {
|
||||
return time.Time{}, errors.New("Failed to get start time of request")
|
||||
return time.Time{}, errors.New("failed to get start time of request")
|
||||
}
|
||||
return start, nil
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func (s *Server) NotifySub(nt notify.Notification) {
|
||||
}
|
||||
_, exists := s.notifier.clients.getClient(nt.Target)
|
||||
if !exists {
|
||||
err := fmt.Errorf("Tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
|
||||
err := fmt.Errorf("tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
|
||||
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
||||
return
|
||||
}
|
||||
@@ -119,7 +119,7 @@ func (s *Server) NotifyID(nt notify.Notification, altID string) {
|
||||
clients, exists := s.notifier.clients.clientsIDMap[altID]
|
||||
s.notifier.clients.lock.RUnlock()
|
||||
if !exists {
|
||||
err := fmt.Errorf("Tried to notify client group that doesn't exist - altID: %s", altID)
|
||||
err := fmt.Errorf("tried to notify client group that doesn't exist - altID: %s", altID)
|
||||
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,8 +15,9 @@ func newTestServerWithNotifier(t *testing.T) *Server {
|
||||
t.Helper()
|
||||
|
||||
cfg := &Config{
|
||||
Host: "127.0.0.1",
|
||||
Port: 0,
|
||||
Host: "127.0.0.1",
|
||||
Port: 0,
|
||||
ShutdownDelay: 0, // No delay for tests
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg)
|
||||
@@ -359,7 +360,7 @@ func Test_ActiveClientStaysAlive(t *testing.T) {
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
<-ticker.C
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
@@ -460,7 +461,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
|
||||
defer close(stop)
|
||||
|
||||
// Send 10 notifications quickly (buffer is 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
Message: "Burst message",
|
||||
@@ -468,7 +469,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
|
||||
}
|
||||
|
||||
// Client should receive all 10
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
select {
|
||||
case <-notifications:
|
||||
// Received
|
||||
@@ -487,7 +488,7 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
|
||||
defer close(stop)
|
||||
|
||||
// Fill buffer completely (buffer is 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
Message: "Fill buffer",
|
||||
@@ -500,15 +501,15 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
|
||||
Message: "Timeout message",
|
||||
})
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(6 * time.Second)
|
||||
// Wait for timeout (5s timeout + small buffer)
|
||||
time.Sleep(5100 * time.Millisecond)
|
||||
|
||||
// Check failure count (should be 1)
|
||||
fails := atomic.LoadInt32(&client.consecutiveFails)
|
||||
require.Equal(t, int32(1), fails, "Should have 1 timeout")
|
||||
|
||||
// Now read all buffered messages
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
<-notifications
|
||||
}
|
||||
|
||||
@@ -538,7 +539,7 @@ func Test_ConsecutiveFailureDisconnect(t *testing.T) {
|
||||
defer close(stop)
|
||||
|
||||
// Fill buffer and never read to cause 5 consecutive timeouts
|
||||
for i := 0; i < 20; i++ {
|
||||
for range 20 {
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
Message: "Timeout message",
|
||||
@@ -684,7 +685,7 @@ func Test_ConcurrentSubscriptions(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
clients := make([]*Client, 100)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
@@ -716,7 +717,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
|
||||
messageCount := 50
|
||||
|
||||
// Send from multiple goroutines
|
||||
for i := 0; i < messageCount; i++ {
|
||||
for i := range messageCount {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
@@ -733,7 +734,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
|
||||
// This is expected behavior - we're testing thread safety, not guaranteed delivery
|
||||
// Just verify we receive at least some messages without panicking or deadlocking
|
||||
received := 0
|
||||
timeout := time.After(2 * time.Second)
|
||||
timeout := time.After(500 * time.Millisecond)
|
||||
for received < messageCount {
|
||||
select {
|
||||
case <-notifications:
|
||||
@@ -751,7 +752,7 @@ func Test_ConcurrentCleanup(t *testing.T) {
|
||||
server := newTestServerWithNotifier(t)
|
||||
|
||||
// Create some clients
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
client, _ := server.GetClient("", "")
|
||||
// Set some to be old
|
||||
if i%2 == 0 {
|
||||
@@ -790,39 +791,34 @@ func Test_NoRaceConditions(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Create a few clients and read from them
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 5 {
|
||||
wg.Go(func() {
|
||||
client, _ := server.GetClient("", "")
|
||||
notifications, stop := client.Listen()
|
||||
defer close(stop)
|
||||
|
||||
// Actively read messages
|
||||
timeout := time.After(2 * time.Second)
|
||||
timeout := time.After(200 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-notifications:
|
||||
// Keep reading
|
||||
// Keep reading
|
||||
case <-timeout:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// Send a few notifications
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 20; j++ {
|
||||
wg.Go(func() {
|
||||
for range 10 {
|
||||
server.NotifyAll(notify.Notification{
|
||||
Message: "Stress test",
|
||||
})
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -948,7 +944,7 @@ func Test_ListenSignature(t *testing.T) {
|
||||
require.NotNil(t, stop)
|
||||
|
||||
// notifications should be receive-only
|
||||
_, ok := interface{}(notifications).(<-chan notify.Notification)
|
||||
_, ok := any(notifications).(<-chan notify.Notification)
|
||||
require.True(t, ok, "notifications should be receive-only channel")
|
||||
|
||||
// stop should be closeable
|
||||
@@ -964,7 +960,7 @@ func Test_BufferSize(t *testing.T) {
|
||||
defer close(stop)
|
||||
|
||||
// Send 10 messages without reading (buffer size is 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
Message: "Buffered",
|
||||
@@ -975,7 +971,7 @@ func Test_BufferSize(t *testing.T) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Read all 10
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
select {
|
||||
case <-notifications:
|
||||
// Success
|
||||
|
||||
@@ -30,13 +30,13 @@ const (
|
||||
MethodPATCH Method = "PATCH"
|
||||
)
|
||||
|
||||
// Server.AddRoutes registers the page handlers for the server.
|
||||
// AddRoutes registers the page handlers for the server.
|
||||
// At least one route must be provided.
|
||||
// If any route patterns (path + method) are defined multiple times, the first
|
||||
// instance will be added and any additional conflicts will be discarded.
|
||||
func (server *Server) AddRoutes(routes ...Route) error {
|
||||
func (s *Server) AddRoutes(routes ...Route) error {
|
||||
if len(routes) == 0 {
|
||||
return errors.New("No routes provided")
|
||||
return errors.New("no routes provided")
|
||||
}
|
||||
patterns := []string{}
|
||||
mux := http.NewServeMux()
|
||||
@@ -47,10 +47,10 @@ func (server *Server) AddRoutes(routes ...Route) error {
|
||||
}
|
||||
for _, method := range route.Methods {
|
||||
if !validMethod(method) {
|
||||
return fmt.Errorf("Invalid method %s for path %s", method, route.Path)
|
||||
return fmt.Errorf("invalid method %s for path %s", method, route.Path)
|
||||
}
|
||||
if route.Handler == nil {
|
||||
return fmt.Errorf("No handler provided for %s %s", method, route.Path)
|
||||
return fmt.Errorf("no handler provided for %s %s", method, route.Path)
|
||||
}
|
||||
pattern := fmt.Sprintf("%s %s", method, route.Path)
|
||||
if slices.Contains(patterns, pattern) {
|
||||
@@ -61,8 +61,8 @@ func (server *Server) AddRoutes(routes ...Route) error {
|
||||
}
|
||||
}
|
||||
|
||||
server.server.Handler = mux
|
||||
server.routes = true
|
||||
s.server.Handler = mux
|
||||
s.routes = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func Test_AddRoutes(t *testing.T) {
|
||||
server := createTestServer(t, &buf)
|
||||
err := server.AddRoutes()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "No routes provided")
|
||||
assert.Contains(t, err.Error(), "no routes provided")
|
||||
})
|
||||
|
||||
t.Run("Single valid route", func(t *testing.T) {
|
||||
@@ -58,7 +58,7 @@ func Test_AddRoutes(t *testing.T) {
|
||||
Handler: handler,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid method")
|
||||
assert.Contains(t, err.Error(), "invalid method")
|
||||
})
|
||||
|
||||
t.Run("No handler provided", func(t *testing.T) {
|
||||
@@ -69,7 +69,7 @@ func Test_AddRoutes(t *testing.T) {
|
||||
Handler: nil,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "No handler provided")
|
||||
assert.Contains(t, err.Error(), "no handler provided")
|
||||
})
|
||||
|
||||
t.Run("All HTTP methods are valid", func(t *testing.T) {
|
||||
@@ -203,7 +203,7 @@ func Test_AddRoutes_MultipleMethods(t *testing.T) {
|
||||
Handler: handler,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid method")
|
||||
assert.Contains(t, err.Error(), "invalid method")
|
||||
})
|
||||
|
||||
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {
|
||||
|
||||
@@ -26,14 +26,14 @@ type Server struct {
|
||||
}
|
||||
|
||||
// Ready returns a channel that is closed when the server is started
|
||||
func (server *Server) Ready() <-chan struct{} {
|
||||
return server.ready
|
||||
func (s *Server) Ready() <-chan struct{} {
|
||||
return s.ready
|
||||
}
|
||||
|
||||
// IsReady checks if the server is running
|
||||
func (server *Server) IsReady() bool {
|
||||
func (s *Server) IsReady() bool {
|
||||
select {
|
||||
case <-server.ready:
|
||||
case <-s.ready:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -41,13 +41,13 @@ func (server *Server) IsReady() bool {
|
||||
}
|
||||
|
||||
// Addr returns the server's network address
|
||||
func (server *Server) Addr() string {
|
||||
return server.server.Addr
|
||||
func (s *Server) Addr() string {
|
||||
return s.server.Addr
|
||||
}
|
||||
|
||||
// Handler returns the server's HTTP handler for testing purposes
|
||||
func (server *Server) Handler() http.Handler {
|
||||
return server.server.Handler
|
||||
func (s *Server) Handler() http.Handler {
|
||||
return s.server.Handler
|
||||
}
|
||||
|
||||
// NewServer returns a new hws.Server with the specified configuration.
|
||||
@@ -75,7 +75,7 @@ func NewServer(config *Config) (*Server, error) {
|
||||
|
||||
valid := isValidHostname(config.Host)
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host)
|
||||
return nil, fmt.Errorf("hostname '%s' is not valid", config.Host)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
@@ -95,62 +95,64 @@ func NewServer(config *Config) (*Server, error) {
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (server *Server) Start(ctx context.Context) error {
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("Context cannot be nil")
|
||||
}
|
||||
if !server.routes {
|
||||
if !s.routes {
|
||||
return errors.New("Server.AddRoutes must be run before starting the server")
|
||||
}
|
||||
if !server.middleware {
|
||||
err := server.AddMiddleware()
|
||||
if !s.middleware {
|
||||
err := s.AddMiddleware()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "server.AddMiddleware")
|
||||
}
|
||||
}
|
||||
|
||||
server.startNotifier()
|
||||
s.startNotifier()
|
||||
|
||||
go func() {
|
||||
if server.logger == nil {
|
||||
fmt.Printf("Listening for requests on %s", server.server.Addr)
|
||||
if s.logger == nil {
|
||||
fmt.Printf("Listening for requests on %s", s.server.Addr)
|
||||
} else {
|
||||
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
|
||||
s.logger.logger.Info().Str("address", s.server.Addr).Msg("Listening for requests")
|
||||
}
|
||||
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
if server.logger == nil {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
if s.logger == nil {
|
||||
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
||||
} else {
|
||||
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
||||
s.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
server.waitUntilReady(ctx)
|
||||
s.waitUntilReady(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) Shutdown(ctx context.Context) error {
|
||||
server.logger.logger.Debug().Dur("shutdown_delay", server.shutdowndelay).Msg("HWS Server shutting down")
|
||||
server.NotifyAll(notify.Notification{
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if s.logger != nil {
|
||||
s.logger.logger.Debug().Dur("shutdown_delay", s.shutdowndelay).Msg("HWS Server shutting down")
|
||||
}
|
||||
s.NotifyAll(notify.Notification{
|
||||
Title: "Shutting down",
|
||||
Message: fmt.Sprintf("Server is shutting down in %v", server.shutdowndelay),
|
||||
Message: fmt.Sprintf("Server is shutting down in %v", s.shutdowndelay),
|
||||
Level: LevelShutdown,
|
||||
})
|
||||
<-time.NewTimer(server.shutdowndelay).C
|
||||
if !server.IsReady() {
|
||||
<-time.NewTimer(s.shutdowndelay).C
|
||||
if !s.IsReady() {
|
||||
return errors.New("Server isn't running")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("Context cannot be nil")
|
||||
}
|
||||
err := server.server.Shutdown(ctx)
|
||||
err := s.server.Shutdown(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to shutdown the server gracefully")
|
||||
}
|
||||
server.closeNotifier()
|
||||
server.ready = make(chan struct{})
|
||||
s.closeNotifier()
|
||||
s.ready = make(chan struct{})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -168,7 +170,7 @@ func isValidHostname(host string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (server *Server) waitUntilReady(ctx context.Context) error {
|
||||
func (s *Server) waitUntilReady(ctx context.Context) error {
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -180,14 +182,14 @@ func (server *Server) waitUntilReady(ctx context.Context) error {
|
||||
return ctx.Err()
|
||||
|
||||
case <-ticker.C:
|
||||
resp, err := http.Get("http://" + server.server.Addr + "/healthz")
|
||||
resp, err := http.Get("http://" + s.server.Addr + "/healthz")
|
||||
if err != nil {
|
||||
continue // not accepting yet
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
closeOnce.Do(func() { close(server.ready) })
|
||||
closeOnce.Do(func() { close(s.ready) })
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,9 @@ func randomPort() uint64 {
|
||||
|
||||
func createTestServer(t *testing.T, w io.Writer) *hws.Server {
|
||||
server, err := hws.NewServer(&hws.Config{
|
||||
Host: "127.0.0.1",
|
||||
Port: randomPort(),
|
||||
Host: "127.0.0.1",
|
||||
Port: randomPort(),
|
||||
ShutdownDelay: 0, // No delay for tests
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string {
|
||||
}
|
||||
|
||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||
return func() (any, error) {
|
||||
return ConfigFromEnv()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,13 +6,15 @@ require (
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||
git.haelnorr.com/h/golib/env v0.9.1
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4
|
||||
git.haelnorr.com/h/golib/hws v0.3.0
|
||||
git.haelnorr.com/h/golib/hws v0.5.0
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
)
|
||||
|
||||
require git.haelnorr.com/h/golib/notify v0.1.0 // indirect
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
|
||||
@@ -4,10 +4,12 @@ git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjo
|
||||
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
|
||||
git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo=
|
||||
git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
|
||||
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
|
||||
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
|
||||
@@ -23,8 +23,7 @@ func (tm TestModel) GetID() int {
|
||||
return tm.ID
|
||||
}
|
||||
|
||||
type TestTransaction struct {
|
||||
}
|
||||
type TestTransaction struct{}
|
||||
|
||||
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return nil, nil
|
||||
@@ -137,8 +136,10 @@ func TestCurrentModel(t *testing.T) {
|
||||
func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
|
||||
// Clear environment variables
|
||||
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
|
||||
os.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
||||
_ = os.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
defer func() {
|
||||
_ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
||||
}()
|
||||
|
||||
_, err := ConfigFromEnv()
|
||||
assert.Error(t, err)
|
||||
@@ -327,7 +328,9 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
|
||||
|
||||
db, _, err := createMockDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
@@ -409,7 +412,9 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
|
||||
|
||||
db, _, err := createMockDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
@@ -454,7 +459,9 @@ func TestLogin_BasicFunctionality(t *testing.T) {
|
||||
|
||||
db, _, err := createMockDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
@@ -476,6 +483,7 @@ func TestLogin_BasicFunctionality(t *testing.T) {
|
||||
// This test mainly checks that the function doesn't panic and has right call signature
|
||||
// The actual JWT functionality is tested in jwt package itself
|
||||
assert.NotPanics(t, func() {
|
||||
auth.Login(w, r, user, rememberMe)
|
||||
err := auth.Login(w, r, user, rememberMe)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ func (auth *Authenticator[T, TX]) IgnorePaths(paths ...string) error {
|
||||
u.RawQuery == "" &&
|
||||
u.Fragment == ""
|
||||
if !valid {
|
||||
return fmt.Errorf("Invalid path: '%s'", path)
|
||||
return fmt.Errorf("invalid path: '%s'", path)
|
||||
}
|
||||
}
|
||||
auth.ignoredPaths = prepareGlobs(paths)
|
||||
|
||||
@@ -38,7 +38,9 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
|
||||
Error: errors.Wrap(err, "auth.beginTx"),
|
||||
}
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
// Type assert to TX - safe because user's beginTx should return their TX type
|
||||
txTyped, ok := tx.(TX)
|
||||
if !ok {
|
||||
@@ -64,7 +66,14 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
|
||||
Msg("Failed to authenticate user")
|
||||
return r, nil
|
||||
}
|
||||
tx.Commit()
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, &hws.HWSError{
|
||||
Message: "Failed to commit transaction",
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: errors.Wrap(err, "tx.Commit"),
|
||||
}
|
||||
}
|
||||
authContext := setAuthenticatedModel(r.Context(), model)
|
||||
newReq := r.WithContext(authContext)
|
||||
return newReq, nil
|
||||
|
||||
@@ -39,9 +39,17 @@ type ContextLoader[T Model] func(ctx context.Context) T
|
||||
// }
|
||||
type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error)
|
||||
|
||||
type contextKey string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return "hwsauth context key" + string(c)
|
||||
}
|
||||
|
||||
var authenticatedModelContextKey = contextKey("authenticated-model")
|
||||
|
||||
// Return a new context with the user added in
|
||||
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
|
||||
return context.WithValue(ctx, "hwsauth context key authenticated-model", m)
|
||||
return context.WithValue(ctx, authenticatedModelContextKey, m)
|
||||
}
|
||||
|
||||
// Retrieve a user from the given context. Returns nil if not set
|
||||
@@ -53,7 +61,7 @@ func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[
|
||||
model = authenticatedModel[T]{}
|
||||
}
|
||||
}()
|
||||
model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T])
|
||||
model, cok := ctx.Value(authenticatedModelContextKey).(authenticatedModel[T])
|
||||
if !cok {
|
||||
return authenticatedModel[T]{}, false
|
||||
}
|
||||
|
||||
@@ -19,15 +19,12 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := getAuthorizedModel[T](r.Context())
|
||||
if !ok {
|
||||
err := auth.server.ThrowError(w, r, hws.HWSError{
|
||||
auth.server.ThrowError(w, r, hws.HWSError{
|
||||
Error: errors.New("Login required"),
|
||||
Message: "Please login to view this page",
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
auth.server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -66,15 +63,12 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
model, ok := getAuthorizedModel[T](r.Context())
|
||||
if !ok {
|
||||
err := auth.server.ThrowError(w, r, hws.HWSError{
|
||||
auth.server.ThrowError(w, r, hws.HWSError{
|
||||
Error: errors.New("Login required"),
|
||||
Message: "Please login to view this page",
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
auth.server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
isFresh := time.Now().Before(time.Unix(model.fresh, 0))
|
||||
|
||||
Reference in New Issue
Block a user