From c0bbe687f90b735ccd2ba5b22e5762f1fef630de Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 18 Feb 2025 19:57:30 +1100 Subject: [PATCH] Added timeout for acquiring global lock --- config/config.go | 2 ++ db/connection.go | 43 ++++++++++++++++++++++++++----------------- handlers/login.go | 1 + main.go | 11 ++++++++--- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/config/config.go b/config/config.go index f4e60d7..66a17a8 100644 --- a/config/config.go +++ b/config/config.go @@ -22,6 +22,7 @@ type Config struct { WriteTimeout time.Duration // Timeout for writing requests in seconds IdleTimeout time.Duration // Timeout for idle connections in seconds DBName string // Filename of the db (doesnt include file extension) + DBLockTimeout time.Duration // Timeout for acquiring database lock SecretKey string // Secret key for signing tokens AccessTokenExpiry int64 // Access token expiry in minutes RefreshTokenExpiry int64 // Refresh token expiry in minutes @@ -87,6 +88,7 @@ func GetConfig(args map[string]string) (*Config, error) { WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10), IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120), DBName: GetEnvDefault("DB_NAME", "projectreshoot"), + DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60), SecretKey: os.Getenv("SECRET_KEY"), AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5), RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day diff --git a/db/connection.go b/db/connection.go index 9bb39bd..7a19106 100644 --- a/db/connection.go +++ b/db/connection.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "os" + "time" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -75,7 +76,7 @@ func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { return default: if conn.acquireReadLock() { - close(lockAcquired) // Lock acquired + close(lockAcquired) } } }() @@ -118,7 +119,7 @@ func (stx *SafeTX) Exec( return stx.tx.ExecContext(ctx, query, args...) } -// Commit commits the transaction and releases the lock. +// Commit the current transaction and release the read lock func (stx *SafeTX) Commit() error { if stx.tx == nil { return errors.New("Cannot commit without a transaction") @@ -130,7 +131,7 @@ func (stx *SafeTX) Commit() error { return err } -// Rollback aborts the transaction. +// Abort the current transaction, releasing the read lock func (stx *SafeTX) Rollback() error { if stx.tx == nil { return errors.New("Cannot rollback without a transaction") @@ -141,21 +142,31 @@ func (stx *SafeTX) Rollback() error { return err } -// Pause blocks new transactions for a backup. -func (conn *SafeConn) Pause() { - conn.globalLockRequested = 1 - for !conn.acquireGlobalLock() { - // TODO: add a timeout? - // TODO: failed to acquire lock: print info with readLockCount - // every second, or update it dynamically - } - // force logger to log to Stdout +// Acquire a global lock, preventing all transactions +func (conn *SafeConn) Pause(timeoutAfter time.Duration) { + // force logger to log to Stdout so the signalling process can check log := conn.logger.With().Logger().Output(os.Stdout) - log.Info().Msg("Global database lock acquired") - conn.globalLockRequested = 0 + log.Info().Msg("Attempting to acquire global database lock") + conn.globalLockRequested = 1 + defer func() { conn.globalLockRequested = 0 }() + timeout := time.After(timeoutAfter) + attempt := 0 + for { + if conn.acquireGlobalLock() { + log.Info().Msg("Global database lock acquired") + return + } + select { + case <-timeout: + log.Info().Msg("Timeout: Global database lock abandoned") + return + case <-time.After(100 * time.Millisecond): + attempt++ + } + } } -// Resume allows transactions to proceed. +// Release the global lock func (conn *SafeConn) Resume() { conn.releaseGlobalLock() // force logger to log to Stdout @@ -179,8 +190,6 @@ func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) if err != nil { return nil, errors.Wrap(err, "sql.Open") } - conn := MakeSafe(db, logger) - return conn, nil } diff --git a/handlers/login.go b/handlers/login.go index 8788b01..85f6c1c 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -3,6 +3,7 @@ package handlers import ( "context" "net/http" + "time" "projectreshoot/config" "projectreshoot/cookies" diff --git a/main.go b/main.go index baa7350..70c1dc9 100644 --- a/main.go +++ b/main.go @@ -48,7 +48,12 @@ func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) { var maint uint32 // atomic: 1 if in maintenance mode -func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Logger) { +func handleMaintSignals( + conn *db.SafeConn, + srv *http.Server, + logger *zerolog.Logger, + config *config.Config, +) { ch := make(chan os.Signal, 1) srv.RegisterOnShutdown(func() { close(ch) @@ -62,7 +67,7 @@ func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Log log := logger.With().Logger().Output(os.Stdout) log.Info().Msg("Signal received: Starting maintenance") log.Info().Msg("Attempting to acquire database lock") - conn.Pause() + conn.Pause(config.DBLockTimeout * time.Second) } case syscall.SIGUSR2: if atomic.LoadUint32(&maint) != 0 { @@ -139,7 +144,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } // Setups a channel to listen for os.Signal - handleMaintSignals(conn, httpServer, logger) + handleMaintSignals(conn, httpServer, logger, config) // Runs the http server go func() {