117 lines
2.6 KiB
Go
117 lines
2.6 KiB
Go
package tests
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"io/fs"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/pressly/goose/v3"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
func findMigrations() (*fs.FS, error) {
|
|
dir, err := os.Getwd()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for {
|
|
if _, err := os.Stat(filepath.Join(dir, "Makefile")); err == nil {
|
|
migrationsdir := os.DirFS(filepath.Join(dir, "cmd", "migrate", "migrations"))
|
|
return &migrationsdir, nil
|
|
}
|
|
|
|
parent := filepath.Dir(dir)
|
|
if parent == dir { // Reached root
|
|
return nil, errors.New("Unable to locate migrations directory")
|
|
}
|
|
dir = parent
|
|
}
|
|
}
|
|
|
|
func findTestData() (string, error) {
|
|
dir, err := os.Getwd()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for {
|
|
if _, err := os.Stat(filepath.Join(dir, "Makefile")); err == nil {
|
|
return filepath.Join(dir, "pkg", "tests", "testdata.sql"), nil
|
|
}
|
|
|
|
parent := filepath.Dir(dir)
|
|
if parent == dir { // Reached root
|
|
return "", errors.New("Unable to locate test data")
|
|
}
|
|
dir = parent
|
|
}
|
|
}
|
|
|
|
func migrateTestDB(wconn *sql.DB, version int64) error {
|
|
migrations, err := findMigrations()
|
|
if err != nil {
|
|
return errors.Wrap(err, "findMigrations")
|
|
}
|
|
provider, err := goose.NewProvider(goose.DialectSQLite3, wconn, *migrations)
|
|
if err != nil {
|
|
return errors.Wrap(err, "goose.NewProvider")
|
|
}
|
|
ctx := context.Background()
|
|
if _, err := provider.UpTo(ctx, version); err != nil {
|
|
return errors.Wrap(err, "provider.UpTo")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func loadTestData(wconn *sql.DB) error {
|
|
dataPath, err := findTestData()
|
|
if err != nil {
|
|
return errors.Wrap(err, "findSchema")
|
|
}
|
|
sqlBytes, err := os.ReadFile(dataPath)
|
|
if err != nil {
|
|
return errors.Wrap(err, "os.ReadFile")
|
|
}
|
|
dataSQL := string(sqlBytes)
|
|
|
|
_, err = wconn.Exec(dataSQL)
|
|
if err != nil {
|
|
return errors.Wrap(err, "tx.Exec")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Returns two db connection handles. First is a readwrite connection, second
|
|
// is a read only connection
|
|
func SetupTestDB(version int64) (*sql.DB, *sql.DB, error) {
|
|
opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
|
|
file := fmt.Sprintf("file::memory:?cache=shared&%s", opts)
|
|
wconn, err := sql.Open("sqlite", file)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "sql.Open")
|
|
}
|
|
|
|
err = migrateTestDB(wconn, version)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "migrateTestDB")
|
|
}
|
|
err = loadTestData(wconn)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "loadTestData")
|
|
}
|
|
|
|
opts = "_synchronous=NORMAL&mode=ro"
|
|
file = fmt.Sprintf("file::memory:?cache=shared&%s", opts)
|
|
rconn, err := sql.Open("sqlite", file)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "sql.Open")
|
|
}
|
|
return wconn, rconn, nil
|
|
}
|