Compare commits
3 Commits
hws/v0.5.0
...
hwsauth/v0
| Author | SHA1 | Date | |
|---|---|---|---|
| 05be28d7f3 | |||
| 8f7c87cef2 | |||
| 525b3b1396 |
@@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string {
|
||||
}
|
||||
|
||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||
return func() (any, error) {
|
||||
return ConfigFromEnv()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,13 +6,15 @@ require (
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||
git.haelnorr.com/h/golib/env v0.9.1
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4
|
||||
git.haelnorr.com/h/golib/hws v0.3.0
|
||||
git.haelnorr.com/h/golib/hws v0.5.0
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
)
|
||||
|
||||
require git.haelnorr.com/h/golib/notify v0.1.0 // indirect
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
|
||||
@@ -4,10 +4,12 @@ git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjo
|
||||
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
|
||||
git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo=
|
||||
git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
|
||||
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
|
||||
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
|
||||
@@ -23,8 +23,7 @@ func (tm TestModel) GetID() int {
|
||||
return tm.ID
|
||||
}
|
||||
|
||||
type TestTransaction struct {
|
||||
}
|
||||
type TestTransaction struct{}
|
||||
|
||||
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return nil, nil
|
||||
@@ -137,8 +136,10 @@ func TestCurrentModel(t *testing.T) {
|
||||
func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
|
||||
// Clear environment variables
|
||||
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
|
||||
os.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
||||
_ = os.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
defer func() {
|
||||
_ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
||||
}()
|
||||
|
||||
_, err := ConfigFromEnv()
|
||||
assert.Error(t, err)
|
||||
@@ -327,7 +328,9 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
|
||||
|
||||
db, _, err := createMockDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
@@ -409,7 +412,9 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
|
||||
|
||||
db, _, err := createMockDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
@@ -454,7 +459,9 @@ func TestLogin_BasicFunctionality(t *testing.T) {
|
||||
|
||||
db, _, err := createMockDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
@@ -476,6 +483,7 @@ func TestLogin_BasicFunctionality(t *testing.T) {
|
||||
// This test mainly checks that the function doesn't panic and has right call signature
|
||||
// The actual JWT functionality is tested in jwt package itself
|
||||
assert.NotPanics(t, func() {
|
||||
auth.Login(w, r, user, rememberMe)
|
||||
err := auth.Login(w, r, user, rememberMe)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ func (auth *Authenticator[T, TX]) IgnorePaths(paths ...string) error {
|
||||
u.RawQuery == "" &&
|
||||
u.Fragment == ""
|
||||
if !valid {
|
||||
return fmt.Errorf("Invalid path: '%s'", path)
|
||||
return fmt.Errorf("invalid path: '%s'", path)
|
||||
}
|
||||
}
|
||||
auth.ignoredPaths = prepareGlobs(paths)
|
||||
|
||||
@@ -33,14 +33,18 @@ func (auth *Authenticator[T, TX]) Logout(tx TX, w http.ResponseWriter, r *http.R
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "auth.getTokens")
|
||||
}
|
||||
if aT != nil {
|
||||
err = aT.Revoke(jwt.DBTransaction(tx))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "aT.Revoke")
|
||||
}
|
||||
}
|
||||
if rT != nil {
|
||||
err = rT.Revoke(jwt.DBTransaction(tx))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rT.Revoke")
|
||||
}
|
||||
}
|
||||
cookies.DeleteCookie(w, "access", "/")
|
||||
cookies.DeleteCookie(w, "refresh", "/")
|
||||
return nil
|
||||
|
||||
@@ -16,12 +16,20 @@ import (
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// server.AddMiddleware(auth.Authenticate())
|
||||
func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware {
|
||||
return auth.server.NewMiddleware(auth.authenticate())
|
||||
// server.AddMiddleware(auth.Authenticate(nil))
|
||||
//
|
||||
// If extraCheck is provided, it will run just before the user is added to the context,
|
||||
// and the return will determine if the user will be added, or the request passed on
|
||||
// without the user.
|
||||
func (auth *Authenticator[T, TX]) Authenticate(
|
||||
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
|
||||
) hws.Middleware {
|
||||
return auth.server.NewMiddleware(auth.authenticate(extraCheck))
|
||||
}
|
||||
|
||||
func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
|
||||
func (auth *Authenticator[T, TX]) authenticate(
|
||||
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
|
||||
) hws.MiddlewareFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
|
||||
if globTest(r.URL.Path, auth.ignoredPaths) {
|
||||
return r, nil
|
||||
@@ -38,7 +46,9 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
|
||||
Error: errors.Wrap(err, "auth.beginTx"),
|
||||
}
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
// Type assert to TX - safe because user's beginTx should return their TX type
|
||||
txTyped, ok := tx.(TX)
|
||||
if !ok {
|
||||
@@ -64,11 +74,29 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
|
||||
Msg("Failed to authenticate user")
|
||||
return r, nil
|
||||
}
|
||||
tx.Commit()
|
||||
var check bool
|
||||
if extraCheck != nil {
|
||||
var err *hws.HWSError
|
||||
check, err = extraCheck(ctx, model.model, txTyped, w, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, &hws.HWSError{
|
||||
Message: "Failed to commit transaction",
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: errors.Wrap(err, "tx.Commit"),
|
||||
}
|
||||
}
|
||||
authContext := setAuthenticatedModel(r.Context(), model)
|
||||
newReq := r.WithContext(authContext)
|
||||
if extraCheck == nil || check {
|
||||
return newReq, nil
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
func globTest(testPath string, globs []glob.Glob) bool {
|
||||
|
||||
@@ -39,9 +39,17 @@ type ContextLoader[T Model] func(ctx context.Context) T
|
||||
// }
|
||||
type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error)
|
||||
|
||||
type contextKey string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return "hwsauth context key" + string(c)
|
||||
}
|
||||
|
||||
var authenticatedModelContextKey = contextKey("authenticated-model")
|
||||
|
||||
// Return a new context with the user added in
|
||||
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
|
||||
return context.WithValue(ctx, "hwsauth context key authenticated-model", m)
|
||||
return context.WithValue(ctx, authenticatedModelContextKey, m)
|
||||
}
|
||||
|
||||
// Retrieve a user from the given context. Returns nil if not set
|
||||
@@ -53,7 +61,7 @@ func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[
|
||||
model = authenticatedModel[T]{}
|
||||
}
|
||||
}()
|
||||
model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T])
|
||||
model, cok := ctx.Value(authenticatedModelContextKey).(authenticatedModel[T])
|
||||
if !cok {
|
||||
return authenticatedModel[T]{}, false
|
||||
}
|
||||
|
||||
@@ -19,15 +19,12 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := getAuthorizedModel[T](r.Context())
|
||||
if !ok {
|
||||
err := auth.server.ThrowError(w, r, hws.HWSError{
|
||||
auth.server.ThrowError(w, r, hws.HWSError{
|
||||
Error: errors.New("Login required"),
|
||||
Message: "Please login to view this page",
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
auth.server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -66,15 +63,12 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
model, ok := getAuthorizedModel[T](r.Context())
|
||||
if !ok {
|
||||
err := auth.server.ThrowError(w, r, hws.HWSError{
|
||||
auth.server.ThrowError(w, r, hws.HWSError{
|
||||
Error: errors.New("Login required"),
|
||||
Message: "Please login to view this page",
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
auth.server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
isFresh := time.Now().Before(time.Unix(model.fresh, 0))
|
||||
|
||||
@@ -34,7 +34,7 @@ func (auth *Authenticator[T, TX]) RefreshAuthTokens(tx TX, w http.ResponseWriter
|
||||
rememberMe := map[string]bool{
|
||||
"session": false,
|
||||
"exp": true,
|
||||
}[aT.TTL]
|
||||
}[rT.TTL]
|
||||
// issue new tokens for the user
|
||||
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
|
||||
if err != nil {
|
||||
@@ -55,14 +55,21 @@ func (auth *Authenticator[T, TX]) getTokens(
|
||||
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
||||
// get the existing tokens from the cookies
|
||||
atStr, rtStr := jwt.GetTokenCookies(r)
|
||||
aT, err := auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
|
||||
var aT *jwt.AccessToken
|
||||
var rT *jwt.RefreshToken
|
||||
var err error
|
||||
if atStr != "" {
|
||||
aT, err = auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
|
||||
}
|
||||
rT, err := auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
|
||||
}
|
||||
if rtStr != "" {
|
||||
rT, err = auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
|
||||
}
|
||||
}
|
||||
return aT, rT, nil
|
||||
}
|
||||
|
||||
@@ -72,13 +79,17 @@ func revokeTokenPair(
|
||||
aT *jwt.AccessToken,
|
||||
rT *jwt.RefreshToken,
|
||||
) error {
|
||||
if aT != nil {
|
||||
err := aT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "aT.Revoke")
|
||||
}
|
||||
err = rT.Revoke(tx)
|
||||
}
|
||||
if rT != nil {
|
||||
err := rT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rT.Revoke")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user