Changed to using atomics as mutex was causing deadlock
This commit is contained in:
@@ -4,17 +4,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Wraps the database handle, providing a mutex to safely manage transactions
|
|
||||||
type SafeConn struct {
|
type SafeConn struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
mux sync.RWMutex
|
readLockCount int32
|
||||||
|
globalLockStatus int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeSafe(db *sql.DB) *SafeConn {
|
func MakeSafe(db *sql.DB) *SafeConn {
|
||||||
@@ -27,24 +27,63 @@ type SafeTX struct {
|
|||||||
sc *SafeConn
|
sc *SafeConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *SafeConn) acquireGlobalLock() bool {
|
||||||
|
if atomic.LoadInt32(&conn.readLockCount) > 0 || atomic.LoadInt32(&conn.globalLockStatus) == 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
atomic.StoreInt32(&conn.globalLockStatus, 1)
|
||||||
|
fmt.Println("=====================GLOBAL LOCK ACQUIRED==================")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *SafeConn) releaseGlobalLock() {
|
||||||
|
atomic.StoreInt32(&conn.globalLockStatus, 0)
|
||||||
|
fmt.Println("=====================GLOBAL LOCK RELEASED==================")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *SafeConn) acquireReadLock() bool {
|
||||||
|
if atomic.LoadInt32(&conn.globalLockStatus) == 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
atomic.AddInt32(&conn.readLockCount, 1)
|
||||||
|
fmt.Println("=====================READ LOCK ACQUIRED==================")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *SafeConn) releaseReadLock() {
|
||||||
|
atomic.AddInt32(&conn.readLockCount, -1)
|
||||||
|
fmt.Println("=====================READ LOCK RELEASED==================")
|
||||||
|
}
|
||||||
|
|
||||||
// Starts a new transaction based on the current context. Will cancel if
|
// Starts a new transaction based on the current context. Will cancel if
|
||||||
// the context is closed/cancelled/done
|
// the context is closed/cancelled/done
|
||||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
||||||
lockAcquired := make(chan struct{})
|
lockAcquired := make(chan struct{})
|
||||||
|
lockCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn.mux.RLock()
|
select {
|
||||||
close(lockAcquired)
|
case <-lockCtx.Done():
|
||||||
|
fmt.Println("=====================READ LOCK ABANDONED==================")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if conn.acquireReadLock() {
|
||||||
|
close(lockAcquired) // Lock acquired
|
||||||
|
}
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-lockAcquired:
|
case <-lockAcquired:
|
||||||
tx, err := conn.db.BeginTx(ctx, nil)
|
tx, err := conn.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.mux.RUnlock()
|
conn.releaseReadLock()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &SafeTX{tx: tx, sc: conn}, nil
|
return &SafeTX{tx: tx, sc: conn}, nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
cancel()
|
||||||
return nil, errors.New("Transaction time out due to database lock")
|
return nil, errors.New("Transaction time out due to database lock")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,7 +120,7 @@ func (stx *SafeTX) Commit() error {
|
|||||||
err := stx.tx.Commit()
|
err := stx.tx.Commit()
|
||||||
stx.tx = nil
|
stx.tx = nil
|
||||||
|
|
||||||
stx.releaseLock()
|
stx.sc.releaseReadLock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,31 +131,30 @@ func (stx *SafeTX) Rollback() error {
|
|||||||
}
|
}
|
||||||
err := stx.tx.Rollback()
|
err := stx.tx.Rollback()
|
||||||
stx.tx = nil
|
stx.tx = nil
|
||||||
stx.releaseLock()
|
stx.sc.releaseReadLock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the read lock for the transaction
|
|
||||||
func (stx *SafeTX) releaseLock() {
|
|
||||||
if stx.sc != nil {
|
|
||||||
stx.sc.mux.RUnlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pause blocks new transactions for a backup.
|
// Pause blocks new transactions for a backup.
|
||||||
func (conn *SafeConn) Pause() {
|
func (conn *SafeConn) Pause() {
|
||||||
conn.mux.Lock() // Blocks all new transactions
|
for !conn.acquireGlobalLock() {
|
||||||
|
// TODO: add a timeout?
|
||||||
|
}
|
||||||
|
fmt.Println("Global database lock acquired")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resume allows transactions to proceed.
|
// Resume allows transactions to proceed.
|
||||||
func (conn *SafeConn) Resume() {
|
func (conn *SafeConn) Resume() {
|
||||||
conn.mux.Unlock()
|
conn.releaseGlobalLock()
|
||||||
|
fmt.Println("Global database lock released")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the database connection
|
// Close the database connection
|
||||||
func (conn *SafeConn) Close() error {
|
func (conn *SafeConn) Close() error {
|
||||||
conn.mux.Lock()
|
fmt.Println("=====================DB LOCKING FOR SHUTDOWN==================")
|
||||||
defer conn.mux.Unlock()
|
conn.acquireGlobalLock()
|
||||||
|
defer conn.releaseGlobalLock()
|
||||||
|
fmt.Println("=====================DB LOCKED FOR SHUTDOWN==================")
|
||||||
return conn.db.Close()
|
return conn.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,13 +25,12 @@ func WithTransaction(
|
|||||||
),
|
),
|
||||||
) {
|
) {
|
||||||
// Create a cancellable context from the request context
|
// Create a cancellable context from the request context
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := conn.Begin(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
|
||||||
logger.Warn().Err(err).Msg("Request failed to start a transaction")
|
logger.Warn().Err(err).Msg("Request failed to start a transaction")
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
page.Error(
|
page.Error(
|
||||||
|
|||||||
37
main.go
37
main.go
@@ -13,6 +13,8 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/config"
|
||||||
@@ -43,6 +45,36 @@ func getStaticFiles() (http.FileSystem, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var maint uint32 // atomic: 1 if in maintenance mode
|
||||||
|
|
||||||
|
func handleMaintSignals(conn *db.SafeConn, srv *http.Server) {
|
||||||
|
ch := make(chan os.Signal, 1)
|
||||||
|
srv.RegisterOnShutdown(func() {
|
||||||
|
close(ch)
|
||||||
|
})
|
||||||
|
go func() {
|
||||||
|
for sig := range ch {
|
||||||
|
switch sig {
|
||||||
|
case syscall.SIGUSR1:
|
||||||
|
if atomic.LoadUint32(&maint) != 1 {
|
||||||
|
atomic.StoreUint32(&maint, 1)
|
||||||
|
fmt.Println("Signal received: Starting maintenance")
|
||||||
|
fmt.Println("Attempting to acquire database lock")
|
||||||
|
conn.Pause()
|
||||||
|
}
|
||||||
|
case syscall.SIGUSR2:
|
||||||
|
if atomic.LoadUint32(&maint) != 0 {
|
||||||
|
fmt.Println("Signal received: Maintenance over")
|
||||||
|
fmt.Println("Releasing database lock")
|
||||||
|
conn.Resume()
|
||||||
|
atomic.StoreUint32(&maint, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
|
||||||
|
}
|
||||||
|
|
||||||
// Initializes and runs the server
|
// Initializes and runs the server
|
||||||
func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||||
@@ -103,6 +135,10 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setups a channel to listen for os.Signal
|
||||||
|
handleMaintSignals(conn, httpServer)
|
||||||
|
|
||||||
|
// Runs the http server
|
||||||
go func() {
|
go func() {
|
||||||
fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr)
|
fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr)
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
@@ -110,6 +146,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Handles graceful shutdown
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
Reference in New Issue
Block a user