Added timeout for acquiring global lock

This commit is contained in:
2025-02-18 19:57:30 +11:00
parent 38d47cdf63
commit c0bbe687f9
4 changed files with 37 additions and 20 deletions

View File

@@ -22,6 +22,7 @@ type Config struct {
WriteTimeout time.Duration // Timeout for writing requests in seconds WriteTimeout time.Duration // Timeout for writing requests in seconds
IdleTimeout time.Duration // Timeout for idle connections in seconds IdleTimeout time.Duration // Timeout for idle connections in seconds
DBName string // Filename of the db (doesnt include file extension) DBName string // Filename of the db (doesnt include file extension)
DBLockTimeout time.Duration // Timeout for acquiring database lock
SecretKey string // Secret key for signing tokens SecretKey string // Secret key for signing tokens
AccessTokenExpiry int64 // Access token expiry in minutes AccessTokenExpiry int64 // Access token expiry in minutes
RefreshTokenExpiry int64 // Refresh token expiry in minutes RefreshTokenExpiry int64 // Refresh token expiry in minutes
@@ -87,6 +88,7 @@ func GetConfig(args map[string]string) (*Config, error) {
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10), WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120), IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
DBName: GetEnvDefault("DB_NAME", "projectreshoot"), DBName: GetEnvDefault("DB_NAME", "projectreshoot"),
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
SecretKey: os.Getenv("SECRET_KEY"), SecretKey: os.Getenv("SECRET_KEY"),
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5), AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day

View File

@@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"os" "os"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@@ -75,7 +76,7 @@ func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
return return
default: default:
if conn.acquireReadLock() { if conn.acquireReadLock() {
close(lockAcquired) // Lock acquired close(lockAcquired)
} }
} }
}() }()
@@ -118,7 +119,7 @@ func (stx *SafeTX) Exec(
return stx.tx.ExecContext(ctx, query, args...) return stx.tx.ExecContext(ctx, query, args...)
} }
// Commit commits the transaction and releases the lock. // Commit the current transaction and release the read lock
func (stx *SafeTX) Commit() error { func (stx *SafeTX) Commit() error {
if stx.tx == nil { if stx.tx == nil {
return errors.New("Cannot commit without a transaction") return errors.New("Cannot commit without a transaction")
@@ -130,7 +131,7 @@ func (stx *SafeTX) Commit() error {
return err return err
} }
// Rollback aborts the transaction. // Abort the current transaction, releasing the read lock
func (stx *SafeTX) Rollback() error { func (stx *SafeTX) Rollback() error {
if stx.tx == nil { if stx.tx == nil {
return errors.New("Cannot rollback without a transaction") return errors.New("Cannot rollback without a transaction")
@@ -141,21 +142,31 @@ func (stx *SafeTX) Rollback() error {
return err return err
} }
// Pause blocks new transactions for a backup. // Acquire a global lock, preventing all transactions
func (conn *SafeConn) Pause() { func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
conn.globalLockRequested = 1 // force logger to log to Stdout so the signalling process can check
for !conn.acquireGlobalLock() {
// TODO: add a timeout?
// TODO: failed to acquire lock: print info with readLockCount
// every second, or update it dynamically
}
// force logger to log to Stdout
log := conn.logger.With().Logger().Output(os.Stdout) log := conn.logger.With().Logger().Output(os.Stdout)
log.Info().Msg("Global database lock acquired") log.Info().Msg("Attempting to acquire global database lock")
conn.globalLockRequested = 0 conn.globalLockRequested = 1
defer func() { conn.globalLockRequested = 0 }()
timeout := time.After(timeoutAfter)
attempt := 0
for {
if conn.acquireGlobalLock() {
log.Info().Msg("Global database lock acquired")
return
}
select {
case <-timeout:
log.Info().Msg("Timeout: Global database lock abandoned")
return
case <-time.After(100 * time.Millisecond):
attempt++
}
}
} }
// Resume allows transactions to proceed. // Release the global lock
func (conn *SafeConn) Resume() { func (conn *SafeConn) Resume() {
conn.releaseGlobalLock() conn.releaseGlobalLock()
// force logger to log to Stdout // force logger to log to Stdout
@@ -179,8 +190,6 @@ func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sql.Open") return nil, errors.Wrap(err, "sql.Open")
} }
conn := MakeSafe(db, logger) conn := MakeSafe(db, logger)
return conn, nil return conn, nil
} }

View File

@@ -3,6 +3,7 @@ package handlers
import ( import (
"context" "context"
"net/http" "net/http"
"time"
"projectreshoot/config" "projectreshoot/config"
"projectreshoot/cookies" "projectreshoot/cookies"

11
main.go
View File

@@ -48,7 +48,12 @@ func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) {
var maint uint32 // atomic: 1 if in maintenance mode var maint uint32 // atomic: 1 if in maintenance mode
func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Logger) { func handleMaintSignals(
conn *db.SafeConn,
srv *http.Server,
logger *zerolog.Logger,
config *config.Config,
) {
ch := make(chan os.Signal, 1) ch := make(chan os.Signal, 1)
srv.RegisterOnShutdown(func() { srv.RegisterOnShutdown(func() {
close(ch) close(ch)
@@ -62,7 +67,7 @@ func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Log
log := logger.With().Logger().Output(os.Stdout) log := logger.With().Logger().Output(os.Stdout)
log.Info().Msg("Signal received: Starting maintenance") log.Info().Msg("Signal received: Starting maintenance")
log.Info().Msg("Attempting to acquire database lock") log.Info().Msg("Attempting to acquire database lock")
conn.Pause() conn.Pause(config.DBLockTimeout * time.Second)
} }
case syscall.SIGUSR2: case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 { if atomic.LoadUint32(&maint) != 0 {
@@ -139,7 +144,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
} }
// Setups a channel to listen for os.Signal // Setups a channel to listen for os.Signal
handleMaintSignals(conn, httpServer, logger) handleMaintSignals(conn, httpServer, logger, config)
// Runs the http server // Runs the http server
go func() { go func() {