package hwsauth import ( "context" "net/http" "time" "git.haelnorr.com/h/golib/hws" "github.com/gobwas/glob" "github.com/pkg/errors" ) // Authenticate returns the main authentication middleware. // This middleware validates JWT tokens, refreshes expired tokens, and adds // the authenticated user to the request context. // // Example: // // server.AddMiddleware(auth.Authenticate()) func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware { return auth.server.NewMiddleware(auth.authenticate()) } func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { if globTest(r.URL.Path, auth.ignoredPaths) { return r, nil } ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) defer cancel() // Start the transaction tx, err := auth.beginTx(ctx) if err != nil { return nil, &hws.HWSError{ Message: "Unable to start transaction", StatusCode: http.StatusServiceUnavailable, Error: errors.Wrap(err, "auth.beginTx"), } } defer func() { _ = tx.Rollback() }() // Type assert to TX - safe because user's beginTx should return their TX type txTyped, ok := tx.(TX) if !ok { return nil, &hws.HWSError{ Message: "Transaction type mismatch", StatusCode: http.StatusInternalServerError, Error: errors.Wrap(err, "TX type not ok"), } } model, err := auth.getAuthenticatedUser(txTyped, w, r) if err != nil { rberr := tx.Rollback() if rberr != nil { return nil, &hws.HWSError{ Message: "Failed rolling back after error", StatusCode: http.StatusInternalServerError, Error: errors.Wrap(err, "tx.Rollback"), } } auth.logger.Debug(). Str("remote_addr", r.RemoteAddr). Err(err). Msg("Failed to authenticate user") return r, nil } err = tx.Commit() if err != nil { return nil, &hws.HWSError{ Message: "Failed to commit transaction", StatusCode: http.StatusInternalServerError, Error: errors.Wrap(err, "tx.Commit"), } } authContext := setAuthenticatedModel(r.Context(), model) newReq := r.WithContext(authContext) return newReq, nil } } func globTest(testPath string, globs []glob.Glob) bool { for _, g := range globs { if g.Match(testPath) { return true } } return false }