diff --git a/db/connection.go b/db/connection.go index 17311d3..432fe1e 100644 --- a/db/connection.go +++ b/db/connection.go @@ -4,17 +4,17 @@ import ( "context" "database/sql" "fmt" - "sync" + "sync/atomic" "github.com/pkg/errors" _ "modernc.org/sqlite" ) -// Wraps the database handle, providing a mutex to safely manage transactions type SafeConn struct { - db *sql.DB - mux sync.RWMutex + db *sql.DB + readLockCount int32 + globalLockStatus int32 } func MakeSafe(db *sql.DB) *SafeConn { @@ -27,24 +27,63 @@ type SafeTX struct { sc *SafeConn } +func (conn *SafeConn) acquireGlobalLock() bool { + if atomic.LoadInt32(&conn.readLockCount) > 0 || atomic.LoadInt32(&conn.globalLockStatus) == 1 { + return false + } + atomic.StoreInt32(&conn.globalLockStatus, 1) + fmt.Println("=====================GLOBAL LOCK ACQUIRED==================") + return true +} + +func (conn *SafeConn) releaseGlobalLock() { + atomic.StoreInt32(&conn.globalLockStatus, 0) + fmt.Println("=====================GLOBAL LOCK RELEASED==================") +} + +func (conn *SafeConn) acquireReadLock() bool { + if atomic.LoadInt32(&conn.globalLockStatus) == 1 { + return false + } + atomic.AddInt32(&conn.readLockCount, 1) + fmt.Println("=====================READ LOCK ACQUIRED==================") + return true +} + +func (conn *SafeConn) releaseReadLock() { + atomic.AddInt32(&conn.readLockCount, -1) + fmt.Println("=====================READ LOCK RELEASED==================") +} + // Starts a new transaction based on the current context. Will cancel if // the context is closed/cancelled/done func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { lockAcquired := make(chan struct{}) + lockCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { - conn.mux.RLock() - close(lockAcquired) + select { + case <-lockCtx.Done(): + fmt.Println("=====================READ LOCK ABANDONED==================") + return + default: + if conn.acquireReadLock() { + close(lockAcquired) // Lock acquired + } + } }() select { case <-lockAcquired: tx, err := conn.db.BeginTx(ctx, nil) if err != nil { - conn.mux.RUnlock() + conn.releaseReadLock() return nil, err } return &SafeTX{tx: tx, sc: conn}, nil case <-ctx.Done(): + cancel() return nil, errors.New("Transaction time out due to database lock") } } @@ -81,7 +120,7 @@ func (stx *SafeTX) Commit() error { err := stx.tx.Commit() stx.tx = nil - stx.releaseLock() + stx.sc.releaseReadLock() return err } @@ -92,31 +131,30 @@ func (stx *SafeTX) Rollback() error { } err := stx.tx.Rollback() stx.tx = nil - stx.releaseLock() + stx.sc.releaseReadLock() return err } -// Release the read lock for the transaction -func (stx *SafeTX) releaseLock() { - if stx.sc != nil { - stx.sc.mux.RUnlock() - } -} - // Pause blocks new transactions for a backup. func (conn *SafeConn) Pause() { - conn.mux.Lock() // Blocks all new transactions + for !conn.acquireGlobalLock() { + // TODO: add a timeout? + } + fmt.Println("Global database lock acquired") } // Resume allows transactions to proceed. func (conn *SafeConn) Resume() { - conn.mux.Unlock() + conn.releaseGlobalLock() + fmt.Println("Global database lock released") } // Close the database connection func (conn *SafeConn) Close() error { - conn.mux.Lock() - defer conn.mux.Unlock() + fmt.Println("=====================DB LOCKING FOR SHUTDOWN==================") + conn.acquireGlobalLock() + defer conn.releaseGlobalLock() + fmt.Println("=====================DB LOCKED FOR SHUTDOWN==================") return conn.db.Close() } diff --git a/handlers/withtransaction.go b/handlers/withtransaction.go index d81884f..db963fc 100644 --- a/handlers/withtransaction.go +++ b/handlers/withtransaction.go @@ -25,13 +25,12 @@ func WithTransaction( ), ) { // Create a cancellable context from the request context - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second) defer cancel() // Start the transaction tx, err := conn.Begin(ctx) if err != nil { - tx.Rollback() logger.Warn().Err(err).Msg("Request failed to start a transaction") w.WriteHeader(http.StatusServiceUnavailable) page.Error( diff --git a/main.go b/main.go index 84ed2fd..03406c6 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,8 @@ import ( "os/signal" "strconv" "sync" + "sync/atomic" + "syscall" "time" "projectreshoot/config" @@ -43,6 +45,36 @@ func getStaticFiles() (http.FileSystem, error) { } } +var maint uint32 // atomic: 1 if in maintenance mode + +func handleMaintSignals(conn *db.SafeConn, srv *http.Server) { + ch := make(chan os.Signal, 1) + srv.RegisterOnShutdown(func() { + close(ch) + }) + go func() { + for sig := range ch { + switch sig { + case syscall.SIGUSR1: + if atomic.LoadUint32(&maint) != 1 { + atomic.StoreUint32(&maint, 1) + fmt.Println("Signal received: Starting maintenance") + fmt.Println("Attempting to acquire database lock") + conn.Pause() + } + case syscall.SIGUSR2: + if atomic.LoadUint32(&maint) != 0 { + fmt.Println("Signal received: Maintenance over") + fmt.Println("Releasing database lock") + conn.Resume() + atomic.StoreUint32(&maint, 0) + } + } + } + }() + signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2) +} + // Initializes and runs the server func run(ctx context.Context, w io.Writer, args map[string]string) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) @@ -103,6 +135,10 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return nil } + // Setups a channel to listen for os.Signal + handleMaintSignals(conn, httpServer) + + // Runs the http server go func() { fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr) if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { @@ -110,6 +146,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } }() + // Handles graceful shutdown var wg sync.WaitGroup wg.Add(1) go func() {