Compare commits

...

4 Commits

Author SHA1 Message Date
05be28d7f3 fixed fatal bug after access token expires 2026-02-07 17:58:02 +11:00
8f7c87cef2 added extracheck to hwsauth 2026-02-07 16:42:08 +11:00
525b3b1396 updated to use new hws version 2026-02-03 19:11:59 +11:00
563908bbb4 updated hws.ThrowError to not return an error and log it to console instead
fixed errors_test

fixed tests
2026-02-03 18:43:31 +11:00
24 changed files with 356 additions and 289 deletions

View File

@@ -13,12 +13,12 @@ import (
func Test_ConfigFromEnv(t *testing.T) {
t.Run("Default values when no env vars set", func(t *testing.T) {
// Clear any existing env vars
os.Unsetenv("HWS_HOST")
os.Unsetenv("HWS_PORT")
os.Unsetenv("HWS_GZIP")
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
os.Unsetenv("HWS_WRITE_TIMEOUT")
os.Unsetenv("HWS_IDLE_TIMEOUT")
_ = os.Unsetenv("HWS_HOST")
_ = os.Unsetenv("HWS_PORT")
_ = os.Unsetenv("HWS_GZIP")
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
@@ -33,8 +33,10 @@ func Test_ConfigFromEnv(t *testing.T) {
})
t.Run("Custom host", func(t *testing.T) {
os.Setenv("HWS_HOST", "192.168.1.1")
defer os.Unsetenv("HWS_HOST")
_ = os.Setenv("HWS_HOST", "192.168.1.1")
defer func() {
_ = os.Unsetenv("HWS_HOST")
}()
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
@@ -42,8 +44,10 @@ func Test_ConfigFromEnv(t *testing.T) {
})
t.Run("Custom port", func(t *testing.T) {
os.Setenv("HWS_PORT", "8080")
defer os.Unsetenv("HWS_PORT")
_ = os.Setenv("HWS_PORT", "8080")
defer func() {
_ = os.Unsetenv("HWS_PORT")
}()
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
@@ -51,8 +55,10 @@ func Test_ConfigFromEnv(t *testing.T) {
})
t.Run("GZIP enabled", func(t *testing.T) {
os.Setenv("HWS_GZIP", "true")
defer os.Unsetenv("HWS_GZIP")
_ = os.Setenv("HWS_GZIP", "true")
defer func() {
_ = os.Unsetenv("HWS_GZIP")
}()
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
@@ -60,12 +66,14 @@ func Test_ConfigFromEnv(t *testing.T) {
})
t.Run("Custom timeouts", func(t *testing.T) {
os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
os.Setenv("HWS_WRITE_TIMEOUT", "30")
os.Setenv("HWS_IDLE_TIMEOUT", "300")
defer os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
defer os.Unsetenv("HWS_WRITE_TIMEOUT")
defer os.Unsetenv("HWS_IDLE_TIMEOUT")
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
_ = os.Setenv("HWS_WRITE_TIMEOUT", "30")
_ = os.Setenv("HWS_IDLE_TIMEOUT", "300")
defer func() {
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
}()
config, err := hws.ConfigFromEnv()
require.NoError(t, err)
@@ -75,19 +83,19 @@ func Test_ConfigFromEnv(t *testing.T) {
})
t.Run("All custom values", func(t *testing.T) {
os.Setenv("HWS_HOST", "0.0.0.0")
os.Setenv("HWS_PORT", "9000")
os.Setenv("HWS_GZIP", "true")
os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
os.Setenv("HWS_WRITE_TIMEOUT", "15")
os.Setenv("HWS_IDLE_TIMEOUT", "180")
_ = os.Setenv("HWS_HOST", "0.0.0.0")
_ = os.Setenv("HWS_PORT", "9000")
_ = os.Setenv("HWS_GZIP", "true")
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
_ = os.Setenv("HWS_WRITE_TIMEOUT", "15")
_ = os.Setenv("HWS_IDLE_TIMEOUT", "180")
defer func() {
os.Unsetenv("HWS_HOST")
os.Unsetenv("HWS_PORT")
os.Unsetenv("HWS_GZIP")
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
os.Unsetenv("HWS_WRITE_TIMEOUT")
os.Unsetenv("HWS_IDLE_TIMEOUT")
_ = os.Unsetenv("HWS_HOST")
_ = os.Unsetenv("HWS_PORT")
_ = os.Unsetenv("HWS_GZIP")
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
}()
config, err := hws.ConfigFromEnv()

View File

@@ -9,7 +9,7 @@ import (
"github.com/pkg/errors"
)
// Error to use with Server.ThrowError
// HWSError wraps an error with other information for use with HWS features
type HWSError struct {
StatusCode int // HTTP Status code
Message string // Error message
@@ -41,7 +41,7 @@ type ErrorPage interface {
}
// AddErrorPage registers a handler that returns an ErrorPage
func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
func (s *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError})
@@ -56,7 +56,7 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
return errors.New("Render method of the error page did not write anything to the response writer")
}
server.errorPage = pageFunc
s.errorPage = pageFunc
return nil
}
@@ -64,7 +64,19 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
// the error with the level specified by the HWSError.
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
// and the request chain should be terminated.
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) error {
func (s *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) {
err := s.throwError(w, r, error)
if err != nil {
s.LogError(error)
s.LogError(HWSError{
Message: "Error occured during throwError",
Error: errors.Wrap(err, "s.throwError"),
Level: ErrorERROR,
})
}
}
func (s *Server) throwError(w http.ResponseWriter, r *http.Request, error HWSError) error {
if error.StatusCode <= 0 {
return errors.New("HWSError.StatusCode cannot be 0.")
}
@@ -77,32 +89,27 @@ func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error H
if r == nil {
return errors.New("Request cannot be nil")
}
if !server.IsReady() {
if !s.IsReady() {
return errors.New("ThrowError called before server started")
}
w.WriteHeader(error.StatusCode)
server.LogError(error)
if server.errorPage == nil {
server.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
s.LogError(error)
if s.errorPage == nil {
s.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
return nil
}
if error.RenderErrorPage {
server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
errPage, err := server.errorPage(error)
s.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
errPage, err := s.errorPage(error)
if err != nil {
server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
s.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
}
err = errPage.Render(r.Context(), w)
if err != nil {
server.LogError(HWSError{Message: "Failed to render error page", Error: err})
s.LogError(HWSError{Message: "Failed to render error page", Error: err})
}
} else {
server.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
s.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
}
return nil
}
func (server *Server) ThrowFatal(w http.ResponseWriter, err error) {
w.WriteHeader(http.StatusInternalServerError)
server.LogFatal(err)
}

View File

@@ -14,22 +14,26 @@ import (
"github.com/stretchr/testify/require"
)
type goodPage struct{}
type badPage struct{}
type (
goodPage struct{}
badPage struct{}
)
func goodRender(error hws.HWSError) (hws.ErrorPage, error) {
return goodPage{}, nil
}
func badRender1(error hws.HWSError) (hws.ErrorPage, error) {
return badPage{}, nil
}
func badRender2(error hws.HWSError) (hws.ErrorPage, error) {
return nil, errors.New("I'm an error")
}
func (g goodPage) Render(ctx context.Context, w io.Writer) error {
w.Write([]byte("Test write to ResponseWriter"))
return nil
_, err := w.Write([]byte("Test write to ResponseWriter"))
return err
}
func (b badPage) Render(ctx context.Context, w io.Writer) error {
@@ -85,40 +89,42 @@ func Test_ThrowError(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
t.Run("Server not started", func(t *testing.T) {
err := server.ThrowError(rr, req, hws.HWSError{
buf.Reset()
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "Error",
Error: errors.New("Error"),
})
assert.Error(t, err)
// ThrowError logs errors internally when validation fails
output := buf.String()
assert.Contains(t, output, "ThrowError called before server started")
})
startTestServer(t, server)
defer server.Shutdown(t.Context())
tests := []struct {
name string
request *http.Request
error hws.HWSError
valid bool
name string
request *http.Request
error hws.HWSError
expectLogItem string
}{
{
name: "No HWSError.Status code",
request: nil,
error: hws.HWSError{},
valid: false,
name: "No HWSError.Status code",
request: nil,
error: hws.HWSError{},
expectLogItem: "HWSError.StatusCode cannot be 0",
},
{
name: "Negative HWSError.Status code",
request: nil,
error: hws.HWSError{StatusCode: -1},
valid: false,
name: "Negative HWSError.Status code",
request: nil,
error: hws.HWSError{StatusCode: -1},
expectLogItem: "HWSError.StatusCode cannot be 0",
},
{
name: "No HWSError.Message",
request: nil,
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
valid: false,
name: "No HWSError.Message",
request: nil,
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
expectLogItem: "HWSError.Message cannot be empty",
},
{
name: "No HWSError.Error",
@@ -127,7 +133,7 @@ func Test_ThrowError(t *testing.T) {
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
},
valid: false,
expectLogItem: "HWSError.Error cannot be nil",
},
{
name: "No request provided",
@@ -137,7 +143,7 @@ func Test_ThrowError(t *testing.T) {
Message: "An error occured",
Error: errors.New("Error"),
},
valid: false,
expectLogItem: "Request cannot be nil",
},
{
name: "Valid",
@@ -147,106 +153,92 @@ func Test_ThrowError(t *testing.T) {
Message: "An error occured",
Error: errors.New("Error"),
},
valid: true,
expectLogItem: "An error occured",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf.Reset()
rr := httptest.NewRecorder()
err := server.ThrowError(rr, tt.request, tt.error)
if tt.valid {
assert.NoError(t, err)
} else {
t.Log(err)
assert.Error(t, err)
}
server.ThrowError(rr, tt.request, tt.error)
// ThrowError no longer returns errors; check logs instead
output := buf.String()
assert.Contains(t, output, tt.expectLogItem)
})
}
t.Run("Log level set correctly", func(t *testing.T) {
buf.Reset()
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
Level: hws.ErrorWARN,
})
assert.NoError(t, err)
_, err = buf.ReadString([]byte(" ")[0])
_, err := buf.ReadString([]byte(" ")[0])
require.NoError(t, err)
loglvl, err := buf.ReadString([]byte(" ")[0])
assert.NoError(t, err)
if loglvl != "\x1b[33mWRN\x1b[0m " {
err = errors.New("Log level not set correctly")
}
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "\x1b[33mWRN\x1b[0m ", loglvl, "Log level should be WRN for ErrorWARN")
buf.Reset()
err = server.ThrowError(rr, req, hws.HWSError{
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
})
assert.NoError(t, err)
_, err = buf.ReadString([]byte(" ")[0])
require.NoError(t, err)
loglvl, err = buf.ReadString([]byte(" ")[0])
assert.NoError(t, err)
if loglvl != "\x1b[31mERR\x1b[0m " {
err = errors.New("Log level not set correctly")
}
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "\x1b[31mERR\x1b[0m ", loglvl, "Log level should be ERR when no level specified")
})
t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
// Must be run before adding the error page to the test server
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
RenderErrorPage: true,
})
assert.NoError(t, err)
body := rr.Body.String()
if body != "" {
assert.Error(t, nil)
}
assert.Empty(t, body, "Error page should not render when no error page is set")
})
t.Run("Error page renders", func(t *testing.T) {
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
// Adding the error page will carry over to all future tests and cant be undone
server.AddErrorPage(goodRender)
err := server.ThrowError(rr, req, hws.HWSError{
err := server.AddErrorPage(goodRender)
require.NoError(t, err)
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
RenderErrorPage: true,
})
assert.NoError(t, err)
body := rr.Body.String()
if body == "" {
assert.Error(t, nil)
}
assert.NotEmpty(t, body, "Error page should render when RenderErrorPage is true")
})
t.Run("Error page doesnt render if no told to render", func(t *testing.T) {
t.Run("Error page doesnt render if not told to render", func(t *testing.T) {
// Error page already added to server
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
})
assert.NoError(t, err)
body := rr.Body.String()
if body != "" {
assert.Error(t, nil)
}
assert.Empty(t, body, "Error page should not render when RenderErrorPage is false")
})
server.Shutdown(t.Context())
err := server.Shutdown(t.Context())
require.NoError(t, err)
t.Run("Doesn't error if no logger added to server", func(t *testing.T) {
t.Run("Doesn't panic if no logger added to server", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
@@ -261,13 +253,18 @@ func Test_ThrowError(t *testing.T) {
err = server.Start(t.Context())
require.NoError(t, err)
<-server.Ready()
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
err = server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
})
assert.NoError(t, err)
// Should not panic when no logger is present
assert.NotPanics(t, func() {
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "An error occured",
Error: errors.New("Error"),
})
}, "ThrowError should not panic when no logger is present")
err = server.Shutdown(t.Context())
require.NoError(t, err)
})
}

View File

@@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string {
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
return func() (any, error) {
return ConfigFromEnv()
}
}

View File

@@ -43,23 +43,12 @@ func (s *Server) LogError(err HWSError) {
}
}
func (server *Server) LogFatal(err error) {
if err == nil {
err = errors.New("LogFatal was called with a nil error")
}
if server.logger == nil {
fmt.Printf("FATAL - %s: %s", "A fatal error has occured", err.Error())
return
}
server.logger.logger.Fatal().Err(err).Msg("A fatal error has occured")
}
// AddLogger adds a logger to the server to use for request logging.
func (server *Server) AddLogger(hlogger *hlog.Logger) error {
func (s *Server) AddLogger(hlogger *hlog.Logger) error {
if hlogger == nil {
return errors.New("unable to add logger, no logger provided")
}
server.logger = &logger{
s.logger = &logger{
logger: hlogger,
}
return nil
@@ -68,7 +57,7 @@ func (server *Server) AddLogger(hlogger *hlog.Logger) error {
// LoggerIgnorePaths sets a list of URL paths to ignore logging for.
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
// Useful for ignoring requests to CSS files or favicons
func (server *Server) LoggerIgnorePaths(paths ...string) error {
func (s *Server) LoggerIgnorePaths(paths ...string) error {
for _, path := range paths {
u, err := url.Parse(path)
valid := err == nil &&
@@ -80,7 +69,7 @@ func (server *Server) LoggerIgnorePaths(paths ...string) error {
return fmt.Errorf("invalid path: '%s'", path)
}
}
server.logger.ignoredPaths = prepareGlobs(paths)
s.logger.ignoredPaths = prepareGlobs(paths)
return nil
}

View File

@@ -169,7 +169,7 @@ func Test_LogFatal(t *testing.T) {
// Note: We cannot actually test Fatal() as it calls os.Exit()
// Testing this would require subprocess testing which is overly complex
// These tests document the expected behavior and verify the function signatures exist
t.Run("LogFatal with nil logger prints to stdout", func(t *testing.T) {
_, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
@@ -197,7 +197,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("http://example.com/path")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path")
assert.Contains(t, err.Error(), "invalid path")
})
t.Run("Invalid path with host", func(t *testing.T) {
@@ -207,7 +207,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("//example.com/path")
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), "Invalid path")
assert.Contains(t, err.Error(), "invalid path")
}
})
@@ -217,7 +217,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("/path?query=value")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path")
assert.Contains(t, err.Error(), "invalid path")
})
t.Run("Invalid path with fragment", func(t *testing.T) {
@@ -226,7 +226,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("/path#fragment")
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path")
assert.Contains(t, err.Error(), "invalid path")
})
t.Run("Valid paths", func(t *testing.T) {

View File

@@ -5,35 +5,37 @@ import (
"net/http"
)
type Middleware func(h http.Handler) http.Handler
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
type (
Middleware func(h http.Handler) http.Handler
MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
)
// Server.AddMiddleware registers all the middleware.
// AddMiddleware registers all the middleware.
// Middleware will be run in the order that they are provided.
// Can only be called once
func (server *Server) AddMiddleware(middleware ...Middleware) error {
if !server.routes {
func (s *Server) AddMiddleware(middleware ...Middleware) error {
if !s.routes {
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
}
if server.middleware {
if s.middleware {
return errors.New("Server.AddMiddleware already called")
}
// RUN LOGGING MIDDLEWARE FIRST
server.server.Handler = logging(server.server.Handler, server.logger)
s.server.Handler = logging(s.server.Handler, s.logger)
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
for i := len(middleware); i > 0; i-- {
server.server.Handler = middleware[i-1](server.server.Handler)
s.server.Handler = middleware[i-1](s.server.Handler)
}
// RUN GZIP
if server.GZIP {
server.server.Handler = addgzip(server.server.Handler)
if s.GZIP {
s.server.Handler = addgzip(s.server.Handler)
}
// RUN TIMER MIDDLEWARE LAST
server.server.Handler = startTimer(server.server.Handler)
s.server.Handler = startTimer(s.server.Handler)
server.middleware = true
s.middleware = true
return nil
}
@@ -43,14 +45,14 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error {
// and returns a new request and optional HWSError.
// If a HWSError is returned, server.ThrowError will be called.
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
func (server *Server) NewMiddleware(
func (s *Server) NewMiddleware(
middlewareFunc MiddlewareFunc,
) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
newReq, herr := middlewareFunc(w, r)
if herr != nil {
server.ThrowError(w, r, *herr)
s.ThrowError(w, r, *herr)
if herr.RenderErrorPage {
return
}

View File

@@ -18,16 +18,24 @@ func startTimer(next http.Handler) http.Handler {
)
}
type contextKey string
func (c contextKey) String() string {
return "hws context key " + string(c)
}
var requestTimerCtxKey = contextKey("request-timer")
// Set the start time of the request
func setStart(ctx context.Context, time time.Time) context.Context {
return context.WithValue(ctx, "hws context key request-timer", time)
return context.WithValue(ctx, requestTimerCtxKey, time)
}
// Get the start time of the request
func getStartTime(ctx context.Context) (time.Time, error) {
start, ok := ctx.Value("hws context key request-timer").(time.Time)
start, ok := ctx.Value(requestTimerCtxKey).(time.Time)
if !ok {
return time.Time{}, errors.New("Failed to get start time of request")
return time.Time{}, errors.New("failed to get start time of request")
}
return start, nil
}

View File

@@ -100,7 +100,7 @@ func (s *Server) NotifySub(nt notify.Notification) {
}
_, exists := s.notifier.clients.getClient(nt.Target)
if !exists {
err := fmt.Errorf("Tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
err := fmt.Errorf("tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
return
}
@@ -119,7 +119,7 @@ func (s *Server) NotifyID(nt notify.Notification, altID string) {
clients, exists := s.notifier.clients.clientsIDMap[altID]
s.notifier.clients.lock.RUnlock()
if !exists {
err := fmt.Errorf("Tried to notify client group that doesn't exist - altID: %s", altID)
err := fmt.Errorf("tried to notify client group that doesn't exist - altID: %s", altID)
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
return
}

View File

@@ -15,8 +15,9 @@ func newTestServerWithNotifier(t *testing.T) *Server {
t.Helper()
cfg := &Config{
Host: "127.0.0.1",
Port: 0,
Host: "127.0.0.1",
Port: 0,
ShutdownDelay: 0, // No delay for tests
}
server, err := NewServer(cfg)
@@ -359,7 +360,7 @@ func Test_ActiveClientStaysAlive(t *testing.T) {
done := make(chan bool)
go func() {
for i := 0; i < 3; i++ {
for range 3 {
<-ticker.C
server.NotifySub(notify.Notification{
Target: client.sub.ID,
@@ -460,7 +461,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
defer close(stop)
// Send 10 notifications quickly (buffer is 10)
for i := 0; i < 10; i++ {
for range 10 {
server.NotifySub(notify.Notification{
Target: client.sub.ID,
Message: "Burst message",
@@ -468,7 +469,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
}
// Client should receive all 10
for i := 0; i < 10; i++ {
for i := range 10 {
select {
case <-notifications:
// Received
@@ -487,7 +488,7 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
defer close(stop)
// Fill buffer completely (buffer is 10)
for i := 0; i < 10; i++ {
for range 10 {
server.NotifySub(notify.Notification{
Target: client.sub.ID,
Message: "Fill buffer",
@@ -500,15 +501,15 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
Message: "Timeout message",
})
// Wait for timeout
time.Sleep(6 * time.Second)
// Wait for timeout (5s timeout + small buffer)
time.Sleep(5100 * time.Millisecond)
// Check failure count (should be 1)
fails := atomic.LoadInt32(&client.consecutiveFails)
require.Equal(t, int32(1), fails, "Should have 1 timeout")
// Now read all buffered messages
for i := 0; i < 10; i++ {
for range 10 {
<-notifications
}
@@ -538,7 +539,7 @@ func Test_ConsecutiveFailureDisconnect(t *testing.T) {
defer close(stop)
// Fill buffer and never read to cause 5 consecutive timeouts
for i := 0; i < 20; i++ {
for range 20 {
server.NotifySub(notify.Notification{
Target: client.sub.ID,
Message: "Timeout message",
@@ -684,7 +685,7 @@ func Test_ConcurrentSubscriptions(t *testing.T) {
var wg sync.WaitGroup
clients := make([]*Client, 100)
for i := 0; i < 100; i++ {
for i := range 100 {
wg.Add(1)
go func(index int) {
defer wg.Done()
@@ -716,7 +717,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
messageCount := 50
// Send from multiple goroutines
for i := 0; i < messageCount; i++ {
for i := range messageCount {
wg.Add(1)
go func(index int) {
defer wg.Done()
@@ -733,7 +734,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
// This is expected behavior - we're testing thread safety, not guaranteed delivery
// Just verify we receive at least some messages without panicking or deadlocking
received := 0
timeout := time.After(2 * time.Second)
timeout := time.After(500 * time.Millisecond)
for received < messageCount {
select {
case <-notifications:
@@ -751,7 +752,7 @@ func Test_ConcurrentCleanup(t *testing.T) {
server := newTestServerWithNotifier(t)
// Create some clients
for i := 0; i < 10; i++ {
for i := range 10 {
client, _ := server.GetClient("", "")
// Set some to be old
if i%2 == 0 {
@@ -790,39 +791,34 @@ func Test_NoRaceConditions(t *testing.T) {
var wg sync.WaitGroup
// Create a few clients and read from them
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for range 5 {
wg.Go(func() {
client, _ := server.GetClient("", "")
notifications, stop := client.Listen()
defer close(stop)
// Actively read messages
timeout := time.After(2 * time.Second)
timeout := time.After(200 * time.Millisecond)
for {
select {
case <-notifications:
// Keep reading
// Keep reading
case <-timeout:
return
}
}
}()
})
}
// Send a few notifications
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 20; j++ {
wg.Go(func() {
for range 10 {
server.NotifyAll(notify.Notification{
Message: "Stress test",
})
time.Sleep(50 * time.Millisecond)
time.Sleep(10 * time.Millisecond)
}
}()
})
wg.Wait()
}
@@ -948,7 +944,7 @@ func Test_ListenSignature(t *testing.T) {
require.NotNil(t, stop)
// notifications should be receive-only
_, ok := interface{}(notifications).(<-chan notify.Notification)
_, ok := any(notifications).(<-chan notify.Notification)
require.True(t, ok, "notifications should be receive-only channel")
// stop should be closeable
@@ -964,7 +960,7 @@ func Test_BufferSize(t *testing.T) {
defer close(stop)
// Send 10 messages without reading (buffer size is 10)
for i := 0; i < 10; i++ {
for range 10 {
server.NotifySub(notify.Notification{
Target: client.sub.ID,
Message: "Buffered",
@@ -975,7 +971,7 @@ func Test_BufferSize(t *testing.T) {
time.Sleep(100 * time.Millisecond)
// Read all 10
for i := 0; i < 10; i++ {
for i := range 10 {
select {
case <-notifications:
// Success

View File

@@ -30,13 +30,13 @@ const (
MethodPATCH Method = "PATCH"
)
// Server.AddRoutes registers the page handlers for the server.
// AddRoutes registers the page handlers for the server.
// At least one route must be provided.
// If any route patterns (path + method) are defined multiple times, the first
// instance will be added and any additional conflicts will be discarded.
func (server *Server) AddRoutes(routes ...Route) error {
func (s *Server) AddRoutes(routes ...Route) error {
if len(routes) == 0 {
return errors.New("No routes provided")
return errors.New("no routes provided")
}
patterns := []string{}
mux := http.NewServeMux()
@@ -47,10 +47,10 @@ func (server *Server) AddRoutes(routes ...Route) error {
}
for _, method := range route.Methods {
if !validMethod(method) {
return fmt.Errorf("Invalid method %s for path %s", method, route.Path)
return fmt.Errorf("invalid method %s for path %s", method, route.Path)
}
if route.Handler == nil {
return fmt.Errorf("No handler provided for %s %s", method, route.Path)
return fmt.Errorf("no handler provided for %s %s", method, route.Path)
}
pattern := fmt.Sprintf("%s %s", method, route.Path)
if slices.Contains(patterns, pattern) {
@@ -61,8 +61,8 @@ func (server *Server) AddRoutes(routes ...Route) error {
}
}
server.server.Handler = mux
server.routes = true
s.server.Handler = mux
s.routes = true
return nil
}

View File

@@ -18,7 +18,7 @@ func Test_AddRoutes(t *testing.T) {
server := createTestServer(t, &buf)
err := server.AddRoutes()
assert.Error(t, err)
assert.Contains(t, err.Error(), "No routes provided")
assert.Contains(t, err.Error(), "no routes provided")
})
t.Run("Single valid route", func(t *testing.T) {
@@ -58,7 +58,7 @@ func Test_AddRoutes(t *testing.T) {
Handler: handler,
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid method")
assert.Contains(t, err.Error(), "invalid method")
})
t.Run("No handler provided", func(t *testing.T) {
@@ -69,7 +69,7 @@ func Test_AddRoutes(t *testing.T) {
Handler: nil,
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "No handler provided")
assert.Contains(t, err.Error(), "no handler provided")
})
t.Run("All HTTP methods are valid", func(t *testing.T) {
@@ -203,7 +203,7 @@ func Test_AddRoutes_MultipleMethods(t *testing.T) {
Handler: handler,
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid method")
assert.Contains(t, err.Error(), "invalid method")
})
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {

View File

@@ -26,14 +26,14 @@ type Server struct {
}
// Ready returns a channel that is closed when the server is started
func (server *Server) Ready() <-chan struct{} {
return server.ready
func (s *Server) Ready() <-chan struct{} {
return s.ready
}
// IsReady checks if the server is running
func (server *Server) IsReady() bool {
func (s *Server) IsReady() bool {
select {
case <-server.ready:
case <-s.ready:
return true
default:
return false
@@ -41,13 +41,13 @@ func (server *Server) IsReady() bool {
}
// Addr returns the server's network address
func (server *Server) Addr() string {
return server.server.Addr
func (s *Server) Addr() string {
return s.server.Addr
}
// Handler returns the server's HTTP handler for testing purposes
func (server *Server) Handler() http.Handler {
return server.server.Handler
func (s *Server) Handler() http.Handler {
return s.server.Handler
}
// NewServer returns a new hws.Server with the specified configuration.
@@ -75,7 +75,7 @@ func NewServer(config *Config) (*Server, error) {
valid := isValidHostname(config.Host)
if !valid {
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host)
return nil, fmt.Errorf("hostname '%s' is not valid", config.Host)
}
httpServer := &http.Server{
@@ -95,62 +95,64 @@ func NewServer(config *Config) (*Server, error) {
return server, nil
}
func (server *Server) Start(ctx context.Context) error {
func (s *Server) Start(ctx context.Context) error {
if ctx == nil {
return errors.New("Context cannot be nil")
}
if !server.routes {
if !s.routes {
return errors.New("Server.AddRoutes must be run before starting the server")
}
if !server.middleware {
err := server.AddMiddleware()
if !s.middleware {
err := s.AddMiddleware()
if err != nil {
return errors.Wrap(err, "server.AddMiddleware")
}
}
server.startNotifier()
s.startNotifier()
go func() {
if server.logger == nil {
fmt.Printf("Listening for requests on %s", server.server.Addr)
if s.logger == nil {
fmt.Printf("Listening for requests on %s", s.server.Addr)
} else {
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
s.logger.logger.Info().Str("address", s.server.Addr).Msg("Listening for requests")
}
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if server.logger == nil {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if s.logger == nil {
fmt.Printf("Server encountered a fatal error: %s", err.Error())
} else {
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
s.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
}
}
}()
server.waitUntilReady(ctx)
s.waitUntilReady(ctx)
return nil
}
func (server *Server) Shutdown(ctx context.Context) error {
server.logger.logger.Debug().Dur("shutdown_delay", server.shutdowndelay).Msg("HWS Server shutting down")
server.NotifyAll(notify.Notification{
func (s *Server) Shutdown(ctx context.Context) error {
if s.logger != nil {
s.logger.logger.Debug().Dur("shutdown_delay", s.shutdowndelay).Msg("HWS Server shutting down")
}
s.NotifyAll(notify.Notification{
Title: "Shutting down",
Message: fmt.Sprintf("Server is shutting down in %v", server.shutdowndelay),
Message: fmt.Sprintf("Server is shutting down in %v", s.shutdowndelay),
Level: LevelShutdown,
})
<-time.NewTimer(server.shutdowndelay).C
if !server.IsReady() {
<-time.NewTimer(s.shutdowndelay).C
if !s.IsReady() {
return errors.New("Server isn't running")
}
if ctx == nil {
return errors.New("Context cannot be nil")
}
err := server.server.Shutdown(ctx)
err := s.server.Shutdown(ctx)
if err != nil {
return errors.Wrap(err, "Failed to shutdown the server gracefully")
}
server.closeNotifier()
server.ready = make(chan struct{})
s.closeNotifier()
s.ready = make(chan struct{})
return nil
}
@@ -168,7 +170,7 @@ func isValidHostname(host string) bool {
return false
}
func (server *Server) waitUntilReady(ctx context.Context) error {
func (s *Server) waitUntilReady(ctx context.Context) error {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
@@ -180,14 +182,14 @@ func (server *Server) waitUntilReady(ctx context.Context) error {
return ctx.Err()
case <-ticker.C:
resp, err := http.Get("http://" + server.server.Addr + "/healthz")
resp, err := http.Get("http://" + s.server.Addr + "/healthz")
if err != nil {
continue // not accepting yet
}
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
closeOnce.Do(func() { close(server.ready) })
closeOnce.Do(func() { close(s.ready) })
return nil
}
}

View File

@@ -26,8 +26,9 @@ func randomPort() uint64 {
func createTestServer(t *testing.T, w io.Writer) *hws.Server {
server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1",
Port: randomPort(),
Host: "127.0.0.1",
Port: randomPort(),
ShutdownDelay: 0, // No delay for tests
})
require.NoError(t, err)

View File

@@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string {
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
return func() (any, error) {
return ConfigFromEnv()
}
}

View File

@@ -6,13 +6,15 @@ require (
git.haelnorr.com/h/golib/cookies v0.9.0
git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/hlog v0.10.4
git.haelnorr.com/h/golib/hws v0.3.0
git.haelnorr.com/h/golib/hws v0.5.0
git.haelnorr.com/h/golib/jwt v0.10.1
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.11.1
)
require git.haelnorr.com/h/golib/notify v0.1.0 // indirect
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect

View File

@@ -4,10 +4,12 @@ git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjo
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo=
git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=

View File

@@ -23,8 +23,7 @@ func (tm TestModel) GetID() int {
return tm.ID
}
type TestTransaction struct {
}
type TestTransaction struct{}
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
return nil, nil
@@ -137,8 +136,10 @@ func TestCurrentModel(t *testing.T) {
func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
// Clear environment variables
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
os.Setenv("HWSAUTH_SECRET_KEY", "")
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
_ = os.Setenv("HWSAUTH_SECRET_KEY", "")
defer func() {
_ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
}()
_, err := ConfigFromEnv()
assert.Error(t, err)
@@ -327,7 +328,9 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
db, _, err := createMockDB()
require.NoError(t, err)
defer db.Close()
defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator(
cfg,
@@ -409,7 +412,9 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
db, _, err := createMockDB()
require.NoError(t, err)
defer db.Close()
defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator(
cfg,
@@ -454,7 +459,9 @@ func TestLogin_BasicFunctionality(t *testing.T) {
db, _, err := createMockDB()
require.NoError(t, err)
defer db.Close()
defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator(
cfg,
@@ -476,6 +483,7 @@ func TestLogin_BasicFunctionality(t *testing.T) {
// This test mainly checks that the function doesn't panic and has right call signature
// The actual JWT functionality is tested in jwt package itself
assert.NotPanics(t, func() {
auth.Login(w, r, user, rememberMe)
err := auth.Login(w, r, user, rememberMe)
require.NoError(t, err)
})
}

View File

@@ -24,7 +24,7 @@ func (auth *Authenticator[T, TX]) IgnorePaths(paths ...string) error {
u.RawQuery == "" &&
u.Fragment == ""
if !valid {
return fmt.Errorf("Invalid path: '%s'", path)
return fmt.Errorf("invalid path: '%s'", path)
}
}
auth.ignoredPaths = prepareGlobs(paths)

View File

@@ -33,13 +33,17 @@ func (auth *Authenticator[T, TX]) Logout(tx TX, w http.ResponseWriter, r *http.R
if err != nil {
return errors.Wrap(err, "auth.getTokens")
}
err = aT.Revoke(jwt.DBTransaction(tx))
if err != nil {
return errors.Wrap(err, "aT.Revoke")
if aT != nil {
err = aT.Revoke(jwt.DBTransaction(tx))
if err != nil {
return errors.Wrap(err, "aT.Revoke")
}
}
err = rT.Revoke(jwt.DBTransaction(tx))
if err != nil {
return errors.Wrap(err, "rT.Revoke")
if rT != nil {
err = rT.Revoke(jwt.DBTransaction(tx))
if err != nil {
return errors.Wrap(err, "rT.Revoke")
}
}
cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/")

View File

@@ -16,12 +16,20 @@ import (
//
// Example:
//
// server.AddMiddleware(auth.Authenticate())
func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware {
return auth.server.NewMiddleware(auth.authenticate())
// server.AddMiddleware(auth.Authenticate(nil))
//
// If extraCheck is provided, it will run just before the user is added to the context,
// and the return will determine if the user will be added, or the request passed on
// without the user.
func (auth *Authenticator[T, TX]) Authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.Middleware {
return auth.server.NewMiddleware(auth.authenticate(extraCheck))
}
func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
func (auth *Authenticator[T, TX]) authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.MiddlewareFunc {
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
if globTest(r.URL.Path, auth.ignoredPaths) {
return r, nil
@@ -38,7 +46,9 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
Error: errors.Wrap(err, "auth.beginTx"),
}
}
defer tx.Rollback()
defer func() {
_ = tx.Rollback()
}()
// Type assert to TX - safe because user's beginTx should return their TX type
txTyped, ok := tx.(TX)
if !ok {
@@ -64,10 +74,28 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
Msg("Failed to authenticate user")
return r, nil
}
tx.Commit()
var check bool
if extraCheck != nil {
var err *hws.HWSError
check, err = extraCheck(ctx, model.model, txTyped, w, r)
if err != nil {
return nil, err
}
}
err = tx.Commit()
if err != nil {
return nil, &hws.HWSError{
Message: "Failed to commit transaction",
StatusCode: http.StatusInternalServerError,
Error: errors.Wrap(err, "tx.Commit"),
}
}
authContext := setAuthenticatedModel(r.Context(), model)
newReq := r.WithContext(authContext)
return newReq, nil
if extraCheck == nil || check {
return newReq, nil
}
return r, nil
}
}

View File

@@ -39,9 +39,17 @@ type ContextLoader[T Model] func(ctx context.Context) T
// }
type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error)
type contextKey string
func (c contextKey) String() string {
return "hwsauth context key" + string(c)
}
var authenticatedModelContextKey = contextKey("authenticated-model")
// Return a new context with the user added in
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
return context.WithValue(ctx, "hwsauth context key authenticated-model", m)
return context.WithValue(ctx, authenticatedModelContextKey, m)
}
// Retrieve a user from the given context. Returns nil if not set
@@ -53,7 +61,7 @@ func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[
model = authenticatedModel[T]{}
}
}()
model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T])
model, cok := ctx.Value(authenticatedModelContextKey).(authenticatedModel[T])
if !cok {
return authenticatedModel[T]{}, false
}

View File

@@ -19,15 +19,12 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := getAuthorizedModel[T](r.Context())
if !ok {
err := auth.server.ThrowError(w, r, hws.HWSError{
auth.server.ThrowError(w, r, hws.HWSError{
Error: errors.New("Login required"),
Message: "Please login to view this page",
StatusCode: http.StatusUnauthorized,
RenderErrorPage: true,
})
if err != nil {
auth.server.ThrowFatal(w, err)
}
return
}
next.ServeHTTP(w, r)
@@ -66,15 +63,12 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model, ok := getAuthorizedModel[T](r.Context())
if !ok {
err := auth.server.ThrowError(w, r, hws.HWSError{
auth.server.ThrowError(w, r, hws.HWSError{
Error: errors.New("Login required"),
Message: "Please login to view this page",
StatusCode: http.StatusUnauthorized,
RenderErrorPage: true,
})
if err != nil {
auth.server.ThrowFatal(w, err)
}
return
}
isFresh := time.Now().Before(time.Unix(model.fresh, 0))

View File

@@ -34,7 +34,7 @@ func (auth *Authenticator[T, TX]) RefreshAuthTokens(tx TX, w http.ResponseWriter
rememberMe := map[string]bool{
"session": false,
"exp": true,
}[aT.TTL]
}[rT.TTL]
// issue new tokens for the user
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
if err != nil {
@@ -55,13 +55,20 @@ func (auth *Authenticator[T, TX]) getTokens(
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
// get the existing tokens from the cookies
atStr, rtStr := jwt.GetTokenCookies(r)
aT, err := auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
var aT *jwt.AccessToken
var rT *jwt.RefreshToken
var err error
if atStr != "" {
aT, err = auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
}
}
rT, err := auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
if rtStr != "" {
rT, err = auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
}
}
return aT, rT, nil
}
@@ -72,13 +79,17 @@ func revokeTokenPair(
aT *jwt.AccessToken,
rT *jwt.RefreshToken,
) error {
err := aT.Revoke(tx)
if err != nil {
return errors.Wrap(err, "aT.Revoke")
if aT != nil {
err := aT.Revoke(tx)
if err != nil {
return errors.Wrap(err, "aT.Revoke")
}
}
err = rT.Revoke(tx)
if err != nil {
return errors.Wrap(err, "rT.Revoke")
if rT != nil {
err := rT.Revoke(tx)
if err != nil {
return errors.Wrap(err, "rT.Revoke")
}
}
return nil
}