Merge pull request #12 from Haelnorr/dbbackups
Testing db migrations updates
This commit is contained in:
14
.github/workflows/deploy_staging.yaml
vendored
14
.github/workflows/deploy_staging.yaml
vendored
@@ -33,11 +33,15 @@ jobs:
|
||||
- name: Build the binary
|
||||
run: make build SUFFIX=-staging-$GITHUB_SHA
|
||||
|
||||
- name: Build the migration binary
|
||||
run: make migrate SUFFIX=-staging-$GITHUB_SHA
|
||||
|
||||
- name: Deploy to Server
|
||||
env:
|
||||
USER: deploy
|
||||
HOST: projectreshoot.com
|
||||
DIR: /home/deploy/releases/staging
|
||||
MIG_DIR: /home/deploy/migration-bin
|
||||
DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }}
|
||||
run: |
|
||||
mkdir -p ~/.ssh
|
||||
@@ -49,7 +53,13 @@ jobs:
|
||||
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
||||
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
||||
|
||||
scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
|
||||
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_staging.sh $GITHUB_SHA
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
||||
scp -i ~/.ssh/id_ed25519 prmigrate-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
||||
|
||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrationcleanup.sh $USER@$HOST:$MIG_DIR
|
||||
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy.sh $GITHUB_SHA staging
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,6 +4,7 @@ query.sql
|
||||
.logs/
|
||||
server.log
|
||||
tmp/
|
||||
prmigrate
|
||||
projectreshoot
|
||||
static/css/output.css
|
||||
view/**/*_templ.go
|
||||
|
||||
6
Makefile
6
Makefile
@@ -1,5 +1,6 @@
|
||||
# Makefile
|
||||
.PHONY: build
|
||||
.PHONY: migrate
|
||||
|
||||
BINARY_NAME=projectreshoot
|
||||
|
||||
@@ -29,3 +30,8 @@ test:
|
||||
|
||||
clean:
|
||||
go clean
|
||||
|
||||
migrate:
|
||||
go mod tidy && \
|
||||
go generate && \
|
||||
go build -ldflags="-w -s" -o prmigrate${SUFFIX} ./migrate
|
||||
|
||||
@@ -21,7 +21,7 @@ type Config struct {
|
||||
ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
|
||||
WriteTimeout time.Duration // Timeout for writing requests in seconds
|
||||
IdleTimeout time.Duration // Timeout for idle connections in seconds
|
||||
DBName string // Filename of the db (doesnt include file extension)
|
||||
DBName string // Filename of the db - hardcoded and doubles as DB version
|
||||
DBLockTimeout time.Duration // Timeout for acquiring database lock
|
||||
SecretKey string // Secret key for signing tokens
|
||||
AccessTokenExpiry int64 // Access token expiry in minutes
|
||||
@@ -87,7 +87,7 @@ func GetConfig(args map[string]string) (*Config, error) {
|
||||
ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
|
||||
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
|
||||
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
|
||||
DBName: GetEnvDefault("DB_NAME", "projectreshoot"),
|
||||
DBName: "00001",
|
||||
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
|
||||
SecretKey: os.Getenv("SECRET_KEY"),
|
||||
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
|
||||
|
||||
205
db/connection.go
205
db/connection.go
@@ -1,10 +1,9 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
@@ -12,179 +11,47 @@ import (
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type SafeConn struct {
|
||||
db *sql.DB
|
||||
readLockCount uint32
|
||||
globalLockStatus uint32
|
||||
globalLockRequested uint32
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
||||
return &SafeConn{db: db, logger: logger}
|
||||
}
|
||||
|
||||
// Extends sql.Tx for use with SafeConn
|
||||
type SafeTX struct {
|
||||
tx *sql.Tx
|
||||
sc *SafeConn
|
||||
}
|
||||
|
||||
func (conn *SafeConn) acquireGlobalLock() bool {
|
||||
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
||||
return false
|
||||
}
|
||||
conn.globalLockStatus = 1
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *SafeConn) releaseGlobalLock() {
|
||||
conn.globalLockStatus = 0
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock released")
|
||||
}
|
||||
|
||||
func (conn *SafeConn) acquireReadLock() bool {
|
||||
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
||||
return false
|
||||
}
|
||||
conn.readLockCount += 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *SafeConn) releaseReadLock() {
|
||||
conn.readLockCount -= 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock released")
|
||||
}
|
||||
|
||||
// Starts a new transaction based on the current context. Will cancel if
|
||||
// the context is closed/cancelled/done
|
||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
||||
lockAcquired := make(chan struct{})
|
||||
lockCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-lockCtx.Done():
|
||||
return
|
||||
default:
|
||||
if conn.acquireReadLock() {
|
||||
close(lockAcquired)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lockAcquired:
|
||||
tx, err := conn.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
conn.releaseReadLock()
|
||||
return nil, err
|
||||
}
|
||||
return &SafeTX{tx: tx, sc: conn}, nil
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
return nil, errors.New("Transaction time out due to database lock")
|
||||
}
|
||||
}
|
||||
|
||||
// Query the database inside the transaction
|
||||
func (stx *SafeTX) Query(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (*sql.Rows, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot query without a transaction")
|
||||
}
|
||||
return stx.tx.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Exec a statement on the database inside the transaction
|
||||
func (stx *SafeTX) Exec(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (sql.Result, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot exec without a transaction")
|
||||
}
|
||||
return stx.tx.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Commit the current transaction and release the read lock
|
||||
func (stx *SafeTX) Commit() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot commit without a transaction")
|
||||
}
|
||||
err := stx.tx.Commit()
|
||||
stx.tx = nil
|
||||
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Abort the current transaction, releasing the read lock
|
||||
func (stx *SafeTX) Rollback() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot rollback without a transaction")
|
||||
}
|
||||
err := stx.tx.Rollback()
|
||||
stx.tx = nil
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Acquire a global lock, preventing all transactions
|
||||
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
||||
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
||||
conn.globalLockRequested = 1
|
||||
defer func() { conn.globalLockRequested = 0 }()
|
||||
timeout := time.After(timeoutAfter)
|
||||
attempt := 0
|
||||
for {
|
||||
if conn.acquireGlobalLock() {
|
||||
conn.logger.Info().Msg("Global database lock acquired")
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-timeout:
|
||||
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Release the global lock
|
||||
func (conn *SafeConn) Resume() {
|
||||
conn.releaseGlobalLock()
|
||||
conn.logger.Info().Msg("Global database lock released")
|
||||
}
|
||||
|
||||
// Close the database connection
|
||||
func (conn *SafeConn) Close() error {
|
||||
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
||||
conn.acquireGlobalLock()
|
||||
defer conn.releaseGlobalLock()
|
||||
conn.logger.Debug().Msg("Closing database connection")
|
||||
return conn.db.Close()
|
||||
}
|
||||
|
||||
// Returns a database connection handle for the DB
|
||||
func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) {
|
||||
func ConnectToDatabase(
|
||||
dbName string,
|
||||
logger *zerolog.Logger,
|
||||
) (*SafeConn, error) {
|
||||
file := fmt.Sprintf("file:%s.db", dbName)
|
||||
db, err := sql.Open("sqlite", file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
version, err := strconv.Atoi(dbName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "strconv.Atoi")
|
||||
}
|
||||
err = checkDBVersion(db, version)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkDBVersion")
|
||||
}
|
||||
conn := MakeSafe(db, logger)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check the database version
|
||||
func checkDBVersion(db *sql.DB, expectVer int) error {
|
||||
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
|
||||
ORDER BY version_id DESC LIMIT 1`
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "checkDBVersion")
|
||||
}
|
||||
if rows.Next() {
|
||||
var version int
|
||||
err = rows.Scan(&version)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rows.Scan")
|
||||
}
|
||||
if version != expectVer {
|
||||
return errors.New("Version mismatch")
|
||||
}
|
||||
} else {
|
||||
return errors.New("No version found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
129
db/safeconn.go
Normal file
129
db/safeconn.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type SafeConn struct {
|
||||
db *sql.DB
|
||||
readLockCount uint32
|
||||
globalLockStatus uint32
|
||||
globalLockRequested uint32
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
// Make the provided db handle safe and attach a logger to it
|
||||
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
||||
return &SafeConn{db: db, logger: logger}
|
||||
}
|
||||
|
||||
// Attempts to acquire a global lock on the database connection
|
||||
func (conn *SafeConn) acquireGlobalLock() bool {
|
||||
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
||||
return false
|
||||
}
|
||||
conn.globalLockStatus = 1
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
// Releases a global lock on the database connection
|
||||
func (conn *SafeConn) releaseGlobalLock() {
|
||||
conn.globalLockStatus = 0
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock released")
|
||||
}
|
||||
|
||||
// Acquire a read lock on the connection. Multiple read locks can be acquired
|
||||
// at the same time
|
||||
func (conn *SafeConn) acquireReadLock() bool {
|
||||
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
||||
return false
|
||||
}
|
||||
conn.readLockCount += 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
// Release a read lock. Decrements read lock count by 1
|
||||
func (conn *SafeConn) releaseReadLock() {
|
||||
conn.readLockCount -= 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock released")
|
||||
}
|
||||
|
||||
// Starts a new transaction based on the current context. Will cancel if
|
||||
// the context is closed/cancelled/done
|
||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
||||
lockAcquired := make(chan struct{})
|
||||
lockCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-lockCtx.Done():
|
||||
return
|
||||
default:
|
||||
if conn.acquireReadLock() {
|
||||
close(lockAcquired)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lockAcquired:
|
||||
tx, err := conn.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
conn.releaseReadLock()
|
||||
return nil, err
|
||||
}
|
||||
return &SafeTX{tx: tx, sc: conn}, nil
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
return nil, errors.New("Transaction time out due to database lock")
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire a global lock, preventing all transactions
|
||||
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
||||
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
||||
conn.globalLockRequested = 1
|
||||
defer func() { conn.globalLockRequested = 0 }()
|
||||
timeout := time.After(timeoutAfter)
|
||||
attempt := 0
|
||||
for {
|
||||
if conn.acquireGlobalLock() {
|
||||
conn.logger.Info().Msg("Global database lock acquired")
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-timeout:
|
||||
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Release the global lock
|
||||
func (conn *SafeConn) Resume() {
|
||||
conn.releaseGlobalLock()
|
||||
conn.logger.Info().Msg("Global database lock released")
|
||||
}
|
||||
|
||||
// Close the database connection
|
||||
func (conn *SafeConn) Close() error {
|
||||
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
||||
conn.acquireGlobalLock()
|
||||
defer conn.releaseGlobalLock()
|
||||
conn.logger.Debug().Msg("Closing database connection")
|
||||
return conn.db.Close()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"projectreshoot/tests"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -12,8 +13,12 @@ import (
|
||||
)
|
||||
|
||||
func TestSafeConn(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
@@ -77,8 +82,12 @@ func TestSafeConn(t *testing.T) {
|
||||
})
|
||||
}
|
||||
func TestSafeTX(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
61
db/safetx.go
Normal file
61
db/safetx.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Extends sql.Tx for use with SafeConn
|
||||
type SafeTX struct {
|
||||
tx *sql.Tx
|
||||
sc *SafeConn
|
||||
}
|
||||
|
||||
// Query the database inside the transaction
|
||||
func (stx *SafeTX) Query(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (*sql.Rows, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot query without a transaction")
|
||||
}
|
||||
return stx.tx.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Exec a statement on the database inside the transaction
|
||||
func (stx *SafeTX) Exec(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (sql.Result, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot exec without a transaction")
|
||||
}
|
||||
return stx.tx.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Commit the current transaction and release the read lock
|
||||
func (stx *SafeTX) Commit() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot commit without a transaction")
|
||||
}
|
||||
err := stx.tx.Commit()
|
||||
stx.tx = nil
|
||||
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Abort the current transaction, releasing the read lock
|
||||
func (stx *SafeTX) Rollback() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot rollback without a transaction")
|
||||
}
|
||||
err := stx.tx.Rollback()
|
||||
stx.tx = nil
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
108
deploy/db/backup.sh
Executable file
108
deploy/db/backup.sh
Executable file
@@ -0,0 +1,108 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Exit on error
|
||||
set -e
|
||||
|
||||
if [[ -z "$1" ]]; then
|
||||
echo "Usage: $0 <environment>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ENVR="$1"
|
||||
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||
echo "Error: environment must be 'production' or 'staging'."
|
||||
exit 1
|
||||
fi
|
||||
ACTIVE_DIR="/home/deploy/$ENVR"
|
||||
DATA_DIR="/home/deploy/data/$ENVR"
|
||||
BACKUP_DIR="/home/deploy/data/backups/$ENVR"
|
||||
if [[ "$ENVR" == "production" ]]; then
|
||||
SERVICE_NAME="projectreshoot"
|
||||
declare -a PORTS=("3000" "3001" "3002")
|
||||
else
|
||||
SERVICE_NAME="$ENVR.projectreshoot"
|
||||
declare -a PORTS=("3005" "3006" "3007")
|
||||
fi
|
||||
|
||||
# Send SIGUSR2 to release maintenance mode
|
||||
release_maintenance() {
|
||||
echo "Releasing maintenance mode..."
|
||||
for PORT in "${PORTS[@]}"; do
|
||||
sudo systemctl kill -s SIGUSR2 "$SERVICE_NAME@$PORT.service"
|
||||
done
|
||||
}
|
||||
|
||||
shopt -s nullglob
|
||||
DB_FILES=("$ACTIVE_DIR"/*.db)
|
||||
DB_COUNT=${#DB_FILES[@]}
|
||||
|
||||
if [[ $DB_COUNT -gt 1 ]]; then
|
||||
echo "Error: More than one .db file found in $ACTIVE_DIR. Manual intervention required."
|
||||
exit 1
|
||||
elif [[ $DB_COUNT -eq 0 ]]; then
|
||||
echo "Error: No .db file found in $ACTIVE_DIR."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract the filename without extension
|
||||
DB_FILE="${DB_FILES[0]}"
|
||||
DB_VER=$(basename "$DB_FILE" .db)
|
||||
|
||||
# Send SIGUSR1 to trigger maintenance mode
|
||||
for PORT in "${PORTS[@]}"; do
|
||||
sudo systemctl kill -s SIGUSR1 "$SERVICE_NAME@$PORT.service"
|
||||
done
|
||||
trap release_maintenance EXIT
|
||||
|
||||
# Function to check logs for success or failure
|
||||
check_logs() {
|
||||
local port="$1"
|
||||
local service="$SERVICE_NAME@$port.service"
|
||||
|
||||
echo "Waiting for $service to enter maintenance mode..."
|
||||
|
||||
# Check the last few lines first in case the message already appeared
|
||||
if sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Global database lock acquired"; then
|
||||
echo "$service successfully entered maintenance mode."
|
||||
return 0
|
||||
elif sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Timeout: Global database lock abandoned"; then
|
||||
echo "Error: $service failed to enter maintenance mode."
|
||||
return 1
|
||||
fi
|
||||
|
||||
# If not found, continuously watch logs until we get a success or failure message
|
||||
sudo journalctl -u "$service" -f --no-pager | while read -r line; do
|
||||
if echo "$line" | grep -q "Global database lock acquired"; then
|
||||
echo "$service successfully entered maintenance mode."
|
||||
pkill -P $$ journalctl # Kill journalctl process once we have success
|
||||
return 0
|
||||
elif echo "$line" | grep -q "Timeout: Global database lock abandoned"; then
|
||||
echo "Error: $service failed to enter maintenance mode."
|
||||
pkill -P $$ journalctl # Kill journalctl process on failure
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
# Check logs for each service
|
||||
for PORT in "${PORTS[@]}"; do
|
||||
check_logs "$PORT"
|
||||
done
|
||||
|
||||
# Get current datetime in YYYY-MM-DD-HHMM format
|
||||
TIMESTAMP=$(date +"%Y-%m-%d-%H%M")
|
||||
|
||||
# Define source and destination paths
|
||||
SOURCE_DB="$DATA_DIR/$DB_VER.db"
|
||||
BACKUP_DB="$BACKUP_DIR/${DB_VER}-${TIMESTAMP}.db"
|
||||
|
||||
# Copy the database file
|
||||
if [[ -f "$SOURCE_DB" ]]; then
|
||||
cp "$SOURCE_DB" "$BACKUP_DB"
|
||||
echo "Backup created: $BACKUP_DB"
|
||||
else
|
||||
echo "Error: Source database file $SOURCE_DB not found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
77
deploy/db/migrate.sh
Executable file
77
deploy/db/migrate.sh
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [[ -z "$1" ]]; then
|
||||
echo "Usage: $0 <environment> <version> <commit-hash>"
|
||||
exit 1
|
||||
fi
|
||||
ENVR="$1"
|
||||
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||
echo "Error: environment must be 'production' or 'staging'."
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$2" ]]; then
|
||||
echo "Usage: $0 <environment> <version> <commit-hash>"
|
||||
exit 1
|
||||
fi
|
||||
TGT_VER="$2"
|
||||
re='^[0-9]+$'
|
||||
if ! [[ $TGT_VER =~ $re ]] ; then
|
||||
echo "Error: version not a number" >&2
|
||||
exit 1
|
||||
fi
|
||||
if [ -z "$3" ]; then
|
||||
echo "Usage: $0 <environment> <version> <commit-hash>"
|
||||
exit 1
|
||||
fi
|
||||
COMMIT_HASH=$3
|
||||
MIGRATION_BIN="/home/deploy/migration-bin"
|
||||
BACKUP_OUTPUT=$(/bin/bash ${MIGRATION_BIN}/backup.sh "$ENVR" 2>&1)
|
||||
echo "$BACKUP_OUTPUT"
|
||||
if [[ $? -ne 0 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
BACKUP_FILE=$(echo "$BACKUP_OUTPUT" | grep -oP '(?<=Backup created: ).*')
|
||||
if [[ -z "$BACKUP_FILE" ]]; then
|
||||
echo "Error: backup failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
FILE_NAME=${BACKUP_FILE##*/}
|
||||
CUR_VER=${FILE_NAME%%-*}
|
||||
if [[ $((+$TGT_VER)) == $((+$CUR_VER)) ]]; then
|
||||
echo "Version same, skipping migration"
|
||||
exit 0
|
||||
fi
|
||||
if [[ $((+$TGT_VER)) > $((+$CUR_VER)) ]]; then
|
||||
CMD="up-to"
|
||||
fi
|
||||
if [[ $((+$TGT_VER)) < $((+$CUR_VER)) ]]; then
|
||||
CMD="down-to"
|
||||
fi
|
||||
TIMESTAMP=$(date +"%Y-%m-%d-%H%M")
|
||||
|
||||
ACTIVE_DIR="/home/deploy/$ENVR"
|
||||
DATA_DIR="/home/deploy/data/$ENVR"
|
||||
BACKUP_DIR="/home/deploy/data/backups/$ENVR"
|
||||
UPDATED_BACKUP="$BACKUP_DIR/${TGT_VER}-${TIMESTAMP}.db"
|
||||
UPDATED_COPY="$DATA_DIR/${TGT_VER}.db"
|
||||
UPDATED_LINK="$ACTIVE_DIR/${TGT_VER}.db"
|
||||
|
||||
cp $BACKUP_FILE $UPDATED_BACKUP
|
||||
failed_cleanup() {
|
||||
rm $UPDATED_BACKUP
|
||||
}
|
||||
trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT
|
||||
|
||||
echo "Migration in progress from $CUR_VER to $TGT_VER"
|
||||
${MIGRATION_BIN}/prmigrate-$COMMIT_HASH $UPDATED_BACKUP $CMD $TGT_VER
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Migration failed"
|
||||
exit 1
|
||||
fi
|
||||
echo "Migration completed"
|
||||
|
||||
cp $UPDATED_BACKUP $UPDATED_COPY
|
||||
ln -s $UPDATED_COPY $UPDATED_LINK
|
||||
echo "Upgraded database linked and ready for deploy"
|
||||
exit 0
|
||||
27
deploy/db/migrationcleanup.sh
Executable file
27
deploy/db/migrationcleanup.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Exit on error
|
||||
set -e
|
||||
|
||||
if [[ -z "$1" ]]; then
|
||||
echo "Usage: $0 <environment> <version>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ENVR="$1"
|
||||
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||
echo "Error: environment must be 'production' or 'staging'."
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$2" ]]; then
|
||||
echo "Usage: $0 <environment> <version>"
|
||||
exit 1
|
||||
fi
|
||||
TGT_VER="$2"
|
||||
re='^[0-9]+$'
|
||||
if ! [[ $TGT_VER =~ $re ]] ; then
|
||||
echo "Error: version not a number" >&2
|
||||
exit 1
|
||||
fi
|
||||
ACTIVE_DIR="/home/deploy/$ENVR"
|
||||
find "$ACTIVE_DIR" -type l -name "*.db" ! -name "${TGT_VER}.db" -exec rm -v {} +
|
||||
113
deploy/deploy.sh
Normal file
113
deploy/deploy.sh
Normal file
@@ -0,0 +1,113 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Exit on error
|
||||
set -e
|
||||
|
||||
# Check if commit hash is passed as an argument
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <commit-hash>"
|
||||
exit 1
|
||||
fi
|
||||
COMMIT_HASH=$1
|
||||
ENVR="$2"
|
||||
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||
echo "Error: environment must be 'production' or 'staging'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RELEASES_DIR="/home/deploy/releases/$ENVR"
|
||||
DEPLOY_BIN="/home/deploy/$ENVR/projectreshoot"
|
||||
MIGRATION_BIN="/home/deploy/migration-bin"
|
||||
BINARY_NAME="projectreshoot-$ENVR-${COMMIT_HASH}"
|
||||
declare -a PORTS=("3000" "3001" "3002")
|
||||
if [[ "$ENVR" == "production" ]]; then
|
||||
SERVICE_NAME="projectreshoot"
|
||||
declare -a PORTS=("3000" "3001" "3002")
|
||||
else
|
||||
SERVICE_NAME="$ENVR.projectreshoot"
|
||||
declare -a PORTS=("3005" "3006" "3007")
|
||||
fi
|
||||
|
||||
# Check if the binary exists
|
||||
if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then
|
||||
echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}"
|
||||
exit 1
|
||||
fi
|
||||
DB_VER=$(${RELEASES_DIR}/${BINARY_NAME} --dbver | grep -oP '(?<=Database version: ).*')
|
||||
${MIGRATION_BIN}/migrate.sh $ENVR $DB_VER $COMMIT_HASH
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "Migration failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Keep a reference to the previous binary from the symlink
|
||||
if [ -L "${DEPLOY_BIN}" ]; then
|
||||
PREVIOUS=$(readlink -f $DEPLOY_BIN)
|
||||
echo "Current binary is ${PREVIOUS}, saved for rollback."
|
||||
else
|
||||
echo "No symbolic link found, no previous binary to backup."
|
||||
PREVIOUS=""
|
||||
fi
|
||||
|
||||
rollback_deployment() {
|
||||
if [ -n "$PREVIOUS" ]; then
|
||||
echo "Rolling back to previous binary: ${PREVIOUS}"
|
||||
ln -sfn "${PREVIOUS}" "${DEPLOY_BIN}"
|
||||
else
|
||||
echo "No previous binary to roll back to."
|
||||
fi
|
||||
|
||||
# wait to restart the services
|
||||
sleep 10
|
||||
|
||||
# Restart all services with the previous binary
|
||||
for port in "${PORTS[@]}"; do
|
||||
SERVICE="${SERVICE_NAME}@${port}.service"
|
||||
echo "Restarting $SERVICE..."
|
||||
sudo systemctl restart $SERVICE
|
||||
done
|
||||
|
||||
echo "Rollback completed."
|
||||
}
|
||||
|
||||
# Copy the binary to the deployment directory
|
||||
echo "Promoting ${BINARY_NAME} to ${DEPLOY_BIN}..."
|
||||
ln -sf "${RELEASES_DIR}/${BINARY_NAME}" "${DEPLOY_BIN}"
|
||||
|
||||
WAIT_TIME=5
|
||||
restart_service() {
|
||||
local port=$1
|
||||
local SERVICE="${SERVICE_NAME}@${port}.service"
|
||||
echo "Restarting ${SERVICE}..."
|
||||
|
||||
# Restart the service
|
||||
if ! sudo systemctl restart "$SERVICE"; then
|
||||
echo "Error: Failed to restart ${SERVICE}. Rolling back deployment."
|
||||
|
||||
# Call the rollback function
|
||||
rollback_deployment
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Wait a few seconds to allow the service to fully start
|
||||
echo "Waiting for ${SERVICE} to fully start..."
|
||||
sleep $WAIT_TIME
|
||||
|
||||
# Check the status of the service
|
||||
if ! systemctl is-active --quiet "${SERVICE}"; then
|
||||
echo "Error: ${SERVICE} failed to start correctly. Rolling back deployment."
|
||||
|
||||
# Call the rollback function
|
||||
rollback_deployment
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "${SERVICE}.service restarted successfully."
|
||||
}
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
restart_service $port
|
||||
done
|
||||
|
||||
echo "Deployment completed successfully."
|
||||
${MIGRATION_BIN}/migrationcleanup.sh $ENVR $DB_VER
|
||||
6
go.mod
6
go.mod
@@ -8,6 +8,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pressly/goose/v3 v3.24.1
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/crypto v0.33.0
|
||||
@@ -19,10 +20,13 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
|
||||
12
go.sum
12
go.sum
@@ -28,23 +28,31 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
|
||||
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4LY=
|
||||
github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
||||
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo=
|
||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
|
||||
8
main.go
8
main.go
@@ -94,6 +94,12 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
return errors.Wrap(err, "server.GetConfig")
|
||||
}
|
||||
|
||||
// Return the version of the database required
|
||||
if args["dbver"] == "true" {
|
||||
fmt.Printf("Database version: %s\n", config.DBName)
|
||||
return nil
|
||||
}
|
||||
|
||||
var logfile *os.File = nil
|
||||
if config.LogOutput == "both" || config.LogOutput == "file" {
|
||||
logfile, err = logging.GetLogFile(config.LogDir)
|
||||
@@ -186,6 +192,7 @@ func main() {
|
||||
host := flag.String("host", "", "Override host to listen on")
|
||||
port := flag.String("port", "", "Override port to listen on")
|
||||
test := flag.Bool("test", false, "Run test function instead of main program")
|
||||
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
||||
loglevel := flag.String("loglevel", "", "Set log level")
|
||||
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
||||
flag.Parse()
|
||||
@@ -195,6 +202,7 @@ func main() {
|
||||
"host": *host,
|
||||
"port": *port,
|
||||
"test": strconv.FormatBool(*test),
|
||||
"dbver": strconv.FormatBool(*dbver),
|
||||
"loglevel": *loglevel,
|
||||
"logoutput": *logoutput,
|
||||
}
|
||||
|
||||
@@ -17,16 +17,16 @@ import (
|
||||
)
|
||||
|
||||
func TestAuthenticationMiddleware(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
// Basic setup
|
||||
conn, err := tests.SetupTestDB()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := contexts.GetUser(r.Context())
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -14,16 +15,16 @@ import (
|
||||
)
|
||||
|
||||
func TestPageLoginRequired(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
// Basic setup
|
||||
conn, err := tests.SetupTestDB()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -14,16 +15,16 @@ import (
|
||||
)
|
||||
|
||||
func TestReauthRequired(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
// Basic setup
|
||||
conn, err := tests.SetupTestDB()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
69
migrate/migrate.go
Normal file
69
migrate/migrate.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
//go:embed migrations
|
||||
var migrationsFS embed.FS
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 4 {
|
||||
fmt.Println("Usage: prmigrate <file_path> up-to|down-to <version>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
filePath := os.Args[1]
|
||||
direction := os.Args[2]
|
||||
versionStr := os.Args[3]
|
||||
|
||||
version, err := strconv.Atoi(versionStr)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid version number: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
log.Fatalf("Database file does not exist: %v", filePath)
|
||||
}
|
||||
db, err := sql.Open("sqlite", filePath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
migrations, err := fs.Sub(migrationsFS, "migrations")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get migrations from embedded filesystem")
|
||||
}
|
||||
provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create migration provider: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch direction {
|
||||
case "up-to":
|
||||
_, err = provider.UpTo(ctx, int64(version))
|
||||
case "down-to":
|
||||
_, err = provider.DownTo(ctx, int64(version))
|
||||
default:
|
||||
log.Fatalf("Invalid direction: use 'up-to' or 'down-to'")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Migration failed: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Migration successful!")
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
PRAGMA foreign_keys=ON;
|
||||
BEGIN TRANSACTION;
|
||||
CREATE TABLE IF NOT EXISTS jwtblacklist (
|
||||
jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'),
|
||||
exp INTEGER NOT NULL
|
||||
@@ -16,4 +17,11 @@ AFTER INSERT ON jwtblacklist
|
||||
BEGIN
|
||||
DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now');
|
||||
END;
|
||||
COMMIT;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
DROP TRIGGER IF EXISTS cleanup_expired_tokens;
|
||||
DROP TABLE IF EXISTS jwtblacklist;
|
||||
DROP TABLE IF EXISTS users;
|
||||
-- +goose StatementEnd
|
||||
@@ -1,63 +1,83 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/pressly/goose/v3"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func findSQLFile(filename string) (string, error) {
|
||||
func findMigrations() (*fs.FS, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||
migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations"))
|
||||
return &migrationsdir, nil
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir { // Reached root
|
||||
return nil, errors.New("Unable to locate migrations directory")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
func findTestData() (string, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, filename)); err == nil {
|
||||
return filepath.Join(dir, filename), nil
|
||||
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||
return filepath.Join(dir, "tests", "testdata.sql"), nil
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir { // Reached root
|
||||
return "", errors.New(fmt.Sprintf("Unable to locate %s", filename))
|
||||
return "", errors.New("Unable to locate test data")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
// SetupTestDB initializes a test SQLite database with mock data
|
||||
func SetupTestDB() (*sql.DB, error) {
|
||||
func SetupTestDB(version int64) (*sql.DB, error) {
|
||||
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
// Setup the test database
|
||||
schemaPath, err := findSQLFile("schema.sql")
|
||||
|
||||
migrations, err := findMigrations()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "findSchema")
|
||||
return nil, errors.Wrap(err, "findMigrations")
|
||||
}
|
||||
provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "goose.NewProvider")
|
||||
}
|
||||
ctx := context.Background()
|
||||
if _, err := provider.UpTo(ctx, version); err != nil {
|
||||
return nil, errors.Wrap(err, "provider.UpTo")
|
||||
}
|
||||
|
||||
sqlBytes, err := os.ReadFile(schemaPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "os.ReadFile")
|
||||
}
|
||||
schemaSQL := string(sqlBytes)
|
||||
|
||||
_, err = conn.Exec(schemaSQL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
// NOTE: ==================================================
|
||||
// Load the test data
|
||||
dataPath, err := findSQLFile("testdata.sql")
|
||||
dataPath, err := findTestData()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "findSchema")
|
||||
}
|
||||
sqlBytes, err = os.ReadFile(dataPath)
|
||||
sqlBytes, err := os.ReadFile(dataPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "os.ReadFile")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user