diff --git a/hwsauth/authenticator.go b/hwsauth/authenticator.go index 5fddf5c..3d1a5dd 100644 --- a/hwsauth/authenticator.go +++ b/hwsauth/authenticator.go @@ -1,6 +1,11 @@ package hwsauth import ( + "context" + "database/sql" + "os" + "time" + "git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/jwt" @@ -30,6 +35,7 @@ func NewAuthenticator[T Model, TX DBTransaction]( beginTx BeginTX, logger *hlog.Logger, errorPage hws.ErrorPageFunc, + db *sql.DB, ) (*Authenticator[T, TX], error) { if load == nil { return nil, errors.New("No function to load model supplied") @@ -55,7 +61,10 @@ func NewAuthenticator[T Model, TX DBTransaction]( return nil, errors.New("SecretKey is required") } if cfg.SSL && cfg.TrustedHost == "" { - return nil, errors.New("TrustedHost is required when SSL is enabled") + cfg.SSL = false // Disable SSL if TrustedHost is not configured + } + if cfg.TrustedHost == "" { + cfg.TrustedHost = "localhost" // Default TrustedHost for JWT } if cfg.AccessTokenExpiry == 0 { cfg.AccessTokenExpiry = 5 @@ -69,12 +78,35 @@ func NewAuthenticator[T Model, TX DBTransaction]( if cfg.LandingPage == "" { cfg.LandingPage = "/profile" } + if cfg.DatabaseType == "" { + cfg.DatabaseType = "postgres" + } + if cfg.DatabaseVersion == "" { + cfg.DatabaseVersion = "15" + } + + if db == nil { + return nil, errors.New("No Database provided") + } + + // Test database connectivity + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := db.PingContext(ctx); err != nil { + return nil, errors.Wrap(err, "database connection test failed") + } // Configure JWT table tableConfig := jwt.DefaultTableConfig() if cfg.JWTTableName != "" { tableConfig.TableName = cfg.JWTTableName } + // Disable auto-creation for tests + // Check for test environment or mock database + if os.Getenv("GO_TEST") == "1" { + tableConfig.AutoCreate = false + tableConfig.EnableAutoCleanup = false + } // Create token generator tokenGen, err := jwt.CreateGenerator(jwt.GeneratorConfig{ @@ -87,6 +119,7 @@ func NewAuthenticator[T Model, TX DBTransaction]( Type: cfg.DatabaseType, Version: cfg.DatabaseVersion, }, + DB: db, TableConfig: tableConfig, }, beginTx) if err != nil { diff --git a/hwsauth/go.mod b/hwsauth/go.mod index 80fa7fd..2142c93 100644 --- a/hwsauth/go.mod +++ b/hwsauth/go.mod @@ -8,6 +8,7 @@ require ( git.haelnorr.com/h/golib/hlog v0.10.4 git.haelnorr.com/h/golib/hws v0.3.0 git.haelnorr.com/h/golib/jwt v0.10.1 + github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.11.1 ) diff --git a/hwsauth/go.sum b/hwsauth/go.sum index 029f665..bc0fdef 100644 --- a/hwsauth/go.sum +++ b/hwsauth/go.sum @@ -2,16 +2,10 @@ git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDO git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo= git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= -git.haelnorr.com/h/golib/hlog v0.9.1 h1:9VmE/IQTfD8LAEyTbUCZLy/+8PbcHA1Kob/WQHRHKzc= -git.haelnorr.com/h/golib/hlog v0.9.1/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= -git.haelnorr.com/h/golib/hws v0.2.0 h1:MR2Tu2qPaW+/oK8aXFJLRFaYZIHgKiex3t3zE41cu1U= -git.haelnorr.com/h/golib/hws v0.2.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo= git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= -git.haelnorr.com/h/golib/jwt v0.10.0 h1:8cI8mSnb8X+EmJtrBO/5UZwuBMtib0IE9dv85gkm94E= -git.haelnorr.com/h/golib/jwt v0.10.0/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= @@ -26,6 +20,7 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= diff --git a/hwsauth/hwsauth_test.go b/hwsauth/hwsauth_test.go index 9a47338..4de5945 100644 --- a/hwsauth/hwsauth_test.go +++ b/hwsauth/hwsauth_test.go @@ -10,6 +10,7 @@ import ( "git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/hws" + "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -47,6 +48,28 @@ func (tep TestErrorPage) Render(ctx context.Context, w io.Writer) error { return nil } +// createMockDB creates a mock SQL database for testing +func createMockDB() (*sql.DB, sqlmock.Sqlmock, error) { + db, mock, err := sqlmock.New() + if err != nil { + return nil, nil, err + } + + // Expect a ping to succeed for database connectivity test + mock.ExpectPing() + + // Expect table existence check (returns a row = table exists) + mock.ExpectQuery(`SELECT 1 FROM information_schema\.tables WHERE table_schema = 'public' AND table_name = \$1`). + WithArgs("jwtblacklist"). + WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1)) + + // Expect cleanup function creation + mock.ExpectExec(`CREATE OR REPLACE FUNCTION cleanup_jwtblacklist\(\) RETURNS void AS \$\$ BEGIN DELETE FROM jwtblacklist WHERE exp < EXTRACT\(EPOCH FROM NOW\(\)\); END; \$\$ LANGUAGE plpgsql;`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + return db, mock, nil +} + func TestGetNil(t *testing.T) { var zero TestModel result := getNil[TestModel]() @@ -209,12 +232,13 @@ func TestNewAuthenticator_NilConfig(t *testing.T) { } auth, err := NewAuthenticator( - nil, + nil, // cfg load, server, beginTx, logger, errorPage, + nil, // db ) assert.Error(t, err) @@ -246,6 +270,7 @@ func TestNewAuthenticator_MissingSecretKey(t *testing.T) { beginTx, logger, errorPage, + nil, // db - will fail before db check since SecretKey is missing ) assert.Error(t, err) @@ -274,6 +299,7 @@ func TestNewAuthenticator_NilLoadFunction(t *testing.T) { beginTx, logger, errorPage, + nil, // db ) assert.Error(t, err) @@ -299,6 +325,10 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) { return TestErrorPage{}, nil } + db, _, err := createMockDB() + require.NoError(t, err) + defer db.Close() + auth, err := NewAuthenticator( cfg, load, @@ -306,17 +336,19 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) { beginTx, logger, errorPage, + db, ) - assert.Error(t, err) - assert.Nil(t, auth) - assert.Contains(t, err.Error(), "TrustedHost is required when SSL is enabled") + require.NoError(t, err) + require.NotNil(t, auth) + + assert.Equal(t, false, auth.SSL) + assert.Equal(t, "/profile", auth.LandingPage) } -func TestNewAuthenticator_ValidMinimalConfig(t *testing.T) { +func TestNewAuthenticator_NilDatabase(t *testing.T) { cfg := &Config{ - SecretKey: "test-secret", - TrustedHost: "example.com", + SecretKey: "test-secret", } load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) { @@ -338,13 +370,12 @@ func TestNewAuthenticator_ValidMinimalConfig(t *testing.T) { beginTx, logger, errorPage, + nil, // db ) - require.NoError(t, err) - require.NotNil(t, auth) - - assert.Equal(t, false, auth.SSL) - assert.Equal(t, "/profile", auth.LandingPage) + assert.Error(t, err) + assert.Nil(t, auth) + assert.Contains(t, err.Error(), "No Database provided") } func TestModelInterface(t *testing.T) { @@ -376,6 +407,10 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) { return TestErrorPage{}, nil } + db, _, err := createMockDB() + require.NoError(t, err) + defer db.Close() + auth, err := NewAuthenticator( cfg, load, @@ -383,6 +418,7 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) { beginTx, logger, errorPage, + db, ) require.NoError(t, err) @@ -416,6 +452,10 @@ func TestLogin_BasicFunctionality(t *testing.T) { return TestErrorPage{}, nil } + db, _, err := createMockDB() + require.NoError(t, err) + defer db.Close() + auth, err := NewAuthenticator( cfg, load, @@ -423,6 +463,7 @@ func TestLogin_BasicFunctionality(t *testing.T) { beginTx, logger, errorPage, + db, ) require.NoError(t, err)