refactor: changed file structure

This commit is contained in:
2025-03-05 20:18:28 +11:00
parent 5c1089e0ce
commit 1d9af44d0a
137 changed files with 4986 additions and 581 deletions

68
pkg/db/connection.go Normal file
View File

@@ -0,0 +1,68 @@
package db
import (
"database/sql"
"fmt"
"strconv"
"github.com/pkg/errors"
"github.com/rs/zerolog"
_ "github.com/mattn/go-sqlite3"
)
// Returns a database connection handle for the DB
func ConnectToDatabase(
dbName string,
logger *zerolog.Logger,
) (*SafeConn, error) {
opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
file := fmt.Sprintf("file:%s.db?%s", dbName, opts)
wconn, err := sql.Open("sqlite3", file)
if err != nil {
return nil, errors.Wrap(err, "sql.Open (rw)")
}
wconn.SetMaxOpenConns(1)
opts = "_synchronous=NORMAL&mode=ro"
file = fmt.Sprintf("file:%s.db?%s", dbName, opts)
rconn, err := sql.Open("sqlite3", file)
if err != nil {
return nil, errors.Wrap(err, "sql.Open (ro)")
}
version, err := strconv.Atoi(dbName)
if err != nil {
return nil, errors.Wrap(err, "strconv.Atoi")
}
err = checkDBVersion(rconn, version)
if err != nil {
return nil, errors.Wrap(err, "checkDBVersion")
}
conn := MakeSafe(wconn, rconn, logger)
return conn, nil
}
// Check the database version
func checkDBVersion(db *sql.DB, expectVer int) error {
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
ORDER BY version_id DESC LIMIT 1`
rows, err := db.Query(query)
if err != nil {
return errors.Wrap(err, "db.Query")
}
defer rows.Close()
if rows.Next() {
var version int
err = rows.Scan(&version)
if err != nil {
return errors.Wrap(err, "rows.Scan")
}
if version != expectVer {
return errors.New("Version mismatch")
}
} else {
return errors.New("No version found")
}
return nil
}

162
pkg/db/safeconn.go Normal file
View File

@@ -0,0 +1,162 @@
package db
import (
"context"
"database/sql"
"time"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
type SafeConn struct {
wconn *sql.DB
rconn *sql.DB
readLockCount uint32
globalLockStatus uint32
globalLockRequested uint32
logger *zerolog.Logger
}
// Make the provided db handle safe and attach a logger to it
func MakeSafe(wconn *sql.DB, rconn *sql.DB, logger *zerolog.Logger) *SafeConn {
return &SafeConn{wconn: wconn, rconn: rconn, logger: logger}
}
// Attempts to acquire a global lock on the database connection
func (conn *SafeConn) acquireGlobalLock() bool {
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
return false
}
conn.globalLockStatus = 1
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock acquired")
return true
}
// Releases a global lock on the database connection
func (conn *SafeConn) releaseGlobalLock() {
conn.globalLockStatus = 0
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock released")
}
// Acquire a read lock on the connection. Multiple read locks can be acquired
// at the same time
func (conn *SafeConn) acquireReadLock() bool {
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
return false
}
conn.readLockCount += 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock acquired")
return true
}
// Release a read lock. Decrements read lock count by 1
func (conn *SafeConn) releaseReadLock() {
conn.readLockCount -= 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock released")
}
// Starts a new transaction based on the current context. Will cancel if
// the context is closed/cancelled/done
func (conn *SafeConn) Begin(ctx context.Context) (*SafeWTX, error) {
lockAcquired := make(chan struct{})
lockCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-lockCtx.Done():
return
default:
if conn.acquireReadLock() {
close(lockAcquired)
}
}
}()
select {
case <-lockAcquired:
tx, err := conn.wconn.BeginTx(ctx, nil)
if err != nil {
conn.releaseReadLock()
return nil, err
}
return &SafeWTX{tx: tx, sc: conn}, nil
case <-ctx.Done():
cancel()
return nil, errors.New("Transaction time out due to database lock")
}
}
// Starts a new READONLY transaction based on the current context. Will cancel if
// the context is closed/cancelled/done
func (conn *SafeConn) RBegin(ctx context.Context) (*SafeRTX, error) {
lockAcquired := make(chan struct{})
lockCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-lockCtx.Done():
return
default:
if conn.acquireReadLock() {
close(lockAcquired)
}
}
}()
select {
case <-lockAcquired:
tx, err := conn.rconn.BeginTx(ctx, nil)
if err != nil {
conn.releaseReadLock()
return nil, err
}
return &SafeRTX{tx: tx, sc: conn}, nil
case <-ctx.Done():
cancel()
return nil, errors.New("Transaction time out due to database lock")
}
}
// Acquire a global lock, preventing all transactions
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
conn.logger.Info().Msg("Attempting to acquire global database lock")
conn.globalLockRequested = 1
defer func() { conn.globalLockRequested = 0 }()
timeout := time.After(timeoutAfter)
attempt := 0
for {
if conn.acquireGlobalLock() {
conn.logger.Info().Msg("Global database lock acquired")
return
}
select {
case <-timeout:
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
return
case <-time.After(100 * time.Millisecond):
attempt++
}
}
}
// Release the global lock
func (conn *SafeConn) Resume() {
conn.releaseGlobalLock()
conn.logger.Info().Msg("Global database lock released")
}
// Close the database connection
func (conn *SafeConn) Close() error {
conn.logger.Debug().Msg("Acquiring global lock for connection close")
conn.acquireGlobalLock()
defer conn.releaseGlobalLock()
conn.logger.Debug().Msg("Closing database connection")
return conn.wconn.Close()
}

143
pkg/db/safeconntx_test.go Normal file
View File

@@ -0,0 +1,143 @@
package db
import (
"context"
"projectreshoot/pkg/tests"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSafeConn(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
wconn, rconn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := MakeSafe(wconn, rconn, logger)
defer sconn.Close()
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
var requested sync.WaitGroup
var engaged sync.WaitGroup
requested.Add(1)
engaged.Add(1)
go func() {
requested.Done()
sconn.Pause(5 * time.Second)
engaged.Done()
}()
requested.Wait()
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(1), sconn.globalLockRequested)
tx.Commit()
engaged.Wait()
assert.Equal(t, uint32(1), sconn.globalLockStatus)
assert.Equal(t, uint32(0), sconn.globalLockRequested)
sconn.Resume()
})
t.Run("Lock abandons after timeout", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
sconn.Pause(250 * time.Millisecond)
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(0), sconn.globalLockRequested)
tx.Commit()
})
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
var requested sync.WaitGroup
var engaged sync.WaitGroup
requested.Add(1)
engaged.Add(1)
go func() {
requested.Done()
sconn.Pause(5 * time.Second)
engaged.Done()
}()
requested.Wait()
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(1), sconn.globalLockRequested)
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
_, err = sconn.Begin(ctx)
require.Error(t, err)
tx.Commit()
engaged.Wait()
_, err = sconn.Begin(ctx)
require.Error(t, err)
sconn.Resume()
tx, err = sconn.Begin(t.Context())
require.NoError(t, err)
tx.Commit()
})
}
func TestSafeTX(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
wconn, rconn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := MakeSafe(wconn, rconn, logger)
defer sconn.Close()
t.Run("Commit releases lock", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
assert.Equal(t, uint32(1), sconn.readLockCount)
tx.Commit()
assert.Equal(t, uint32(0), sconn.readLockCount)
})
t.Run("Rollback releases lock", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
assert.Equal(t, uint32(1), sconn.readLockCount)
tx.Rollback()
assert.Equal(t, uint32(0), sconn.readLockCount)
})
t.Run("Multiple RTX can gain read lock", func(t *testing.T) {
tx1, err := sconn.RBegin(t.Context())
require.NoError(t, err)
tx2, err := sconn.RBegin(t.Context())
require.NoError(t, err)
tx3, err := sconn.RBegin(t.Context())
require.NoError(t, err)
tx1.Commit()
tx2.Commit()
tx3.Commit()
})
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
sconn.acquireGlobalLock()
defer sconn.releaseGlobalLock()
_, err := sconn.Begin(ctx)
require.Error(t, err)
})
t.Run("Lock acquires if lock released", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
sconn.acquireGlobalLock()
var wg sync.WaitGroup
wg.Add(1)
go func() {
tx, err := sconn.Begin(ctx)
require.NoError(t, err)
tx.Commit()
wg.Done()
}()
sconn.releaseGlobalLock()
wg.Wait()
})
}

163
pkg/db/safetx.go Normal file
View File

@@ -0,0 +1,163 @@
package db
import (
"context"
"database/sql"
"regexp"
"strings"
"github.com/pkg/errors"
)
type SafeTX interface {
Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRow(ctx context.Context, query string, args ...interface{}) (*sql.Row, error)
Commit() error
Rollback() error
}
// Extends sql.Tx for use with SafeConn
type SafeWTX struct {
tx *sql.Tx
sc *SafeConn
}
type SafeRTX struct {
tx *sql.Tx
sc *SafeConn
}
func isWriteOperation(query string) bool {
query = strings.TrimSpace(query)
query = strings.ToUpper(query)
writeOpsRegex := `^(INSERT|UPDATE|DELETE|REPLACE|MERGE|CREATE|DROP|ALTER|TRUNCATE)\s+`
re := regexp.MustCompile(writeOpsRegex)
return re.MatchString(query)
}
// Query the database inside the transaction
func (stx *SafeRTX) Query(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Rows, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
if isWriteOperation(query) {
return nil, errors.New("Cannot query with a write operation")
}
rows, err := stx.tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "tx.QueryContext")
}
return rows, nil
}
// Query the database inside the transaction
func (stx *SafeWTX) Query(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Rows, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
if isWriteOperation(query) {
return nil, errors.New("Cannot query with a write operation")
}
rows, err := stx.tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "tx.QueryContext")
}
return rows, nil
}
// Query a row from the database inside the transaction
func (stx *SafeRTX) QueryRow(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Row, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
if isWriteOperation(query) {
return nil, errors.New("Cannot query with a write operation")
}
return stx.tx.QueryRowContext(ctx, query, args...), nil
}
// Query a row from the database inside the transaction
func (stx *SafeWTX) QueryRow(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Row, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
if isWriteOperation(query) {
return nil, errors.New("Cannot query with a write operation")
}
return stx.tx.QueryRowContext(ctx, query, args...), nil
}
// Exec a statement on the database inside the transaction
func (stx *SafeWTX) Exec(
ctx context.Context,
query string,
args ...interface{},
) (sql.Result, error) {
if stx.tx == nil {
return nil, errors.New("Cannot exec without a transaction")
}
res, err := stx.tx.ExecContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "tx.ExecContext")
}
return res, nil
}
// Commit the current transaction and release the read lock
func (stx *SafeRTX) Commit() error {
if stx.tx == nil {
return errors.New("Cannot commit without a transaction")
}
err := stx.tx.Commit()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Commit the current transaction and release the read lock
func (stx *SafeWTX) Commit() error {
if stx.tx == nil {
return errors.New("Cannot commit without a transaction")
}
err := stx.tx.Commit()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Abort the current transaction, releasing the read lock
func (stx *SafeRTX) Rollback() error {
if stx.tx == nil {
return errors.New("Cannot rollback without a transaction")
}
err := stx.tx.Rollback()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Abort the current transaction, releasing the read lock
func (stx *SafeWTX) Rollback() error {
if stx.tx == nil {
return errors.New("Cannot rollback without a transaction")
}
err := stx.tx.Rollback()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}