Created a migration package to handle the migrations

This commit is contained in:
2025-02-20 23:01:39 +11:00
parent a575025f1f
commit 79838b4aae
6 changed files with 78 additions and 144 deletions

1
.gitignore vendored
View File

@@ -4,6 +4,7 @@ query.sql
.logs/
server.log
tmp/
psmigrate
projectreshoot
static/css/output.css
view/**/*_templ.go

View File

@@ -1,5 +1,6 @@
# Makefile
.PHONY: build
.PHONY: migrate
BINARY_NAME=projectreshoot
@@ -29,3 +30,8 @@ test:
clean:
go clean
migrate:
go mod tidy && \
go generate && \
go build -ldflags="-w -s" -o psmigrate ./migrate

View File

@@ -1,143 +0,0 @@
package db
import (
"context"
"projectreshoot/tests"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSafeConn(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := MakeSafe(conn, logger)
defer sconn.Close()
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
var requested sync.WaitGroup
var engaged sync.WaitGroup
requested.Add(1)
engaged.Add(1)
go func() {
requested.Done()
sconn.Pause(5 * time.Second)
engaged.Done()
}()
requested.Wait()
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(1), sconn.globalLockRequested)
tx.Commit()
engaged.Wait()
assert.Equal(t, uint32(1), sconn.globalLockStatus)
assert.Equal(t, uint32(0), sconn.globalLockRequested)
sconn.Resume()
})
t.Run("Lock abandons after timeout", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
sconn.Pause(250 * time.Millisecond)
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(0), sconn.globalLockRequested)
tx.Commit()
})
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
var requested sync.WaitGroup
var engaged sync.WaitGroup
requested.Add(1)
engaged.Add(1)
go func() {
requested.Done()
sconn.Pause(5 * time.Second)
engaged.Done()
}()
requested.Wait()
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(1), sconn.globalLockRequested)
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
_, err = sconn.Begin(ctx)
require.Error(t, err)
tx.Commit()
engaged.Wait()
_, err = sconn.Begin(ctx)
require.Error(t, err)
sconn.Resume()
tx, err = sconn.Begin(t.Context())
require.NoError(t, err)
tx.Commit()
})
}
func TestSafeTX(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := MakeSafe(conn, logger)
defer sconn.Close()
t.Run("Commit releases lock", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
assert.Equal(t, uint32(1), sconn.readLockCount)
tx.Commit()
assert.Equal(t, uint32(0), sconn.readLockCount)
})
t.Run("Rollback releases lock", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
assert.Equal(t, uint32(1), sconn.readLockCount)
tx.Rollback()
assert.Equal(t, uint32(0), sconn.readLockCount)
})
t.Run("Multiple TX can gain read lock", func(t *testing.T) {
tx1, err := sconn.Begin(t.Context())
require.NoError(t, err)
tx2, err := sconn.Begin(t.Context())
require.NoError(t, err)
tx3, err := sconn.Begin(t.Context())
require.NoError(t, err)
tx1.Commit()
tx2.Commit()
tx3.Commit()
})
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
sconn.acquireGlobalLock()
defer sconn.releaseGlobalLock()
_, err := sconn.Begin(ctx)
require.Error(t, err)
})
t.Run("Lock acquires if lock released", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
sconn.acquireGlobalLock()
var wg sync.WaitGroup
wg.Add(1)
go func() {
tx, err := sconn.Begin(ctx)
require.NoError(t, err)
tx.Commit()
wg.Done()
}()
sconn.releaseGlobalLock()
wg.Wait()
})
}

67
migrate/migrate.go Normal file
View File

@@ -0,0 +1,67 @@
package main
import (
"context"
"database/sql"
"embed"
"fmt"
"io/fs"
"log"
"os"
"strconv"
"github.com/pressly/goose/v3"
_ "modernc.org/sqlite"
)
//go:embed migrations
var migrationsFS embed.FS
func main() {
if len(os.Args) != 4 {
fmt.Println("Usage: psmigrate <file_path> up-to|down-to <version>")
os.Exit(1)
}
filePath := os.Args[1]
direction := os.Args[2]
versionStr := os.Args[3]
version, err := strconv.Atoi(versionStr)
if err != nil {
log.Fatalf("Invalid version number: %v", err)
}
db, err := sql.Open("sqlite", filePath)
if err != nil {
log.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
migrations, err := fs.Sub(migrationsFS, "migrations")
if err != nil {
log.Fatalf("Failed to get migrations from embedded filesystem")
}
provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations)
if err != nil {
log.Fatalf("Failed to create migration provider: %v", err)
}
ctx := context.Background()
switch direction {
case "up-to":
_, err = provider.UpTo(ctx, int64(version))
case "down-to":
_, err = provider.DownTo(ctx, int64(version))
default:
log.Fatalf("Invalid direction: use 'up-to' or 'down-to'")
}
if err != nil {
log.Fatalf("Migration failed: %v", err)
}
fmt.Println("Migration successful!")
}

View File

@@ -21,4 +21,7 @@ END;
-- +goose Down
-- +goose StatementBegin
DROP TRIGGER IF EXISTS cleanup_expired_tokens;
DROP TABLE IF EXISTS jwtblacklist;
DROP TABLE IF EXISTS users;
-- +goose StatementEnd

View File

@@ -21,7 +21,7 @@ func findMigrations() (*fs.FS, error) {
for {
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
migrationsdir := os.DirFS(filepath.Join(dir, "migrations"))
migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations"))
return &migrationsdir, nil
}