From 4b64dccbda6a0eeb93d3361435fee66eaa796c9b Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Thu, 20 Feb 2025 20:26:26 +1100 Subject: [PATCH] Updated tests to use migrations for db init for consistency --- db/connection_test.go | 13 +++++- go.mod | 4 ++ go.sum | 8 ++++ middleware/authentication_test.go | 10 ++--- middleware/pageprotection_test.go | 11 ++--- middleware/reauthentication_test.go | 11 ++--- schema.sql | 19 --------- tests/database.go | 64 +++++++++++++++++++---------- testdata.sql => tests/testdata.sql | 0 9 files changed, 82 insertions(+), 58 deletions(-) delete mode 100644 schema.sql rename testdata.sql => tests/testdata.sql (100%) diff --git a/db/connection_test.go b/db/connection_test.go index 9723526..b1816d7 100644 --- a/db/connection_test.go +++ b/db/connection_test.go @@ -3,6 +3,7 @@ package db import ( "context" "projectreshoot/tests" + "strconv" "sync" "testing" "time" @@ -12,8 +13,12 @@ import ( ) func TestSafeConn(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := MakeSafe(conn, logger) defer sconn.Close() @@ -77,8 +82,12 @@ func TestSafeConn(t *testing.T) { }) } func TestSafeTX(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := MakeSafe(conn, logger) defer sconn.Close() diff --git a/go.mod b/go.mod index 9b690b1..16638f8 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 + github.com/pressly/goose/v3 v3.24.1 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.33.0 @@ -19,9 +20,12 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mfridman/interpolate v0.0.2 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/sethvargo/go-retry v0.3.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect diff --git a/go.sum b/go.sum index a402861..b4bb05c 100644 --- a/go.sum +++ b/go.sum @@ -28,19 +28,27 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= +github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4LY= +github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= +github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index 6ce807c..bb3dceb 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -17,16 +17,16 @@ import ( ) func TestAuthenticationMiddleware(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - // Basic setup - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := db.MakeSafe(conn, logger) defer sconn.Close() - cfg, err := tests.TestConfig() - require.NoError(t, err) - // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := contexts.GetUser(r.Context()) diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go index c6efcba..6150f72 100644 --- a/middleware/pageprotection_test.go +++ b/middleware/pageprotection_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "strconv" "sync/atomic" "testing" @@ -14,16 +15,16 @@ import ( ) func TestPageLoginRequired(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - // Basic setup - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := db.MakeSafe(conn, logger) defer sconn.Close() - cfg, err := tests.TestConfig() - require.NoError(t, err) - // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go index a1f2083..bfb40e8 100644 --- a/middleware/reauthentication_test.go +++ b/middleware/reauthentication_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "strconv" "sync/atomic" "testing" @@ -14,16 +15,16 @@ import ( ) func TestReauthRequired(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - // Basic setup - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := db.MakeSafe(conn, logger) defer sconn.Close() - cfg, err := tests.TestConfig() - require.NoError(t, err) - // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/schema.sql b/schema.sql deleted file mode 100644 index 986d312..0000000 --- a/schema.sql +++ /dev/null @@ -1,19 +0,0 @@ -PRAGMA foreign_keys=ON; -BEGIN TRANSACTION; -CREATE TABLE IF NOT EXISTS jwtblacklist ( -jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'), -exp INTEGER NOT NULL -) STRICT; -CREATE TABLE IF NOT EXISTS "users" ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT NOT NULL UNIQUE, - password_hash TEXT DEFAULT "", - created_at INTEGER DEFAULT (unixepoch()), - bio TEXT DEFAULT "" -) STRICT; -CREATE TRIGGER cleanup_expired_tokens -AFTER INSERT ON jwtblacklist -BEGIN -DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now'); -END; -COMMIT; diff --git a/tests/database.go b/tests/database.go index 549db2b..6157fbc 100644 --- a/tests/database.go +++ b/tests/database.go @@ -1,63 +1,83 @@ package tests import ( + "context" "database/sql" - "fmt" + "io/fs" "os" "path/filepath" "github.com/pkg/errors" + "github.com/pressly/goose/v3" _ "modernc.org/sqlite" ) -func findSQLFile(filename string) (string, error) { +func findMigrations() (*fs.FS, error) { + dir, err := os.Getwd() + if err != nil { + return nil, err + } + + for { + if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil { + migrationsdir := os.DirFS(filepath.Join(dir, "migrations")) + return &migrationsdir, nil + } + + parent := filepath.Dir(dir) + if parent == dir { // Reached root + return nil, errors.New("Unable to locate migrations directory") + } + dir = parent + } +} + +func findTestData() (string, error) { dir, err := os.Getwd() if err != nil { return "", err } for { - if _, err := os.Stat(filepath.Join(dir, filename)); err == nil { - return filepath.Join(dir, filename), nil + if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil { + return filepath.Join(dir, "tests", "testdata.sql"), nil } parent := filepath.Dir(dir) if parent == dir { // Reached root - return "", errors.New(fmt.Sprintf("Unable to locate %s", filename)) + return "", errors.New("Unable to locate test data") } dir = parent } } -// SetupTestDB initializes a test SQLite database with mock data -func SetupTestDB() (*sql.DB, error) { +func SetupTestDB(version int64) (*sql.DB, error) { conn, err := sql.Open("sqlite", "file::memory:?cache=shared") if err != nil { return nil, errors.Wrap(err, "sql.Open") } - // Setup the test database - schemaPath, err := findSQLFile("schema.sql") + + migrations, err := findMigrations() if err != nil { - return nil, errors.Wrap(err, "findSchema") + return nil, errors.Wrap(err, "findMigrations") + } + provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations) + if err != nil { + return nil, errors.Wrap(err, "goose.NewProvider") + } + ctx := context.Background() + if _, err := provider.UpTo(ctx, version); err != nil { + return nil, errors.Wrap(err, "provider.UpTo") } - sqlBytes, err := os.ReadFile(schemaPath) - if err != nil { - return nil, errors.Wrap(err, "os.ReadFile") - } - schemaSQL := string(sqlBytes) - - _, err = conn.Exec(schemaSQL) - if err != nil { - return nil, errors.Wrap(err, "tx.Exec") - } + // NOTE: ================================================== // Load the test data - dataPath, err := findSQLFile("testdata.sql") + dataPath, err := findTestData() if err != nil { return nil, errors.Wrap(err, "findSchema") } - sqlBytes, err = os.ReadFile(dataPath) + sqlBytes, err := os.ReadFile(dataPath) if err != nil { return nil, errors.Wrap(err, "os.ReadFile") } diff --git a/testdata.sql b/tests/testdata.sql similarity index 100% rename from testdata.sql rename to tests/testdata.sql