diff --git a/hwsauth/go.mod b/hwsauth/go.mod index c76c6d1..80fa7fd 100644 --- a/hwsauth/go.mod +++ b/hwsauth/go.mod @@ -5,20 +5,24 @@ go 1.25.5 require ( git.haelnorr.com/h/golib/cookies v0.9.0 git.haelnorr.com/h/golib/env v0.9.1 - git.haelnorr.com/h/golib/hws v0.2.0 - git.haelnorr.com/h/golib/jwt v0.10.0 + git.haelnorr.com/h/golib/hlog v0.10.4 + git.haelnorr.com/h/golib/hws v0.3.0 + git.haelnorr.com/h/golib/jwt v0.10.1 github.com/pkg/errors v0.9.1 - git.haelnorr.com/h/golib/hlog v0.9.1 + github.com/stretchr/testify v1.11.1 ) require ( - github.com/rs/zerolog v1.34.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/google/uuid v1.6.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rs/zerolog v1.34.0 // indirect golang.org/x/sys v0.40.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/apimachinery v0.35.0 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/utils v0.0.0-20260108192941-914a6e750570 // indirect diff --git a/hwsauth/go.sum b/hwsauth/go.sum index 2923954..029f665 100644 --- a/hwsauth/go.sum +++ b/hwsauth/go.sum @@ -4,10 +4,16 @@ git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjo git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/hlog v0.9.1 h1:9VmE/IQTfD8LAEyTbUCZLy/+8PbcHA1Kob/WQHRHKzc= git.haelnorr.com/h/golib/hlog v0.9.1/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk= +git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= +git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= git.haelnorr.com/h/golib/hws v0.2.0 h1:MR2Tu2qPaW+/oK8aXFJLRFaYZIHgKiex3t3zE41cu1U= git.haelnorr.com/h/golib/hws v0.2.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= +git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo= +git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= git.haelnorr.com/h/golib/jwt v0.10.0 h1:8cI8mSnb8X+EmJtrBO/5UZwuBMtib0IE9dv85gkm94E= git.haelnorr.com/h/golib/jwt v0.10.0/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= +git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= +git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -41,6 +47,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8= diff --git a/hwsauth/hwsauth_test.go b/hwsauth/hwsauth_test.go new file mode 100644 index 0000000..9a47338 --- /dev/null +++ b/hwsauth/hwsauth_test.go @@ -0,0 +1,440 @@ +package hwsauth + +import ( + "context" + "database/sql" + "io" + "net/http/httptest" + "os" + "testing" + + "git.haelnorr.com/h/golib/hlog" + "git.haelnorr.com/h/golib/hws" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type TestModel struct { + ID int +} + +func (tm TestModel) GetID() int { + return tm.ID +} + +type TestTransaction struct { +} + +func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) { + return nil, nil +} + +func (tt *TestTransaction) Query(query string, args ...any) (*sql.Rows, error) { + return nil, nil +} + +func (tt *TestTransaction) Commit() error { + return nil +} + +func (tt *TestTransaction) Rollback() error { + return nil +} + +type TestErrorPage struct{} + +func (tep TestErrorPage) Render(ctx context.Context, w io.Writer) error { + return nil +} + +func TestGetNil(t *testing.T) { + var zero TestModel + result := getNil[TestModel]() + assert.Equal(t, zero, result) +} + +func TestSetAndGetAuthenticatedModel(t *testing.T) { + ctx := context.Background() + model := TestModel{ID: 123} + authModel := authenticatedModel[TestModel]{ + model: model, + fresh: 1234567890, + } + + newCtx := setAuthenticatedModel(ctx, authModel) + + retrieved, ok := getAuthorizedModel[TestModel](newCtx) + assert.True(t, ok) + assert.Equal(t, model, retrieved.model) + assert.Equal(t, int64(1234567890), retrieved.fresh) +} + +func TestGetAuthorizedModel_NotSet(t *testing.T) { + ctx := context.Background() + + retrieved, ok := getAuthorizedModel[TestModel](ctx) + assert.False(t, ok) + var zero TestModel + assert.Equal(t, zero, retrieved.model) + assert.Equal(t, int64(0), retrieved.fresh) +} + +func TestCurrentModel(t *testing.T) { + auth := &Authenticator[TestModel, DBTransaction]{} + + t.Run("nil context", func(t *testing.T) { + var nilContext context.Context = nil + result := auth.CurrentModel(nilContext) + var zero TestModel + assert.Equal(t, zero, result) + }) + + t.Run("context without authenticated model", func(t *testing.T) { + ctx := context.Background() + result := auth.CurrentModel(ctx) + var zero TestModel + assert.Equal(t, zero, result) + }) + + t.Run("context with authenticated model", func(t *testing.T) { + ctx := context.Background() + model := TestModel{ID: 456} + authModel := authenticatedModel[TestModel]{ + model: model, + fresh: 1234567890, + } + ctx = setAuthenticatedModel(ctx, authModel) + + result := auth.CurrentModel(ctx) + assert.Equal(t, model, result) + assert.Equal(t, 456, result.GetID()) + }) +} + +func TestConfigFromEnv_MissingSecretKey(t *testing.T) { + // Clear environment variables + originalSecret := os.Getenv("HWSAUTH_SECRET_KEY") + os.Setenv("HWSAUTH_SECRET_KEY", "") + defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret) + + _, err := ConfigFromEnv() + assert.Error(t, err) + assert.Contains(t, err.Error(), "Envar not set: HWSAUTH_SECRET_KEY") +} + +func TestConfigFromEnv_SSLWithoutTrustedHost(t *testing.T) { + // Clear environment variables + t.Setenv("HWSAUTH_SECRET_KEY", "test-secret") + t.Setenv("HWSAUTH_SSL", "true") + t.Setenv("HWSAUTH_TRUSTED_HOST", "") + defer func() { + t.Setenv("HWSAUTH_SECRET_KEY", "") + t.Setenv("HWSAUTH_SSL", "") + t.Setenv("HWSAUTH_TRUSTED_HOST", "") + }() + + _, err := ConfigFromEnv() + assert.Error(t, err) + assert.Contains(t, err.Error(), "SSL is enabled and no HWS_TRUSTED_HOST set") +} + +func TestConfigFromEnv_ValidMinimalConfig(t *testing.T) { + // Set environment variables + t.Setenv("HWSAUTH_SECRET_KEY", "test-secret-key") + defer t.Setenv("HWSAUTH_SECRET_KEY", "") + + cfg, err := ConfigFromEnv() + assert.NoError(t, err) + assert.Equal(t, "test-secret-key", cfg.SecretKey) + assert.Equal(t, false, cfg.SSL) + assert.Equal(t, int64(5), cfg.AccessTokenExpiry) + assert.Equal(t, int64(1440), cfg.RefreshTokenExpiry) + assert.Equal(t, int64(5), cfg.TokenFreshTime) + assert.Equal(t, "/profile", cfg.LandingPage) + assert.Equal(t, "postgres", cfg.DatabaseType) + assert.Equal(t, "15", cfg.DatabaseVersion) + assert.Equal(t, "jwtblacklist", cfg.JWTTableName) +} + +func TestConfigFromEnv_ValidFullConfig(t *testing.T) { + // Set environment variables + t.Setenv("HWSAUTH_SECRET_KEY", "custom-secret") + t.Setenv("HWSAUTH_SSL", "true") + t.Setenv("HWSAUTH_TRUSTED_HOST", "example.com") + t.Setenv("HWSAUTH_ACCESS_TOKEN_EXPIRY", "15") + t.Setenv("HWSAUTH_REFRESH_TOKEN_EXPIRY", "2880") + t.Setenv("HWSAUTH_TOKEN_FRESH_TIME", "10") + t.Setenv("HWSAUTH_LANDING_PAGE", "/dashboard") + t.Setenv("HWSAUTH_DATABASE_TYPE", "mysql") + t.Setenv("HWSAUTH_DATABASE_VERSION", "8.0") + t.Setenv("HWSAUTH_JWT_TABLE_NAME", "custom_tokens") + defer func() { + t.Setenv("HWSAUTH_SECRET_KEY", "") + t.Setenv("HWSAUTH_SSL", "") + t.Setenv("HWSAUTH_TRUSTED_HOST", "") + t.Setenv("HWSAUTH_ACCESS_TOKEN_EXPIRY", "") + t.Setenv("HWSAUTH_REFRESH_TOKEN_EXPIRY", "") + t.Setenv("HWSAUTH_TOKEN_FRESH_TIME", "") + t.Setenv("HWSAUTH_LANDING_PAGE", "") + t.Setenv("HWSAUTH_DATABASE_TYPE", "") + t.Setenv("HWSAUTH_DATABASE_VERSION", "") + t.Setenv("HWSAUTH_JWT_TABLE_NAME", "") + }() + + cfg, err := ConfigFromEnv() + assert.NoError(t, err) + assert.Equal(t, "custom-secret", cfg.SecretKey) + assert.Equal(t, true, cfg.SSL) + assert.Equal(t, "example.com", cfg.TrustedHost) + assert.Equal(t, int64(15), cfg.AccessTokenExpiry) + assert.Equal(t, int64(2880), cfg.RefreshTokenExpiry) + assert.Equal(t, int64(10), cfg.TokenFreshTime) + assert.Equal(t, "/dashboard", cfg.LandingPage) + assert.Equal(t, "mysql", cfg.DatabaseType) + assert.Equal(t, "8.0", cfg.DatabaseVersion) + assert.Equal(t, "custom_tokens", cfg.JWTTableName) +} + +func TestNewAuthenticator_NilConfig(t *testing.T) { + load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { + return TestModel{ID: id}, nil + } + server := &hws.Server{} + beginTx := func(ctx context.Context) (DBTransaction, error) { + return &TestTransaction{}, nil + } + logger := &hlog.Logger{} + errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { + return TestErrorPage{}, nil + } + + auth, err := NewAuthenticator( + nil, + load, + server, + beginTx, + logger, + errorPage, + ) + + assert.Error(t, err) + assert.Nil(t, auth) + assert.Contains(t, err.Error(), "Config is required") +} + +func TestNewAuthenticator_MissingSecretKey(t *testing.T) { + cfg := &Config{ + SecretKey: "", + } + + load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { + return TestModel{ID: id}, nil + } + server := &hws.Server{} + beginTx := func(ctx context.Context) (DBTransaction, error) { + return &TestTransaction{}, nil + } + logger := &hlog.Logger{} + errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { + return TestErrorPage{}, nil + } + + auth, err := NewAuthenticator( + cfg, + load, + server, + beginTx, + logger, + errorPage, + ) + + assert.Error(t, err) + assert.Nil(t, auth) + assert.Contains(t, err.Error(), "SecretKey is required") +} + +func TestNewAuthenticator_NilLoadFunction(t *testing.T) { + cfg := &Config{ + SecretKey: "test-secret", + } + + server := &hws.Server{} + beginTx := func(ctx context.Context) (DBTransaction, error) { + return &TestTransaction{}, nil + } + logger := &hlog.Logger{} + errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { + return TestErrorPage{}, nil + } + + auth, err := NewAuthenticator[TestModel, DBTransaction]( + cfg, + nil, + server, + beginTx, + logger, + errorPage, + ) + + assert.Error(t, err) + assert.Nil(t, auth) + assert.Contains(t, err.Error(), "No function to load model supplied") +} + +func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) { + cfg := &Config{ + SecretKey: "test-secret", + SSL: true, + } + + load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { + return TestModel{ID: id}, nil + } + server := &hws.Server{} + beginTx := func(ctx context.Context) (DBTransaction, error) { + return &TestTransaction{}, nil + } + logger := &hlog.Logger{} + errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { + return TestErrorPage{}, nil + } + + auth, err := NewAuthenticator( + cfg, + load, + server, + beginTx, + logger, + errorPage, + ) + + assert.Error(t, err) + assert.Nil(t, auth) + assert.Contains(t, err.Error(), "TrustedHost is required when SSL is enabled") +} + +func TestNewAuthenticator_ValidMinimalConfig(t *testing.T) { + cfg := &Config{ + SecretKey: "test-secret", + TrustedHost: "example.com", + } + + load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { + return TestModel{ID: id}, nil + } + server := &hws.Server{} + beginTx := func(ctx context.Context) (DBTransaction, error) { + return &TestTransaction{}, nil + } + logger := &hlog.Logger{} + errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { + return TestErrorPage{}, nil + } + + auth, err := NewAuthenticator( + cfg, + load, + server, + beginTx, + logger, + errorPage, + ) + + require.NoError(t, err) + require.NotNil(t, auth) + + assert.Equal(t, false, auth.SSL) + assert.Equal(t, "/profile", auth.LandingPage) +} + +func TestModelInterface(t *testing.T) { + t.Run("TestModel implements Model interface", func(t *testing.T) { + var _ Model = TestModel{} + }) + + t.Run("GetID method", func(t *testing.T) { + model := TestModel{ID: 789} + assert.Equal(t, 789, model.GetID()) + }) +} + +func TestGetAuthenticatedUser_NoTokens(t *testing.T) { + cfg := &Config{ + SecretKey: "test-secret", + TrustedHost: "example.com", + } + + load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { + return TestModel{ID: id}, nil + } + server := &hws.Server{} + beginTx := func(ctx context.Context) (DBTransaction, error) { + return &TestTransaction{}, nil + } + logger := &hlog.Logger{} + errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { + return TestErrorPage{}, nil + } + + auth, err := NewAuthenticator( + cfg, + load, + server, + beginTx, + logger, + errorPage, + ) + require.NoError(t, err) + + tx := &TestTransaction{} + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + model, err := auth.getAuthenticatedUser(tx, w, r) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "No token strings provided") + var zero TestModel + assert.Equal(t, zero, model.model) +} + +func TestLogin_BasicFunctionality(t *testing.T) { + cfg := &Config{ + SecretKey: "test-secret", + TrustedHost: "example.com", + } + + load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { + return TestModel{ID: id}, nil + } + server := &hws.Server{} + beginTx := func(ctx context.Context) (DBTransaction, error) { + return &TestTransaction{}, nil + } + logger := &hlog.Logger{} + errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { + return TestErrorPage{}, nil + } + + auth, err := NewAuthenticator( + cfg, + load, + server, + beginTx, + logger, + errorPage, + ) + require.NoError(t, err) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + user := TestModel{ID: 123} + rememberMe := true + + // This test mainly checks that the function doesn't panic and has right call signature + // The actual JWT functionality is tested in jwt package itself + assert.NotPanics(t, func() { + auth.Login(w, r, user, rememberMe) + }) +} diff --git a/hwsauth/protectpage.go b/hwsauth/protectpage.go index 6b7f1a6..1807459 100644 --- a/hwsauth/protectpage.go +++ b/hwsauth/protectpage.go @@ -5,6 +5,7 @@ import ( "time" "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" ) // LoginReq returns a middleware that requires the user to be authenticated. @@ -18,23 +19,14 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, ok := getAuthorizedModel[T](r.Context()) if !ok { - page, err := auth.errorPage(http.StatusUnauthorized) + err := auth.server.ThrowError(w, r, hws.HWSError{ + Error: errors.New("Login required"), + Message: "Please login to view this page", + StatusCode: http.StatusUnauthorized, + RenderErrorPage: true, + }) if err != nil { - auth.server.ThrowError(w, r, hws.HWSError{ - Error: err, - Message: "Failed to get valid error page", - StatusCode: http.StatusInternalServerError, - RenderErrorPage: true, - }) - } - err = page.Render(r.Context(), w) - if err != nil { - auth.server.ThrowError(w, r, hws.HWSError{ - Error: err, - Message: "Failed to render error page", - StatusCode: http.StatusInternalServerError, - RenderErrorPage: true, - }) + auth.server.ThrowFatal(w, err) } return } @@ -74,23 +66,14 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { model, ok := getAuthorizedModel[T](r.Context()) if !ok { - page, err := auth.errorPage(http.StatusUnauthorized) + err := auth.server.ThrowError(w, r, hws.HWSError{ + Error: errors.New("Login required"), + Message: "Please login to view this page", + StatusCode: http.StatusUnauthorized, + RenderErrorPage: true, + }) if err != nil { - auth.server.ThrowError(w, r, hws.HWSError{ - Error: err, - Message: "Failed to get valid error page", - StatusCode: http.StatusInternalServerError, - RenderErrorPage: true, - }) - } - err = page.Render(r.Context(), w) - if err != nil { - auth.server.ThrowError(w, r, hws.HWSError{ - Error: err, - Message: "Failed to render error page", - StatusCode: http.StatusInternalServerError, - RenderErrorPage: true, - }) + auth.server.ThrowFatal(w, err) } return }