Updated tests to use migrations for db init for consistency
This commit is contained in:
@@ -1,63 +1,83 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/pressly/goose/v3"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func findSQLFile(filename string) (string, error) {
|
||||
func findMigrations() (*fs.FS, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||
migrationsdir := os.DirFS(filepath.Join(dir, "migrations"))
|
||||
return &migrationsdir, nil
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir { // Reached root
|
||||
return nil, errors.New("Unable to locate migrations directory")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
func findTestData() (string, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, filename)); err == nil {
|
||||
return filepath.Join(dir, filename), nil
|
||||
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||
return filepath.Join(dir, "tests", "testdata.sql"), nil
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir { // Reached root
|
||||
return "", errors.New(fmt.Sprintf("Unable to locate %s", filename))
|
||||
return "", errors.New("Unable to locate test data")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
// SetupTestDB initializes a test SQLite database with mock data
|
||||
func SetupTestDB() (*sql.DB, error) {
|
||||
func SetupTestDB(version int64) (*sql.DB, error) {
|
||||
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
// Setup the test database
|
||||
schemaPath, err := findSQLFile("schema.sql")
|
||||
|
||||
migrations, err := findMigrations()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "findSchema")
|
||||
return nil, errors.Wrap(err, "findMigrations")
|
||||
}
|
||||
provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "goose.NewProvider")
|
||||
}
|
||||
ctx := context.Background()
|
||||
if _, err := provider.UpTo(ctx, version); err != nil {
|
||||
return nil, errors.Wrap(err, "provider.UpTo")
|
||||
}
|
||||
|
||||
sqlBytes, err := os.ReadFile(schemaPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "os.ReadFile")
|
||||
}
|
||||
schemaSQL := string(sqlBytes)
|
||||
|
||||
_, err = conn.Exec(schemaSQL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
// NOTE: ==================================================
|
||||
// Load the test data
|
||||
dataPath, err := findSQLFile("testdata.sql")
|
||||
dataPath, err := findTestData()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "findSchema")
|
||||
}
|
||||
sqlBytes, err = os.ReadFile(dataPath)
|
||||
sqlBytes, err := os.ReadFile(dataPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "os.ReadFile")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user