diff --git a/.github/workflows/deploy_staging.yaml b/.github/workflows/deploy_staging.yaml index aab9831..da7556f 100644 --- a/.github/workflows/deploy_staging.yaml +++ b/.github/workflows/deploy_staging.yaml @@ -33,11 +33,15 @@ jobs: - name: Build the binary run: make build SUFFIX=-staging-$GITHUB_SHA + - name: Build the migration binary + run: make migrate SUFFIX=-staging-$GITHUB_SHA + - name: Deploy to Server env: USER: deploy HOST: projectreshoot.com DIR: /home/deploy/releases/staging + MIG_DIR: /home/deploy/migration-bin DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }} run: | mkdir -p ~/.ssh @@ -49,7 +53,13 @@ jobs: echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR - scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR - ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_staging.sh $GITHUB_SHA + ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR + scp -i ~/.ssh/id_ed25519 prmigrate-${GITHUB_SHA} $USER@$HOST:$MIG_DIR + + scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR + scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR + scp -i ~/.ssh/id_ed25519 ./deploy/db/migrationcleanup.sh $USER@$HOST:$MIG_DIR + + ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy.sh $GITHUB_SHA staging diff --git a/.gitignore b/.gitignore index 5b22f92..9e3f92b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ query.sql .logs/ server.log tmp/ +prmigrate projectreshoot static/css/output.css view/**/*_templ.go diff --git a/Makefile b/Makefile index e23833d..54cdec2 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ # Makefile .PHONY: build +.PHONY: migrate BINARY_NAME=projectreshoot @@ -29,3 +30,8 @@ test: clean: go clean + +migrate: + go mod tidy && \ + go generate && \ + go build -ldflags="-w -s" -o prmigrate${SUFFIX} ./migrate diff --git a/config/config.go b/config/config.go index 66a17a8..3358dfd 100644 --- a/config/config.go +++ b/config/config.go @@ -21,7 +21,7 @@ type Config struct { ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds WriteTimeout time.Duration // Timeout for writing requests in seconds IdleTimeout time.Duration // Timeout for idle connections in seconds - DBName string // Filename of the db (doesnt include file extension) + DBName string // Filename of the db - hardcoded and doubles as DB version DBLockTimeout time.Duration // Timeout for acquiring database lock SecretKey string // Secret key for signing tokens AccessTokenExpiry int64 // Access token expiry in minutes @@ -87,7 +87,7 @@ func GetConfig(args map[string]string) (*Config, error) { ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2), WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10), IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120), - DBName: GetEnvDefault("DB_NAME", "projectreshoot"), + DBName: "00001", DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60), SecretKey: os.Getenv("SECRET_KEY"), AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5), diff --git a/db/connection.go b/db/connection.go index 78e4071..7343ee5 100644 --- a/db/connection.go +++ b/db/connection.go @@ -1,10 +1,9 @@ package db import ( - "context" "database/sql" "fmt" - "time" + "strconv" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -12,179 +11,47 @@ import ( _ "modernc.org/sqlite" ) -type SafeConn struct { - db *sql.DB - readLockCount uint32 - globalLockStatus uint32 - globalLockRequested uint32 - logger *zerolog.Logger -} - -func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn { - return &SafeConn{db: db, logger: logger} -} - -// Extends sql.Tx for use with SafeConn -type SafeTX struct { - tx *sql.Tx - sc *SafeConn -} - -func (conn *SafeConn) acquireGlobalLock() bool { - if conn.readLockCount > 0 || conn.globalLockStatus == 1 { - return false - } - conn.globalLockStatus = 1 - conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). - Msg("Global lock acquired") - return true -} - -func (conn *SafeConn) releaseGlobalLock() { - conn.globalLockStatus = 0 - conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). - Msg("Global lock released") -} - -func (conn *SafeConn) acquireReadLock() bool { - if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 { - return false - } - conn.readLockCount += 1 - conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). - Msg("Read lock acquired") - return true -} - -func (conn *SafeConn) releaseReadLock() { - conn.readLockCount -= 1 - conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). - Msg("Read lock released") -} - -// Starts a new transaction based on the current context. Will cancel if -// the context is closed/cancelled/done -func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { - lockAcquired := make(chan struct{}) - lockCtx, cancel := context.WithCancel(ctx) - defer cancel() - - go func() { - select { - case <-lockCtx.Done(): - return - default: - if conn.acquireReadLock() { - close(lockAcquired) - } - } - }() - - select { - case <-lockAcquired: - tx, err := conn.db.BeginTx(ctx, nil) - if err != nil { - conn.releaseReadLock() - return nil, err - } - return &SafeTX{tx: tx, sc: conn}, nil - case <-ctx.Done(): - cancel() - return nil, errors.New("Transaction time out due to database lock") - } -} - -// Query the database inside the transaction -func (stx *SafeTX) Query( - ctx context.Context, - query string, - args ...interface{}, -) (*sql.Rows, error) { - if stx.tx == nil { - return nil, errors.New("Cannot query without a transaction") - } - return stx.tx.QueryContext(ctx, query, args...) -} - -// Exec a statement on the database inside the transaction -func (stx *SafeTX) Exec( - ctx context.Context, - query string, - args ...interface{}, -) (sql.Result, error) { - if stx.tx == nil { - return nil, errors.New("Cannot exec without a transaction") - } - return stx.tx.ExecContext(ctx, query, args...) -} - -// Commit the current transaction and release the read lock -func (stx *SafeTX) Commit() error { - if stx.tx == nil { - return errors.New("Cannot commit without a transaction") - } - err := stx.tx.Commit() - stx.tx = nil - - stx.sc.releaseReadLock() - return err -} - -// Abort the current transaction, releasing the read lock -func (stx *SafeTX) Rollback() error { - if stx.tx == nil { - return errors.New("Cannot rollback without a transaction") - } - err := stx.tx.Rollback() - stx.tx = nil - stx.sc.releaseReadLock() - return err -} - -// Acquire a global lock, preventing all transactions -func (conn *SafeConn) Pause(timeoutAfter time.Duration) { - conn.logger.Info().Msg("Attempting to acquire global database lock") - conn.globalLockRequested = 1 - defer func() { conn.globalLockRequested = 0 }() - timeout := time.After(timeoutAfter) - attempt := 0 - for { - if conn.acquireGlobalLock() { - conn.logger.Info().Msg("Global database lock acquired") - return - } - select { - case <-timeout: - conn.logger.Info().Msg("Timeout: Global database lock abandoned") - return - case <-time.After(100 * time.Millisecond): - attempt++ - } - } -} - -// Release the global lock -func (conn *SafeConn) Resume() { - conn.releaseGlobalLock() - conn.logger.Info().Msg("Global database lock released") -} - -// Close the database connection -func (conn *SafeConn) Close() error { - conn.logger.Debug().Msg("Acquiring global lock for connection close") - conn.acquireGlobalLock() - defer conn.releaseGlobalLock() - conn.logger.Debug().Msg("Closing database connection") - return conn.db.Close() -} - // Returns a database connection handle for the DB -func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) { +func ConnectToDatabase( + dbName string, + logger *zerolog.Logger, +) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) db, err := sql.Open("sqlite", file) if err != nil { return nil, errors.Wrap(err, "sql.Open") } + version, err := strconv.Atoi(dbName) + if err != nil { + return nil, errors.Wrap(err, "strconv.Atoi") + } + err = checkDBVersion(db, version) + if err != nil { + return nil, errors.Wrap(err, "checkDBVersion") + } conn := MakeSafe(db, logger) return conn, nil } + +// Check the database version +func checkDBVersion(db *sql.DB, expectVer int) error { + query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1 + ORDER BY version_id DESC LIMIT 1` + rows, err := db.Query(query) + if err != nil { + return errors.Wrap(err, "checkDBVersion") + } + if rows.Next() { + var version int + err = rows.Scan(&version) + if err != nil { + return errors.Wrap(err, "rows.Scan") + } + if version != expectVer { + return errors.New("Version mismatch") + } + } else { + return errors.New("No version found") + } + return nil +} diff --git a/db/safeconn.go b/db/safeconn.go new file mode 100644 index 0000000..4d66f70 --- /dev/null +++ b/db/safeconn.go @@ -0,0 +1,129 @@ +package db + +import ( + "context" + "database/sql" + "time" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +type SafeConn struct { + db *sql.DB + readLockCount uint32 + globalLockStatus uint32 + globalLockRequested uint32 + logger *zerolog.Logger +} + +// Make the provided db handle safe and attach a logger to it +func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn { + return &SafeConn{db: db, logger: logger} +} + +// Attempts to acquire a global lock on the database connection +func (conn *SafeConn) acquireGlobalLock() bool { + if conn.readLockCount > 0 || conn.globalLockStatus == 1 { + return false + } + conn.globalLockStatus = 1 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock acquired") + return true +} + +// Releases a global lock on the database connection +func (conn *SafeConn) releaseGlobalLock() { + conn.globalLockStatus = 0 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock released") +} + +// Acquire a read lock on the connection. Multiple read locks can be acquired +// at the same time +func (conn *SafeConn) acquireReadLock() bool { + if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 { + return false + } + conn.readLockCount += 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock acquired") + return true +} + +// Release a read lock. Decrements read lock count by 1 +func (conn *SafeConn) releaseReadLock() { + conn.readLockCount -= 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock released") +} + +// Starts a new transaction based on the current context. Will cancel if +// the context is closed/cancelled/done +func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { + lockAcquired := make(chan struct{}) + lockCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + select { + case <-lockCtx.Done(): + return + default: + if conn.acquireReadLock() { + close(lockAcquired) + } + } + }() + + select { + case <-lockAcquired: + tx, err := conn.db.BeginTx(ctx, nil) + if err != nil { + conn.releaseReadLock() + return nil, err + } + return &SafeTX{tx: tx, sc: conn}, nil + case <-ctx.Done(): + cancel() + return nil, errors.New("Transaction time out due to database lock") + } +} + +// Acquire a global lock, preventing all transactions +func (conn *SafeConn) Pause(timeoutAfter time.Duration) { + conn.logger.Info().Msg("Attempting to acquire global database lock") + conn.globalLockRequested = 1 + defer func() { conn.globalLockRequested = 0 }() + timeout := time.After(timeoutAfter) + attempt := 0 + for { + if conn.acquireGlobalLock() { + conn.logger.Info().Msg("Global database lock acquired") + return + } + select { + case <-timeout: + conn.logger.Info().Msg("Timeout: Global database lock abandoned") + return + case <-time.After(100 * time.Millisecond): + attempt++ + } + } +} + +// Release the global lock +func (conn *SafeConn) Resume() { + conn.releaseGlobalLock() + conn.logger.Info().Msg("Global database lock released") +} + +// Close the database connection +func (conn *SafeConn) Close() error { + conn.logger.Debug().Msg("Acquiring global lock for connection close") + conn.acquireGlobalLock() + defer conn.releaseGlobalLock() + conn.logger.Debug().Msg("Closing database connection") + return conn.db.Close() +} diff --git a/db/connection_test.go b/db/safeconntx_test.go similarity index 91% rename from db/connection_test.go rename to db/safeconntx_test.go index 9723526..b1816d7 100644 --- a/db/connection_test.go +++ b/db/safeconntx_test.go @@ -3,6 +3,7 @@ package db import ( "context" "projectreshoot/tests" + "strconv" "sync" "testing" "time" @@ -12,8 +13,12 @@ import ( ) func TestSafeConn(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := MakeSafe(conn, logger) defer sconn.Close() @@ -77,8 +82,12 @@ func TestSafeConn(t *testing.T) { }) } func TestSafeTX(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := MakeSafe(conn, logger) defer sconn.Close() diff --git a/db/safetx.go b/db/safetx.go new file mode 100644 index 0000000..5a3f06e --- /dev/null +++ b/db/safetx.go @@ -0,0 +1,61 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/pkg/errors" +) + +// Extends sql.Tx for use with SafeConn +type SafeTX struct { + tx *sql.Tx + sc *SafeConn +} + +// Query the database inside the transaction +func (stx *SafeTX) Query( + ctx context.Context, + query string, + args ...interface{}, +) (*sql.Rows, error) { + if stx.tx == nil { + return nil, errors.New("Cannot query without a transaction") + } + return stx.tx.QueryContext(ctx, query, args...) +} + +// Exec a statement on the database inside the transaction +func (stx *SafeTX) Exec( + ctx context.Context, + query string, + args ...interface{}, +) (sql.Result, error) { + if stx.tx == nil { + return nil, errors.New("Cannot exec without a transaction") + } + return stx.tx.ExecContext(ctx, query, args...) +} + +// Commit the current transaction and release the read lock +func (stx *SafeTX) Commit() error { + if stx.tx == nil { + return errors.New("Cannot commit without a transaction") + } + err := stx.tx.Commit() + stx.tx = nil + + stx.sc.releaseReadLock() + return err +} + +// Abort the current transaction, releasing the read lock +func (stx *SafeTX) Rollback() error { + if stx.tx == nil { + return errors.New("Cannot rollback without a transaction") + } + err := stx.tx.Rollback() + stx.tx = nil + stx.sc.releaseReadLock() + return err +} diff --git a/deploy/db/backup.sh b/deploy/db/backup.sh new file mode 100755 index 0000000..ea89236 --- /dev/null +++ b/deploy/db/backup.sh @@ -0,0 +1,108 @@ +#!/bin/bash + +# Exit on error +set -e + +if [[ -z "$1" ]]; then + echo "Usage: $0 " + exit 1 +fi + +ENVR="$1" +if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then + echo "Error: environment must be 'production' or 'staging'." + exit 1 +fi +ACTIVE_DIR="/home/deploy/$ENVR" +DATA_DIR="/home/deploy/data/$ENVR" +BACKUP_DIR="/home/deploy/data/backups/$ENVR" +if [[ "$ENVR" == "production" ]]; then + SERVICE_NAME="projectreshoot" + declare -a PORTS=("3000" "3001" "3002") +else + SERVICE_NAME="$ENVR.projectreshoot" + declare -a PORTS=("3005" "3006" "3007") +fi + +# Send SIGUSR2 to release maintenance mode +release_maintenance() { + echo "Releasing maintenance mode..." + for PORT in "${PORTS[@]}"; do + sudo systemctl kill -s SIGUSR2 "$SERVICE_NAME@$PORT.service" + done +} + +shopt -s nullglob +DB_FILES=("$ACTIVE_DIR"/*.db) +DB_COUNT=${#DB_FILES[@]} + +if [[ $DB_COUNT -gt 1 ]]; then + echo "Error: More than one .db file found in $ACTIVE_DIR. Manual intervention required." + exit 1 +elif [[ $DB_COUNT -eq 0 ]]; then + echo "Error: No .db file found in $ACTIVE_DIR." + exit 1 +fi + +# Extract the filename without extension +DB_FILE="${DB_FILES[0]}" +DB_VER=$(basename "$DB_FILE" .db) + +# Send SIGUSR1 to trigger maintenance mode +for PORT in "${PORTS[@]}"; do + sudo systemctl kill -s SIGUSR1 "$SERVICE_NAME@$PORT.service" +done +trap release_maintenance EXIT + +# Function to check logs for success or failure +check_logs() { + local port="$1" + local service="$SERVICE_NAME@$port.service" + + echo "Waiting for $service to enter maintenance mode..." + + # Check the last few lines first in case the message already appeared + if sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Global database lock acquired"; then + echo "$service successfully entered maintenance mode." + return 0 + elif sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Timeout: Global database lock abandoned"; then + echo "Error: $service failed to enter maintenance mode." + return 1 + fi + + # If not found, continuously watch logs until we get a success or failure message + sudo journalctl -u "$service" -f --no-pager | while read -r line; do + if echo "$line" | grep -q "Global database lock acquired"; then + echo "$service successfully entered maintenance mode." + pkill -P $$ journalctl # Kill journalctl process once we have success + return 0 + elif echo "$line" | grep -q "Timeout: Global database lock abandoned"; then + echo "Error: $service failed to enter maintenance mode." + pkill -P $$ journalctl # Kill journalctl process on failure + return 1 + fi + done +} + +# Check logs for each service +for PORT in "${PORTS[@]}"; do + check_logs "$PORT" +done + +# Get current datetime in YYYY-MM-DD-HHMM format +TIMESTAMP=$(date +"%Y-%m-%d-%H%M") + +# Define source and destination paths +SOURCE_DB="$DATA_DIR/$DB_VER.db" +BACKUP_DB="$BACKUP_DIR/${DB_VER}-${TIMESTAMP}.db" + +# Copy the database file +if [[ -f "$SOURCE_DB" ]]; then + cp "$SOURCE_DB" "$BACKUP_DB" + echo "Backup created: $BACKUP_DB" +else + echo "Error: Source database file $SOURCE_DB not found." + exit 1 +fi + + diff --git a/deploy/db/migrate.sh b/deploy/db/migrate.sh new file mode 100755 index 0000000..436fde7 --- /dev/null +++ b/deploy/db/migrate.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +if [[ -z "$1" ]]; then + echo "Usage: $0 " + exit 1 +fi +ENVR="$1" +if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then + echo "Error: environment must be 'production' or 'staging'." + exit 1 +fi +if [[ -z "$2" ]]; then + echo "Usage: $0 " + exit 1 +fi +TGT_VER="$2" +re='^[0-9]+$' +if ! [[ $TGT_VER =~ $re ]] ; then + echo "Error: version not a number" >&2 + exit 1 +fi +if [ -z "$3" ]; then + echo "Usage: $0 " + exit 1 +fi +COMMIT_HASH=$3 +MIGRATION_BIN="/home/deploy/migration-bin" +BACKUP_OUTPUT=$(/bin/bash ${MIGRATION_BIN}/backup.sh "$ENVR" 2>&1) +echo "$BACKUP_OUTPUT" +if [[ $? -ne 0 ]]; then + exit 1 +fi +BACKUP_FILE=$(echo "$BACKUP_OUTPUT" | grep -oP '(?<=Backup created: ).*') +if [[ -z "$BACKUP_FILE" ]]; then + echo "Error: backup failed" + exit 1 +fi + +FILE_NAME=${BACKUP_FILE##*/} +CUR_VER=${FILE_NAME%%-*} +if [[ $((+$TGT_VER)) == $((+$CUR_VER)) ]]; then + echo "Version same, skipping migration" + exit 0 +fi +if [[ $((+$TGT_VER)) > $((+$CUR_VER)) ]]; then + CMD="up-to" +fi +if [[ $((+$TGT_VER)) < $((+$CUR_VER)) ]]; then + CMD="down-to" +fi +TIMESTAMP=$(date +"%Y-%m-%d-%H%M") + +ACTIVE_DIR="/home/deploy/$ENVR" +DATA_DIR="/home/deploy/data/$ENVR" +BACKUP_DIR="/home/deploy/data/backups/$ENVR" +UPDATED_BACKUP="$BACKUP_DIR/${TGT_VER}-${TIMESTAMP}.db" +UPDATED_COPY="$DATA_DIR/${TGT_VER}.db" +UPDATED_LINK="$ACTIVE_DIR/${TGT_VER}.db" + +cp $BACKUP_FILE $UPDATED_BACKUP +failed_cleanup() { + rm $UPDATED_BACKUP +} +trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT + +echo "Migration in progress from $CUR_VER to $TGT_VER" +${MIGRATION_BIN}/prmigrate-$COMMIT_HASH $UPDATED_BACKUP $CMD $TGT_VER +if [ $? -ne 0 ]; then + echo "Migration failed" + exit 1 +fi +echo "Migration completed" + +cp $UPDATED_BACKUP $UPDATED_COPY +ln -s $UPDATED_COPY $UPDATED_LINK +echo "Upgraded database linked and ready for deploy" +exit 0 diff --git a/deploy/db/migrationcleanup.sh b/deploy/db/migrationcleanup.sh new file mode 100755 index 0000000..9de3c5a --- /dev/null +++ b/deploy/db/migrationcleanup.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Exit on error +set -e + +if [[ -z "$1" ]]; then + echo "Usage: $0 " + exit 1 +fi + +ENVR="$1" +if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then + echo "Error: environment must be 'production' or 'staging'." + exit 1 +fi +if [[ -z "$2" ]]; then + echo "Usage: $0 " + exit 1 +fi +TGT_VER="$2" +re='^[0-9]+$' +if ! [[ $TGT_VER =~ $re ]] ; then + echo "Error: version not a number" >&2 + exit 1 +fi +ACTIVE_DIR="/home/deploy/$ENVR" +find "$ACTIVE_DIR" -type l -name "*.db" ! -name "${TGT_VER}.db" -exec rm -v {} + diff --git a/deploy/deploy.sh b/deploy/deploy.sh new file mode 100644 index 0000000..dd198e6 --- /dev/null +++ b/deploy/deploy.sh @@ -0,0 +1,113 @@ +#!/bin/bash + +# Exit on error +set -e + +# Check if commit hash is passed as an argument +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi +COMMIT_HASH=$1 +ENVR="$2" +if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then + echo "Error: environment must be 'production' or 'staging'." + exit 1 +fi + +RELEASES_DIR="/home/deploy/releases/$ENVR" +DEPLOY_BIN="/home/deploy/$ENVR/projectreshoot" +MIGRATION_BIN="/home/deploy/migration-bin" +BINARY_NAME="projectreshoot-$ENVR-${COMMIT_HASH}" +declare -a PORTS=("3000" "3001" "3002") +if [[ "$ENVR" == "production" ]]; then + SERVICE_NAME="projectreshoot" + declare -a PORTS=("3000" "3001" "3002") +else + SERVICE_NAME="$ENVR.projectreshoot" + declare -a PORTS=("3005" "3006" "3007") +fi + +# Check if the binary exists +if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then + echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}" + exit 1 +fi +DB_VER=$(${RELEASES_DIR}/${BINARY_NAME} --dbver | grep -oP '(?<=Database version: ).*') +${MIGRATION_BIN}/migrate.sh $ENVR $DB_VER $COMMIT_HASH +if [[ $? -ne 0 ]]; then + echo "Migration failed" + exit 1 +fi + +# Keep a reference to the previous binary from the symlink +if [ -L "${DEPLOY_BIN}" ]; then + PREVIOUS=$(readlink -f $DEPLOY_BIN) + echo "Current binary is ${PREVIOUS}, saved for rollback." +else + echo "No symbolic link found, no previous binary to backup." + PREVIOUS="" +fi + +rollback_deployment() { + if [ -n "$PREVIOUS" ]; then + echo "Rolling back to previous binary: ${PREVIOUS}" + ln -sfn "${PREVIOUS}" "${DEPLOY_BIN}" + else + echo "No previous binary to roll back to." + fi + + # wait to restart the services + sleep 10 + + # Restart all services with the previous binary + for port in "${PORTS[@]}"; do + SERVICE="${SERVICE_NAME}@${port}.service" + echo "Restarting $SERVICE..." + sudo systemctl restart $SERVICE + done + + echo "Rollback completed." +} + +# Copy the binary to the deployment directory +echo "Promoting ${BINARY_NAME} to ${DEPLOY_BIN}..." +ln -sf "${RELEASES_DIR}/${BINARY_NAME}" "${DEPLOY_BIN}" + +WAIT_TIME=5 +restart_service() { + local port=$1 + local SERVICE="${SERVICE_NAME}@${port}.service" + echo "Restarting ${SERVICE}..." + + # Restart the service + if ! sudo systemctl restart "$SERVICE"; then + echo "Error: Failed to restart ${SERVICE}. Rolling back deployment." + + # Call the rollback function + rollback_deployment + exit 1 + fi + + # Wait a few seconds to allow the service to fully start + echo "Waiting for ${SERVICE} to fully start..." + sleep $WAIT_TIME + + # Check the status of the service + if ! systemctl is-active --quiet "${SERVICE}"; then + echo "Error: ${SERVICE} failed to start correctly. Rolling back deployment." + + # Call the rollback function + rollback_deployment + exit 1 + fi + + echo "${SERVICE}.service restarted successfully." +} + +for port in "${PORTS[@]}"; do + restart_service $port +done + +echo "Deployment completed successfully." +${MIGRATION_BIN}/migrationcleanup.sh $ENVR $DB_VER diff --git a/go.mod b/go.mod index 0852a78..16638f8 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 + github.com/pressly/goose/v3 v3.24.1 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.33.0 @@ -19,10 +20,13 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mfridman/interpolate v0.0.2 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect + github.com/sethvargo/go-retry v0.3.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index 9206c36..b4bb05c 100644 --- a/go.sum +++ b/go.sum @@ -28,23 +28,31 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= +github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4LY= +github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= +github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= -golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo= -golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= diff --git a/main.go b/main.go index ede093c..a68fb5e 100644 --- a/main.go +++ b/main.go @@ -94,6 +94,12 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "server.GetConfig") } + // Return the version of the database required + if args["dbver"] == "true" { + fmt.Printf("Database version: %s\n", config.DBName) + return nil + } + var logfile *os.File = nil if config.LogOutput == "both" || config.LogOutput == "file" { logfile, err = logging.GetLogFile(config.LogDir) @@ -186,6 +192,7 @@ func main() { host := flag.String("host", "", "Override host to listen on") port := flag.String("port", "", "Override port to listen on") test := flag.Bool("test", false, "Run test function instead of main program") + dbver := flag.Bool("dbver", false, "Get the version of the database required") loglevel := flag.String("loglevel", "", "Set log level") logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)") flag.Parse() @@ -195,6 +202,7 @@ func main() { "host": *host, "port": *port, "test": strconv.FormatBool(*test), + "dbver": strconv.FormatBool(*dbver), "loglevel": *loglevel, "logoutput": *logoutput, } diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index 6ce807c..bb3dceb 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -17,16 +17,16 @@ import ( ) func TestAuthenticationMiddleware(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - // Basic setup - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := db.MakeSafe(conn, logger) defer sconn.Close() - cfg, err := tests.TestConfig() - require.NoError(t, err) - // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := contexts.GetUser(r.Context()) diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go index c6efcba..6150f72 100644 --- a/middleware/pageprotection_test.go +++ b/middleware/pageprotection_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "strconv" "sync/atomic" "testing" @@ -14,16 +15,16 @@ import ( ) func TestPageLoginRequired(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - // Basic setup - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := db.MakeSafe(conn, logger) defer sconn.Close() - cfg, err := tests.TestConfig() - require.NoError(t, err) - // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go index a1f2083..bfb40e8 100644 --- a/middleware/reauthentication_test.go +++ b/middleware/reauthentication_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "strconv" "sync/atomic" "testing" @@ -14,16 +15,16 @@ import ( ) func TestReauthRequired(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) logger := tests.NilLogger() - // Basic setup - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) require.NoError(t, err) sconn := db.MakeSafe(conn, logger) defer sconn.Close() - cfg, err := tests.TestConfig() - require.NoError(t, err) - // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/migrate/migrate.go b/migrate/migrate.go new file mode 100644 index 0000000..dd8a99e --- /dev/null +++ b/migrate/migrate.go @@ -0,0 +1,69 @@ +package main + +import ( + "context" + "database/sql" + "embed" + "fmt" + "io/fs" + "log" + "os" + "strconv" + + "github.com/pressly/goose/v3" + + _ "modernc.org/sqlite" +) + +//go:embed migrations +var migrationsFS embed.FS + +func main() { + if len(os.Args) != 4 { + fmt.Println("Usage: prmigrate up-to|down-to ") + os.Exit(1) + } + + filePath := os.Args[1] + direction := os.Args[2] + versionStr := os.Args[3] + + version, err := strconv.Atoi(versionStr) + if err != nil { + log.Fatalf("Invalid version number: %v", err) + } + if _, err := os.Stat(filePath); os.IsNotExist(err) { + log.Fatalf("Database file does not exist: %v", filePath) + } + db, err := sql.Open("sqlite", filePath) + if err != nil { + log.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + migrations, err := fs.Sub(migrationsFS, "migrations") + if err != nil { + log.Fatalf("Failed to get migrations from embedded filesystem") + } + provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations) + if err != nil { + log.Fatalf("Failed to create migration provider: %v", err) + } + + ctx := context.Background() + + switch direction { + case "up-to": + _, err = provider.UpTo(ctx, int64(version)) + case "down-to": + _, err = provider.DownTo(ctx, int64(version)) + default: + log.Fatalf("Invalid direction: use 'up-to' or 'down-to'") + } + + if err != nil { + log.Fatalf("Migration failed: %v", err) + } + + fmt.Println("Migration successful!") +} diff --git a/schema.sql b/migrate/migrations/00001_init.sql similarity index 68% rename from schema.sql rename to migrate/migrations/00001_init.sql index 986d312..1ada972 100644 --- a/schema.sql +++ b/migrate/migrations/00001_init.sql @@ -1,5 +1,6 @@ +-- +goose Up +-- +goose StatementBegin PRAGMA foreign_keys=ON; -BEGIN TRANSACTION; CREATE TABLE IF NOT EXISTS jwtblacklist ( jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'), exp INTEGER NOT NULL @@ -16,4 +17,11 @@ AFTER INSERT ON jwtblacklist BEGIN DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now'); END; -COMMIT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP TRIGGER IF EXISTS cleanup_expired_tokens; +DROP TABLE IF EXISTS jwtblacklist; +DROP TABLE IF EXISTS users; +-- +goose StatementEnd diff --git a/tests/database.go b/tests/database.go index 549db2b..0010636 100644 --- a/tests/database.go +++ b/tests/database.go @@ -1,63 +1,83 @@ package tests import ( + "context" "database/sql" - "fmt" + "io/fs" "os" "path/filepath" "github.com/pkg/errors" + "github.com/pressly/goose/v3" _ "modernc.org/sqlite" ) -func findSQLFile(filename string) (string, error) { +func findMigrations() (*fs.FS, error) { + dir, err := os.Getwd() + if err != nil { + return nil, err + } + + for { + if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil { + migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations")) + return &migrationsdir, nil + } + + parent := filepath.Dir(dir) + if parent == dir { // Reached root + return nil, errors.New("Unable to locate migrations directory") + } + dir = parent + } +} + +func findTestData() (string, error) { dir, err := os.Getwd() if err != nil { return "", err } for { - if _, err := os.Stat(filepath.Join(dir, filename)); err == nil { - return filepath.Join(dir, filename), nil + if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil { + return filepath.Join(dir, "tests", "testdata.sql"), nil } parent := filepath.Dir(dir) if parent == dir { // Reached root - return "", errors.New(fmt.Sprintf("Unable to locate %s", filename)) + return "", errors.New("Unable to locate test data") } dir = parent } } -// SetupTestDB initializes a test SQLite database with mock data -func SetupTestDB() (*sql.DB, error) { +func SetupTestDB(version int64) (*sql.DB, error) { conn, err := sql.Open("sqlite", "file::memory:?cache=shared") if err != nil { return nil, errors.Wrap(err, "sql.Open") } - // Setup the test database - schemaPath, err := findSQLFile("schema.sql") + + migrations, err := findMigrations() if err != nil { - return nil, errors.Wrap(err, "findSchema") + return nil, errors.Wrap(err, "findMigrations") + } + provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations) + if err != nil { + return nil, errors.Wrap(err, "goose.NewProvider") + } + ctx := context.Background() + if _, err := provider.UpTo(ctx, version); err != nil { + return nil, errors.Wrap(err, "provider.UpTo") } - sqlBytes, err := os.ReadFile(schemaPath) - if err != nil { - return nil, errors.Wrap(err, "os.ReadFile") - } - schemaSQL := string(sqlBytes) - - _, err = conn.Exec(schemaSQL) - if err != nil { - return nil, errors.Wrap(err, "tx.Exec") - } + // NOTE: ================================================== // Load the test data - dataPath, err := findSQLFile("testdata.sql") + dataPath, err := findTestData() if err != nil { return nil, errors.Wrap(err, "findSchema") } - sqlBytes, err = os.ReadFile(dataPath) + sqlBytes, err := os.ReadFile(dataPath) if err != nil { return nil, errors.Wrap(err, "os.ReadFile") } diff --git a/testdata.sql b/tests/testdata.sql similarity index 100% rename from testdata.sql rename to tests/testdata.sql