From 79838b4aae12c327fbae0e45efd14368e08d3d08 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Thu, 20 Feb 2025 23:01:39 +1100 Subject: [PATCH] Created a migration package to handle the migrations --- .gitignore | 1 + Makefile | 6 + db/connection_test.go | 143 ------------------ migrate/migrate.go | 67 ++++++++ .../migrations}/00001_init.sql | 3 + tests/database.go | 2 +- 6 files changed, 78 insertions(+), 144 deletions(-) delete mode 100644 db/connection_test.go create mode 100644 migrate/migrate.go rename {migrations => migrate/migrations}/00001_init.sql (85%) diff --git a/.gitignore b/.gitignore index 5b22f92..c2f2f8c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ query.sql .logs/ server.log tmp/ +psmigrate projectreshoot static/css/output.css view/**/*_templ.go diff --git a/Makefile b/Makefile index e23833d..f3535a3 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ # Makefile .PHONY: build +.PHONY: migrate BINARY_NAME=projectreshoot @@ -29,3 +30,8 @@ test: clean: go clean + +migrate: + go mod tidy && \ + go generate && \ + go build -ldflags="-w -s" -o psmigrate ./migrate diff --git a/db/connection_test.go b/db/connection_test.go deleted file mode 100644 index b1816d7..0000000 --- a/db/connection_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package db - -import ( - "context" - "projectreshoot/tests" - "strconv" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSafeConn(t *testing.T) { - cfg, err := tests.TestConfig() - require.NoError(t, err) - logger := tests.NilLogger() - ver, err := strconv.ParseInt(cfg.DBName, 10, 0) - require.NoError(t, err) - conn, err := tests.SetupTestDB(ver) - require.NoError(t, err) - sconn := MakeSafe(conn, logger) - defer sconn.Close() - - t.Run("Global lock waits for read locks to finish", func(t *testing.T) { - tx, err := sconn.Begin(t.Context()) - require.NoError(t, err) - var requested sync.WaitGroup - var engaged sync.WaitGroup - requested.Add(1) - engaged.Add(1) - go func() { - requested.Done() - sconn.Pause(5 * time.Second) - engaged.Done() - }() - requested.Wait() - assert.Equal(t, uint32(0), sconn.globalLockStatus) - assert.Equal(t, uint32(1), sconn.globalLockRequested) - tx.Commit() - engaged.Wait() - assert.Equal(t, uint32(1), sconn.globalLockStatus) - assert.Equal(t, uint32(0), sconn.globalLockRequested) - sconn.Resume() - }) - t.Run("Lock abandons after timeout", func(t *testing.T) { - tx, err := sconn.Begin(t.Context()) - require.NoError(t, err) - sconn.Pause(250 * time.Millisecond) - assert.Equal(t, uint32(0), sconn.globalLockStatus) - assert.Equal(t, uint32(0), sconn.globalLockRequested) - tx.Commit() - }) - t.Run("Pause blocks transactions and resume allows", func(t *testing.T) { - tx, err := sconn.Begin(t.Context()) - require.NoError(t, err) - var requested sync.WaitGroup - var engaged sync.WaitGroup - requested.Add(1) - engaged.Add(1) - go func() { - requested.Done() - sconn.Pause(5 * time.Second) - engaged.Done() - }() - requested.Wait() - assert.Equal(t, uint32(0), sconn.globalLockStatus) - assert.Equal(t, uint32(1), sconn.globalLockRequested) - ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond) - defer cancel() - _, err = sconn.Begin(ctx) - require.Error(t, err) - tx.Commit() - engaged.Wait() - _, err = sconn.Begin(ctx) - require.Error(t, err) - sconn.Resume() - tx, err = sconn.Begin(t.Context()) - require.NoError(t, err) - tx.Commit() - }) -} -func TestSafeTX(t *testing.T) { - cfg, err := tests.TestConfig() - require.NoError(t, err) - logger := tests.NilLogger() - ver, err := strconv.ParseInt(cfg.DBName, 10, 0) - require.NoError(t, err) - conn, err := tests.SetupTestDB(ver) - require.NoError(t, err) - sconn := MakeSafe(conn, logger) - defer sconn.Close() - - t.Run("Commit releases lock", func(t *testing.T) { - tx, err := sconn.Begin(t.Context()) - require.NoError(t, err) - assert.Equal(t, uint32(1), sconn.readLockCount) - tx.Commit() - assert.Equal(t, uint32(0), sconn.readLockCount) - }) - t.Run("Rollback releases lock", func(t *testing.T) { - tx, err := sconn.Begin(t.Context()) - require.NoError(t, err) - assert.Equal(t, uint32(1), sconn.readLockCount) - tx.Rollback() - assert.Equal(t, uint32(0), sconn.readLockCount) - }) - t.Run("Multiple TX can gain read lock", func(t *testing.T) { - tx1, err := sconn.Begin(t.Context()) - require.NoError(t, err) - tx2, err := sconn.Begin(t.Context()) - require.NoError(t, err) - tx3, err := sconn.Begin(t.Context()) - require.NoError(t, err) - tx1.Commit() - tx2.Commit() - tx3.Commit() - }) - t.Run("Lock acquiring times out after timeout", func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond) - defer cancel() - sconn.acquireGlobalLock() - defer sconn.releaseGlobalLock() - _, err := sconn.Begin(ctx) - require.Error(t, err) - }) - t.Run("Lock acquires if lock released", func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond) - defer cancel() - sconn.acquireGlobalLock() - var wg sync.WaitGroup - wg.Add(1) - go func() { - tx, err := sconn.Begin(ctx) - require.NoError(t, err) - tx.Commit() - wg.Done() - }() - sconn.releaseGlobalLock() - wg.Wait() - }) -} diff --git a/migrate/migrate.go b/migrate/migrate.go new file mode 100644 index 0000000..4f1b3a7 --- /dev/null +++ b/migrate/migrate.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "database/sql" + "embed" + "fmt" + "io/fs" + "log" + "os" + "strconv" + + "github.com/pressly/goose/v3" + + _ "modernc.org/sqlite" +) + +//go:embed migrations +var migrationsFS embed.FS + +func main() { + if len(os.Args) != 4 { + fmt.Println("Usage: psmigrate up-to|down-to ") + os.Exit(1) + } + + filePath := os.Args[1] + direction := os.Args[2] + versionStr := os.Args[3] + + version, err := strconv.Atoi(versionStr) + if err != nil { + log.Fatalf("Invalid version number: %v", err) + } + + db, err := sql.Open("sqlite", filePath) + if err != nil { + log.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + migrations, err := fs.Sub(migrationsFS, "migrations") + if err != nil { + log.Fatalf("Failed to get migrations from embedded filesystem") + } + provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations) + if err != nil { + log.Fatalf("Failed to create migration provider: %v", err) + } + + ctx := context.Background() + + switch direction { + case "up-to": + _, err = provider.UpTo(ctx, int64(version)) + case "down-to": + _, err = provider.DownTo(ctx, int64(version)) + default: + log.Fatalf("Invalid direction: use 'up-to' or 'down-to'") + } + + if err != nil { + log.Fatalf("Migration failed: %v", err) + } + + fmt.Println("Migration successful!") +} diff --git a/migrations/00001_init.sql b/migrate/migrations/00001_init.sql similarity index 85% rename from migrations/00001_init.sql rename to migrate/migrations/00001_init.sql index a3bab95..1ada972 100644 --- a/migrations/00001_init.sql +++ b/migrate/migrations/00001_init.sql @@ -21,4 +21,7 @@ END; -- +goose Down -- +goose StatementBegin +DROP TRIGGER IF EXISTS cleanup_expired_tokens; +DROP TABLE IF EXISTS jwtblacklist; +DROP TABLE IF EXISTS users; -- +goose StatementEnd diff --git a/tests/database.go b/tests/database.go index 6157fbc..0010636 100644 --- a/tests/database.go +++ b/tests/database.go @@ -21,7 +21,7 @@ func findMigrations() (*fs.FS, error) { for { if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil { - migrationsdir := os.DirFS(filepath.Join(dir, "migrations")) + migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations")) return &migrationsdir, nil }