From 789d1e75a7b095982b99915421b29fb40811eb68 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 18 Feb 2025 22:49:43 +1100 Subject: [PATCH] Added tests for SafeConn and SafeTx --- db/connection.go | 12 ++-- db/connection_test.go | 134 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 8 deletions(-) create mode 100644 db/connection_test.go diff --git a/db/connection.go b/db/connection.go index 7a19106..56652e2 100644 --- a/db/connection.go +++ b/db/connection.go @@ -144,21 +144,19 @@ func (stx *SafeTX) Rollback() error { // Acquire a global lock, preventing all transactions func (conn *SafeConn) Pause(timeoutAfter time.Duration) { - // force logger to log to Stdout so the signalling process can check - log := conn.logger.With().Logger().Output(os.Stdout) - log.Info().Msg("Attempting to acquire global database lock") + conn.logger.Info().Msg("Attempting to acquire global database lock") conn.globalLockRequested = 1 defer func() { conn.globalLockRequested = 0 }() timeout := time.After(timeoutAfter) attempt := 0 for { if conn.acquireGlobalLock() { - log.Info().Msg("Global database lock acquired") + conn.logger.Info().Msg("Global database lock acquired") return } select { case <-timeout: - log.Info().Msg("Timeout: Global database lock abandoned") + conn.logger.Info().Msg("Timeout: Global database lock abandoned") return case <-time.After(100 * time.Millisecond): attempt++ @@ -169,9 +167,7 @@ func (conn *SafeConn) Pause(timeoutAfter time.Duration) { // Release the global lock func (conn *SafeConn) Resume() { conn.releaseGlobalLock() - // force logger to log to Stdout - log := conn.logger.With().Logger().Output(os.Stdout) - log.Info().Msg("Global database lock released") + conn.logger.Info().Msg("Global database lock released") } // Close the database connection diff --git a/db/connection_test.go b/db/connection_test.go new file mode 100644 index 0000000..9723526 --- /dev/null +++ b/db/connection_test.go @@ -0,0 +1,134 @@ +package db + +import ( + "context" + "projectreshoot/tests" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSafeConn(t *testing.T) { + logger := tests.NilLogger() + conn, err := tests.SetupTestDB() + require.NoError(t, err) + sconn := MakeSafe(conn, logger) + defer sconn.Close() + + t.Run("Global lock waits for read locks to finish", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + var requested sync.WaitGroup + var engaged sync.WaitGroup + requested.Add(1) + engaged.Add(1) + go func() { + requested.Done() + sconn.Pause(5 * time.Second) + engaged.Done() + }() + requested.Wait() + assert.Equal(t, uint32(0), sconn.globalLockStatus) + assert.Equal(t, uint32(1), sconn.globalLockRequested) + tx.Commit() + engaged.Wait() + assert.Equal(t, uint32(1), sconn.globalLockStatus) + assert.Equal(t, uint32(0), sconn.globalLockRequested) + sconn.Resume() + }) + t.Run("Lock abandons after timeout", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + sconn.Pause(250 * time.Millisecond) + assert.Equal(t, uint32(0), sconn.globalLockStatus) + assert.Equal(t, uint32(0), sconn.globalLockRequested) + tx.Commit() + }) + t.Run("Pause blocks transactions and resume allows", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + var requested sync.WaitGroup + var engaged sync.WaitGroup + requested.Add(1) + engaged.Add(1) + go func() { + requested.Done() + sconn.Pause(5 * time.Second) + engaged.Done() + }() + requested.Wait() + assert.Equal(t, uint32(0), sconn.globalLockStatus) + assert.Equal(t, uint32(1), sconn.globalLockRequested) + ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond) + defer cancel() + _, err = sconn.Begin(ctx) + require.Error(t, err) + tx.Commit() + engaged.Wait() + _, err = sconn.Begin(ctx) + require.Error(t, err) + sconn.Resume() + tx, err = sconn.Begin(t.Context()) + require.NoError(t, err) + tx.Commit() + }) +} +func TestSafeTX(t *testing.T) { + logger := tests.NilLogger() + conn, err := tests.SetupTestDB() + require.NoError(t, err) + sconn := MakeSafe(conn, logger) + defer sconn.Close() + + t.Run("Commit releases lock", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + assert.Equal(t, uint32(1), sconn.readLockCount) + tx.Commit() + assert.Equal(t, uint32(0), sconn.readLockCount) + }) + t.Run("Rollback releases lock", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + assert.Equal(t, uint32(1), sconn.readLockCount) + tx.Rollback() + assert.Equal(t, uint32(0), sconn.readLockCount) + }) + t.Run("Multiple TX can gain read lock", func(t *testing.T) { + tx1, err := sconn.Begin(t.Context()) + require.NoError(t, err) + tx2, err := sconn.Begin(t.Context()) + require.NoError(t, err) + tx3, err := sconn.Begin(t.Context()) + require.NoError(t, err) + tx1.Commit() + tx2.Commit() + tx3.Commit() + }) + t.Run("Lock acquiring times out after timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond) + defer cancel() + sconn.acquireGlobalLock() + defer sconn.releaseGlobalLock() + _, err := sconn.Begin(ctx) + require.Error(t, err) + }) + t.Run("Lock acquires if lock released", func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond) + defer cancel() + sconn.acquireGlobalLock() + var wg sync.WaitGroup + wg.Add(1) + go func() { + tx, err := sconn.Begin(ctx) + require.NoError(t, err) + tx.Commit() + wg.Done() + }() + sconn.releaseGlobalLock() + wg.Wait() + }) +}