Files
projectreshoot/tests/database.go

92 lines
1.9 KiB
Go

package tests
import (
"context"
"database/sql"
"io/fs"
"os"
"path/filepath"
"github.com/pkg/errors"
"github.com/pressly/goose/v3"
_ "modernc.org/sqlite"
)
func findMigrations() (*fs.FS, error) {
dir, err := os.Getwd()
if err != nil {
return nil, err
}
for {
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
migrationsdir := os.DirFS(filepath.Join(dir, "migrations"))
return &migrationsdir, nil
}
parent := filepath.Dir(dir)
if parent == dir { // Reached root
return nil, errors.New("Unable to locate migrations directory")
}
dir = parent
}
}
func findTestData() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
return filepath.Join(dir, "tests", "testdata.sql"), nil
}
parent := filepath.Dir(dir)
if parent == dir { // Reached root
return "", errors.New("Unable to locate test data")
}
dir = parent
}
}
func SetupTestDB(version int64) (*sql.DB, error) {
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
migrations, err := findMigrations()
if err != nil {
return nil, errors.Wrap(err, "findMigrations")
}
provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations)
if err != nil {
return nil, errors.Wrap(err, "goose.NewProvider")
}
ctx := context.Background()
if _, err := provider.UpTo(ctx, version); err != nil {
return nil, errors.Wrap(err, "provider.UpTo")
}
// NOTE: ==================================================
// Load the test data
dataPath, err := findTestData()
if err != nil {
return nil, errors.Wrap(err, "findSchema")
}
sqlBytes, err := os.ReadFile(dataPath)
if err != nil {
return nil, errors.Wrap(err, "os.ReadFile")
}
dataSQL := string(sqlBytes)
_, err = conn.Exec(dataSQL)
if err != nil {
return nil, errors.Wrap(err, "tx.Exec")
}
return conn, nil
}