Compare commits

...

2 Commits

Author SHA1 Message Date
05be28d7f3 fixed fatal bug after access token expires 2026-02-07 17:58:02 +11:00
8f7c87cef2 added extracheck to hwsauth 2026-02-07 16:42:08 +11:00
3 changed files with 58 additions and 24 deletions

View File

@@ -33,13 +33,17 @@ func (auth *Authenticator[T, TX]) Logout(tx TX, w http.ResponseWriter, r *http.R
if err != nil { if err != nil {
return errors.Wrap(err, "auth.getTokens") return errors.Wrap(err, "auth.getTokens")
} }
err = aT.Revoke(jwt.DBTransaction(tx)) if aT != nil {
if err != nil { err = aT.Revoke(jwt.DBTransaction(tx))
return errors.Wrap(err, "aT.Revoke") if err != nil {
return errors.Wrap(err, "aT.Revoke")
}
} }
err = rT.Revoke(jwt.DBTransaction(tx)) if rT != nil {
if err != nil { err = rT.Revoke(jwt.DBTransaction(tx))
return errors.Wrap(err, "rT.Revoke") if err != nil {
return errors.Wrap(err, "rT.Revoke")
}
} }
cookies.DeleteCookie(w, "access", "/") cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/") cookies.DeleteCookie(w, "refresh", "/")

View File

@@ -16,12 +16,20 @@ import (
// //
// Example: // Example:
// //
// server.AddMiddleware(auth.Authenticate()) // server.AddMiddleware(auth.Authenticate(nil))
func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware { //
return auth.server.NewMiddleware(auth.authenticate()) // If extraCheck is provided, it will run just before the user is added to the context,
// and the return will determine if the user will be added, or the request passed on
// without the user.
func (auth *Authenticator[T, TX]) Authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.Middleware {
return auth.server.NewMiddleware(auth.authenticate(extraCheck))
} }
func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { func (auth *Authenticator[T, TX]) authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.MiddlewareFunc {
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
if globTest(r.URL.Path, auth.ignoredPaths) { if globTest(r.URL.Path, auth.ignoredPaths) {
return r, nil return r, nil
@@ -66,6 +74,14 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
Msg("Failed to authenticate user") Msg("Failed to authenticate user")
return r, nil return r, nil
} }
var check bool
if extraCheck != nil {
var err *hws.HWSError
check, err = extraCheck(ctx, model.model, txTyped, w, r)
if err != nil {
return nil, err
}
}
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
return nil, &hws.HWSError{ return nil, &hws.HWSError{
@@ -76,7 +92,10 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
} }
authContext := setAuthenticatedModel(r.Context(), model) authContext := setAuthenticatedModel(r.Context(), model)
newReq := r.WithContext(authContext) newReq := r.WithContext(authContext)
return newReq, nil if extraCheck == nil || check {
return newReq, nil
}
return r, nil
} }
} }

View File

@@ -34,7 +34,7 @@ func (auth *Authenticator[T, TX]) RefreshAuthTokens(tx TX, w http.ResponseWriter
rememberMe := map[string]bool{ rememberMe := map[string]bool{
"session": false, "session": false,
"exp": true, "exp": true,
}[aT.TTL] }[rT.TTL]
// issue new tokens for the user // issue new tokens for the user
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL) err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
if err != nil { if err != nil {
@@ -55,13 +55,20 @@ func (auth *Authenticator[T, TX]) getTokens(
) (*jwt.AccessToken, *jwt.RefreshToken, error) { ) (*jwt.AccessToken, *jwt.RefreshToken, error) {
// get the existing tokens from the cookies // get the existing tokens from the cookies
atStr, rtStr := jwt.GetTokenCookies(r) atStr, rtStr := jwt.GetTokenCookies(r)
aT, err := auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr) var aT *jwt.AccessToken
if err != nil { var rT *jwt.RefreshToken
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess") var err error
if atStr != "" {
aT, err = auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
}
} }
rT, err := auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr) if rtStr != "" {
if err != nil { rT, err = auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh") if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
}
} }
return aT, rT, nil return aT, rT, nil
} }
@@ -72,13 +79,17 @@ func revokeTokenPair(
aT *jwt.AccessToken, aT *jwt.AccessToken,
rT *jwt.RefreshToken, rT *jwt.RefreshToken,
) error { ) error {
err := aT.Revoke(tx) if aT != nil {
if err != nil { err := aT.Revoke(tx)
return errors.Wrap(err, "aT.Revoke") if err != nil {
return errors.Wrap(err, "aT.Revoke")
}
} }
err = rT.Revoke(tx) if rT != nil {
if err != nil { err := rT.Revoke(tx)
return errors.Wrap(err, "rT.Revoke") if err != nil {
return errors.Wrap(err, "rT.Revoke")
}
} }
return nil return nil
} }