// Package migrate provides functions for managing database migrations package migrate import ( "bufio" "context" "fmt" "os" "os/exec" "strconv" "strings" "text/tabwriter" "time" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db/migrations" "github.com/pkg/errors" "github.com/uptrace/bun/migrate" ) // RunMigrations executes database migrations func RunMigrations(ctx context.Context, cfg *config.Config, command string, countStr string) error { conn := db.NewDB(cfg.DB) defer func() { _ = conn.Close() }() migrator := migrate.NewMigrator(conn.DB, migrations.Migrations) // Initialize migration tables if err := migrator.Init(ctx); err != nil { return errors.Wrap(err, "migrator.Init") } switch command { case "up": err := migrateUp(ctx, migrator, conn, cfg, countStr) if err != nil { // On error, automatically rollback the migrations that were just applied fmt.Println("[WARN] Migration failed, attempting automatic rollback...") // We need to figure out how many migrations were applied in this batch // For now, we'll skip automatic rollback since it's complex with the new count system // The user can manually rollback if needed return err } return err case "rollback": return migrateRollback(ctx, migrator, conn, cfg, countStr) case "status": return migrateStatus(ctx, migrator) default: return fmt.Errorf("unknown migration command: %s", command) } } // migrateUp runs pending migrations func migrateUp(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error { // Parse count parameter count, all, err := parseMigrationCount(countStr) if err != nil { return errors.Wrap(err, "parse migration count") } fmt.Println("[INFO] Step 1/5: Validating migrations...") if err := validateMigrations(ctx); err != nil { return err } fmt.Println("[INFO] Migration validation passed āœ“") fmt.Println("[INFO] Step 2/5: Checking for pending migrations...") // Check for pending migrations using MigrationsWithStatus (read-only) ms, err := migrator.MigrationsWithStatus(ctx) if err != nil { return errors.Wrap(err, "get migration status") } unapplied := ms.Unapplied() if len(unapplied) == 0 { fmt.Println("[INFO] No pending migrations") return nil } // Select which migrations to apply toApply := selectMigrationsToApply(unapplied, count, all) if len(toApply) == 0 { fmt.Println("[INFO] No migrations to run") return nil } // Print what we're about to do if all { fmt.Printf("[INFO] Running all %d pending migration(s):\n", len(toApply)) } else { fmt.Printf("[INFO] Running %d migration(s):\n", len(toApply)) } for _, m := range toApply { fmt.Printf(" šŸ“‹ %s\n", m.Name) } // Create backup unless --no-backup flag is set if !cfg.Flags.MigrateNoBackup { fmt.Println("[INFO] Step 3/5: Creating backup...") _, err := db.CreateBackup(ctx, cfg.DB, "migration") if err != nil { return errors.Wrap(err, "create backup") } // Clean old backups if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil { fmt.Printf("[WARN] Failed to clean old backups: %v\n", err) } } else { fmt.Println("[INFO] Step 3/5: Skipping backup (--no-backup flag set)") } // Acquire migration lock fmt.Println("[INFO] Step 4/5: Acquiring migration lock...") if err := acquireMigrationLock(ctx, conn); err != nil { return errors.Wrap(err, "acquire migration lock") } defer releaseMigrationLock(ctx, conn) fmt.Println("[INFO] Migration lock acquired") // Run migrations fmt.Println("[INFO] Step 5/5: Applying migrations...") group, err := executeUpMigrations(ctx, migrator, toApply) if err != nil { return errors.Wrap(err, "execute migrations") } if group.IsZero() { fmt.Println("[INFO] No migrations to run") return nil } fmt.Printf("[INFO] Migrated to group %d\n", group.ID) for _, migration := range group.Migrations { fmt.Printf(" āœ… %s\n", migration.Name) } return nil } // migrateRollback rolls back migrations func migrateRollback(ctx context.Context, migrator *migrate.Migrator, conn *db.DB, cfg *config.Config, countStr string) error { // Parse count parameter count, all, err := parseMigrationCount(countStr) if err != nil { return errors.Wrap(err, "parse migration count") } // Get all migrations with status ms, err := migrator.MigrationsWithStatus(ctx) if err != nil { return errors.Wrap(err, "get migration status") } applied := ms.Applied() if len(applied) == 0 { fmt.Println("[INFO] No migrations to rollback") return nil } // Select which migrations to rollback toRollback := selectMigrationsToRollback(applied, count, all) if len(toRollback) == 0 { fmt.Println("[INFO] No migrations to rollback") return nil } // Print what we're about to do if all { fmt.Printf("[INFO] Rolling back all %d migration(s):\n", len(toRollback)) } else { fmt.Printf("[INFO] Rolling back %d migration(s):\n", len(toRollback)) } for _, m := range toRollback { fmt.Printf(" šŸ“‹ %s (group %d)\n", m.Name, m.GroupID) } // Create backup unless --no-backup flag is set if !cfg.Flags.MigrateNoBackup { fmt.Println("[INFO] Creating backup before rollback...") _, err := db.CreateBackup(ctx, cfg.DB, "rollback") if err != nil { return errors.Wrap(err, "create backup") } // Clean old backups if err := db.CleanOldBackups(cfg.DB, cfg.DB.BackupRetention); err != nil { fmt.Printf("[WARN] Failed to clean old backups: %v\n", err) } } else { fmt.Println("[INFO] Skipping backup (--no-backup flag set)") } // Acquire migration lock fmt.Println("[INFO] Acquiring migration lock...") if err := acquireMigrationLock(ctx, conn); err != nil { return errors.Wrap(err, "acquire migration lock") } defer releaseMigrationLock(ctx, conn) fmt.Println("[INFO] Migration lock acquired") // Rollback fmt.Println("[INFO] Executing rollback...") rolledBack, err := executeDownMigrations(ctx, migrator, toRollback) if err != nil { return errors.Wrap(err, "execute rollback") } fmt.Printf("[INFO] Successfully rolled back %d migration(s)\n", len(rolledBack)) for _, migration := range rolledBack { fmt.Printf(" ā†©ļø %s\n", migration.Name) } return nil } // migrateStatus shows migration status func migrateStatus(ctx context.Context, migrator *migrate.Migrator) error { ms, err := migrator.MigrationsWithStatus(ctx) if err != nil { return errors.Wrap(err, "get migration status") } fmt.Println("╔══════════════════════════════════════════════════════════╗") fmt.Println("ā•‘ DATABASE MIGRATION STATUS ā•‘") fmt.Println("ā•šā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•") w := tabwriter.NewWriter(os.Stdout, 0, 0, 1, ' ', 0) _, _ = fmt.Fprintln(w, "STATUS\tMIGRATION\tGROUP\tCOMMENT") _, _ = fmt.Fprintln(w, "----------\t---------------\t-----\t---------------------------") appliedCount := 0 for _, m := range ms { status := "ā³ Pending" group := "-" if m.GroupID > 0 { status = "āœ… Applied" appliedCount++ group = fmt.Sprint(m.GroupID) } _, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", status, m.Name, group, m.Comment) } _ = w.Flush() fmt.Printf("\nšŸ“Š Summary: %d applied, %d pending\n\n", appliedCount, len(ms)-appliedCount) return nil } // validateMigrations ensures migrations compile before running func validateMigrations(ctx context.Context) error { cmd := exec.CommandContext(ctx, "go", "build", "-o", "/dev/null", "./internal/db/migrations") output, err := cmd.CombinedOutput() if err != nil { fmt.Println("[ERROR] Migration validation failed!") fmt.Println(string(output)) return errors.Wrap(err, "migration build failed") } return nil } // acquireMigrationLock prevents concurrent migrations using PostgreSQL advisory lock func acquireMigrationLock(ctx context.Context, conn *db.DB) error { const lockID = 1234567890 // Arbitrary unique ID for migration lock const timeoutSeconds = 300 // 5 minutes // Set statement timeout for this session _, err := conn.ExecContext(ctx, fmt.Sprintf("SET statement_timeout = '%ds'", timeoutSeconds)) if err != nil { return errors.Wrap(err, "set timeout") } var acquired bool err = conn.NewRaw("SELECT pg_try_advisory_lock(?)", lockID). Scan(ctx, &acquired) if err != nil { return errors.Wrap(err, "pg_try_advisory_lock") } if !acquired { return errors.New("migration already in progress (could not acquire lock)") } return nil } // releaseMigrationLock releases the migration lock func releaseMigrationLock(ctx context.Context, conn *db.DB) { const lockID = 1234567890 _, err := conn.NewRaw("SELECT pg_advisory_unlock(?)", lockID).Exec(ctx) if err != nil { fmt.Printf("[WARN] Failed to release migration lock: %v\n", err) } else { fmt.Println("[INFO] Migration lock released") } } // CreateMigration generates a new migration file func CreateMigration(name string) error { if name == "" { return errors.New("migration name cannot be empty") } // Sanitize name (replace spaces with underscores, lowercase) name = strings.ToLower(strings.ReplaceAll(name, " ", "_")) // Generate timestamp timestamp := time.Now().Format("20060102150405") filename := fmt.Sprintf("internal/db/migrations/%s_%s.go", timestamp, name) // Template template := `package migrations import ( "context" "github.com/uptrace/bun" ) func init() { Migrations.MustRegister( // UP migration func(ctx context.Context, conn *bun.DB) error { // Add your migration code here return nil }, // DOWN migration func(ctx context.Context, conn *bun.DB) error { // Add your rollback code here return nil }, ) } ` // Write file if err := os.WriteFile(filename, []byte(template), 0o644); err != nil { return errors.Wrap(err, "write migration file") } fmt.Printf("āœ… Created migration: %s\n", filename) fmt.Println("šŸ“ Next steps:") fmt.Println(" 1. Edit the file and implement the UP and DOWN functions") fmt.Println(" 2. Run: just migrate up") return nil } // parseMigrationCount parses a migration count string // Returns: (count, all, error) // - "" (empty) → (1, false, nil) - default to 1 // - "all" → (0, true, nil) - special case for all // - "5" → (5, false, nil) - specific count // - "invalid" → (0, false, error) func parseMigrationCount(value string) (int, bool, error) { // Default to 1 if empty if value == "" { return 1, false, nil } // Special case for "all" if value == "all" { return 0, true, nil } // Parse as integer count, err := strconv.Atoi(value) if err != nil { return 0, false, errors.New("migration count must be a positive integer or 'all'") } if count < 1 { return 0, false, errors.New("migration count must be a positive integer (1 or greater)") } return count, false, nil } // selectMigrationsToApply returns the subset of unapplied migrations to run func selectMigrationsToApply(unapplied migrate.MigrationSlice, count int, all bool) migrate.MigrationSlice { if all { return unapplied } count = min(count, len(unapplied)) return unapplied[:count] } // selectMigrationsToRollback returns the subset of applied migrations to rollback // Returns migrations in reverse chronological order (most recent first) func selectMigrationsToRollback(applied migrate.MigrationSlice, count int, all bool) migrate.MigrationSlice { if len(applied) == 0 || all { return applied } count = min(count, len(applied)) return applied[:count] } // executeUpMigrations executes a subset of UP migrations func executeUpMigrations(ctx context.Context, migrator *migrate.Migrator, migrations migrate.MigrationSlice) (*migrate.MigrationGroup, error) { if len(migrations) == 0 { return &migrate.MigrationGroup{}, nil } // Get the next group ID ms, err := migrator.MigrationsWithStatus(ctx) if err != nil { return nil, errors.Wrap(err, "get migration status") } lastGroup := ms.LastGroup() groupID := int64(1) if lastGroup.ID > 0 { groupID = lastGroup.ID + 1 } // Create the migration group group := &migrate.MigrationGroup{ ID: groupID, Migrations: make(migrate.MigrationSlice, 0, len(migrations)), } // Execute each migration for i := range migrations { migration := &migrations[i] migration.GroupID = groupID // Mark as applied before execution (Bun's default behavior) if err := migrator.MarkApplied(ctx, migration); err != nil { return group, errors.Wrap(err, "mark applied") } // Add to group group.Migrations = append(group.Migrations, *migration) // Execute the UP function if migration.Up != nil { if err := migration.Up(ctx, migrator, migration); err != nil { return group, errors.Wrap(err, fmt.Sprintf("migration %s failed", migration.Name)) } } } return group, nil } // executeDownMigrations executes a subset of DOWN migrations func executeDownMigrations(ctx context.Context, migrator *migrate.Migrator, migrations migrate.MigrationSlice) (migrate.MigrationSlice, error) { rolledBack := make(migrate.MigrationSlice, 0, len(migrations)) // Execute each migration in order (already reversed) for i := range migrations { migration := &migrations[i] // Execute the DOWN function if migration.Down != nil { if err := migration.Down(ctx, migrator, migration); err != nil { return rolledBack, errors.Wrap(err, fmt.Sprintf("rollback %s failed", migration.Name)) } } // Mark as unapplied after execution if err := migrator.MarkUnapplied(ctx, migration); err != nil { return rolledBack, errors.Wrap(err, "mark unapplied") } rolledBack = append(rolledBack, *migration) } return rolledBack, nil } // ResetDatabase drops and recreates all tables (destructive) func ResetDatabase(ctx context.Context, cfg *config.Config) error { fmt.Println("āš ļø WARNING - This will DELETE ALL DATA in the database!") fmt.Print("Type 'yes' to continue: ") reader := bufio.NewReader(os.Stdin) response, err := reader.ReadString('\n') if err != nil { return errors.Wrap(err, "read input") } response = strings.TrimSpace(response) if response != "yes" { fmt.Println("āŒ Reset cancelled") return nil } conn := db.NewDB(cfg.DB) defer func() { _ = conn.Close() }() models := conn.RegisterModels() for _, model := range models { if err := conn.ResetModel(ctx, model); err != nil { return errors.Wrap(err, "reset model") } } fmt.Println("āœ… Database reset complete") return nil }