From f0f285301abd10c11faf2e0e4f088eb4954d685e Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Thu, 20 Feb 2025 19:06:41 +1100 Subject: [PATCH] Moved SafeTx and SafeConn to their own files --- db/connection.go | 168 ----------------------------------------------- db/safeconn.go | 129 ++++++++++++++++++++++++++++++++++++ db/safetx.go | 61 +++++++++++++++++ 3 files changed, 190 insertions(+), 168 deletions(-) create mode 100644 db/safeconn.go create mode 100644 db/safetx.go diff --git a/db/connection.go b/db/connection.go index 78e4071..5458d85 100644 --- a/db/connection.go +++ b/db/connection.go @@ -1,10 +1,8 @@ package db import ( - "context" "database/sql" "fmt" - "time" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -12,172 +10,6 @@ import ( _ "modernc.org/sqlite" ) -type SafeConn struct { - db *sql.DB - readLockCount uint32 - globalLockStatus uint32 - globalLockRequested uint32 - logger *zerolog.Logger -} - -func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn { - return &SafeConn{db: db, logger: logger} -} - -// Extends sql.Tx for use with SafeConn -type SafeTX struct { - tx *sql.Tx - sc *SafeConn -} - -func (conn *SafeConn) acquireGlobalLock() bool { - if conn.readLockCount > 0 || conn.globalLockStatus == 1 { - return false - } - conn.globalLockStatus = 1 - conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). - Msg("Global lock acquired") - return true -} - -func (conn *SafeConn) releaseGlobalLock() { - conn.globalLockStatus = 0 - conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). - Msg("Global lock released") -} - -func (conn *SafeConn) acquireReadLock() bool { - if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 { - return false - } - conn.readLockCount += 1 - conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). - Msg("Read lock acquired") - return true -} - -func (conn *SafeConn) releaseReadLock() { - conn.readLockCount -= 1 - conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). - Msg("Read lock released") -} - -// Starts a new transaction based on the current context. Will cancel if -// the context is closed/cancelled/done -func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { - lockAcquired := make(chan struct{}) - lockCtx, cancel := context.WithCancel(ctx) - defer cancel() - - go func() { - select { - case <-lockCtx.Done(): - return - default: - if conn.acquireReadLock() { - close(lockAcquired) - } - } - }() - - select { - case <-lockAcquired: - tx, err := conn.db.BeginTx(ctx, nil) - if err != nil { - conn.releaseReadLock() - return nil, err - } - return &SafeTX{tx: tx, sc: conn}, nil - case <-ctx.Done(): - cancel() - return nil, errors.New("Transaction time out due to database lock") - } -} - -// Query the database inside the transaction -func (stx *SafeTX) Query( - ctx context.Context, - query string, - args ...interface{}, -) (*sql.Rows, error) { - if stx.tx == nil { - return nil, errors.New("Cannot query without a transaction") - } - return stx.tx.QueryContext(ctx, query, args...) -} - -// Exec a statement on the database inside the transaction -func (stx *SafeTX) Exec( - ctx context.Context, - query string, - args ...interface{}, -) (sql.Result, error) { - if stx.tx == nil { - return nil, errors.New("Cannot exec without a transaction") - } - return stx.tx.ExecContext(ctx, query, args...) -} - -// Commit the current transaction and release the read lock -func (stx *SafeTX) Commit() error { - if stx.tx == nil { - return errors.New("Cannot commit without a transaction") - } - err := stx.tx.Commit() - stx.tx = nil - - stx.sc.releaseReadLock() - return err -} - -// Abort the current transaction, releasing the read lock -func (stx *SafeTX) Rollback() error { - if stx.tx == nil { - return errors.New("Cannot rollback without a transaction") - } - err := stx.tx.Rollback() - stx.tx = nil - stx.sc.releaseReadLock() - return err -} - -// Acquire a global lock, preventing all transactions -func (conn *SafeConn) Pause(timeoutAfter time.Duration) { - conn.logger.Info().Msg("Attempting to acquire global database lock") - conn.globalLockRequested = 1 - defer func() { conn.globalLockRequested = 0 }() - timeout := time.After(timeoutAfter) - attempt := 0 - for { - if conn.acquireGlobalLock() { - conn.logger.Info().Msg("Global database lock acquired") - return - } - select { - case <-timeout: - conn.logger.Info().Msg("Timeout: Global database lock abandoned") - return - case <-time.After(100 * time.Millisecond): - attempt++ - } - } -} - -// Release the global lock -func (conn *SafeConn) Resume() { - conn.releaseGlobalLock() - conn.logger.Info().Msg("Global database lock released") -} - -// Close the database connection -func (conn *SafeConn) Close() error { - conn.logger.Debug().Msg("Acquiring global lock for connection close") - conn.acquireGlobalLock() - defer conn.releaseGlobalLock() - conn.logger.Debug().Msg("Closing database connection") - return conn.db.Close() -} - // Returns a database connection handle for the DB func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) diff --git a/db/safeconn.go b/db/safeconn.go new file mode 100644 index 0000000..4d66f70 --- /dev/null +++ b/db/safeconn.go @@ -0,0 +1,129 @@ +package db + +import ( + "context" + "database/sql" + "time" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +type SafeConn struct { + db *sql.DB + readLockCount uint32 + globalLockStatus uint32 + globalLockRequested uint32 + logger *zerolog.Logger +} + +// Make the provided db handle safe and attach a logger to it +func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn { + return &SafeConn{db: db, logger: logger} +} + +// Attempts to acquire a global lock on the database connection +func (conn *SafeConn) acquireGlobalLock() bool { + if conn.readLockCount > 0 || conn.globalLockStatus == 1 { + return false + } + conn.globalLockStatus = 1 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock acquired") + return true +} + +// Releases a global lock on the database connection +func (conn *SafeConn) releaseGlobalLock() { + conn.globalLockStatus = 0 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock released") +} + +// Acquire a read lock on the connection. Multiple read locks can be acquired +// at the same time +func (conn *SafeConn) acquireReadLock() bool { + if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 { + return false + } + conn.readLockCount += 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock acquired") + return true +} + +// Release a read lock. Decrements read lock count by 1 +func (conn *SafeConn) releaseReadLock() { + conn.readLockCount -= 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock released") +} + +// Starts a new transaction based on the current context. Will cancel if +// the context is closed/cancelled/done +func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { + lockAcquired := make(chan struct{}) + lockCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + select { + case <-lockCtx.Done(): + return + default: + if conn.acquireReadLock() { + close(lockAcquired) + } + } + }() + + select { + case <-lockAcquired: + tx, err := conn.db.BeginTx(ctx, nil) + if err != nil { + conn.releaseReadLock() + return nil, err + } + return &SafeTX{tx: tx, sc: conn}, nil + case <-ctx.Done(): + cancel() + return nil, errors.New("Transaction time out due to database lock") + } +} + +// Acquire a global lock, preventing all transactions +func (conn *SafeConn) Pause(timeoutAfter time.Duration) { + conn.logger.Info().Msg("Attempting to acquire global database lock") + conn.globalLockRequested = 1 + defer func() { conn.globalLockRequested = 0 }() + timeout := time.After(timeoutAfter) + attempt := 0 + for { + if conn.acquireGlobalLock() { + conn.logger.Info().Msg("Global database lock acquired") + return + } + select { + case <-timeout: + conn.logger.Info().Msg("Timeout: Global database lock abandoned") + return + case <-time.After(100 * time.Millisecond): + attempt++ + } + } +} + +// Release the global lock +func (conn *SafeConn) Resume() { + conn.releaseGlobalLock() + conn.logger.Info().Msg("Global database lock released") +} + +// Close the database connection +func (conn *SafeConn) Close() error { + conn.logger.Debug().Msg("Acquiring global lock for connection close") + conn.acquireGlobalLock() + defer conn.releaseGlobalLock() + conn.logger.Debug().Msg("Closing database connection") + return conn.db.Close() +} diff --git a/db/safetx.go b/db/safetx.go new file mode 100644 index 0000000..5a3f06e --- /dev/null +++ b/db/safetx.go @@ -0,0 +1,61 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/pkg/errors" +) + +// Extends sql.Tx for use with SafeConn +type SafeTX struct { + tx *sql.Tx + sc *SafeConn +} + +// Query the database inside the transaction +func (stx *SafeTX) Query( + ctx context.Context, + query string, + args ...interface{}, +) (*sql.Rows, error) { + if stx.tx == nil { + return nil, errors.New("Cannot query without a transaction") + } + return stx.tx.QueryContext(ctx, query, args...) +} + +// Exec a statement on the database inside the transaction +func (stx *SafeTX) Exec( + ctx context.Context, + query string, + args ...interface{}, +) (sql.Result, error) { + if stx.tx == nil { + return nil, errors.New("Cannot exec without a transaction") + } + return stx.tx.ExecContext(ctx, query, args...) +} + +// Commit the current transaction and release the read lock +func (stx *SafeTX) Commit() error { + if stx.tx == nil { + return errors.New("Cannot commit without a transaction") + } + err := stx.tx.Commit() + stx.tx = nil + + stx.sc.releaseReadLock() + return err +} + +// Abort the current transaction, releasing the read lock +func (stx *SafeTX) Rollback() error { + if stx.tx == nil { + return errors.New("Cannot rollback without a transaction") + } + err := stx.tx.Rollback() + stx.tx = nil + stx.sc.releaseReadLock() + return err +}