366 lines
9.9 KiB
Go
366 lines
9.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"text/tabwriter"
|
|
"time"
|
|
|
|
stderrors "errors"
|
|
|
|
"git.haelnorr.com/h/oslstats/cmd/oslstats/migrations"
|
|
"git.haelnorr.com/h/oslstats/internal/backup"
|
|
"git.haelnorr.com/h/oslstats/internal/config"
|
|
"git.haelnorr.com/h/oslstats/internal/db"
|
|
"github.com/pkg/errors"
|
|
"github.com/uptrace/bun"
|
|
"github.com/uptrace/bun/migrate"
|
|
)
|
|
|
|
// runMigrations executes database migrations
|
|
func runMigrations(ctx context.Context, cfg *config.Config, command string) error {
|
|
conn, close := setupBun(cfg)
|
|
defer func() { _ = close() }()
|
|
|
|
migrator := migrate.NewMigrator(conn, migrations.Migrations)
|
|
|
|
// Initialize migration tables
|
|
if err := migrator.Init(ctx); err != nil {
|
|
return errors.Wrap(err, "migrator.Init")
|
|
}
|
|
|
|
switch command {
|
|
case "up":
|
|
err := migrateUp(ctx, migrator, conn, cfg)
|
|
if err != nil {
|
|
err2 := migrateRollback(ctx, migrator, conn, cfg)
|
|
if err2 != nil {
|
|
return stderrors.Join(errors.Wrap(err2, "error while rolling back after migration error"), err)
|
|
}
|
|
}
|
|
return err
|
|
case "rollback":
|
|
return migrateRollback(ctx, migrator, conn, cfg)
|
|
case "status":
|
|
return migrateStatus(ctx, migrator)
|
|
case "dry-run":
|
|
return migrateDryRun(ctx, migrator)
|
|
default:
|
|
return fmt.Errorf("unknown migration command: %s", command)
|
|
}
|
|
}
|
|
|
|
// migrateUp runs pending migrations
|
|
func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cfg *config.Config) error {
|
|
fmt.Println("[INFO] Step 1/5: Validating migrations...")
|
|
if err := validateMigrations(ctx); err != nil {
|
|
return err
|
|
}
|
|
fmt.Println("[INFO] Migration validation passed ✓")
|
|
|
|
fmt.Println("[INFO] Step 2/5: Checking for pending migrations...")
|
|
// Check for pending migrations using MigrationsWithStatus (read-only)
|
|
ms, err := migrator.MigrationsWithStatus(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "get migration status")
|
|
}
|
|
|
|
unapplied := ms.Unapplied()
|
|
if len(unapplied) == 0 {
|
|
fmt.Println("[INFO] No pending migrations")
|
|
return nil
|
|
}
|
|
|
|
// Create backup unless --no-backup flag is set
|
|
if !cfg.Flags.MigrateNoBackup {
|
|
fmt.Println("[INFO] Step 3/5: Creating backup...")
|
|
_, err := backup.CreateBackup(ctx, cfg, "migration")
|
|
if err != nil {
|
|
return errors.Wrap(err, "create backup")
|
|
}
|
|
|
|
// Clean old backups
|
|
if err := backup.CleanOldBackups(cfg, cfg.DB.BackupRetention); err != nil {
|
|
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
|
|
}
|
|
} else {
|
|
fmt.Println("[INFO] Step 3/5: Skipping backup (--no-backup flag set)")
|
|
}
|
|
|
|
// Acquire migration lock
|
|
fmt.Println("[INFO] Step 4/5: Acquiring migration lock...")
|
|
if err := acquireMigrationLock(ctx, conn); err != nil {
|
|
return errors.Wrap(err, "acquire migration lock")
|
|
}
|
|
defer releaseMigrationLock(ctx, conn)
|
|
fmt.Println("[INFO] Migration lock acquired")
|
|
|
|
// Run migrations
|
|
fmt.Println("[INFO] Step 5/5: Applying migrations...")
|
|
group, err := migrator.Migrate(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "migrate")
|
|
}
|
|
|
|
if group.IsZero() {
|
|
fmt.Println("[INFO] No migrations to run")
|
|
return nil
|
|
}
|
|
|
|
fmt.Printf("[INFO] Migrated to group %d\n", group.ID)
|
|
for _, migration := range group.Migrations {
|
|
fmt.Printf(" ✅ %s\n", migration.Name)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// migrateRollback rolls back the last migration group
|
|
func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *bun.DB, cfg *config.Config) error {
|
|
// Create backup unless --no-backup flag is set
|
|
if !cfg.Flags.MigrateNoBackup {
|
|
fmt.Println("[INFO] Creating backup before rollback...")
|
|
_, err := backup.CreateBackup(ctx, cfg, "rollback")
|
|
if err != nil {
|
|
return errors.Wrap(err, "create backup")
|
|
}
|
|
|
|
// Clean old backups
|
|
if err := backup.CleanOldBackups(cfg, cfg.DB.BackupRetention); err != nil {
|
|
fmt.Printf("[WARN] Failed to clean old backups: %v\n", err)
|
|
}
|
|
} else {
|
|
fmt.Println("[INFO] Skipping backup (--no-backup flag set)")
|
|
}
|
|
|
|
// Acquire migration lock
|
|
fmt.Println("[INFO] Acquiring migration lock...")
|
|
if err := acquireMigrationLock(ctx, conn); err != nil {
|
|
return errors.Wrap(err, "acquire migration lock")
|
|
}
|
|
defer releaseMigrationLock(ctx, conn)
|
|
fmt.Println("[INFO] Migration lock acquired")
|
|
|
|
// Rollback
|
|
fmt.Println("[INFO] Rolling back last migration group...")
|
|
group, err := migrator.Rollback(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "rollback")
|
|
}
|
|
|
|
if group.IsZero() {
|
|
fmt.Println("[INFO] No migrations to rollback")
|
|
return nil
|
|
}
|
|
|
|
fmt.Printf("[INFO] Rolled back group %d\n", group.ID)
|
|
for _, migration := range group.Migrations {
|
|
fmt.Printf(" ↩️ %s\n", migration.Name)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// migrateStatus shows migration status
|
|
func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error {
|
|
ms, err := migrator.MigrationsWithStatus(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "get migration status")
|
|
}
|
|
|
|
fmt.Println("╔══════════════════════════════════════════════════════════╗")
|
|
fmt.Println("║ DATABASE MIGRATION STATUS ║")
|
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
|
|
|
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
|
_, _ = fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tMIGRATED AT")
|
|
_, _ = fmt.Fprintln(w, "------\t---------\t-----\t-----------")
|
|
|
|
appliedCount := 0
|
|
for _, m := range ms {
|
|
status := "⏳ Pending"
|
|
migratedAt := "-"
|
|
group := "-"
|
|
|
|
if m.GroupID > 0 {
|
|
status = "✅ Applied"
|
|
appliedCount++
|
|
group = fmt.Sprint(m.GroupID)
|
|
if !m.MigratedAt.IsZero() {
|
|
migratedAt = m.MigratedAt.Format("2006-01-02 15:04:05")
|
|
}
|
|
}
|
|
|
|
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, migratedAt)
|
|
}
|
|
|
|
_ = w.Flush()
|
|
|
|
fmt.Printf("\n📊 Summary: %d applied, %d pending\n\n",
|
|
appliedCount, len(ms)-appliedCount)
|
|
|
|
return nil
|
|
}
|
|
|
|
// migrateDryRun shows what migrations would run without applying them
|
|
func migrateDryRun(ctx context.Context, migrator *migrate.Migrator) error {
|
|
group, err := migrator.Migrate(ctx, migrate.WithNopMigration())
|
|
if err != nil {
|
|
return errors.Wrap(err, "dry-run")
|
|
}
|
|
|
|
if group.IsZero() {
|
|
fmt.Println("[INFO] No pending migrations")
|
|
return nil
|
|
}
|
|
|
|
fmt.Println("[INFO] Pending migrations (dry-run):")
|
|
for _, migration := range group.Migrations {
|
|
fmt.Printf(" 📋 %s\n", migration.Name)
|
|
}
|
|
fmt.Printf("[INFO] Would migrate to group %d\n", group.ID)
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateMigrations ensures migrations compile before running
|
|
func validateMigrations(ctx context.Context) error {
|
|
cmd := exec.CommandContext(ctx, "go", "build",
|
|
"-o", "/dev/null", "./cmd/oslstats/migrations")
|
|
|
|
output, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
fmt.Println("[ERROR] Migration validation failed!")
|
|
fmt.Println(string(output))
|
|
return errors.Wrap(err, "migration build failed")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// acquireMigrationLock prevents concurrent migrations using PostgreSQL advisory lock
|
|
func acquireMigrationLock(ctx context.Context, conn *bun.DB) error {
|
|
const lockID = 1234567890 // Arbitrary unique ID for migration lock
|
|
const timeoutSeconds = 300 // 5 minutes
|
|
|
|
// Set statement timeout for this session
|
|
_, err := conn.ExecContext(ctx,
|
|
fmt.Sprintf("SET statement_timeout = '%ds'", timeoutSeconds))
|
|
if err != nil {
|
|
return errors.Wrap(err, "set timeout")
|
|
}
|
|
|
|
var acquired bool
|
|
err = conn.NewRaw("SELECT pg_try_advisory_lock(?)", lockID).
|
|
Scan(ctx, &acquired)
|
|
if err != nil {
|
|
return errors.Wrap(err, "pg_try_advisory_lock")
|
|
}
|
|
|
|
if !acquired {
|
|
return errors.New("migration already in progress (could not acquire lock)")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// releaseMigrationLock releases the migration lock
|
|
func releaseMigrationLock(ctx context.Context, conn *bun.DB) {
|
|
const lockID = 1234567890
|
|
|
|
_, err := conn.NewRaw("SELECT pg_advisory_unlock(?)", lockID).Exec(ctx)
|
|
if err != nil {
|
|
fmt.Printf("[WARN] Failed to release migration lock: %v\n", err)
|
|
} else {
|
|
fmt.Println("[INFO] Migration lock released")
|
|
}
|
|
}
|
|
|
|
// createMigration generates a new migration file
|
|
func createMigration(name string) error {
|
|
if name == "" {
|
|
return errors.New("migration name cannot be empty")
|
|
}
|
|
|
|
// Sanitize name (replace spaces with underscores, lowercase)
|
|
name = strings.ToLower(strings.ReplaceAll(name, " ", "_"))
|
|
|
|
// Generate timestamp
|
|
timestamp := time.Now().Format("20060102150405")
|
|
filename := fmt.Sprintf("cmd/oslstats/migrations/%s_%s.go", timestamp, name)
|
|
|
|
// Template
|
|
template := `package migrations
|
|
|
|
import (
|
|
"context"
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
func init() {
|
|
Migrations.MustRegister(
|
|
// UP migration
|
|
func(ctx context.Context, dbConn *bun.DB) error {
|
|
// Add your migration code here
|
|
return nil
|
|
},
|
|
// DOWN migration
|
|
func(ctx context.Context, dbConn *bun.DB) error {
|
|
// Add your rollback code here
|
|
return nil
|
|
},
|
|
)
|
|
}
|
|
`
|
|
|
|
// Write file
|
|
if err := os.WriteFile(filename, []byte(template), 0o644); err != nil {
|
|
return errors.Wrap(err, "write migration file")
|
|
}
|
|
|
|
fmt.Printf("✅ Created migration: %s\n", filename)
|
|
fmt.Println("📝 Next steps:")
|
|
fmt.Println(" 1. Edit the file and implement the UP and DOWN functions")
|
|
fmt.Println(" 2. Run: make migrate")
|
|
|
|
return nil
|
|
}
|
|
|
|
// resetDatabase drops and recreates all tables (destructive)
|
|
func resetDatabase(ctx context.Context, cfg *config.Config) error {
|
|
fmt.Println("⚠️ WARNING - This will DELETE ALL DATA in the database!")
|
|
fmt.Print("Type 'yes' to continue: ")
|
|
|
|
reader := bufio.NewReader(os.Stdin)
|
|
response, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
return errors.Wrap(err, "read input")
|
|
}
|
|
|
|
response = strings.TrimSpace(response)
|
|
if response != "yes" {
|
|
fmt.Println("❌ Reset cancelled")
|
|
return nil
|
|
}
|
|
conn, close := setupBun(cfg)
|
|
defer func() { _ = close() }()
|
|
|
|
models := []any{
|
|
(*db.User)(nil),
|
|
(*db.DiscordToken)(nil),
|
|
}
|
|
|
|
for _, model := range models {
|
|
if err := conn.ResetModel(ctx, model); err != nil {
|
|
return errors.Wrap(err, "reset model")
|
|
}
|
|
}
|
|
|
|
fmt.Println("✅ Database reset complete")
|
|
return nil
|
|
}
|