big ole refactor
This commit is contained in:
511
internal/db/migrate/migrate.go
Normal file
511
internal/db/migrate/migrate.go
Normal file
@@ -0,0 +1,511 @@
|
||||
// Package migrate provides functions for managing database migrations
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/db/migrations"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun/migrate"
|
||||
)
|
||||
|
||||
// RunMigrations executes database migrations
|
||||
func RunMigrations(ctx context.Context, cfg *config.Config, command string, countStr string) error {
|
||||
conn := db.NewDB(cfg.DB)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
migrator := migrate.NewMigrator(conn.DB, migrations.Migrations)
|
||||
|
||||
// Initialize migration tables
|
||||
if err := migrator.Init(ctx); err != nil {
|
||||
return errors.Wrap(err, "migrator.Init")
|
||||
}
|
||||
|
||||
switch command {
|
||||
case "up":
|
||||
err := migrateUp(ctx, migrator, conn, cfg, countStr)
|
||||
if err != nil {
|
||||
// On error, automatically rollback the migrations that were just applied
|
||||
fmt.Println("[WARN] Migration failed, attempting automatic rollback...")
|
||||
// We need to figure out how many migrations were applied in this batch
|
||||
// For now, we'll skip automatic rollback since it's complex with the new count system
|
||||
// The user can manually rollback if needed
|
||||
return err
|
||||
}
|
||||
return err
|
||||
case "rollback":
|
||||
return migrateRollback(ctx, migrator, conn, cfg, countStr)
|
||||
case "status":
|
||||
return migrateStatus(ctx, migrator)
|
||||
default:
|
||||
return fmt.Errorf("unknown migration command: %s", command)
|
||||
}
|
||||
}
|
||||
|
||||
// migrateUp runs pending migrations
|
||||
func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error {
|
||||
// Parse count parameter
|
||||
count, all, err := parseMigrationCount(countStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "parse migration count")
|
||||
}
|
||||
|
||||
fmt.Println("[INFO] Step 1/5: Validating migrations...")
|
||||
if err := validateMigrations(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("[INFO] Migration validation passed ✓")
|
||||
|
||||
fmt.Println("[INFO] Step 2/5: Checking for pending migrations...")
|
||||
// Check for pending migrations using MigrationsWithStatus (read-only)
|
||||
ms, err := migrator.MigrationsWithStatus(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get migration status")
|
||||
}
|
||||
|
||||
unapplied := ms.Unapplied()
|
||||
if len(unapplied) == 0 {
|
||||
fmt.Println("[INFO] No pending migrations")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Select which migrations to apply
|
||||
toApply := selectMigrationsToApply(unapplied, count, all)
|
||||
if len(toApply) == 0 {
|
||||
fmt.Println("[INFO] No migrations to run")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Print what we're about to do
|
||||
if all {
|
||||
fmt.Printf("[INFO] Running all %d pending migration(s):\n", len(toApply))
|
||||
} else {
|
||||
fmt.Printf("[INFO] Running %d migration(s):\n", len(toApply))
|
||||
}
|
||||
for _, m := range toApply {
|
||||
fmt.Printf(" 📋 %s\n", m.Name)
|
||||
}
|
||||
|
||||
// Create backup unless --no-backup flag is set
|
||||
if !cfg.Flags.MigrateNoBackup {
|
||||
fmt.Println("[INFO] Step 3/5: Creating backup...")
|
||||
_, err := db.CreateBackup(ctx, cfg.DB, "migration")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create backup")
|
||||
}
|
||||
|
||||
// Clean old backups
|
||||
if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil {
|
||||
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("[INFO] Step 3/5: Skipping backup (--no-backup flag set)")
|
||||
}
|
||||
|
||||
// Acquire migration lock
|
||||
fmt.Println("[INFO] Step 4/5: Acquiring migration lock...")
|
||||
if err := acquireMigrationLock(ctx, conn); err != nil {
|
||||
return errors.Wrap(err, "acquire migration lock")
|
||||
}
|
||||
defer releaseMigrationLock(ctx, conn)
|
||||
fmt.Println("[INFO] Migration lock acquired")
|
||||
|
||||
// Run migrations
|
||||
fmt.Println("[INFO] Step 5/5: Applying migrations...")
|
||||
group, err := executeUpMigrations(ctx, migrator, toApply)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "execute migrations")
|
||||
}
|
||||
|
||||
if group.IsZero() {
|
||||
fmt.Println("[INFO] No migrations to run")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("[INFO] Migrated to group %d\n", group.ID)
|
||||
for _, migration := range group.Migrations {
|
||||
fmt.Printf(" ✅ %s\n", migration.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateRollback rolls back migrations
|
||||
func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error {
|
||||
// Parse count parameter
|
||||
count, all, err := parseMigrationCount(countStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "parse migration count")
|
||||
}
|
||||
|
||||
// Get all migrations with status
|
||||
ms, err := migrator.MigrationsWithStatus(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get migration status")
|
||||
}
|
||||
|
||||
applied := ms.Applied()
|
||||
if len(applied) == 0 {
|
||||
fmt.Println("[INFO] No migrations to rollback")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Select which migrations to rollback
|
||||
toRollback := selectMigrationsToRollback(applied, count, all)
|
||||
if len(toRollback) == 0 {
|
||||
fmt.Println("[INFO] No migrations to rollback")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Print what we're about to do
|
||||
if all {
|
||||
fmt.Printf("[INFO] Rolling back all %d migration(s):\n", len(toRollback))
|
||||
} else {
|
||||
fmt.Printf("[INFO] Rolling back %d migration(s):\n", len(toRollback))
|
||||
}
|
||||
for _, m := range toRollback {
|
||||
fmt.Printf(" 📋 %s (group %d)\n", m.Name, m.GroupID)
|
||||
}
|
||||
|
||||
// Create backup unless --no-backup flag is set
|
||||
if !cfg.Flags.MigrateNoBackup {
|
||||
fmt.Println("[INFO] Creating backup before rollback...")
|
||||
_, err := db.CreateBackup(ctx, cfg.DB, "rollback")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create backup")
|
||||
}
|
||||
|
||||
// Clean old backups
|
||||
if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil {
|
||||
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("[INFO] Skipping backup (--no-backup flag set)")
|
||||
}
|
||||
|
||||
// Acquire migration lock
|
||||
fmt.Println("[INFO] Acquiring migration lock...")
|
||||
if err := acquireMigrationLock(ctx, conn); err != nil {
|
||||
return errors.Wrap(err, "acquire migration lock")
|
||||
}
|
||||
defer releaseMigrationLock(ctx, conn)
|
||||
fmt.Println("[INFO] Migration lock acquired")
|
||||
|
||||
// Rollback
|
||||
fmt.Println("[INFO] Executing rollback...")
|
||||
rolledBack, err := executeDownMigrations(ctx, migrator, toRollback)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "execute rollback")
|
||||
}
|
||||
|
||||
fmt.Printf("[INFO] Successfully rolled back %d migration(s)\n", len(rolledBack))
|
||||
for _, migration := range rolledBack {
|
||||
fmt.Printf(" ↩️ %s\n", migration.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateStatus shows migration status
|
||||
func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error {
|
||||
ms, err := migrator.MigrationsWithStatus(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get migration status")
|
||||
}
|
||||
|
||||
fmt.Println("╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ DATABASE MIGRATION STATUS ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 1, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tCOMMENT")
|
||||
_, _ = fmt.Fprintln(w, "----------\t---------------\t-----\t---------------------------")
|
||||
|
||||
appliedCount := 0
|
||||
for _, m := range ms {
|
||||
status := "⏳ Pending"
|
||||
group := "-"
|
||||
|
||||
if m.GroupID > 0 {
|
||||
status = "✅ Applied"
|
||||
appliedCount++
|
||||
group = fmt.Sprint(m.GroupID)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, m.Comment)
|
||||
}
|
||||
|
||||
_ = w.Flush()
|
||||
|
||||
fmt.Printf("\n📊 Summary: %d applied, %d pending\n\n",
|
||||
appliedCount, len(ms)-appliedCount)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateMigrations ensures migrations compile before running
|
||||
func validateMigrations(ctx context.Context) error {
|
||||
cmd := exec.CommandContext(ctx, "go", "build",
|
||||
"-o", "/dev/null", "./cmd/oslstats/migrations")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR] Migration validation failed!")
|
||||
fmt.Println(string(output))
|
||||
return errors.Wrap(err, "migration build failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acquireMigrationLock prevents concurrent migrations using PostgreSQL advisory lock
|
||||
func acquireMigrationLock(ctx context.Context, conn *db.DB) error {
|
||||
const lockID = 1234567890 // Arbitrary unique ID for migration lock
|
||||
const timeoutSeconds = 300 // 5 minutes
|
||||
|
||||
// Set statement timeout for this session
|
||||
_, err := conn.ExecContext(ctx,
|
||||
fmt.Sprintf("SET statement_timeout = '%ds'", timeoutSeconds))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "set timeout")
|
||||
}
|
||||
|
||||
var acquired bool
|
||||
err = conn.NewRaw("SELECT pg_try_advisory_lock(?)", lockID).
|
||||
Scan(ctx, &acquired)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "pg_try_advisory_lock")
|
||||
}
|
||||
|
||||
if !acquired {
|
||||
return errors.New("migration already in progress (could not acquire lock)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// releaseMigrationLock releases the migration lock
|
||||
func releaseMigrationLock(ctx context.Context, conn *db.DB) {
|
||||
const lockID = 1234567890
|
||||
|
||||
_, err := conn.NewRaw("SELECT pg_advisory_unlock(?)", lockID).Exec(ctx)
|
||||
if err != nil {
|
||||
fmt.Printf("[WARN] Failed to release migration lock: %v\n", err)
|
||||
} else {
|
||||
fmt.Println("[INFO] Migration lock released")
|
||||
}
|
||||
}
|
||||
|
||||
// CreateMigration generates a new migration file
|
||||
func CreateMigration(name string) error {
|
||||
if name == "" {
|
||||
return errors.New("migration name cannot be empty")
|
||||
}
|
||||
|
||||
// Sanitize name (replace spaces with underscores, lowercase)
|
||||
name = strings.ToLower(strings.ReplaceAll(name, " ", "_"))
|
||||
|
||||
// Generate timestamp
|
||||
timestamp := time.Now().Format("20060102150405")
|
||||
filename := fmt.Sprintf("internal/db/migrations/%s_%s.go", timestamp, name)
|
||||
|
||||
// Template
|
||||
template := `package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Migrations.MustRegister(
|
||||
// UP migration
|
||||
func(ctx context.Context, conn *bun.DB) error {
|
||||
// Add your migration code here
|
||||
return nil
|
||||
},
|
||||
// DOWN migration
|
||||
func(ctx context.Context, conn *bun.DB) error {
|
||||
// Add your rollback code here
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
`
|
||||
|
||||
// Write file
|
||||
if err := os.WriteFile(filename, []byte(template), 0o644); err != nil {
|
||||
return errors.Wrap(err, "write migration file")
|
||||
}
|
||||
|
||||
fmt.Printf("✅ Created migration: %s\n", filename)
|
||||
fmt.Println("📝 Next steps:")
|
||||
fmt.Println(" 1. Edit the file and implement the UP and DOWN functions")
|
||||
fmt.Println(" 2. Run: just migrate up")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseMigrationCount parses a migration count string
|
||||
// Returns: (count, all, error)
|
||||
// - "" (empty) → (1, false, nil) - default to 1
|
||||
// - "all" → (0, true, nil) - special case for all
|
||||
// - "5" → (5, false, nil) - specific count
|
||||
// - "invalid" → (0, false, error)
|
||||
func parseMigrationCount(value string) (int, bool, error) {
|
||||
// Default to 1 if empty
|
||||
if value == "" {
|
||||
return 1, false, nil
|
||||
}
|
||||
|
||||
// Special case for "all"
|
||||
if value == "all" {
|
||||
return 0, true, nil
|
||||
}
|
||||
|
||||
// Parse as integer
|
||||
count, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0, false, errors.New("migration count must be a positive integer or 'all'")
|
||||
}
|
||||
if count < 1 {
|
||||
return 0, false, errors.New("migration count must be a positive integer (1 or greater)")
|
||||
}
|
||||
|
||||
return count, false, nil
|
||||
}
|
||||
|
||||
// selectMigrationsToApply returns the subset of unapplied migrations to run
|
||||
func selectMigrationsToApply(unapplied migrate.MigrationSlice, count int, all bool) migrate.MigrationSlice {
|
||||
if all {
|
||||
return unapplied
|
||||
}
|
||||
|
||||
count = min(count, len(unapplied))
|
||||
return unapplied[:count]
|
||||
}
|
||||
|
||||
// selectMigrationsToRollback returns the subset of applied migrations to rollback
|
||||
// Returns migrations in reverse chronological order (most recent first)
|
||||
func selectMigrationsToRollback(applied migrate.MigrationSlice, count int, all bool) migrate.MigrationSlice {
|
||||
if len(applied) == 0 || all {
|
||||
return applied
|
||||
}
|
||||
count = min(count, len(applied))
|
||||
return applied[:count]
|
||||
}
|
||||
|
||||
// executeUpMigrations executes a subset of UP migrations
|
||||
func executeUpMigrations(ctx context.Context, migrator *migrate.Migrator, migrations migrate.MigrationSlice) (*migrate.MigrationGroup, error) {
|
||||
if len(migrations) == 0 {
|
||||
return &migrate.MigrationGroup{}, nil
|
||||
}
|
||||
|
||||
// Get the next group ID
|
||||
ms, err := migrator.MigrationsWithStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get migration status")
|
||||
}
|
||||
|
||||
lastGroup := ms.LastGroup()
|
||||
groupID := int64(1)
|
||||
if lastGroup.ID > 0 {
|
||||
groupID = lastGroup.ID + 1
|
||||
}
|
||||
|
||||
// Create the migration group
|
||||
group := &migrate.MigrationGroup{
|
||||
ID: groupID,
|
||||
Migrations: make(migrate.MigrationSlice, 0, len(migrations)),
|
||||
}
|
||||
|
||||
// Execute each migration
|
||||
for i := range migrations {
|
||||
migration := &migrations[i]
|
||||
migration.GroupID = groupID
|
||||
|
||||
// Mark as applied before execution (Bun's default behavior)
|
||||
if err := migrator.MarkApplied(ctx, migration); err != nil {
|
||||
return group, errors.Wrap(err, "mark applied")
|
||||
}
|
||||
|
||||
// Add to group
|
||||
group.Migrations = append(group.Migrations, *migration)
|
||||
|
||||
// Execute the UP function
|
||||
if migration.Up != nil {
|
||||
if err := migration.Up(ctx, migrator, migration); err != nil {
|
||||
return group, errors.Wrap(err, fmt.Sprintf("migration %s failed", migration.Name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// executeDownMigrations executes a subset of DOWN migrations
|
||||
func executeDownMigrations(ctx context.Context, migrator *migrate.Migrator, migrations migrate.MigrationSlice) (migrate.MigrationSlice, error) {
|
||||
rolledBack := make(migrate.MigrationSlice, 0, len(migrations))
|
||||
|
||||
// Execute each migration in order (already reversed)
|
||||
for i := range migrations {
|
||||
migration := &migrations[i]
|
||||
|
||||
// Execute the DOWN function
|
||||
if migration.Down != nil {
|
||||
if err := migration.Down(ctx, migrator, migration); err != nil {
|
||||
return rolledBack, errors.Wrap(err, fmt.Sprintf("rollback %s failed", migration.Name))
|
||||
}
|
||||
}
|
||||
|
||||
// Mark as unapplied after execution
|
||||
if err := migrator.MarkUnapplied(ctx, migration); err != nil {
|
||||
return rolledBack, errors.Wrap(err, "mark unapplied")
|
||||
}
|
||||
|
||||
rolledBack = append(rolledBack, *migration)
|
||||
}
|
||||
|
||||
return rolledBack, nil
|
||||
}
|
||||
|
||||
// ResetDatabase drops and recreates all tables (destructive)
|
||||
func ResetDatabase(ctx context.Context, cfg *config.Config) error {
|
||||
fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!")
|
||||
fmt.Print("Type 'yes' to continue: ")
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read input")
|
||||
}
|
||||
|
||||
response = strings.TrimSpace(response)
|
||||
if response != "yes" {
|
||||
fmt.Println("❌ Reset cancelled")
|
||||
return nil
|
||||
}
|
||||
conn := db.NewDB(cfg.DB)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
models := conn.RegisterModels()
|
||||
|
||||
for _, model := range models {
|
||||
if err := conn.ResetModel(ctx, model); err != nil {
|
||||
return errors.Wrap(err, "reset model")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("✅ Database reset complete")
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user