Merge pull request #12 from Haelnorr/dbbackups

Testing db migrations updates
This commit is contained in:
2025-02-22 11:42:42 +11:00
committed by GitHub
22 changed files with 744 additions and 217 deletions

View File

@@ -33,11 +33,15 @@ jobs:
- name: Build the binary
run: make build SUFFIX=-staging-$GITHUB_SHA
- name: Build the migration binary
run: make migrate SUFFIX=-staging-$GITHUB_SHA
- name: Deploy to Server
env:
USER: deploy
HOST: projectreshoot.com
DIR: /home/deploy/releases/staging
MIG_DIR: /home/deploy/migration-bin
DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }}
run: |
mkdir -p ~/.ssh
@@ -49,7 +53,13 @@ jobs:
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_staging.sh $GITHUB_SHA
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
scp -i ~/.ssh/id_ed25519 prmigrate-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrationcleanup.sh $USER@$HOST:$MIG_DIR
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy.sh $GITHUB_SHA staging

1
.gitignore vendored
View File

@@ -4,6 +4,7 @@ query.sql
.logs/
server.log
tmp/
prmigrate
projectreshoot
static/css/output.css
view/**/*_templ.go

View File

@@ -1,5 +1,6 @@
# Makefile
.PHONY: build
.PHONY: migrate
BINARY_NAME=projectreshoot
@@ -29,3 +30,8 @@ test:
clean:
go clean
migrate:
go mod tidy && \
go generate && \
go build -ldflags="-w -s" -o prmigrate${SUFFIX} ./migrate

View File

@@ -21,7 +21,7 @@ type Config struct {
ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
WriteTimeout time.Duration // Timeout for writing requests in seconds
IdleTimeout time.Duration // Timeout for idle connections in seconds
DBName string // Filename of the db (doesnt include file extension)
DBName string // Filename of the db - hardcoded and doubles as DB version
DBLockTimeout time.Duration // Timeout for acquiring database lock
SecretKey string // Secret key for signing tokens
AccessTokenExpiry int64 // Access token expiry in minutes
@@ -87,7 +87,7 @@ func GetConfig(args map[string]string) (*Config, error) {
ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
DBName: GetEnvDefault("DB_NAME", "projectreshoot"),
DBName: "00001",
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
SecretKey: os.Getenv("SECRET_KEY"),
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),

View File

@@ -1,10 +1,9 @@
package db
import (
"context"
"database/sql"
"fmt"
"time"
"strconv"
"github.com/pkg/errors"
"github.com/rs/zerolog"
@@ -12,179 +11,47 @@ import (
_ "modernc.org/sqlite"
)
type SafeConn struct {
db *sql.DB
readLockCount uint32
globalLockStatus uint32
globalLockRequested uint32
logger *zerolog.Logger
}
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
return &SafeConn{db: db, logger: logger}
}
// Extends sql.Tx for use with SafeConn
type SafeTX struct {
tx *sql.Tx
sc *SafeConn
}
func (conn *SafeConn) acquireGlobalLock() bool {
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
return false
}
conn.globalLockStatus = 1
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock acquired")
return true
}
func (conn *SafeConn) releaseGlobalLock() {
conn.globalLockStatus = 0
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock released")
}
func (conn *SafeConn) acquireReadLock() bool {
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
return false
}
conn.readLockCount += 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock acquired")
return true
}
func (conn *SafeConn) releaseReadLock() {
conn.readLockCount -= 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock released")
}
// Starts a new transaction based on the current context. Will cancel if
// the context is closed/cancelled/done
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
lockAcquired := make(chan struct{})
lockCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-lockCtx.Done():
return
default:
if conn.acquireReadLock() {
close(lockAcquired)
}
}
}()
select {
case <-lockAcquired:
tx, err := conn.db.BeginTx(ctx, nil)
if err != nil {
conn.releaseReadLock()
return nil, err
}
return &SafeTX{tx: tx, sc: conn}, nil
case <-ctx.Done():
cancel()
return nil, errors.New("Transaction time out due to database lock")
}
}
// Query the database inside the transaction
func (stx *SafeTX) Query(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Rows, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
return stx.tx.QueryContext(ctx, query, args...)
}
// Exec a statement on the database inside the transaction
func (stx *SafeTX) Exec(
ctx context.Context,
query string,
args ...interface{},
) (sql.Result, error) {
if stx.tx == nil {
return nil, errors.New("Cannot exec without a transaction")
}
return stx.tx.ExecContext(ctx, query, args...)
}
// Commit the current transaction and release the read lock
func (stx *SafeTX) Commit() error {
if stx.tx == nil {
return errors.New("Cannot commit without a transaction")
}
err := stx.tx.Commit()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Abort the current transaction, releasing the read lock
func (stx *SafeTX) Rollback() error {
if stx.tx == nil {
return errors.New("Cannot rollback without a transaction")
}
err := stx.tx.Rollback()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Acquire a global lock, preventing all transactions
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
conn.logger.Info().Msg("Attempting to acquire global database lock")
conn.globalLockRequested = 1
defer func() { conn.globalLockRequested = 0 }()
timeout := time.After(timeoutAfter)
attempt := 0
for {
if conn.acquireGlobalLock() {
conn.logger.Info().Msg("Global database lock acquired")
return
}
select {
case <-timeout:
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
return
case <-time.After(100 * time.Millisecond):
attempt++
}
}
}
// Release the global lock
func (conn *SafeConn) Resume() {
conn.releaseGlobalLock()
conn.logger.Info().Msg("Global database lock released")
}
// Close the database connection
func (conn *SafeConn) Close() error {
conn.logger.Debug().Msg("Acquiring global lock for connection close")
conn.acquireGlobalLock()
defer conn.releaseGlobalLock()
conn.logger.Debug().Msg("Closing database connection")
return conn.db.Close()
}
// Returns a database connection handle for the DB
func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) {
func ConnectToDatabase(
dbName string,
logger *zerolog.Logger,
) (*SafeConn, error) {
file := fmt.Sprintf("file:%s.db", dbName)
db, err := sql.Open("sqlite", file)
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
version, err := strconv.Atoi(dbName)
if err != nil {
return nil, errors.Wrap(err, "strconv.Atoi")
}
err = checkDBVersion(db, version)
if err != nil {
return nil, errors.Wrap(err, "checkDBVersion")
}
conn := MakeSafe(db, logger)
return conn, nil
}
// Check the database version
func checkDBVersion(db *sql.DB, expectVer int) error {
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
ORDER BY version_id DESC LIMIT 1`
rows, err := db.Query(query)
if err != nil {
return errors.Wrap(err, "checkDBVersion")
}
if rows.Next() {
var version int
err = rows.Scan(&version)
if err != nil {
return errors.Wrap(err, "rows.Scan")
}
if version != expectVer {
return errors.New("Version mismatch")
}
} else {
return errors.New("No version found")
}
return nil
}

129
db/safeconn.go Normal file
View File

@@ -0,0 +1,129 @@
package db
import (
"context"
"database/sql"
"time"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
type SafeConn struct {
db *sql.DB
readLockCount uint32
globalLockStatus uint32
globalLockRequested uint32
logger *zerolog.Logger
}
// Make the provided db handle safe and attach a logger to it
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
return &SafeConn{db: db, logger: logger}
}
// Attempts to acquire a global lock on the database connection
func (conn *SafeConn) acquireGlobalLock() bool {
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
return false
}
conn.globalLockStatus = 1
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock acquired")
return true
}
// Releases a global lock on the database connection
func (conn *SafeConn) releaseGlobalLock() {
conn.globalLockStatus = 0
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock released")
}
// Acquire a read lock on the connection. Multiple read locks can be acquired
// at the same time
func (conn *SafeConn) acquireReadLock() bool {
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
return false
}
conn.readLockCount += 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock acquired")
return true
}
// Release a read lock. Decrements read lock count by 1
func (conn *SafeConn) releaseReadLock() {
conn.readLockCount -= 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock released")
}
// Starts a new transaction based on the current context. Will cancel if
// the context is closed/cancelled/done
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
lockAcquired := make(chan struct{})
lockCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-lockCtx.Done():
return
default:
if conn.acquireReadLock() {
close(lockAcquired)
}
}
}()
select {
case <-lockAcquired:
tx, err := conn.db.BeginTx(ctx, nil)
if err != nil {
conn.releaseReadLock()
return nil, err
}
return &SafeTX{tx: tx, sc: conn}, nil
case <-ctx.Done():
cancel()
return nil, errors.New("Transaction time out due to database lock")
}
}
// Acquire a global lock, preventing all transactions
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
conn.logger.Info().Msg("Attempting to acquire global database lock")
conn.globalLockRequested = 1
defer func() { conn.globalLockRequested = 0 }()
timeout := time.After(timeoutAfter)
attempt := 0
for {
if conn.acquireGlobalLock() {
conn.logger.Info().Msg("Global database lock acquired")
return
}
select {
case <-timeout:
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
return
case <-time.After(100 * time.Millisecond):
attempt++
}
}
}
// Release the global lock
func (conn *SafeConn) Resume() {
conn.releaseGlobalLock()
conn.logger.Info().Msg("Global database lock released")
}
// Close the database connection
func (conn *SafeConn) Close() error {
conn.logger.Debug().Msg("Acquiring global lock for connection close")
conn.acquireGlobalLock()
defer conn.releaseGlobalLock()
conn.logger.Debug().Msg("Closing database connection")
return conn.db.Close()
}

View File

@@ -3,6 +3,7 @@ package db
import (
"context"
"projectreshoot/tests"
"strconv"
"sync"
"testing"
"time"
@@ -12,8 +13,12 @@ import (
)
func TestSafeConn(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
conn, err := tests.SetupTestDB()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := MakeSafe(conn, logger)
defer sconn.Close()
@@ -77,8 +82,12 @@ func TestSafeConn(t *testing.T) {
})
}
func TestSafeTX(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
conn, err := tests.SetupTestDB()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := MakeSafe(conn, logger)
defer sconn.Close()

61
db/safetx.go Normal file
View File

@@ -0,0 +1,61 @@
package db
import (
"context"
"database/sql"
"github.com/pkg/errors"
)
// Extends sql.Tx for use with SafeConn
type SafeTX struct {
tx *sql.Tx
sc *SafeConn
}
// Query the database inside the transaction
func (stx *SafeTX) Query(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Rows, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
return stx.tx.QueryContext(ctx, query, args...)
}
// Exec a statement on the database inside the transaction
func (stx *SafeTX) Exec(
ctx context.Context,
query string,
args ...interface{},
) (sql.Result, error) {
if stx.tx == nil {
return nil, errors.New("Cannot exec without a transaction")
}
return stx.tx.ExecContext(ctx, query, args...)
}
// Commit the current transaction and release the read lock
func (stx *SafeTX) Commit() error {
if stx.tx == nil {
return errors.New("Cannot commit without a transaction")
}
err := stx.tx.Commit()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Abort the current transaction, releasing the read lock
func (stx *SafeTX) Rollback() error {
if stx.tx == nil {
return errors.New("Cannot rollback without a transaction")
}
err := stx.tx.Rollback()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}

108
deploy/db/backup.sh Executable file
View File

@@ -0,0 +1,108 @@
#!/bin/bash
# Exit on error
set -e
if [[ -z "$1" ]]; then
echo "Usage: $0 <environment>"
exit 1
fi
ENVR="$1"
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
echo "Error: environment must be 'production' or 'staging'."
exit 1
fi
ACTIVE_DIR="/home/deploy/$ENVR"
DATA_DIR="/home/deploy/data/$ENVR"
BACKUP_DIR="/home/deploy/data/backups/$ENVR"
if [[ "$ENVR" == "production" ]]; then
SERVICE_NAME="projectreshoot"
declare -a PORTS=("3000" "3001" "3002")
else
SERVICE_NAME="$ENVR.projectreshoot"
declare -a PORTS=("3005" "3006" "3007")
fi
# Send SIGUSR2 to release maintenance mode
release_maintenance() {
echo "Releasing maintenance mode..."
for PORT in "${PORTS[@]}"; do
sudo systemctl kill -s SIGUSR2 "$SERVICE_NAME@$PORT.service"
done
}
shopt -s nullglob
DB_FILES=("$ACTIVE_DIR"/*.db)
DB_COUNT=${#DB_FILES[@]}
if [[ $DB_COUNT -gt 1 ]]; then
echo "Error: More than one .db file found in $ACTIVE_DIR. Manual intervention required."
exit 1
elif [[ $DB_COUNT -eq 0 ]]; then
echo "Error: No .db file found in $ACTIVE_DIR."
exit 1
fi
# Extract the filename without extension
DB_FILE="${DB_FILES[0]}"
DB_VER=$(basename "$DB_FILE" .db)
# Send SIGUSR1 to trigger maintenance mode
for PORT in "${PORTS[@]}"; do
sudo systemctl kill -s SIGUSR1 "$SERVICE_NAME@$PORT.service"
done
trap release_maintenance EXIT
# Function to check logs for success or failure
check_logs() {
local port="$1"
local service="$SERVICE_NAME@$port.service"
echo "Waiting for $service to enter maintenance mode..."
# Check the last few lines first in case the message already appeared
if sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Global database lock acquired"; then
echo "$service successfully entered maintenance mode."
return 0
elif sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Timeout: Global database lock abandoned"; then
echo "Error: $service failed to enter maintenance mode."
return 1
fi
# If not found, continuously watch logs until we get a success or failure message
sudo journalctl -u "$service" -f --no-pager | while read -r line; do
if echo "$line" | grep -q "Global database lock acquired"; then
echo "$service successfully entered maintenance mode."
pkill -P $$ journalctl # Kill journalctl process once we have success
return 0
elif echo "$line" | grep -q "Timeout: Global database lock abandoned"; then
echo "Error: $service failed to enter maintenance mode."
pkill -P $$ journalctl # Kill journalctl process on failure
return 1
fi
done
}
# Check logs for each service
for PORT in "${PORTS[@]}"; do
check_logs "$PORT"
done
# Get current datetime in YYYY-MM-DD-HHMM format
TIMESTAMP=$(date +"%Y-%m-%d-%H%M")
# Define source and destination paths
SOURCE_DB="$DATA_DIR/$DB_VER.db"
BACKUP_DB="$BACKUP_DIR/${DB_VER}-${TIMESTAMP}.db"
# Copy the database file
if [[ -f "$SOURCE_DB" ]]; then
cp "$SOURCE_DB" "$BACKUP_DB"
echo "Backup created: $BACKUP_DB"
else
echo "Error: Source database file $SOURCE_DB not found."
exit 1
fi

77
deploy/db/migrate.sh Executable file
View File

@@ -0,0 +1,77 @@
#!/bin/bash
if [[ -z "$1" ]]; then
echo "Usage: $0 <environment> <version> <commit-hash>"
exit 1
fi
ENVR="$1"
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
echo "Error: environment must be 'production' or 'staging'."
exit 1
fi
if [[ -z "$2" ]]; then
echo "Usage: $0 <environment> <version> <commit-hash>"
exit 1
fi
TGT_VER="$2"
re='^[0-9]+$'
if ! [[ $TGT_VER =~ $re ]] ; then
echo "Error: version not a number" >&2
exit 1
fi
if [ -z "$3" ]; then
echo "Usage: $0 <environment> <version> <commit-hash>"
exit 1
fi
COMMIT_HASH=$3
MIGRATION_BIN="/home/deploy/migration-bin"
BACKUP_OUTPUT=$(/bin/bash ${MIGRATION_BIN}/backup.sh "$ENVR" 2>&1)
echo "$BACKUP_OUTPUT"
if [[ $? -ne 0 ]]; then
exit 1
fi
BACKUP_FILE=$(echo "$BACKUP_OUTPUT" | grep -oP '(?<=Backup created: ).*')
if [[ -z "$BACKUP_FILE" ]]; then
echo "Error: backup failed"
exit 1
fi
FILE_NAME=${BACKUP_FILE##*/}
CUR_VER=${FILE_NAME%%-*}
if [[ $((+$TGT_VER)) == $((+$CUR_VER)) ]]; then
echo "Version same, skipping migration"
exit 0
fi
if [[ $((+$TGT_VER)) > $((+$CUR_VER)) ]]; then
CMD="up-to"
fi
if [[ $((+$TGT_VER)) < $((+$CUR_VER)) ]]; then
CMD="down-to"
fi
TIMESTAMP=$(date +"%Y-%m-%d-%H%M")
ACTIVE_DIR="/home/deploy/$ENVR"
DATA_DIR="/home/deploy/data/$ENVR"
BACKUP_DIR="/home/deploy/data/backups/$ENVR"
UPDATED_BACKUP="$BACKUP_DIR/${TGT_VER}-${TIMESTAMP}.db"
UPDATED_COPY="$DATA_DIR/${TGT_VER}.db"
UPDATED_LINK="$ACTIVE_DIR/${TGT_VER}.db"
cp $BACKUP_FILE $UPDATED_BACKUP
failed_cleanup() {
rm $UPDATED_BACKUP
}
trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT
echo "Migration in progress from $CUR_VER to $TGT_VER"
${MIGRATION_BIN}/prmigrate-$COMMIT_HASH $UPDATED_BACKUP $CMD $TGT_VER
if [ $? -ne 0 ]; then
echo "Migration failed"
exit 1
fi
echo "Migration completed"
cp $UPDATED_BACKUP $UPDATED_COPY
ln -s $UPDATED_COPY $UPDATED_LINK
echo "Upgraded database linked and ready for deploy"
exit 0

27
deploy/db/migrationcleanup.sh Executable file
View File

@@ -0,0 +1,27 @@
#!/bin/bash
# Exit on error
set -e
if [[ -z "$1" ]]; then
echo "Usage: $0 <environment> <version>"
exit 1
fi
ENVR="$1"
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
echo "Error: environment must be 'production' or 'staging'."
exit 1
fi
if [[ -z "$2" ]]; then
echo "Usage: $0 <environment> <version>"
exit 1
fi
TGT_VER="$2"
re='^[0-9]+$'
if ! [[ $TGT_VER =~ $re ]] ; then
echo "Error: version not a number" >&2
exit 1
fi
ACTIVE_DIR="/home/deploy/$ENVR"
find "$ACTIVE_DIR" -type l -name "*.db" ! -name "${TGT_VER}.db" -exec rm -v {} +

113
deploy/deploy.sh Normal file
View File

@@ -0,0 +1,113 @@
#!/bin/bash
# Exit on error
set -e
# Check if commit hash is passed as an argument
if [ -z "$1" ]; then
echo "Usage: $0 <commit-hash>"
exit 1
fi
COMMIT_HASH=$1
ENVR="$2"
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
echo "Error: environment must be 'production' or 'staging'."
exit 1
fi
RELEASES_DIR="/home/deploy/releases/$ENVR"
DEPLOY_BIN="/home/deploy/$ENVR/projectreshoot"
MIGRATION_BIN="/home/deploy/migration-bin"
BINARY_NAME="projectreshoot-$ENVR-${COMMIT_HASH}"
declare -a PORTS=("3000" "3001" "3002")
if [[ "$ENVR" == "production" ]]; then
SERVICE_NAME="projectreshoot"
declare -a PORTS=("3000" "3001" "3002")
else
SERVICE_NAME="$ENVR.projectreshoot"
declare -a PORTS=("3005" "3006" "3007")
fi
# Check if the binary exists
if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then
echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}"
exit 1
fi
DB_VER=$(${RELEASES_DIR}/${BINARY_NAME} --dbver | grep -oP '(?<=Database version: ).*')
${MIGRATION_BIN}/migrate.sh $ENVR $DB_VER $COMMIT_HASH
if [[ $? -ne 0 ]]; then
echo "Migration failed"
exit 1
fi
# Keep a reference to the previous binary from the symlink
if [ -L "${DEPLOY_BIN}" ]; then
PREVIOUS=$(readlink -f $DEPLOY_BIN)
echo "Current binary is ${PREVIOUS}, saved for rollback."
else
echo "No symbolic link found, no previous binary to backup."
PREVIOUS=""
fi
rollback_deployment() {
if [ -n "$PREVIOUS" ]; then
echo "Rolling back to previous binary: ${PREVIOUS}"
ln -sfn "${PREVIOUS}" "${DEPLOY_BIN}"
else
echo "No previous binary to roll back to."
fi
# wait to restart the services
sleep 10
# Restart all services with the previous binary
for port in "${PORTS[@]}"; do
SERVICE="${SERVICE_NAME}@${port}.service"
echo "Restarting $SERVICE..."
sudo systemctl restart $SERVICE
done
echo "Rollback completed."
}
# Copy the binary to the deployment directory
echo "Promoting ${BINARY_NAME} to ${DEPLOY_BIN}..."
ln -sf "${RELEASES_DIR}/${BINARY_NAME}" "${DEPLOY_BIN}"
WAIT_TIME=5
restart_service() {
local port=$1
local SERVICE="${SERVICE_NAME}@${port}.service"
echo "Restarting ${SERVICE}..."
# Restart the service
if ! sudo systemctl restart "$SERVICE"; then
echo "Error: Failed to restart ${SERVICE}. Rolling back deployment."
# Call the rollback function
rollback_deployment
exit 1
fi
# Wait a few seconds to allow the service to fully start
echo "Waiting for ${SERVICE} to fully start..."
sleep $WAIT_TIME
# Check the status of the service
if ! systemctl is-active --quiet "${SERVICE}"; then
echo "Error: ${SERVICE} failed to start correctly. Rolling back deployment."
# Call the rollback function
rollback_deployment
exit 1
fi
echo "${SERVICE}.service restarted successfully."
}
for port in "${PORTS[@]}"; do
restart_service $port
done
echo "Deployment completed successfully."
${MIGRATION_BIN}/migrationcleanup.sh $ENVR $DB_VER

6
go.mod
View File

@@ -8,6 +8,7 @@ require (
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
github.com/pressly/goose/v3 v3.24.1
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.10.0
golang.org/x/crypto v0.33.0
@@ -19,10 +20,13 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect

12
go.sum
View File

@@ -28,23 +28,31 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4LY=
github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo=
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=

View File

@@ -94,6 +94,12 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return errors.Wrap(err, "server.GetConfig")
}
// Return the version of the database required
if args["dbver"] == "true" {
fmt.Printf("Database version: %s\n", config.DBName)
return nil
}
var logfile *os.File = nil
if config.LogOutput == "both" || config.LogOutput == "file" {
logfile, err = logging.GetLogFile(config.LogDir)
@@ -186,6 +192,7 @@ func main() {
host := flag.String("host", "", "Override host to listen on")
port := flag.String("port", "", "Override port to listen on")
test := flag.Bool("test", false, "Run test function instead of main program")
dbver := flag.Bool("dbver", false, "Get the version of the database required")
loglevel := flag.String("loglevel", "", "Set log level")
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
flag.Parse()
@@ -195,6 +202,7 @@ func main() {
"host": *host,
"port": *port,
"test": strconv.FormatBool(*test),
"dbver": strconv.FormatBool(*dbver),
"loglevel": *loglevel,
"logoutput": *logoutput,
}

View File

@@ -17,16 +17,16 @@ import (
)
func TestAuthenticationMiddleware(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
// Basic setup
conn, err := tests.SetupTestDB()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := contexts.GetUser(r.Context())

View File

@@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
@@ -14,16 +15,16 @@ import (
)
func TestPageLoginRequired(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
// Basic setup
conn, err := tests.SetupTestDB()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)

View File

@@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
@@ -14,16 +15,16 @@ import (
)
func TestReauthRequired(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
// Basic setup
conn, err := tests.SetupTestDB()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)

69
migrate/migrate.go Normal file
View File

@@ -0,0 +1,69 @@
package main
import (
"context"
"database/sql"
"embed"
"fmt"
"io/fs"
"log"
"os"
"strconv"
"github.com/pressly/goose/v3"
_ "modernc.org/sqlite"
)
//go:embed migrations
var migrationsFS embed.FS
func main() {
if len(os.Args) != 4 {
fmt.Println("Usage: prmigrate <file_path> up-to|down-to <version>")
os.Exit(1)
}
filePath := os.Args[1]
direction := os.Args[2]
versionStr := os.Args[3]
version, err := strconv.Atoi(versionStr)
if err != nil {
log.Fatalf("Invalid version number: %v", err)
}
if _, err := os.Stat(filePath); os.IsNotExist(err) {
log.Fatalf("Database file does not exist: %v", filePath)
}
db, err := sql.Open("sqlite", filePath)
if err != nil {
log.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
migrations, err := fs.Sub(migrationsFS, "migrations")
if err != nil {
log.Fatalf("Failed to get migrations from embedded filesystem")
}
provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations)
if err != nil {
log.Fatalf("Failed to create migration provider: %v", err)
}
ctx := context.Background()
switch direction {
case "up-to":
_, err = provider.UpTo(ctx, int64(version))
case "down-to":
_, err = provider.DownTo(ctx, int64(version))
default:
log.Fatalf("Invalid direction: use 'up-to' or 'down-to'")
}
if err != nil {
log.Fatalf("Migration failed: %v", err)
}
fmt.Println("Migration successful!")
}

View File

@@ -1,5 +1,6 @@
-- +goose Up
-- +goose StatementBegin
PRAGMA foreign_keys=ON;
BEGIN TRANSACTION;
CREATE TABLE IF NOT EXISTS jwtblacklist (
jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'),
exp INTEGER NOT NULL
@@ -16,4 +17,11 @@ AFTER INSERT ON jwtblacklist
BEGIN
DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now');
END;
COMMIT;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP TRIGGER IF EXISTS cleanup_expired_tokens;
DROP TABLE IF EXISTS jwtblacklist;
DROP TABLE IF EXISTS users;
-- +goose StatementEnd

View File

@@ -1,63 +1,83 @@
package tests
import (
"context"
"database/sql"
"fmt"
"io/fs"
"os"
"path/filepath"
"github.com/pkg/errors"
"github.com/pressly/goose/v3"
_ "modernc.org/sqlite"
)
func findSQLFile(filename string) (string, error) {
func findMigrations() (*fs.FS, error) {
dir, err := os.Getwd()
if err != nil {
return nil, err
}
for {
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations"))
return &migrationsdir, nil
}
parent := filepath.Dir(dir)
if parent == dir { // Reached root
return nil, errors.New("Unable to locate migrations directory")
}
dir = parent
}
}
func findTestData() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
if _, err := os.Stat(filepath.Join(dir, filename)); err == nil {
return filepath.Join(dir, filename), nil
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
return filepath.Join(dir, "tests", "testdata.sql"), nil
}
parent := filepath.Dir(dir)
if parent == dir { // Reached root
return "", errors.New(fmt.Sprintf("Unable to locate %s", filename))
return "", errors.New("Unable to locate test data")
}
dir = parent
}
}
// SetupTestDB initializes a test SQLite database with mock data
func SetupTestDB() (*sql.DB, error) {
func SetupTestDB(version int64) (*sql.DB, error) {
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
// Setup the test database
schemaPath, err := findSQLFile("schema.sql")
migrations, err := findMigrations()
if err != nil {
return nil, errors.Wrap(err, "findSchema")
return nil, errors.Wrap(err, "findMigrations")
}
provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations)
if err != nil {
return nil, errors.Wrap(err, "goose.NewProvider")
}
ctx := context.Background()
if _, err := provider.UpTo(ctx, version); err != nil {
return nil, errors.Wrap(err, "provider.UpTo")
}
sqlBytes, err := os.ReadFile(schemaPath)
if err != nil {
return nil, errors.Wrap(err, "os.ReadFile")
}
schemaSQL := string(sqlBytes)
_, err = conn.Exec(schemaSQL)
if err != nil {
return nil, errors.Wrap(err, "tx.Exec")
}
// NOTE: ==================================================
// Load the test data
dataPath, err := findSQLFile("testdata.sql")
dataPath, err := findTestData()
if err != nil {
return nil, errors.Wrap(err, "findSchema")
}
sqlBytes, err = os.ReadFile(dataPath)
sqlBytes, err := os.ReadFile(dataPath)
if err != nil {
return nil, errors.Wrap(err, "os.ReadFile")
}