Added proper debug logging to safeconn methods

This commit is contained in:
2025-02-18 10:02:42 +11:00
parent b4b57c14cb
commit c2de8d254a
9 changed files with 83 additions and 45 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
query.sql
*.db
.logs/
server.log
tmp/
projectreshoot
static/css/output.css

View File

@@ -4,21 +4,23 @@ import (
"context"
"database/sql"
"fmt"
"sync/atomic"
"os"
"github.com/pkg/errors"
"github.com/rs/zerolog"
_ "modernc.org/sqlite"
)
type SafeConn struct {
db *sql.DB
readLockCount int32
globalLockStatus int32
readLockCount uint32
globalLockStatus uint32
logger *zerolog.Logger
}
func MakeSafe(db *sql.DB) *SafeConn {
return &SafeConn{db: db}
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
return &SafeConn{db: db, logger: logger}
}
// Extends sql.Tx for use with SafeConn
@@ -28,27 +30,35 @@ type SafeTX struct {
}
func (conn *SafeConn) acquireGlobalLock() bool {
if atomic.LoadInt32(&conn.readLockCount) > 0 || atomic.LoadInt32(&conn.globalLockStatus) == 1 {
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
return false
}
atomic.StoreInt32(&conn.globalLockStatus, 1)
conn.globalLockStatus = 1
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock acquired")
return true
}
func (conn *SafeConn) releaseGlobalLock() {
atomic.StoreInt32(&conn.globalLockStatus, 0)
conn.globalLockStatus = 0
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock released")
}
func (conn *SafeConn) acquireReadLock() bool {
if atomic.LoadInt32(&conn.globalLockStatus) == 1 {
if conn.globalLockStatus == 1 {
return false
}
atomic.AddInt32(&conn.readLockCount, 1)
conn.readLockCount += 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock acquired")
return true
}
func (conn *SafeConn) releaseReadLock() {
atomic.AddInt32(&conn.readLockCount, -1)
conn.readLockCount -= 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock released")
}
// Starts a new transaction based on the current context. Will cancel if
@@ -134,32 +144,40 @@ func (stx *SafeTX) Rollback() error {
func (conn *SafeConn) Pause() {
for !conn.acquireGlobalLock() {
// TODO: add a timeout?
// TODO: failed to acquire lock: print info with readLockCount
// every second, or update it dynamically
}
fmt.Println("Global database lock acquired")
// force logger to log to Stdout
log := conn.logger.With().Logger().Output(os.Stdout)
log.Info().Msg("Global database lock acquired")
}
// Resume allows transactions to proceed.
func (conn *SafeConn) Resume() {
conn.releaseGlobalLock()
fmt.Println("Global database lock released")
// force logger to log to Stdout
log := conn.logger.With().Logger().Output(os.Stdout)
log.Info().Msg("Global database lock released")
}
// Close the database connection
func (conn *SafeConn) Close() error {
conn.logger.Debug().Msg("Acquiring global lock for connection close")
conn.acquireGlobalLock()
defer conn.releaseGlobalLock()
conn.logger.Debug().Msg("Closing database connection")
return conn.db.Close()
}
// Returns a database connection handle for the DB
func ConnectToDatabase(dbName string) (*SafeConn, error) {
func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) {
file := fmt.Sprintf("file:%s.db", dbName)
db, err := sql.Open("sqlite", file)
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
conn := MakeSafe(db)
conn := MakeSafe(db, logger)
return conn, nil
}

2
go.mod
View File

@@ -23,7 +23,9 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.61.13 // indirect
modernc.org/mathutil v1.7.1 // indirect

12
go.sum
View File

@@ -16,6 +16,11 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@@ -42,8 +47,8 @@ golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtD
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -51,8 +56,9 @@ golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=

33
main.go
View File

@@ -23,20 +23,21 @@ import (
"projectreshoot/server"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
//go:embed static/*
var embeddedStatic embed.FS
// Gets the static files
func getStaticFiles() (http.FileSystem, error) {
func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) {
if _, err := os.Stat("static"); err == nil {
// Use actual filesystem in development
fmt.Println("Using filesystem for static files")
logger.Debug().Msg("Using filesystem for static files")
return http.Dir("static"), nil
} else {
// Use embedded filesystem in production
fmt.Println("Using embedded static files")
logger.Debug().Msg("Using embedded static files")
subFS, err := fs.Sub(embeddedStatic, "static")
if err != nil {
return nil, errors.Wrap(err, "fs.Sub")
@@ -47,7 +48,7 @@ func getStaticFiles() (http.FileSystem, error) {
var maint uint32 // atomic: 1 if in maintenance mode
func handleMaintSignals(conn *db.SafeConn, srv *http.Server) {
func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Logger) {
ch := make(chan os.Signal, 1)
srv.RegisterOnShutdown(func() {
close(ch)
@@ -58,14 +59,16 @@ func handleMaintSignals(conn *db.SafeConn, srv *http.Server) {
case syscall.SIGUSR1:
if atomic.LoadUint32(&maint) != 1 {
atomic.StoreUint32(&maint, 1)
fmt.Println("Signal received: Starting maintenance")
fmt.Println("Attempting to acquire database lock")
log := logger.With().Logger().Output(os.Stdout)
log.Info().Msg("Signal received: Starting maintenance")
log.Info().Msg("Attempting to acquire database lock")
conn.Pause()
}
case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 {
fmt.Println("Signal received: Maintenance over")
fmt.Println("Releasing database lock")
log := logger.With().Logger().Output(os.Stdout)
log.Info().Msg("Signal received: Maintenance over")
log.Info().Msg("Releasing database lock")
conn.Resume()
atomic.StoreUint32(&maint, 0)
}
@@ -109,13 +112,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return errors.Wrap(err, "logging.GetLogger")
}
conn, err := db.ConnectToDatabase(config.DBName)
conn, err := db.ConnectToDatabase(config.DBName, logger)
if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase")
}
defer conn.Close()
staticFS, err := getStaticFiles()
staticFS, err := getStaticFiles(logger)
if err != nil {
return errors.Wrap(err, "getStaticFiles")
}
@@ -136,13 +139,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
}
// Setups a channel to listen for os.Signal
handleMaintSignals(conn, httpServer)
handleMaintSignals(conn, httpServer, logger)
// Runs the http server
go func() {
fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr)
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Fprintf(os.Stderr, "Error listening and serving: %s\n", err)
logger.Error().Err(err).Msg("Error listening and serving")
}
}()
@@ -156,11 +159,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
fmt.Fprintf(os.Stderr, "Error shutting down http server: %s\n", err)
logger.Error().Err(err).Msg("Error shutting down server")
}
}()
wg.Wait()
fmt.Fprintln(w, "Shutting down")
logger.Info().Msg("Shutting down")
return nil
}

View File

@@ -115,8 +115,8 @@ func Authentication(
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
// Failed to start transaction, warn the user they cant login right now
logger.Warn().Err(err).Msg("Request failed to start a transaction")
// Failed to start transaction, send 503 code to client
logger.Warn().Err(err).Msg("Skipping Auth - unable to start a transaction")
w.WriteHeader(http.StatusServiceUnavailable)
next.ServeHTTP(w, r)
return

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
"projectreshoot/contexts"
@@ -16,15 +17,15 @@ import (
)
func TestAuthenticationMiddleware(t *testing.T) {
logger := tests.NilLogger()
// Basic setup
conn, err := tests.SetupTestDB()
require.NoError(t, err)
sconn := db.MakeSafe(conn)
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.DebugLogger(t)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -38,9 +39,10 @@ func TestAuthenticationMiddleware(t *testing.T) {
w.Write([]byte(strconv.Itoa(user.ID)))
}
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
authHandler := Authentication(logger, cfg, sconn, testHandler)
authHandler := Authentication(logger, cfg, sconn, testHandler, &maint)
require.NoError(t, err)
server := httptest.NewServer(authHandler)
defer server.Close()

View File

@@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"projectreshoot/db"
@@ -13,24 +14,26 @@ import (
)
func TestPageLoginRequired(t *testing.T) {
logger := tests.NilLogger()
// Basic setup
conn, err := tests.SetupTestDB()
require.NoError(t, err)
sconn := db.MakeSafe(conn)
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.DebugLogger(t)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
loginRequiredHandler := RequiresLogin(testHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler)
defer server.Close()

View File

@@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"projectreshoot/db"
@@ -13,25 +14,27 @@ import (
)
func TestReauthRequired(t *testing.T) {
logger := tests.NilLogger()
// Basic setup
conn, err := tests.SetupTestDB()
require.NoError(t, err)
sconn := db.MakeSafe(conn)
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.DebugLogger(t)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
reauthRequiredHandler := RequiresFresh(testHandler)
loginRequiredHandler := RequiresLogin(reauthRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler)
defer server.Close()