Created a migration package to handle the migrations
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,6 +4,7 @@ query.sql
|
|||||||
.logs/
|
.logs/
|
||||||
server.log
|
server.log
|
||||||
tmp/
|
tmp/
|
||||||
|
psmigrate
|
||||||
projectreshoot
|
projectreshoot
|
||||||
static/css/output.css
|
static/css/output.css
|
||||||
view/**/*_templ.go
|
view/**/*_templ.go
|
||||||
|
|||||||
6
Makefile
6
Makefile
@@ -1,5 +1,6 @@
|
|||||||
# Makefile
|
# Makefile
|
||||||
.PHONY: build
|
.PHONY: build
|
||||||
|
.PHONY: migrate
|
||||||
|
|
||||||
BINARY_NAME=projectreshoot
|
BINARY_NAME=projectreshoot
|
||||||
|
|
||||||
@@ -29,3 +30,8 @@ test:
|
|||||||
|
|
||||||
clean:
|
clean:
|
||||||
go clean
|
go clean
|
||||||
|
|
||||||
|
migrate:
|
||||||
|
go mod tidy && \
|
||||||
|
go generate && \
|
||||||
|
go build -ldflags="-w -s" -o psmigrate ./migrate
|
||||||
|
|||||||
@@ -1,143 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"projectreshoot/tests"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSafeConn(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
var requested sync.WaitGroup
|
|
||||||
var engaged sync.WaitGroup
|
|
||||||
requested.Add(1)
|
|
||||||
engaged.Add(1)
|
|
||||||
go func() {
|
|
||||||
requested.Done()
|
|
||||||
sconn.Pause(5 * time.Second)
|
|
||||||
engaged.Done()
|
|
||||||
}()
|
|
||||||
requested.Wait()
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
|
||||||
tx.Commit()
|
|
||||||
engaged.Wait()
|
|
||||||
assert.Equal(t, uint32(1), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
|
||||||
sconn.Resume()
|
|
||||||
})
|
|
||||||
t.Run("Lock abandons after timeout", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn.Pause(250 * time.Millisecond)
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
|
||||||
tx.Commit()
|
|
||||||
})
|
|
||||||
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
var requested sync.WaitGroup
|
|
||||||
var engaged sync.WaitGroup
|
|
||||||
requested.Add(1)
|
|
||||||
engaged.Add(1)
|
|
||||||
go func() {
|
|
||||||
requested.Done()
|
|
||||||
sconn.Pause(5 * time.Second)
|
|
||||||
engaged.Done()
|
|
||||||
}()
|
|
||||||
requested.Wait()
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
_, err = sconn.Begin(ctx)
|
|
||||||
require.Error(t, err)
|
|
||||||
tx.Commit()
|
|
||||||
engaged.Wait()
|
|
||||||
_, err = sconn.Begin(ctx)
|
|
||||||
require.Error(t, err)
|
|
||||||
sconn.Resume()
|
|
||||||
tx, err = sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx.Commit()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
func TestSafeTX(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
t.Run("Commit releases lock", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
|
||||||
tx.Commit()
|
|
||||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
|
||||||
})
|
|
||||||
t.Run("Rollback releases lock", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
|
||||||
tx.Rollback()
|
|
||||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
|
||||||
})
|
|
||||||
t.Run("Multiple TX can gain read lock", func(t *testing.T) {
|
|
||||||
tx1, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx2, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx3, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx1.Commit()
|
|
||||||
tx2.Commit()
|
|
||||||
tx3.Commit()
|
|
||||||
})
|
|
||||||
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
sconn.acquireGlobalLock()
|
|
||||||
defer sconn.releaseGlobalLock()
|
|
||||||
_, err := sconn.Begin(ctx)
|
|
||||||
require.Error(t, err)
|
|
||||||
})
|
|
||||||
t.Run("Lock acquires if lock released", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
sconn.acquireGlobalLock()
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
tx, err := sconn.Begin(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx.Commit()
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
sconn.releaseGlobalLock()
|
|
||||||
wg.Wait()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
67
migrate/migrate.go
Normal file
67
migrate/migrate.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"embed"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/pressly/goose/v3"
|
||||||
|
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed migrations
|
||||||
|
var migrationsFS embed.FS
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if len(os.Args) != 4 {
|
||||||
|
fmt.Println("Usage: psmigrate <file_path> up-to|down-to <version>")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := os.Args[1]
|
||||||
|
direction := os.Args[2]
|
||||||
|
versionStr := os.Args[3]
|
||||||
|
|
||||||
|
version, err := strconv.Atoi(versionStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Invalid version number: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite", filePath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
migrations, err := fs.Sub(migrationsFS, "migrations")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to get migrations from embedded filesystem")
|
||||||
|
}
|
||||||
|
provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create migration provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
switch direction {
|
||||||
|
case "up-to":
|
||||||
|
_, err = provider.UpTo(ctx, int64(version))
|
||||||
|
case "down-to":
|
||||||
|
_, err = provider.DownTo(ctx, int64(version))
|
||||||
|
default:
|
||||||
|
log.Fatalf("Invalid direction: use 'up-to' or 'down-to'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Migration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Migration successful!")
|
||||||
|
}
|
||||||
@@ -21,4 +21,7 @@ END;
|
|||||||
|
|
||||||
-- +goose Down
|
-- +goose Down
|
||||||
-- +goose StatementBegin
|
-- +goose StatementBegin
|
||||||
|
DROP TRIGGER IF EXISTS cleanup_expired_tokens;
|
||||||
|
DROP TABLE IF EXISTS jwtblacklist;
|
||||||
|
DROP TABLE IF EXISTS users;
|
||||||
-- +goose StatementEnd
|
-- +goose StatementEnd
|
||||||
@@ -21,7 +21,7 @@ func findMigrations() (*fs.FS, error) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||||
migrationsdir := os.DirFS(filepath.Join(dir, "migrations"))
|
migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations"))
|
||||||
return &migrationsdir, nil
|
return &migrationsdir, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user