Moved SafeTx and SafeConn to their own files
This commit is contained in:
168
db/connection.go
168
db/connection.go
@@ -1,10 +1,8 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
@@ -12,172 +10,6 @@ import (
|
|||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SafeConn struct {
|
|
||||||
db *sql.DB
|
|
||||||
readLockCount uint32
|
|
||||||
globalLockStatus uint32
|
|
||||||
globalLockRequested uint32
|
|
||||||
logger *zerolog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
|
||||||
return &SafeConn{db: db, logger: logger}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extends sql.Tx for use with SafeConn
|
|
||||||
type SafeTX struct {
|
|
||||||
tx *sql.Tx
|
|
||||||
sc *SafeConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SafeConn) acquireGlobalLock() bool {
|
|
||||||
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
conn.globalLockStatus = 1
|
|
||||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
|
||||||
Msg("Global lock acquired")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SafeConn) releaseGlobalLock() {
|
|
||||||
conn.globalLockStatus = 0
|
|
||||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
|
||||||
Msg("Global lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SafeConn) acquireReadLock() bool {
|
|
||||||
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
conn.readLockCount += 1
|
|
||||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
|
||||||
Msg("Read lock acquired")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SafeConn) releaseReadLock() {
|
|
||||||
conn.readLockCount -= 1
|
|
||||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
|
||||||
Msg("Read lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Starts a new transaction based on the current context. Will cancel if
|
|
||||||
// the context is closed/cancelled/done
|
|
||||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
|
||||||
lockAcquired := make(chan struct{})
|
|
||||||
lockCtx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-lockCtx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
if conn.acquireReadLock() {
|
|
||||||
close(lockAcquired)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-lockAcquired:
|
|
||||||
tx, err := conn.db.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
conn.releaseReadLock()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &SafeTX{tx: tx, sc: conn}, nil
|
|
||||||
case <-ctx.Done():
|
|
||||||
cancel()
|
|
||||||
return nil, errors.New("Transaction time out due to database lock")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query the database inside the transaction
|
|
||||||
func (stx *SafeTX) Query(
|
|
||||||
ctx context.Context,
|
|
||||||
query string,
|
|
||||||
args ...interface{},
|
|
||||||
) (*sql.Rows, error) {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return nil, errors.New("Cannot query without a transaction")
|
|
||||||
}
|
|
||||||
return stx.tx.QueryContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec a statement on the database inside the transaction
|
|
||||||
func (stx *SafeTX) Exec(
|
|
||||||
ctx context.Context,
|
|
||||||
query string,
|
|
||||||
args ...interface{},
|
|
||||||
) (sql.Result, error) {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return nil, errors.New("Cannot exec without a transaction")
|
|
||||||
}
|
|
||||||
return stx.tx.ExecContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commit the current transaction and release the read lock
|
|
||||||
func (stx *SafeTX) Commit() error {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return errors.New("Cannot commit without a transaction")
|
|
||||||
}
|
|
||||||
err := stx.tx.Commit()
|
|
||||||
stx.tx = nil
|
|
||||||
|
|
||||||
stx.sc.releaseReadLock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Abort the current transaction, releasing the read lock
|
|
||||||
func (stx *SafeTX) Rollback() error {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return errors.New("Cannot rollback without a transaction")
|
|
||||||
}
|
|
||||||
err := stx.tx.Rollback()
|
|
||||||
stx.tx = nil
|
|
||||||
stx.sc.releaseReadLock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Acquire a global lock, preventing all transactions
|
|
||||||
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
|
||||||
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
|
||||||
conn.globalLockRequested = 1
|
|
||||||
defer func() { conn.globalLockRequested = 0 }()
|
|
||||||
timeout := time.After(timeoutAfter)
|
|
||||||
attempt := 0
|
|
||||||
for {
|
|
||||||
if conn.acquireGlobalLock() {
|
|
||||||
conn.logger.Info().Msg("Global database lock acquired")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
|
||||||
return
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
attempt++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release the global lock
|
|
||||||
func (conn *SafeConn) Resume() {
|
|
||||||
conn.releaseGlobalLock()
|
|
||||||
conn.logger.Info().Msg("Global database lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the database connection
|
|
||||||
func (conn *SafeConn) Close() error {
|
|
||||||
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
|
||||||
conn.acquireGlobalLock()
|
|
||||||
defer conn.releaseGlobalLock()
|
|
||||||
conn.logger.Debug().Msg("Closing database connection")
|
|
||||||
return conn.db.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a database connection handle for the DB
|
// Returns a database connection handle for the DB
|
||||||
func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) {
|
func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) {
|
||||||
file := fmt.Sprintf("file:%s.db", dbName)
|
file := fmt.Sprintf("file:%s.db", dbName)
|
||||||
|
|||||||
129
db/safeconn.go
Normal file
129
db/safeconn.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SafeConn struct {
|
||||||
|
db *sql.DB
|
||||||
|
readLockCount uint32
|
||||||
|
globalLockStatus uint32
|
||||||
|
globalLockRequested uint32
|
||||||
|
logger *zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the provided db handle safe and attach a logger to it
|
||||||
|
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
||||||
|
return &SafeConn{db: db, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempts to acquire a global lock on the database connection
|
||||||
|
func (conn *SafeConn) acquireGlobalLock() bool {
|
||||||
|
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conn.globalLockStatus = 1
|
||||||
|
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||||
|
Msg("Global lock acquired")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Releases a global lock on the database connection
|
||||||
|
func (conn *SafeConn) releaseGlobalLock() {
|
||||||
|
conn.globalLockStatus = 0
|
||||||
|
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||||
|
Msg("Global lock released")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire a read lock on the connection. Multiple read locks can be acquired
|
||||||
|
// at the same time
|
||||||
|
func (conn *SafeConn) acquireReadLock() bool {
|
||||||
|
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conn.readLockCount += 1
|
||||||
|
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||||
|
Msg("Read lock acquired")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release a read lock. Decrements read lock count by 1
|
||||||
|
func (conn *SafeConn) releaseReadLock() {
|
||||||
|
conn.readLockCount -= 1
|
||||||
|
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||||
|
Msg("Read lock released")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Starts a new transaction based on the current context. Will cancel if
|
||||||
|
// the context is closed/cancelled/done
|
||||||
|
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
||||||
|
lockAcquired := make(chan struct{})
|
||||||
|
lockCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-lockCtx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if conn.acquireReadLock() {
|
||||||
|
close(lockAcquired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-lockAcquired:
|
||||||
|
tx, err := conn.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
conn.releaseReadLock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &SafeTX{tx: tx, sc: conn}, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancel()
|
||||||
|
return nil, errors.New("Transaction time out due to database lock")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire a global lock, preventing all transactions
|
||||||
|
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
||||||
|
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
||||||
|
conn.globalLockRequested = 1
|
||||||
|
defer func() { conn.globalLockRequested = 0 }()
|
||||||
|
timeout := time.After(timeoutAfter)
|
||||||
|
attempt := 0
|
||||||
|
for {
|
||||||
|
if conn.acquireGlobalLock() {
|
||||||
|
conn.logger.Info().Msg("Global database lock acquired")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
||||||
|
return
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
attempt++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the global lock
|
||||||
|
func (conn *SafeConn) Resume() {
|
||||||
|
conn.releaseGlobalLock()
|
||||||
|
conn.logger.Info().Msg("Global database lock released")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the database connection
|
||||||
|
func (conn *SafeConn) Close() error {
|
||||||
|
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
||||||
|
conn.acquireGlobalLock()
|
||||||
|
defer conn.releaseGlobalLock()
|
||||||
|
conn.logger.Debug().Msg("Closing database connection")
|
||||||
|
return conn.db.Close()
|
||||||
|
}
|
||||||
61
db/safetx.go
Normal file
61
db/safetx.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Extends sql.Tx for use with SafeConn
|
||||||
|
type SafeTX struct {
|
||||||
|
tx *sql.Tx
|
||||||
|
sc *SafeConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query the database inside the transaction
|
||||||
|
func (stx *SafeTX) Query(
|
||||||
|
ctx context.Context,
|
||||||
|
query string,
|
||||||
|
args ...interface{},
|
||||||
|
) (*sql.Rows, error) {
|
||||||
|
if stx.tx == nil {
|
||||||
|
return nil, errors.New("Cannot query without a transaction")
|
||||||
|
}
|
||||||
|
return stx.tx.QueryContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec a statement on the database inside the transaction
|
||||||
|
func (stx *SafeTX) Exec(
|
||||||
|
ctx context.Context,
|
||||||
|
query string,
|
||||||
|
args ...interface{},
|
||||||
|
) (sql.Result, error) {
|
||||||
|
if stx.tx == nil {
|
||||||
|
return nil, errors.New("Cannot exec without a transaction")
|
||||||
|
}
|
||||||
|
return stx.tx.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit the current transaction and release the read lock
|
||||||
|
func (stx *SafeTX) Commit() error {
|
||||||
|
if stx.tx == nil {
|
||||||
|
return errors.New("Cannot commit without a transaction")
|
||||||
|
}
|
||||||
|
err := stx.tx.Commit()
|
||||||
|
stx.tx = nil
|
||||||
|
|
||||||
|
stx.sc.releaseReadLock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort the current transaction, releasing the read lock
|
||||||
|
func (stx *SafeTX) Rollback() error {
|
||||||
|
if stx.tx == nil {
|
||||||
|
return errors.New("Cannot rollback without a transaction")
|
||||||
|
}
|
||||||
|
err := stx.tx.Rollback()
|
||||||
|
stx.tx = nil
|
||||||
|
stx.sc.releaseReadLock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user