Compare commits

...

1 Commits

Author SHA1 Message Date
8f7c87cef2 added extracheck to hwsauth 2026-02-07 16:42:08 +11:00

View File

@@ -16,12 +16,20 @@ import (
// //
// Example: // Example:
// //
// server.AddMiddleware(auth.Authenticate()) // server.AddMiddleware(auth.Authenticate(nil))
func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware { //
return auth.server.NewMiddleware(auth.authenticate()) // If extraCheck is provided, it will run just before the user is added to the context,
// and the return will determine if the user will be added, or the request passed on
// without the user.
func (auth *Authenticator[T, TX]) Authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.Middleware {
return auth.server.NewMiddleware(auth.authenticate(extraCheck))
} }
func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { func (auth *Authenticator[T, TX]) authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.MiddlewareFunc {
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
if globTest(r.URL.Path, auth.ignoredPaths) { if globTest(r.URL.Path, auth.ignoredPaths) {
return r, nil return r, nil
@@ -66,6 +74,14 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
Msg("Failed to authenticate user") Msg("Failed to authenticate user")
return r, nil return r, nil
} }
var check bool
if extraCheck != nil {
var err *hws.HWSError
check, err = extraCheck(ctx, model.model, txTyped, w, r)
if err != nil {
return nil, err
}
}
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
return nil, &hws.HWSError{ return nil, &hws.HWSError{
@@ -76,7 +92,10 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
} }
authContext := setAuthenticatedModel(r.Context(), model) authContext := setAuthenticatedModel(r.Context(), model)
newReq := r.WithContext(authContext) newReq := r.WithContext(authContext)
return newReq, nil if extraCheck == nil || check {
return newReq, nil
}
return r, nil
} }
} }