package hwsauth import ( "context" "database/sql" "io" "net/http/httptest" "os" "testing" "git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/hws" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type TestModel struct { ID int } func (tm TestModel) GetID() int { return tm.ID } type TestTransaction struct{} func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) { return nil, nil } func (tt *TestTransaction) Query(query string, args ...any) (*sql.Rows, error) { return nil, nil } func (tt *TestTransaction) Commit() error { return nil } func (tt *TestTransaction) Rollback() error { return nil } type TestErrorPage struct{} func (tep TestErrorPage) Render(ctx context.Context, w io.Writer) error { return nil } // createMockDB creates a mock SQL database for testing func createMockDB() (*sql.DB, sqlmock.Sqlmock, error) { db, mock, err := sqlmock.New() if err != nil { return nil, nil, err } // Expect a ping to succeed for database connectivity test mock.ExpectPing() // Expect table existence check (returns a row = table exists) mock.ExpectQuery(`SELECT 1 FROM information_schema\.tables WHERE table_schema = 'public' AND table_name = \$1`). WithArgs("jwtblacklist"). WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1)) // Expect cleanup function creation mock.ExpectExec(`CREATE OR REPLACE FUNCTION cleanup_jwtblacklist\(\) RETURNS void AS \$\$ BEGIN DELETE FROM jwtblacklist WHERE exp < EXTRACT\(EPOCH FROM NOW\(\)\); END; \$\$ LANGUAGE plpgsql;`). WillReturnResult(sqlmock.NewResult(0, 0)) return db, mock, nil } func TestGetNil(t *testing.T) { var zero TestModel result := getNil[TestModel]() assert.Equal(t, zero, result) } func TestSetAndGetAuthenticatedModel(t *testing.T) { ctx := context.Background() model := TestModel{ID: 123} authModel := authenticatedModel[TestModel]{ model: model, fresh: 1234567890, } newCtx := setAuthenticatedModel(ctx, authModel) retrieved, ok := getAuthorizedModel[TestModel](newCtx) assert.True(t, ok) assert.Equal(t, model, retrieved.model) assert.Equal(t, int64(1234567890), retrieved.fresh) } func TestGetAuthorizedModel_NotSet(t *testing.T) { ctx := context.Background() retrieved, ok := getAuthorizedModel[TestModel](ctx) assert.False(t, ok) var zero TestModel assert.Equal(t, zero, retrieved.model) assert.Equal(t, int64(0), retrieved.fresh) } func TestCurrentModel(t *testing.T) { auth := &Authenticator[TestModel, DBTransaction]{} t.Run("nil context", func(t *testing.T) { var nilContext context.Context = nil result := auth.CurrentModel(nilContext) var zero TestModel assert.Equal(t, zero, result) }) t.Run("context without authenticated model", func(t *testing.T) { ctx := context.Background() result := auth.CurrentModel(ctx) var zero TestModel assert.Equal(t, zero, result) }) t.Run("context with authenticated model", func(t *testing.T) { ctx := context.Background() model := TestModel{ID: 456} authModel := authenticatedModel[TestModel]{ model: model, fresh: 1234567890, } ctx = setAuthenticatedModel(ctx, authModel) result := auth.CurrentModel(ctx) assert.Equal(t, model, result) assert.Equal(t, 456, result.GetID()) }) } func TestConfigFromEnv_MissingSecretKey(t *testing.T) { // Clear environment variables originalSecret := os.Getenv("HWSAUTH_SECRET_KEY") _ = os.Setenv("HWSAUTH_SECRET_KEY", "") defer func() { _ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret) }() _, err := ConfigFromEnv() assert.Error(t, err) assert.Contains(t, err.Error(), "Envar not set: HWSAUTH_SECRET_KEY") } func TestConfigFromEnv_SSLWithoutTrustedHost(t *testing.T) { // Clear environment variables t.Setenv("HWSAUTH_SECRET_KEY", "test-secret") t.Setenv("HWSAUTH_SSL", "true") t.Setenv("HWSAUTH_TRUSTED_HOST", "") defer func() { t.Setenv("HWSAUTH_SECRET_KEY", "") t.Setenv("HWSAUTH_SSL", "") t.Setenv("HWSAUTH_TRUSTED_HOST", "") }() _, err := ConfigFromEnv() assert.Error(t, err) assert.Contains(t, err.Error(), "SSL is enabled and no HWS_TRUSTED_HOST set") } func TestConfigFromEnv_ValidMinimalConfig(t *testing.T) { // Set environment variables t.Setenv("HWSAUTH_SECRET_KEY", "test-secret-key") defer t.Setenv("HWSAUTH_SECRET_KEY", "") cfg, err := ConfigFromEnv() assert.NoError(t, err) assert.Equal(t, "test-secret-key", cfg.SecretKey) assert.Equal(t, false, cfg.SSL) assert.Equal(t, int64(5), cfg.AccessTokenExpiry) assert.Equal(t, int64(1440), cfg.RefreshTokenExpiry) assert.Equal(t, int64(5), cfg.TokenFreshTime) assert.Equal(t, "/profile", cfg.LandingPage) assert.Equal(t, "postgres", cfg.DatabaseType) assert.Equal(t, "15", cfg.DatabaseVersion) assert.Equal(t, "jwtblacklist", cfg.JWTTableName) } func TestConfigFromEnv_ValidFullConfig(t *testing.T) { // Set environment variables t.Setenv("HWSAUTH_SECRET_KEY", "custom-secret") t.Setenv("HWSAUTH_SSL", "true") t.Setenv("HWSAUTH_TRUSTED_HOST", "example.com") t.Setenv("HWSAUTH_ACCESS_TOKEN_EXPIRY", "15") t.Setenv("HWSAUTH_REFRESH_TOKEN_EXPIRY", "2880") t.Setenv("HWSAUTH_TOKEN_FRESH_TIME", "10") t.Setenv("HWSAUTH_LANDING_PAGE", "/dashboard") t.Setenv("HWSAUTH_DATABASE_TYPE", "mysql") t.Setenv("HWSAUTH_DATABASE_VERSION", "8.0") t.Setenv("HWSAUTH_JWT_TABLE_NAME", "custom_tokens") defer func() { t.Setenv("HWSAUTH_SECRET_KEY", "") t.Setenv("HWSAUTH_SSL", "") t.Setenv("HWSAUTH_TRUSTED_HOST", "") t.Setenv("HWSAUTH_ACCESS_TOKEN_EXPIRY", "") t.Setenv("HWSAUTH_REFRESH_TOKEN_EXPIRY", "") t.Setenv("HWSAUTH_TOKEN_FRESH_TIME", "") t.Setenv("HWSAUTH_LANDING_PAGE", "") t.Setenv("HWSAUTH_DATABASE_TYPE", "") t.Setenv("HWSAUTH_DATABASE_VERSION", "") t.Setenv("HWSAUTH_JWT_TABLE_NAME", "") }() cfg, err := ConfigFromEnv() assert.NoError(t, err) assert.Equal(t, "custom-secret", cfg.SecretKey) assert.Equal(t, true, cfg.SSL) assert.Equal(t, "example.com", cfg.TrustedHost) assert.Equal(t, int64(15), cfg.AccessTokenExpiry) assert.Equal(t, int64(2880), cfg.RefreshTokenExpiry) assert.Equal(t, int64(10), cfg.TokenFreshTime) assert.Equal(t, "/dashboard", cfg.LandingPage) assert.Equal(t, "mysql", cfg.DatabaseType) assert.Equal(t, "8.0", cfg.DatabaseVersion) assert.Equal(t, "custom_tokens", cfg.JWTTableName) } func TestNewAuthenticator_NilConfig(t *testing.T) { load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { return TestModel{ID: id}, nil } server := &hws.Server{} beginTx := func(ctx context.Context) (DBTransaction, error) { return &TestTransaction{}, nil } logger := &hlog.Logger{} errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { return TestErrorPage{}, nil } auth, err := NewAuthenticator( nil, // cfg load, server, beginTx, logger, errorPage, nil, // db ) assert.Error(t, err) assert.Nil(t, auth) assert.Contains(t, err.Error(), "Config is required") } func TestNewAuthenticator_MissingSecretKey(t *testing.T) { cfg := &Config{ SecretKey: "", } load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { return TestModel{ID: id}, nil } server := &hws.Server{} beginTx := func(ctx context.Context) (DBTransaction, error) { return &TestTransaction{}, nil } logger := &hlog.Logger{} errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { return TestErrorPage{}, nil } auth, err := NewAuthenticator( cfg, load, server, beginTx, logger, errorPage, nil, // db - will fail before db check since SecretKey is missing ) assert.Error(t, err) assert.Nil(t, auth) assert.Contains(t, err.Error(), "SecretKey is required") } func TestNewAuthenticator_NilLoadFunction(t *testing.T) { cfg := &Config{ SecretKey: "test-secret", } server := &hws.Server{} beginTx := func(ctx context.Context) (DBTransaction, error) { return &TestTransaction{}, nil } logger := &hlog.Logger{} errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { return TestErrorPage{}, nil } auth, err := NewAuthenticator[TestModel, DBTransaction]( cfg, nil, server, beginTx, logger, errorPage, nil, // db ) assert.Error(t, err) assert.Nil(t, auth) assert.Contains(t, err.Error(), "No function to load model supplied") } func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) { cfg := &Config{ SecretKey: "test-secret", SSL: true, } load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { return TestModel{ID: id}, nil } server := &hws.Server{} beginTx := func(ctx context.Context) (DBTransaction, error) { return &TestTransaction{}, nil } logger := &hlog.Logger{} errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { return TestErrorPage{}, nil } db, _, err := createMockDB() require.NoError(t, err) defer func() { _ = db.Close() }() auth, err := NewAuthenticator( cfg, load, server, beginTx, logger, errorPage, db, ) require.NoError(t, err) require.NotNil(t, auth) assert.Equal(t, false, auth.SSL) assert.Equal(t, "/profile", auth.LandingPage) } func TestNewAuthenticator_NilDatabase(t *testing.T) { cfg := &Config{ SecretKey: "test-secret", } load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { return TestModel{ID: id}, nil } server := &hws.Server{} beginTx := func(ctx context.Context) (DBTransaction, error) { return &TestTransaction{}, nil } logger := &hlog.Logger{} errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { return TestErrorPage{}, nil } auth, err := NewAuthenticator( cfg, load, server, beginTx, logger, errorPage, nil, // db ) assert.Error(t, err) assert.Nil(t, auth) assert.Contains(t, err.Error(), "No Database provided") } func TestModelInterface(t *testing.T) { t.Run("TestModel implements Model interface", func(t *testing.T) { var _ Model = TestModel{} }) t.Run("GetID method", func(t *testing.T) { model := TestModel{ID: 789} assert.Equal(t, 789, model.GetID()) }) } func TestGetAuthenticatedUser_NoTokens(t *testing.T) { cfg := &Config{ SecretKey: "test-secret", TrustedHost: "example.com", } load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { return TestModel{ID: id}, nil } server := &hws.Server{} beginTx := func(ctx context.Context) (DBTransaction, error) { return &TestTransaction{}, nil } logger := &hlog.Logger{} errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { return TestErrorPage{}, nil } db, _, err := createMockDB() require.NoError(t, err) defer func() { _ = db.Close() }() auth, err := NewAuthenticator( cfg, load, server, beginTx, logger, errorPage, db, ) require.NoError(t, err) tx := &TestTransaction{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) model, err := auth.getAuthenticatedUser(tx, w, r) assert.Error(t, err) assert.Contains(t, err.Error(), "No token strings provided") var zero TestModel assert.Equal(t, zero, model.model) } func TestLogin_BasicFunctionality(t *testing.T) { cfg := &Config{ SecretKey: "test-secret", TrustedHost: "example.com", } load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { return TestModel{ID: id}, nil } server := &hws.Server{} beginTx := func(ctx context.Context) (DBTransaction, error) { return &TestTransaction{}, nil } logger := &hlog.Logger{} errorPage := func(error hws.HWSError) (hws.ErrorPage, error) { return TestErrorPage{}, nil } db, _, err := createMockDB() require.NoError(t, err) defer func() { _ = db.Close() }() auth, err := NewAuthenticator( cfg, load, server, beginTx, logger, errorPage, db, ) require.NoError(t, err) w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/", nil) user := TestModel{ID: 123} rememberMe := true // This test mainly checks that the function doesn't panic and has right call signature // The actual JWT functionality is tested in jwt package itself assert.NotPanics(t, func() { err := auth.Login(w, r, user, rememberMe) require.NoError(t, err) }) }