Created a migration package to handle the migrations
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,6 +4,7 @@ query.sql
|
||||
.logs/
|
||||
server.log
|
||||
tmp/
|
||||
psmigrate
|
||||
projectreshoot
|
||||
static/css/output.css
|
||||
view/**/*_templ.go
|
||||
|
||||
6
Makefile
6
Makefile
@@ -1,5 +1,6 @@
|
||||
# Makefile
|
||||
.PHONY: build
|
||||
.PHONY: migrate
|
||||
|
||||
BINARY_NAME=projectreshoot
|
||||
|
||||
@@ -29,3 +30,8 @@ test:
|
||||
|
||||
clean:
|
||||
go clean
|
||||
|
||||
migrate:
|
||||
go mod tidy && \
|
||||
go generate && \
|
||||
go build -ldflags="-w -s" -o psmigrate ./migrate
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"projectreshoot/tests"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSafeConn(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
var requested sync.WaitGroup
|
||||
var engaged sync.WaitGroup
|
||||
requested.Add(1)
|
||||
engaged.Add(1)
|
||||
go func() {
|
||||
requested.Done()
|
||||
sconn.Pause(5 * time.Second)
|
||||
engaged.Done()
|
||||
}()
|
||||
requested.Wait()
|
||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
||||
tx.Commit()
|
||||
engaged.Wait()
|
||||
assert.Equal(t, uint32(1), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
||||
sconn.Resume()
|
||||
})
|
||||
t.Run("Lock abandons after timeout", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
sconn.Pause(250 * time.Millisecond)
|
||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
||||
tx.Commit()
|
||||
})
|
||||
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
var requested sync.WaitGroup
|
||||
var engaged sync.WaitGroup
|
||||
requested.Add(1)
|
||||
engaged.Add(1)
|
||||
go func() {
|
||||
requested.Done()
|
||||
sconn.Pause(5 * time.Second)
|
||||
engaged.Done()
|
||||
}()
|
||||
requested.Wait()
|
||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err = sconn.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
tx.Commit()
|
||||
engaged.Wait()
|
||||
_, err = sconn.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
sconn.Resume()
|
||||
tx, err = sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx.Commit()
|
||||
})
|
||||
}
|
||||
func TestSafeTX(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
t.Run("Commit releases lock", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
||||
tx.Commit()
|
||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
||||
})
|
||||
t.Run("Rollback releases lock", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
||||
tx.Rollback()
|
||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
||||
})
|
||||
t.Run("Multiple TX can gain read lock", func(t *testing.T) {
|
||||
tx1, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx2, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx3, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx1.Commit()
|
||||
tx2.Commit()
|
||||
tx3.Commit()
|
||||
})
|
||||
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
sconn.acquireGlobalLock()
|
||||
defer sconn.releaseGlobalLock()
|
||||
_, err := sconn.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
t.Run("Lock acquires if lock released", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
sconn.acquireGlobalLock()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
tx, err := sconn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
tx.Commit()
|
||||
wg.Done()
|
||||
}()
|
||||
sconn.releaseGlobalLock()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
67
migrate/migrate.go
Normal file
67
migrate/migrate.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
//go:embed migrations
|
||||
var migrationsFS embed.FS
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 4 {
|
||||
fmt.Println("Usage: psmigrate <file_path> up-to|down-to <version>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
filePath := os.Args[1]
|
||||
direction := os.Args[2]
|
||||
versionStr := os.Args[3]
|
||||
|
||||
version, err := strconv.Atoi(versionStr)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid version number: %v", err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", filePath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
migrations, err := fs.Sub(migrationsFS, "migrations")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get migrations from embedded filesystem")
|
||||
}
|
||||
provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create migration provider: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch direction {
|
||||
case "up-to":
|
||||
_, err = provider.UpTo(ctx, int64(version))
|
||||
case "down-to":
|
||||
_, err = provider.DownTo(ctx, int64(version))
|
||||
default:
|
||||
log.Fatalf("Invalid direction: use 'up-to' or 'down-to'")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Migration failed: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Migration successful!")
|
||||
}
|
||||
@@ -21,4 +21,7 @@ END;
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
DROP TRIGGER IF EXISTS cleanup_expired_tokens;
|
||||
DROP TABLE IF EXISTS jwtblacklist;
|
||||
DROP TABLE IF EXISTS users;
|
||||
-- +goose StatementEnd
|
||||
@@ -21,7 +21,7 @@ func findMigrations() (*fs.FS, error) {
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||
migrationsdir := os.DirFS(filepath.Join(dir, "migrations"))
|
||||
migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations"))
|
||||
return &migrationsdir, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user