Merge pull request #12 from Haelnorr/dbbackups
Testing db migrations updates
This commit is contained in:
14
.github/workflows/deploy_staging.yaml
vendored
14
.github/workflows/deploy_staging.yaml
vendored
@@ -33,11 +33,15 @@ jobs:
|
|||||||
- name: Build the binary
|
- name: Build the binary
|
||||||
run: make build SUFFIX=-staging-$GITHUB_SHA
|
run: make build SUFFIX=-staging-$GITHUB_SHA
|
||||||
|
|
||||||
|
- name: Build the migration binary
|
||||||
|
run: make migrate SUFFIX=-staging-$GITHUB_SHA
|
||||||
|
|
||||||
- name: Deploy to Server
|
- name: Deploy to Server
|
||||||
env:
|
env:
|
||||||
USER: deploy
|
USER: deploy
|
||||||
HOST: projectreshoot.com
|
HOST: projectreshoot.com
|
||||||
DIR: /home/deploy/releases/staging
|
DIR: /home/deploy/releases/staging
|
||||||
|
MIG_DIR: /home/deploy/migration-bin
|
||||||
DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }}
|
DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }}
|
||||||
run: |
|
run: |
|
||||||
mkdir -p ~/.ssh
|
mkdir -p ~/.ssh
|
||||||
@@ -49,7 +53,13 @@ jobs:
|
|||||||
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
||||||
|
|
||||||
scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
|
scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_staging.sh $GITHUB_SHA
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
||||||
|
scp -i ~/.ssh/id_ed25519 prmigrate-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
||||||
|
|
||||||
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
||||||
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
||||||
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrationcleanup.sh $USER@$HOST:$MIG_DIR
|
||||||
|
|
||||||
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy.sh $GITHUB_SHA staging
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,6 +4,7 @@ query.sql
|
|||||||
.logs/
|
.logs/
|
||||||
server.log
|
server.log
|
||||||
tmp/
|
tmp/
|
||||||
|
prmigrate
|
||||||
projectreshoot
|
projectreshoot
|
||||||
static/css/output.css
|
static/css/output.css
|
||||||
view/**/*_templ.go
|
view/**/*_templ.go
|
||||||
|
|||||||
6
Makefile
6
Makefile
@@ -1,5 +1,6 @@
|
|||||||
# Makefile
|
# Makefile
|
||||||
.PHONY: build
|
.PHONY: build
|
||||||
|
.PHONY: migrate
|
||||||
|
|
||||||
BINARY_NAME=projectreshoot
|
BINARY_NAME=projectreshoot
|
||||||
|
|
||||||
@@ -29,3 +30,8 @@ test:
|
|||||||
|
|
||||||
clean:
|
clean:
|
||||||
go clean
|
go clean
|
||||||
|
|
||||||
|
migrate:
|
||||||
|
go mod tidy && \
|
||||||
|
go generate && \
|
||||||
|
go build -ldflags="-w -s" -o prmigrate${SUFFIX} ./migrate
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ type Config struct {
|
|||||||
ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
|
ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
|
||||||
WriteTimeout time.Duration // Timeout for writing requests in seconds
|
WriteTimeout time.Duration // Timeout for writing requests in seconds
|
||||||
IdleTimeout time.Duration // Timeout for idle connections in seconds
|
IdleTimeout time.Duration // Timeout for idle connections in seconds
|
||||||
DBName string // Filename of the db (doesnt include file extension)
|
DBName string // Filename of the db - hardcoded and doubles as DB version
|
||||||
DBLockTimeout time.Duration // Timeout for acquiring database lock
|
DBLockTimeout time.Duration // Timeout for acquiring database lock
|
||||||
SecretKey string // Secret key for signing tokens
|
SecretKey string // Secret key for signing tokens
|
||||||
AccessTokenExpiry int64 // Access token expiry in minutes
|
AccessTokenExpiry int64 // Access token expiry in minutes
|
||||||
@@ -87,7 +87,7 @@ func GetConfig(args map[string]string) (*Config, error) {
|
|||||||
ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
|
ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
|
||||||
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
|
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
|
||||||
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
|
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
|
||||||
DBName: GetEnvDefault("DB_NAME", "projectreshoot"),
|
DBName: "00001",
|
||||||
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
|
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
|
||||||
SecretKey: os.Getenv("SECRET_KEY"),
|
SecretKey: os.Getenv("SECRET_KEY"),
|
||||||
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
|
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
|
||||||
|
|||||||
205
db/connection.go
205
db/connection.go
@@ -1,10 +1,9 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"strconv"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
@@ -12,179 +11,47 @@ import (
|
|||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SafeConn struct {
|
|
||||||
db *sql.DB
|
|
||||||
readLockCount uint32
|
|
||||||
globalLockStatus uint32
|
|
||||||
globalLockRequested uint32
|
|
||||||
logger *zerolog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
|
||||||
return &SafeConn{db: db, logger: logger}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extends sql.Tx for use with SafeConn
|
|
||||||
type SafeTX struct {
|
|
||||||
tx *sql.Tx
|
|
||||||
sc *SafeConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SafeConn) acquireGlobalLock() bool {
|
|
||||||
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
conn.globalLockStatus = 1
|
|
||||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
|
||||||
Msg("Global lock acquired")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SafeConn) releaseGlobalLock() {
|
|
||||||
conn.globalLockStatus = 0
|
|
||||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
|
||||||
Msg("Global lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SafeConn) acquireReadLock() bool {
|
|
||||||
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
conn.readLockCount += 1
|
|
||||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
|
||||||
Msg("Read lock acquired")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SafeConn) releaseReadLock() {
|
|
||||||
conn.readLockCount -= 1
|
|
||||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
|
||||||
Msg("Read lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Starts a new transaction based on the current context. Will cancel if
|
|
||||||
// the context is closed/cancelled/done
|
|
||||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
|
||||||
lockAcquired := make(chan struct{})
|
|
||||||
lockCtx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-lockCtx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
if conn.acquireReadLock() {
|
|
||||||
close(lockAcquired)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-lockAcquired:
|
|
||||||
tx, err := conn.db.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
conn.releaseReadLock()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &SafeTX{tx: tx, sc: conn}, nil
|
|
||||||
case <-ctx.Done():
|
|
||||||
cancel()
|
|
||||||
return nil, errors.New("Transaction time out due to database lock")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query the database inside the transaction
|
|
||||||
func (stx *SafeTX) Query(
|
|
||||||
ctx context.Context,
|
|
||||||
query string,
|
|
||||||
args ...interface{},
|
|
||||||
) (*sql.Rows, error) {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return nil, errors.New("Cannot query without a transaction")
|
|
||||||
}
|
|
||||||
return stx.tx.QueryContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec a statement on the database inside the transaction
|
|
||||||
func (stx *SafeTX) Exec(
|
|
||||||
ctx context.Context,
|
|
||||||
query string,
|
|
||||||
args ...interface{},
|
|
||||||
) (sql.Result, error) {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return nil, errors.New("Cannot exec without a transaction")
|
|
||||||
}
|
|
||||||
return stx.tx.ExecContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commit the current transaction and release the read lock
|
|
||||||
func (stx *SafeTX) Commit() error {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return errors.New("Cannot commit without a transaction")
|
|
||||||
}
|
|
||||||
err := stx.tx.Commit()
|
|
||||||
stx.tx = nil
|
|
||||||
|
|
||||||
stx.sc.releaseReadLock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Abort the current transaction, releasing the read lock
|
|
||||||
func (stx *SafeTX) Rollback() error {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return errors.New("Cannot rollback without a transaction")
|
|
||||||
}
|
|
||||||
err := stx.tx.Rollback()
|
|
||||||
stx.tx = nil
|
|
||||||
stx.sc.releaseReadLock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Acquire a global lock, preventing all transactions
|
|
||||||
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
|
||||||
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
|
||||||
conn.globalLockRequested = 1
|
|
||||||
defer func() { conn.globalLockRequested = 0 }()
|
|
||||||
timeout := time.After(timeoutAfter)
|
|
||||||
attempt := 0
|
|
||||||
for {
|
|
||||||
if conn.acquireGlobalLock() {
|
|
||||||
conn.logger.Info().Msg("Global database lock acquired")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
|
||||||
return
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
attempt++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release the global lock
|
|
||||||
func (conn *SafeConn) Resume() {
|
|
||||||
conn.releaseGlobalLock()
|
|
||||||
conn.logger.Info().Msg("Global database lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the database connection
|
|
||||||
func (conn *SafeConn) Close() error {
|
|
||||||
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
|
||||||
conn.acquireGlobalLock()
|
|
||||||
defer conn.releaseGlobalLock()
|
|
||||||
conn.logger.Debug().Msg("Closing database connection")
|
|
||||||
return conn.db.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a database connection handle for the DB
|
// Returns a database connection handle for the DB
|
||||||
func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) {
|
func ConnectToDatabase(
|
||||||
|
dbName string,
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
) (*SafeConn, error) {
|
||||||
file := fmt.Sprintf("file:%s.db", dbName)
|
file := fmt.Sprintf("file:%s.db", dbName)
|
||||||
db, err := sql.Open("sqlite", file)
|
db, err := sql.Open("sqlite", file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "sql.Open")
|
return nil, errors.Wrap(err, "sql.Open")
|
||||||
}
|
}
|
||||||
|
version, err := strconv.Atoi(dbName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "strconv.Atoi")
|
||||||
|
}
|
||||||
|
err = checkDBVersion(db, version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "checkDBVersion")
|
||||||
|
}
|
||||||
conn := MakeSafe(db, logger)
|
conn := MakeSafe(db, logger)
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check the database version
|
||||||
|
func checkDBVersion(db *sql.DB, expectVer int) error {
|
||||||
|
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
|
||||||
|
ORDER BY version_id DESC LIMIT 1`
|
||||||
|
rows, err := db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "checkDBVersion")
|
||||||
|
}
|
||||||
|
if rows.Next() {
|
||||||
|
var version int
|
||||||
|
err = rows.Scan(&version)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "rows.Scan")
|
||||||
|
}
|
||||||
|
if version != expectVer {
|
||||||
|
return errors.New("Version mismatch")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errors.New("No version found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
129
db/safeconn.go
Normal file
129
db/safeconn.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SafeConn struct {
|
||||||
|
db *sql.DB
|
||||||
|
readLockCount uint32
|
||||||
|
globalLockStatus uint32
|
||||||
|
globalLockRequested uint32
|
||||||
|
logger *zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the provided db handle safe and attach a logger to it
|
||||||
|
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
||||||
|
return &SafeConn{db: db, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempts to acquire a global lock on the database connection
|
||||||
|
func (conn *SafeConn) acquireGlobalLock() bool {
|
||||||
|
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conn.globalLockStatus = 1
|
||||||
|
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||||
|
Msg("Global lock acquired")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Releases a global lock on the database connection
|
||||||
|
func (conn *SafeConn) releaseGlobalLock() {
|
||||||
|
conn.globalLockStatus = 0
|
||||||
|
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||||
|
Msg("Global lock released")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire a read lock on the connection. Multiple read locks can be acquired
|
||||||
|
// at the same time
|
||||||
|
func (conn *SafeConn) acquireReadLock() bool {
|
||||||
|
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conn.readLockCount += 1
|
||||||
|
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||||
|
Msg("Read lock acquired")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release a read lock. Decrements read lock count by 1
|
||||||
|
func (conn *SafeConn) releaseReadLock() {
|
||||||
|
conn.readLockCount -= 1
|
||||||
|
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||||
|
Msg("Read lock released")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Starts a new transaction based on the current context. Will cancel if
|
||||||
|
// the context is closed/cancelled/done
|
||||||
|
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
||||||
|
lockAcquired := make(chan struct{})
|
||||||
|
lockCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-lockCtx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if conn.acquireReadLock() {
|
||||||
|
close(lockAcquired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-lockAcquired:
|
||||||
|
tx, err := conn.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
conn.releaseReadLock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &SafeTX{tx: tx, sc: conn}, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancel()
|
||||||
|
return nil, errors.New("Transaction time out due to database lock")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire a global lock, preventing all transactions
|
||||||
|
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
||||||
|
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
||||||
|
conn.globalLockRequested = 1
|
||||||
|
defer func() { conn.globalLockRequested = 0 }()
|
||||||
|
timeout := time.After(timeoutAfter)
|
||||||
|
attempt := 0
|
||||||
|
for {
|
||||||
|
if conn.acquireGlobalLock() {
|
||||||
|
conn.logger.Info().Msg("Global database lock acquired")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
||||||
|
return
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
attempt++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the global lock
|
||||||
|
func (conn *SafeConn) Resume() {
|
||||||
|
conn.releaseGlobalLock()
|
||||||
|
conn.logger.Info().Msg("Global database lock released")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the database connection
|
||||||
|
func (conn *SafeConn) Close() error {
|
||||||
|
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
||||||
|
conn.acquireGlobalLock()
|
||||||
|
defer conn.releaseGlobalLock()
|
||||||
|
conn.logger.Debug().Msg("Closing database connection")
|
||||||
|
return conn.db.Close()
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package db
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"projectreshoot/tests"
|
"projectreshoot/tests"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -12,8 +13,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestSafeConn(t *testing.T) {
|
func TestSafeConn(t *testing.T) {
|
||||||
|
cfg, err := tests.TestConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
logger := tests.NilLogger()
|
logger := tests.NilLogger()
|
||||||
conn, err := tests.SetupTestDB()
|
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn, err := tests.SetupTestDB(ver)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sconn := MakeSafe(conn, logger)
|
sconn := MakeSafe(conn, logger)
|
||||||
defer sconn.Close()
|
defer sconn.Close()
|
||||||
@@ -77,8 +82,12 @@ func TestSafeConn(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
func TestSafeTX(t *testing.T) {
|
func TestSafeTX(t *testing.T) {
|
||||||
|
cfg, err := tests.TestConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
logger := tests.NilLogger()
|
logger := tests.NilLogger()
|
||||||
conn, err := tests.SetupTestDB()
|
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn, err := tests.SetupTestDB(ver)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sconn := MakeSafe(conn, logger)
|
sconn := MakeSafe(conn, logger)
|
||||||
defer sconn.Close()
|
defer sconn.Close()
|
||||||
61
db/safetx.go
Normal file
61
db/safetx.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Extends sql.Tx for use with SafeConn
|
||||||
|
type SafeTX struct {
|
||||||
|
tx *sql.Tx
|
||||||
|
sc *SafeConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query the database inside the transaction
|
||||||
|
func (stx *SafeTX) Query(
|
||||||
|
ctx context.Context,
|
||||||
|
query string,
|
||||||
|
args ...interface{},
|
||||||
|
) (*sql.Rows, error) {
|
||||||
|
if stx.tx == nil {
|
||||||
|
return nil, errors.New("Cannot query without a transaction")
|
||||||
|
}
|
||||||
|
return stx.tx.QueryContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec a statement on the database inside the transaction
|
||||||
|
func (stx *SafeTX) Exec(
|
||||||
|
ctx context.Context,
|
||||||
|
query string,
|
||||||
|
args ...interface{},
|
||||||
|
) (sql.Result, error) {
|
||||||
|
if stx.tx == nil {
|
||||||
|
return nil, errors.New("Cannot exec without a transaction")
|
||||||
|
}
|
||||||
|
return stx.tx.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit the current transaction and release the read lock
|
||||||
|
func (stx *SafeTX) Commit() error {
|
||||||
|
if stx.tx == nil {
|
||||||
|
return errors.New("Cannot commit without a transaction")
|
||||||
|
}
|
||||||
|
err := stx.tx.Commit()
|
||||||
|
stx.tx = nil
|
||||||
|
|
||||||
|
stx.sc.releaseReadLock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort the current transaction, releasing the read lock
|
||||||
|
func (stx *SafeTX) Rollback() error {
|
||||||
|
if stx.tx == nil {
|
||||||
|
return errors.New("Cannot rollback without a transaction")
|
||||||
|
}
|
||||||
|
err := stx.tx.Rollback()
|
||||||
|
stx.tx = nil
|
||||||
|
stx.sc.releaseReadLock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
108
deploy/db/backup.sh
Executable file
108
deploy/db/backup.sh
Executable file
@@ -0,0 +1,108 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Exit on error
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [[ -z "$1" ]]; then
|
||||||
|
echo "Usage: $0 <environment>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ENVR="$1"
|
||||||
|
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||||
|
echo "Error: environment must be 'production' or 'staging'."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
ACTIVE_DIR="/home/deploy/$ENVR"
|
||||||
|
DATA_DIR="/home/deploy/data/$ENVR"
|
||||||
|
BACKUP_DIR="/home/deploy/data/backups/$ENVR"
|
||||||
|
if [[ "$ENVR" == "production" ]]; then
|
||||||
|
SERVICE_NAME="projectreshoot"
|
||||||
|
declare -a PORTS=("3000" "3001" "3002")
|
||||||
|
else
|
||||||
|
SERVICE_NAME="$ENVR.projectreshoot"
|
||||||
|
declare -a PORTS=("3005" "3006" "3007")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Send SIGUSR2 to release maintenance mode
|
||||||
|
release_maintenance() {
|
||||||
|
echo "Releasing maintenance mode..."
|
||||||
|
for PORT in "${PORTS[@]}"; do
|
||||||
|
sudo systemctl kill -s SIGUSR2 "$SERVICE_NAME@$PORT.service"
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
shopt -s nullglob
|
||||||
|
DB_FILES=("$ACTIVE_DIR"/*.db)
|
||||||
|
DB_COUNT=${#DB_FILES[@]}
|
||||||
|
|
||||||
|
if [[ $DB_COUNT -gt 1 ]]; then
|
||||||
|
echo "Error: More than one .db file found in $ACTIVE_DIR. Manual intervention required."
|
||||||
|
exit 1
|
||||||
|
elif [[ $DB_COUNT -eq 0 ]]; then
|
||||||
|
echo "Error: No .db file found in $ACTIVE_DIR."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract the filename without extension
|
||||||
|
DB_FILE="${DB_FILES[0]}"
|
||||||
|
DB_VER=$(basename "$DB_FILE" .db)
|
||||||
|
|
||||||
|
# Send SIGUSR1 to trigger maintenance mode
|
||||||
|
for PORT in "${PORTS[@]}"; do
|
||||||
|
sudo systemctl kill -s SIGUSR1 "$SERVICE_NAME@$PORT.service"
|
||||||
|
done
|
||||||
|
trap release_maintenance EXIT
|
||||||
|
|
||||||
|
# Function to check logs for success or failure
|
||||||
|
check_logs() {
|
||||||
|
local port="$1"
|
||||||
|
local service="$SERVICE_NAME@$port.service"
|
||||||
|
|
||||||
|
echo "Waiting for $service to enter maintenance mode..."
|
||||||
|
|
||||||
|
# Check the last few lines first in case the message already appeared
|
||||||
|
if sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Global database lock acquired"; then
|
||||||
|
echo "$service successfully entered maintenance mode."
|
||||||
|
return 0
|
||||||
|
elif sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Timeout: Global database lock abandoned"; then
|
||||||
|
echo "Error: $service failed to enter maintenance mode."
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If not found, continuously watch logs until we get a success or failure message
|
||||||
|
sudo journalctl -u "$service" -f --no-pager | while read -r line; do
|
||||||
|
if echo "$line" | grep -q "Global database lock acquired"; then
|
||||||
|
echo "$service successfully entered maintenance mode."
|
||||||
|
pkill -P $$ journalctl # Kill journalctl process once we have success
|
||||||
|
return 0
|
||||||
|
elif echo "$line" | grep -q "Timeout: Global database lock abandoned"; then
|
||||||
|
echo "Error: $service failed to enter maintenance mode."
|
||||||
|
pkill -P $$ journalctl # Kill journalctl process on failure
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check logs for each service
|
||||||
|
for PORT in "${PORTS[@]}"; do
|
||||||
|
check_logs "$PORT"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Get current datetime in YYYY-MM-DD-HHMM format
|
||||||
|
TIMESTAMP=$(date +"%Y-%m-%d-%H%M")
|
||||||
|
|
||||||
|
# Define source and destination paths
|
||||||
|
SOURCE_DB="$DATA_DIR/$DB_VER.db"
|
||||||
|
BACKUP_DB="$BACKUP_DIR/${DB_VER}-${TIMESTAMP}.db"
|
||||||
|
|
||||||
|
# Copy the database file
|
||||||
|
if [[ -f "$SOURCE_DB" ]]; then
|
||||||
|
cp "$SOURCE_DB" "$BACKUP_DB"
|
||||||
|
echo "Backup created: $BACKUP_DB"
|
||||||
|
else
|
||||||
|
echo "Error: Source database file $SOURCE_DB not found."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
77
deploy/db/migrate.sh
Executable file
77
deploy/db/migrate.sh
Executable file
@@ -0,0 +1,77 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
if [[ -z "$1" ]]; then
|
||||||
|
echo "Usage: $0 <environment> <version> <commit-hash>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
ENVR="$1"
|
||||||
|
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||||
|
echo "Error: environment must be 'production' or 'staging'."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -z "$2" ]]; then
|
||||||
|
echo "Usage: $0 <environment> <version> <commit-hash>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
TGT_VER="$2"
|
||||||
|
re='^[0-9]+$'
|
||||||
|
if ! [[ $TGT_VER =~ $re ]] ; then
|
||||||
|
echo "Error: version not a number" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ -z "$3" ]; then
|
||||||
|
echo "Usage: $0 <environment> <version> <commit-hash>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
COMMIT_HASH=$3
|
||||||
|
MIGRATION_BIN="/home/deploy/migration-bin"
|
||||||
|
BACKUP_OUTPUT=$(/bin/bash ${MIGRATION_BIN}/backup.sh "$ENVR" 2>&1)
|
||||||
|
echo "$BACKUP_OUTPUT"
|
||||||
|
if [[ $? -ne 0 ]]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
BACKUP_FILE=$(echo "$BACKUP_OUTPUT" | grep -oP '(?<=Backup created: ).*')
|
||||||
|
if [[ -z "$BACKUP_FILE" ]]; then
|
||||||
|
echo "Error: backup failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
FILE_NAME=${BACKUP_FILE##*/}
|
||||||
|
CUR_VER=${FILE_NAME%%-*}
|
||||||
|
if [[ $((+$TGT_VER)) == $((+$CUR_VER)) ]]; then
|
||||||
|
echo "Version same, skipping migration"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
if [[ $((+$TGT_VER)) > $((+$CUR_VER)) ]]; then
|
||||||
|
CMD="up-to"
|
||||||
|
fi
|
||||||
|
if [[ $((+$TGT_VER)) < $((+$CUR_VER)) ]]; then
|
||||||
|
CMD="down-to"
|
||||||
|
fi
|
||||||
|
TIMESTAMP=$(date +"%Y-%m-%d-%H%M")
|
||||||
|
|
||||||
|
ACTIVE_DIR="/home/deploy/$ENVR"
|
||||||
|
DATA_DIR="/home/deploy/data/$ENVR"
|
||||||
|
BACKUP_DIR="/home/deploy/data/backups/$ENVR"
|
||||||
|
UPDATED_BACKUP="$BACKUP_DIR/${TGT_VER}-${TIMESTAMP}.db"
|
||||||
|
UPDATED_COPY="$DATA_DIR/${TGT_VER}.db"
|
||||||
|
UPDATED_LINK="$ACTIVE_DIR/${TGT_VER}.db"
|
||||||
|
|
||||||
|
cp $BACKUP_FILE $UPDATED_BACKUP
|
||||||
|
failed_cleanup() {
|
||||||
|
rm $UPDATED_BACKUP
|
||||||
|
}
|
||||||
|
trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT
|
||||||
|
|
||||||
|
echo "Migration in progress from $CUR_VER to $TGT_VER"
|
||||||
|
${MIGRATION_BIN}/prmigrate-$COMMIT_HASH $UPDATED_BACKUP $CMD $TGT_VER
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Migration failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Migration completed"
|
||||||
|
|
||||||
|
cp $UPDATED_BACKUP $UPDATED_COPY
|
||||||
|
ln -s $UPDATED_COPY $UPDATED_LINK
|
||||||
|
echo "Upgraded database linked and ready for deploy"
|
||||||
|
exit 0
|
||||||
27
deploy/db/migrationcleanup.sh
Executable file
27
deploy/db/migrationcleanup.sh
Executable file
@@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Exit on error
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [[ -z "$1" ]]; then
|
||||||
|
echo "Usage: $0 <environment> <version>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ENVR="$1"
|
||||||
|
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||||
|
echo "Error: environment must be 'production' or 'staging'."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -z "$2" ]]; then
|
||||||
|
echo "Usage: $0 <environment> <version>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
TGT_VER="$2"
|
||||||
|
re='^[0-9]+$'
|
||||||
|
if ! [[ $TGT_VER =~ $re ]] ; then
|
||||||
|
echo "Error: version not a number" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
ACTIVE_DIR="/home/deploy/$ENVR"
|
||||||
|
find "$ACTIVE_DIR" -type l -name "*.db" ! -name "${TGT_VER}.db" -exec rm -v {} +
|
||||||
113
deploy/deploy.sh
Normal file
113
deploy/deploy.sh
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Exit on error
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Check if commit hash is passed as an argument
|
||||||
|
if [ -z "$1" ]; then
|
||||||
|
echo "Usage: $0 <commit-hash>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
COMMIT_HASH=$1
|
||||||
|
ENVR="$2"
|
||||||
|
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||||
|
echo "Error: environment must be 'production' or 'staging'."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
RELEASES_DIR="/home/deploy/releases/$ENVR"
|
||||||
|
DEPLOY_BIN="/home/deploy/$ENVR/projectreshoot"
|
||||||
|
MIGRATION_BIN="/home/deploy/migration-bin"
|
||||||
|
BINARY_NAME="projectreshoot-$ENVR-${COMMIT_HASH}"
|
||||||
|
declare -a PORTS=("3000" "3001" "3002")
|
||||||
|
if [[ "$ENVR" == "production" ]]; then
|
||||||
|
SERVICE_NAME="projectreshoot"
|
||||||
|
declare -a PORTS=("3000" "3001" "3002")
|
||||||
|
else
|
||||||
|
SERVICE_NAME="$ENVR.projectreshoot"
|
||||||
|
declare -a PORTS=("3005" "3006" "3007")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if the binary exists
|
||||||
|
if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then
|
||||||
|
echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
DB_VER=$(${RELEASES_DIR}/${BINARY_NAME} --dbver | grep -oP '(?<=Database version: ).*')
|
||||||
|
${MIGRATION_BIN}/migrate.sh $ENVR $DB_VER $COMMIT_HASH
|
||||||
|
if [[ $? -ne 0 ]]; then
|
||||||
|
echo "Migration failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Keep a reference to the previous binary from the symlink
|
||||||
|
if [ -L "${DEPLOY_BIN}" ]; then
|
||||||
|
PREVIOUS=$(readlink -f $DEPLOY_BIN)
|
||||||
|
echo "Current binary is ${PREVIOUS}, saved for rollback."
|
||||||
|
else
|
||||||
|
echo "No symbolic link found, no previous binary to backup."
|
||||||
|
PREVIOUS=""
|
||||||
|
fi
|
||||||
|
|
||||||
|
rollback_deployment() {
|
||||||
|
if [ -n "$PREVIOUS" ]; then
|
||||||
|
echo "Rolling back to previous binary: ${PREVIOUS}"
|
||||||
|
ln -sfn "${PREVIOUS}" "${DEPLOY_BIN}"
|
||||||
|
else
|
||||||
|
echo "No previous binary to roll back to."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# wait to restart the services
|
||||||
|
sleep 10
|
||||||
|
|
||||||
|
# Restart all services with the previous binary
|
||||||
|
for port in "${PORTS[@]}"; do
|
||||||
|
SERVICE="${SERVICE_NAME}@${port}.service"
|
||||||
|
echo "Restarting $SERVICE..."
|
||||||
|
sudo systemctl restart $SERVICE
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Rollback completed."
|
||||||
|
}
|
||||||
|
|
||||||
|
# Copy the binary to the deployment directory
|
||||||
|
echo "Promoting ${BINARY_NAME} to ${DEPLOY_BIN}..."
|
||||||
|
ln -sf "${RELEASES_DIR}/${BINARY_NAME}" "${DEPLOY_BIN}"
|
||||||
|
|
||||||
|
WAIT_TIME=5
|
||||||
|
restart_service() {
|
||||||
|
local port=$1
|
||||||
|
local SERVICE="${SERVICE_NAME}@${port}.service"
|
||||||
|
echo "Restarting ${SERVICE}..."
|
||||||
|
|
||||||
|
# Restart the service
|
||||||
|
if ! sudo systemctl restart "$SERVICE"; then
|
||||||
|
echo "Error: Failed to restart ${SERVICE}. Rolling back deployment."
|
||||||
|
|
||||||
|
# Call the rollback function
|
||||||
|
rollback_deployment
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Wait a few seconds to allow the service to fully start
|
||||||
|
echo "Waiting for ${SERVICE} to fully start..."
|
||||||
|
sleep $WAIT_TIME
|
||||||
|
|
||||||
|
# Check the status of the service
|
||||||
|
if ! systemctl is-active --quiet "${SERVICE}"; then
|
||||||
|
echo "Error: ${SERVICE} failed to start correctly. Rolling back deployment."
|
||||||
|
|
||||||
|
# Call the rollback function
|
||||||
|
rollback_deployment
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "${SERVICE}.service restarted successfully."
|
||||||
|
}
|
||||||
|
|
||||||
|
for port in "${PORTS[@]}"; do
|
||||||
|
restart_service $port
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Deployment completed successfully."
|
||||||
|
${MIGRATION_BIN}/migrationcleanup.sh $ENVR $DB_VER
|
||||||
6
go.mod
6
go.mod
@@ -8,6 +8,7 @@ require (
|
|||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
|
github.com/pressly/goose/v3 v3.24.1
|
||||||
github.com/rs/zerolog v1.33.0
|
github.com/rs/zerolog v1.33.0
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
golang.org/x/crypto v0.33.0
|
golang.org/x/crypto v0.33.0
|
||||||
@@ -19,10 +20,13 @@ require (
|
|||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect
|
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||||
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
|
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
|
||||||
golang.org/x/sync v0.11.0 // indirect
|
golang.org/x/sync v0.11.0 // indirect
|
||||||
golang.org/x/sys v0.30.0 // indirect
|
golang.org/x/sys v0.30.0 // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||||
|
|||||||
12
go.sum
12
go.sum
@@ -28,23 +28,31 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
|
|||||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
|
||||||
|
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4LY=
|
||||||
|
github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
||||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||||
|
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
||||||
|
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||||
|
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo=
|
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
|
||||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
||||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
||||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||||
|
|||||||
8
main.go
8
main.go
@@ -94,6 +94,12 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|||||||
return errors.Wrap(err, "server.GetConfig")
|
return errors.Wrap(err, "server.GetConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the version of the database required
|
||||||
|
if args["dbver"] == "true" {
|
||||||
|
fmt.Printf("Database version: %s\n", config.DBName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var logfile *os.File = nil
|
var logfile *os.File = nil
|
||||||
if config.LogOutput == "both" || config.LogOutput == "file" {
|
if config.LogOutput == "both" || config.LogOutput == "file" {
|
||||||
logfile, err = logging.GetLogFile(config.LogDir)
|
logfile, err = logging.GetLogFile(config.LogDir)
|
||||||
@@ -186,6 +192,7 @@ func main() {
|
|||||||
host := flag.String("host", "", "Override host to listen on")
|
host := flag.String("host", "", "Override host to listen on")
|
||||||
port := flag.String("port", "", "Override port to listen on")
|
port := flag.String("port", "", "Override port to listen on")
|
||||||
test := flag.Bool("test", false, "Run test function instead of main program")
|
test := flag.Bool("test", false, "Run test function instead of main program")
|
||||||
|
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
||||||
loglevel := flag.String("loglevel", "", "Set log level")
|
loglevel := flag.String("loglevel", "", "Set log level")
|
||||||
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
@@ -195,6 +202,7 @@ func main() {
|
|||||||
"host": *host,
|
"host": *host,
|
||||||
"port": *port,
|
"port": *port,
|
||||||
"test": strconv.FormatBool(*test),
|
"test": strconv.FormatBool(*test),
|
||||||
|
"dbver": strconv.FormatBool(*dbver),
|
||||||
"loglevel": *loglevel,
|
"loglevel": *loglevel,
|
||||||
"logoutput": *logoutput,
|
"logoutput": *logoutput,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,16 +17,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthenticationMiddleware(t *testing.T) {
|
func TestAuthenticationMiddleware(t *testing.T) {
|
||||||
|
cfg, err := tests.TestConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
logger := tests.NilLogger()
|
logger := tests.NilLogger()
|
||||||
// Basic setup
|
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||||
conn, err := tests.SetupTestDB()
|
require.NoError(t, err)
|
||||||
|
conn, err := tests.SetupTestDB(ver)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sconn := db.MakeSafe(conn, logger)
|
sconn := db.MakeSafe(conn, logger)
|
||||||
defer sconn.Close()
|
defer sconn.Close()
|
||||||
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Handler to check outcome of Authentication middleware
|
// Handler to check outcome of Authentication middleware
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := contexts.GetUser(r.Context())
|
user := contexts.GetUser(r.Context())
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -14,16 +15,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPageLoginRequired(t *testing.T) {
|
func TestPageLoginRequired(t *testing.T) {
|
||||||
|
cfg, err := tests.TestConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
logger := tests.NilLogger()
|
logger := tests.NilLogger()
|
||||||
// Basic setup
|
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||||
conn, err := tests.SetupTestDB()
|
require.NoError(t, err)
|
||||||
|
conn, err := tests.SetupTestDB(ver)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sconn := db.MakeSafe(conn, logger)
|
sconn := db.MakeSafe(conn, logger)
|
||||||
defer sconn.Close()
|
defer sconn.Close()
|
||||||
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Handler to check outcome of Authentication middleware
|
// Handler to check outcome of Authentication middleware
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -14,16 +15,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestReauthRequired(t *testing.T) {
|
func TestReauthRequired(t *testing.T) {
|
||||||
|
cfg, err := tests.TestConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
logger := tests.NilLogger()
|
logger := tests.NilLogger()
|
||||||
// Basic setup
|
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||||
conn, err := tests.SetupTestDB()
|
require.NoError(t, err)
|
||||||
|
conn, err := tests.SetupTestDB(ver)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sconn := db.MakeSafe(conn, logger)
|
sconn := db.MakeSafe(conn, logger)
|
||||||
defer sconn.Close()
|
defer sconn.Close()
|
||||||
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Handler to check outcome of Authentication middleware
|
// Handler to check outcome of Authentication middleware
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|||||||
69
migrate/migrate.go
Normal file
69
migrate/migrate.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"embed"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/pressly/goose/v3"
|
||||||
|
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed migrations
|
||||||
|
var migrationsFS embed.FS
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if len(os.Args) != 4 {
|
||||||
|
fmt.Println("Usage: prmigrate <file_path> up-to|down-to <version>")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := os.Args[1]
|
||||||
|
direction := os.Args[2]
|
||||||
|
versionStr := os.Args[3]
|
||||||
|
|
||||||
|
version, err := strconv.Atoi(versionStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Invalid version number: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||||
|
log.Fatalf("Database file does not exist: %v", filePath)
|
||||||
|
}
|
||||||
|
db, err := sql.Open("sqlite", filePath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
migrations, err := fs.Sub(migrationsFS, "migrations")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to get migrations from embedded filesystem")
|
||||||
|
}
|
||||||
|
provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create migration provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
switch direction {
|
||||||
|
case "up-to":
|
||||||
|
_, err = provider.UpTo(ctx, int64(version))
|
||||||
|
case "down-to":
|
||||||
|
_, err = provider.DownTo(ctx, int64(version))
|
||||||
|
default:
|
||||||
|
log.Fatalf("Invalid direction: use 'up-to' or 'down-to'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Migration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Migration successful!")
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
|
-- +goose Up
|
||||||
|
-- +goose StatementBegin
|
||||||
PRAGMA foreign_keys=ON;
|
PRAGMA foreign_keys=ON;
|
||||||
BEGIN TRANSACTION;
|
|
||||||
CREATE TABLE IF NOT EXISTS jwtblacklist (
|
CREATE TABLE IF NOT EXISTS jwtblacklist (
|
||||||
jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'),
|
jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'),
|
||||||
exp INTEGER NOT NULL
|
exp INTEGER NOT NULL
|
||||||
@@ -16,4 +17,11 @@ AFTER INSERT ON jwtblacklist
|
|||||||
BEGIN
|
BEGIN
|
||||||
DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now');
|
DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now');
|
||||||
END;
|
END;
|
||||||
COMMIT;
|
-- +goose StatementEnd
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
-- +goose StatementBegin
|
||||||
|
DROP TRIGGER IF EXISTS cleanup_expired_tokens;
|
||||||
|
DROP TABLE IF EXISTS jwtblacklist;
|
||||||
|
DROP TABLE IF EXISTS users;
|
||||||
|
-- +goose StatementEnd
|
||||||
@@ -1,63 +1,83 @@
|
|||||||
package tests
|
package tests
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"io/fs"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/pressly/goose/v3"
|
||||||
|
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func findSQLFile(filename string) (string, error) {
|
func findMigrations() (*fs.FS, error) {
|
||||||
|
dir, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||||
|
migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations"))
|
||||||
|
return &migrationsdir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parent := filepath.Dir(dir)
|
||||||
|
if parent == dir { // Reached root
|
||||||
|
return nil, errors.New("Unable to locate migrations directory")
|
||||||
|
}
|
||||||
|
dir = parent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findTestData() (string, error) {
|
||||||
dir, err := os.Getwd()
|
dir, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if _, err := os.Stat(filepath.Join(dir, filename)); err == nil {
|
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||||
return filepath.Join(dir, filename), nil
|
return filepath.Join(dir, "tests", "testdata.sql"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
parent := filepath.Dir(dir)
|
parent := filepath.Dir(dir)
|
||||||
if parent == dir { // Reached root
|
if parent == dir { // Reached root
|
||||||
return "", errors.New(fmt.Sprintf("Unable to locate %s", filename))
|
return "", errors.New("Unable to locate test data")
|
||||||
}
|
}
|
||||||
dir = parent
|
dir = parent
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupTestDB initializes a test SQLite database with mock data
|
func SetupTestDB(version int64) (*sql.DB, error) {
|
||||||
func SetupTestDB() (*sql.DB, error) {
|
|
||||||
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "sql.Open")
|
return nil, errors.Wrap(err, "sql.Open")
|
||||||
}
|
}
|
||||||
// Setup the test database
|
|
||||||
schemaPath, err := findSQLFile("schema.sql")
|
migrations, err := findMigrations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "findSchema")
|
return nil, errors.Wrap(err, "findMigrations")
|
||||||
|
}
|
||||||
|
provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "goose.NewProvider")
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
if _, err := provider.UpTo(ctx, version); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "provider.UpTo")
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlBytes, err := os.ReadFile(schemaPath)
|
// NOTE: ==================================================
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "os.ReadFile")
|
|
||||||
}
|
|
||||||
schemaSQL := string(sqlBytes)
|
|
||||||
|
|
||||||
_, err = conn.Exec(schemaSQL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "tx.Exec")
|
|
||||||
}
|
|
||||||
// Load the test data
|
// Load the test data
|
||||||
dataPath, err := findSQLFile("testdata.sql")
|
dataPath, err := findTestData()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "findSchema")
|
return nil, errors.Wrap(err, "findSchema")
|
||||||
}
|
}
|
||||||
sqlBytes, err = os.ReadFile(dataPath)
|
sqlBytes, err := os.ReadFile(dataPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "os.ReadFile")
|
return nil, errors.Wrap(err, "os.ReadFile")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user