Compare commits

...

1 Commits

Author SHA1 Message Date
525b3b1396 updated to use new hws version 2026-02-03 19:11:59 +11:00
8 changed files with 49 additions and 26 deletions

View File

@@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string {
} }
// ConfigFunc returns the ConfigFromEnv function for ezconf // ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) { func (e EZConfIntegration) ConfigFunc() func() (any, error) {
return func() (interface{}, error) { return func() (any, error) {
return ConfigFromEnv() return ConfigFromEnv()
} }
} }

View File

@@ -6,13 +6,15 @@ require (
git.haelnorr.com/h/golib/cookies v0.9.0 git.haelnorr.com/h/golib/cookies v0.9.0
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/hlog v0.10.4 git.haelnorr.com/h/golib/hlog v0.10.4
git.haelnorr.com/h/golib/hws v0.3.0 git.haelnorr.com/h/golib/hws v0.5.0
git.haelnorr.com/h/golib/jwt v0.10.1 git.haelnorr.com/h/golib/jwt v0.10.1
github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
) )
require git.haelnorr.com/h/golib/notify v0.1.0 // indirect
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect

View File

@@ -4,10 +4,12 @@ git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjo
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo= git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=

View File

@@ -23,8 +23,7 @@ func (tm TestModel) GetID() int {
return tm.ID return tm.ID
} }
type TestTransaction struct { type TestTransaction struct{}
}
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) { func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
return nil, nil return nil, nil
@@ -137,8 +136,10 @@ func TestCurrentModel(t *testing.T) {
func TestConfigFromEnv_MissingSecretKey(t *testing.T) { func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
// Clear environment variables // Clear environment variables
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY") originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
os.Setenv("HWSAUTH_SECRET_KEY", "") _ = os.Setenv("HWSAUTH_SECRET_KEY", "")
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret) defer func() {
_ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
}()
_, err := ConfigFromEnv() _, err := ConfigFromEnv()
assert.Error(t, err) assert.Error(t, err)
@@ -327,7 +328,9 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
db, _, err := createMockDB() db, _, err := createMockDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator( auth, err := NewAuthenticator(
cfg, cfg,
@@ -409,7 +412,9 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
db, _, err := createMockDB() db, _, err := createMockDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator( auth, err := NewAuthenticator(
cfg, cfg,
@@ -454,7 +459,9 @@ func TestLogin_BasicFunctionality(t *testing.T) {
db, _, err := createMockDB() db, _, err := createMockDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator( auth, err := NewAuthenticator(
cfg, cfg,
@@ -476,6 +483,7 @@ func TestLogin_BasicFunctionality(t *testing.T) {
// This test mainly checks that the function doesn't panic and has right call signature // This test mainly checks that the function doesn't panic and has right call signature
// The actual JWT functionality is tested in jwt package itself // The actual JWT functionality is tested in jwt package itself
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
auth.Login(w, r, user, rememberMe) err := auth.Login(w, r, user, rememberMe)
require.NoError(t, err)
}) })
} }

View File

@@ -24,7 +24,7 @@ func (auth *Authenticator[T, TX]) IgnorePaths(paths ...string) error {
u.RawQuery == "" && u.RawQuery == "" &&
u.Fragment == "" u.Fragment == ""
if !valid { if !valid {
return fmt.Errorf("Invalid path: '%s'", path) return fmt.Errorf("invalid path: '%s'", path)
} }
} }
auth.ignoredPaths = prepareGlobs(paths) auth.ignoredPaths = prepareGlobs(paths)

View File

@@ -38,7 +38,9 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
Error: errors.Wrap(err, "auth.beginTx"), Error: errors.Wrap(err, "auth.beginTx"),
} }
} }
defer tx.Rollback() defer func() {
_ = tx.Rollback()
}()
// Type assert to TX - safe because user's beginTx should return their TX type // Type assert to TX - safe because user's beginTx should return their TX type
txTyped, ok := tx.(TX) txTyped, ok := tx.(TX)
if !ok { if !ok {
@@ -64,7 +66,14 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
Msg("Failed to authenticate user") Msg("Failed to authenticate user")
return r, nil return r, nil
} }
tx.Commit() err = tx.Commit()
if err != nil {
return nil, &hws.HWSError{
Message: "Failed to commit transaction",
StatusCode: http.StatusInternalServerError,
Error: errors.Wrap(err, "tx.Commit"),
}
}
authContext := setAuthenticatedModel(r.Context(), model) authContext := setAuthenticatedModel(r.Context(), model)
newReq := r.WithContext(authContext) newReq := r.WithContext(authContext)
return newReq, nil return newReq, nil

View File

@@ -39,9 +39,17 @@ type ContextLoader[T Model] func(ctx context.Context) T
// } // }
type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error) type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error)
type contextKey string
func (c contextKey) String() string {
return "hwsauth context key" + string(c)
}
var authenticatedModelContextKey = contextKey("authenticated-model")
// Return a new context with the user added in // Return a new context with the user added in
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context { func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
return context.WithValue(ctx, "hwsauth context key authenticated-model", m) return context.WithValue(ctx, authenticatedModelContextKey, m)
} }
// Retrieve a user from the given context. Returns nil if not set // Retrieve a user from the given context. Returns nil if not set
@@ -53,7 +61,7 @@ func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[
model = authenticatedModel[T]{} model = authenticatedModel[T]{}
} }
}() }()
model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T]) model, cok := ctx.Value(authenticatedModelContextKey).(authenticatedModel[T])
if !cok { if !cok {
return authenticatedModel[T]{}, false return authenticatedModel[T]{}, false
} }

View File

@@ -19,15 +19,12 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := getAuthorizedModel[T](r.Context()) _, ok := getAuthorizedModel[T](r.Context())
if !ok { if !ok {
err := auth.server.ThrowError(w, r, hws.HWSError{ auth.server.ThrowError(w, r, hws.HWSError{
Error: errors.New("Login required"), Error: errors.New("Login required"),
Message: "Please login to view this page", Message: "Please login to view this page",
StatusCode: http.StatusUnauthorized, StatusCode: http.StatusUnauthorized,
RenderErrorPage: true, RenderErrorPage: true,
}) })
if err != nil {
auth.server.ThrowFatal(w, err)
}
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@@ -66,15 +63,12 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model, ok := getAuthorizedModel[T](r.Context()) model, ok := getAuthorizedModel[T](r.Context())
if !ok { if !ok {
err := auth.server.ThrowError(w, r, hws.HWSError{ auth.server.ThrowError(w, r, hws.HWSError{
Error: errors.New("Login required"), Error: errors.New("Login required"),
Message: "Please login to view this page", Message: "Please login to view this page",
StatusCode: http.StatusUnauthorized, StatusCode: http.StatusUnauthorized,
RenderErrorPage: true, RenderErrorPage: true,
}) })
if err != nil {
auth.server.ThrowFatal(w, err)
}
return return
} }
isFresh := time.Now().Before(time.Unix(model.fresh, 0)) isFresh := time.Now().Before(time.Unix(model.fresh, 0))