diff --git a/hwsauth/middleware.go b/hwsauth/middleware.go index 6c6c222..b5d468b 100644 --- a/hwsauth/middleware.go +++ b/hwsauth/middleware.go @@ -16,12 +16,20 @@ import ( // // Example: // -// server.AddMiddleware(auth.Authenticate()) -func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware { - return auth.server.NewMiddleware(auth.authenticate()) +// server.AddMiddleware(auth.Authenticate(nil)) +// +// If extraCheck is provided, it will run just before the user is added to the context, +// and the return will determine if the user will be added, or the request passed on +// without the user. +func (auth *Authenticator[T, TX]) Authenticate( + extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError), +) hws.Middleware { + return auth.server.NewMiddleware(auth.authenticate(extraCheck)) } -func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { +func (auth *Authenticator[T, TX]) authenticate( + extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError), +) hws.MiddlewareFunc { return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { if globTest(r.URL.Path, auth.ignoredPaths) { return r, nil @@ -66,6 +74,14 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { Msg("Failed to authenticate user") return r, nil } + var check bool + if extraCheck != nil { + var err *hws.HWSError + check, err = extraCheck(ctx, model.model, txTyped, w, r) + if err != nil { + return nil, err + } + } err = tx.Commit() if err != nil { return nil, &hws.HWSError{ @@ -76,7 +92,10 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { } authContext := setAuthenticatedModel(r.Context(), model) newReq := r.WithContext(authContext) - return newReq, nil + if extraCheck == nil || check { + return newReq, nil + } + return r, nil } }