Changed to using atomics as mutex was causing deadlock

This commit is contained in:
2025-02-17 23:14:13 +11:00
parent 19f26d62a3
commit 556c93fc49
3 changed files with 96 additions and 22 deletions

View File

@@ -4,17 +4,17 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"sync" "sync/atomic"
"github.com/pkg/errors" "github.com/pkg/errors"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
// Wraps the database handle, providing a mutex to safely manage transactions
type SafeConn struct { type SafeConn struct {
db *sql.DB db *sql.DB
mux sync.RWMutex readLockCount int32
globalLockStatus int32
} }
func MakeSafe(db *sql.DB) *SafeConn { func MakeSafe(db *sql.DB) *SafeConn {
@@ -27,24 +27,63 @@ type SafeTX struct {
sc *SafeConn sc *SafeConn
} }
func (conn *SafeConn) acquireGlobalLock() bool {
if atomic.LoadInt32(&conn.readLockCount) > 0 || atomic.LoadInt32(&conn.globalLockStatus) == 1 {
return false
}
atomic.StoreInt32(&conn.globalLockStatus, 1)
fmt.Println("=====================GLOBAL LOCK ACQUIRED==================")
return true
}
func (conn *SafeConn) releaseGlobalLock() {
atomic.StoreInt32(&conn.globalLockStatus, 0)
fmt.Println("=====================GLOBAL LOCK RELEASED==================")
}
func (conn *SafeConn) acquireReadLock() bool {
if atomic.LoadInt32(&conn.globalLockStatus) == 1 {
return false
}
atomic.AddInt32(&conn.readLockCount, 1)
fmt.Println("=====================READ LOCK ACQUIRED==================")
return true
}
func (conn *SafeConn) releaseReadLock() {
atomic.AddInt32(&conn.readLockCount, -1)
fmt.Println("=====================READ LOCK RELEASED==================")
}
// Starts a new transaction based on the current context. Will cancel if // Starts a new transaction based on the current context. Will cancel if
// the context is closed/cancelled/done // the context is closed/cancelled/done
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
lockAcquired := make(chan struct{}) lockAcquired := make(chan struct{})
lockCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() { go func() {
conn.mux.RLock() select {
close(lockAcquired) case <-lockCtx.Done():
fmt.Println("=====================READ LOCK ABANDONED==================")
return
default:
if conn.acquireReadLock() {
close(lockAcquired) // Lock acquired
}
}
}() }()
select { select {
case <-lockAcquired: case <-lockAcquired:
tx, err := conn.db.BeginTx(ctx, nil) tx, err := conn.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
conn.mux.RUnlock() conn.releaseReadLock()
return nil, err return nil, err
} }
return &SafeTX{tx: tx, sc: conn}, nil return &SafeTX{tx: tx, sc: conn}, nil
case <-ctx.Done(): case <-ctx.Done():
cancel()
return nil, errors.New("Transaction time out due to database lock") return nil, errors.New("Transaction time out due to database lock")
} }
} }
@@ -81,7 +120,7 @@ func (stx *SafeTX) Commit() error {
err := stx.tx.Commit() err := stx.tx.Commit()
stx.tx = nil stx.tx = nil
stx.releaseLock() stx.sc.releaseReadLock()
return err return err
} }
@@ -92,31 +131,30 @@ func (stx *SafeTX) Rollback() error {
} }
err := stx.tx.Rollback() err := stx.tx.Rollback()
stx.tx = nil stx.tx = nil
stx.releaseLock() stx.sc.releaseReadLock()
return err return err
} }
// Release the read lock for the transaction
func (stx *SafeTX) releaseLock() {
if stx.sc != nil {
stx.sc.mux.RUnlock()
}
}
// Pause blocks new transactions for a backup. // Pause blocks new transactions for a backup.
func (conn *SafeConn) Pause() { func (conn *SafeConn) Pause() {
conn.mux.Lock() // Blocks all new transactions for !conn.acquireGlobalLock() {
// TODO: add a timeout?
}
fmt.Println("Global database lock acquired")
} }
// Resume allows transactions to proceed. // Resume allows transactions to proceed.
func (conn *SafeConn) Resume() { func (conn *SafeConn) Resume() {
conn.mux.Unlock() conn.releaseGlobalLock()
fmt.Println("Global database lock released")
} }
// Close the database connection // Close the database connection
func (conn *SafeConn) Close() error { func (conn *SafeConn) Close() error {
conn.mux.Lock() fmt.Println("=====================DB LOCKING FOR SHUTDOWN==================")
defer conn.mux.Unlock() conn.acquireGlobalLock()
defer conn.releaseGlobalLock()
fmt.Println("=====================DB LOCKED FOR SHUTDOWN==================")
return conn.db.Close() return conn.db.Close()
} }

View File

@@ -25,13 +25,12 @@ func WithTransaction(
), ),
) { ) {
// Create a cancellable context from the request context // Create a cancellable context from the request context
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second)
defer cancel() defer cancel()
// Start the transaction // Start the transaction
tx, err := conn.Begin(ctx) tx, err := conn.Begin(ctx)
if err != nil { if err != nil {
tx.Rollback()
logger.Warn().Err(err).Msg("Request failed to start a transaction") logger.Warn().Err(err).Msg("Request failed to start a transaction")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
page.Error( page.Error(

37
main.go
View File

@@ -13,6 +13,8 @@ import (
"os/signal" "os/signal"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"syscall"
"time" "time"
"projectreshoot/config" "projectreshoot/config"
@@ -43,6 +45,36 @@ func getStaticFiles() (http.FileSystem, error) {
} }
} }
var maint uint32 // atomic: 1 if in maintenance mode
func handleMaintSignals(conn *db.SafeConn, srv *http.Server) {
ch := make(chan os.Signal, 1)
srv.RegisterOnShutdown(func() {
close(ch)
})
go func() {
for sig := range ch {
switch sig {
case syscall.SIGUSR1:
if atomic.LoadUint32(&maint) != 1 {
atomic.StoreUint32(&maint, 1)
fmt.Println("Signal received: Starting maintenance")
fmt.Println("Attempting to acquire database lock")
conn.Pause()
}
case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 {
fmt.Println("Signal received: Maintenance over")
fmt.Println("Releasing database lock")
conn.Resume()
atomic.StoreUint32(&maint, 0)
}
}
}
}()
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
}
// Initializes and runs the server // Initializes and runs the server
func run(ctx context.Context, w io.Writer, args map[string]string) error { func run(ctx context.Context, w io.Writer, args map[string]string) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
@@ -103,6 +135,10 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return nil return nil
} }
// Setups a channel to listen for os.Signal
handleMaintSignals(conn, httpServer)
// Runs the http server
go func() { go func() {
fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr) fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
@@ -110,6 +146,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
} }
}() }()
// Handles graceful shutdown
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {