diff --git a/hwsauth/ezconf.go b/hwsauth/ezconf.go index 39ca8ff..0feff48 100644 --- a/hwsauth/ezconf.go +++ b/hwsauth/ezconf.go @@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string { } // ConfigFunc returns the ConfigFromEnv function for ezconf -func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) { - return func() (interface{}, error) { +func (e EZConfIntegration) ConfigFunc() func() (any, error) { + return func() (any, error) { return ConfigFromEnv() } } diff --git a/hwsauth/go.mod b/hwsauth/go.mod index 3cb438d..dc6a07a 100644 --- a/hwsauth/go.mod +++ b/hwsauth/go.mod @@ -6,13 +6,15 @@ require ( git.haelnorr.com/h/golib/cookies v0.9.0 git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/hlog v0.10.4 - git.haelnorr.com/h/golib/hws v0.3.0 + git.haelnorr.com/h/golib/hws v0.5.0 git.haelnorr.com/h/golib/jwt v0.10.1 github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.11.1 ) +require git.haelnorr.com/h/golib/notify v0.1.0 // indirect + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-logr/logr v1.4.3 // indirect diff --git a/hwsauth/go.sum b/hwsauth/go.sum index a60e947..b9bff77 100644 --- a/hwsauth/go.sum +++ b/hwsauth/go.sum @@ -4,10 +4,12 @@ git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjo git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= -git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo= -git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= +git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c= +git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= +git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10= +git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/hwsauth/hwsauth_test.go b/hwsauth/hwsauth_test.go index 4de5945..4a3b0b2 100644 --- a/hwsauth/hwsauth_test.go +++ b/hwsauth/hwsauth_test.go @@ -23,8 +23,7 @@ func (tm TestModel) GetID() int { return tm.ID } -type TestTransaction struct { -} +type TestTransaction struct{} func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) { return nil, nil @@ -137,8 +136,10 @@ func TestCurrentModel(t *testing.T) { func TestConfigFromEnv_MissingSecretKey(t *testing.T) { // Clear environment variables originalSecret := os.Getenv("HWSAUTH_SECRET_KEY") - os.Setenv("HWSAUTH_SECRET_KEY", "") - defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret) + _ = os.Setenv("HWSAUTH_SECRET_KEY", "") + defer func() { + _ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret) + }() _, err := ConfigFromEnv() assert.Error(t, err) @@ -327,7 +328,9 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) { db, _, err := createMockDB() require.NoError(t, err) - defer db.Close() + defer func() { + _ = db.Close() + }() auth, err := NewAuthenticator( cfg, @@ -409,7 +412,9 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) { db, _, err := createMockDB() require.NoError(t, err) - defer db.Close() + defer func() { + _ = db.Close() + }() auth, err := NewAuthenticator( cfg, @@ -454,7 +459,9 @@ func TestLogin_BasicFunctionality(t *testing.T) { db, _, err := createMockDB() require.NoError(t, err) - defer db.Close() + defer func() { + _ = db.Close() + }() auth, err := NewAuthenticator( cfg, @@ -476,6 +483,7 @@ func TestLogin_BasicFunctionality(t *testing.T) { // This test mainly checks that the function doesn't panic and has right call signature // The actual JWT functionality is tested in jwt package itself assert.NotPanics(t, func() { - auth.Login(w, r, user, rememberMe) + err := auth.Login(w, r, user, rememberMe) + require.NoError(t, err) }) } diff --git a/hwsauth/ignorepaths.go b/hwsauth/ignorepaths.go index e01b8bc..9a53e35 100644 --- a/hwsauth/ignorepaths.go +++ b/hwsauth/ignorepaths.go @@ -24,7 +24,7 @@ func (auth *Authenticator[T, TX]) IgnorePaths(paths ...string) error { u.RawQuery == "" && u.Fragment == "" if !valid { - return fmt.Errorf("Invalid path: '%s'", path) + return fmt.Errorf("invalid path: '%s'", path) } } auth.ignoredPaths = prepareGlobs(paths) diff --git a/hwsauth/middleware.go b/hwsauth/middleware.go index c576654..6c6c222 100644 --- a/hwsauth/middleware.go +++ b/hwsauth/middleware.go @@ -38,7 +38,9 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { Error: errors.Wrap(err, "auth.beginTx"), } } - defer tx.Rollback() + defer func() { + _ = tx.Rollback() + }() // Type assert to TX - safe because user's beginTx should return their TX type txTyped, ok := tx.(TX) if !ok { @@ -64,7 +66,14 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { Msg("Failed to authenticate user") return r, nil } - tx.Commit() + err = tx.Commit() + if err != nil { + return nil, &hws.HWSError{ + Message: "Failed to commit transaction", + StatusCode: http.StatusInternalServerError, + Error: errors.Wrap(err, "tx.Commit"), + } + } authContext := setAuthenticatedModel(r.Context(), model) newReq := r.WithContext(authContext) return newReq, nil diff --git a/hwsauth/model.go b/hwsauth/model.go index d37c761..134fff8 100644 --- a/hwsauth/model.go +++ b/hwsauth/model.go @@ -39,9 +39,17 @@ type ContextLoader[T Model] func(ctx context.Context) T // } type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error) +type contextKey string + +func (c contextKey) String() string { + return "hwsauth context key" + string(c) +} + +var authenticatedModelContextKey = contextKey("authenticated-model") + // Return a new context with the user added in func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context { - return context.WithValue(ctx, "hwsauth context key authenticated-model", m) + return context.WithValue(ctx, authenticatedModelContextKey, m) } // Retrieve a user from the given context. Returns nil if not set @@ -53,7 +61,7 @@ func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[ model = authenticatedModel[T]{} } }() - model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T]) + model, cok := ctx.Value(authenticatedModelContextKey).(authenticatedModel[T]) if !cok { return authenticatedModel[T]{}, false } diff --git a/hwsauth/protectpage.go b/hwsauth/protectpage.go index 1807459..8868d35 100644 --- a/hwsauth/protectpage.go +++ b/hwsauth/protectpage.go @@ -19,15 +19,12 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, ok := getAuthorizedModel[T](r.Context()) if !ok { - err := auth.server.ThrowError(w, r, hws.HWSError{ + auth.server.ThrowError(w, r, hws.HWSError{ Error: errors.New("Login required"), Message: "Please login to view this page", StatusCode: http.StatusUnauthorized, RenderErrorPage: true, }) - if err != nil { - auth.server.ThrowFatal(w, err) - } return } next.ServeHTTP(w, r) @@ -66,15 +63,12 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { model, ok := getAuthorizedModel[T](r.Context()) if !ok { - err := auth.server.ThrowError(w, r, hws.HWSError{ + auth.server.ThrowError(w, r, hws.HWSError{ Error: errors.New("Login required"), Message: "Please login to view this page", StatusCode: http.StatusUnauthorized, RenderErrorPage: true, }) - if err != nil { - auth.server.ThrowFatal(w, err) - } return } isFresh := time.Now().Before(time.Unix(model.fresh, 0))