Added proper debug logging to safeconn methods
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
query.sql
|
||||
*.db
|
||||
.logs/
|
||||
server.log
|
||||
tmp/
|
||||
projectreshoot
|
||||
static/css/output.css
|
||||
|
||||
@@ -4,21 +4,23 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type SafeConn struct {
|
||||
db *sql.DB
|
||||
readLockCount int32
|
||||
globalLockStatus int32
|
||||
readLockCount uint32
|
||||
globalLockStatus uint32
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func MakeSafe(db *sql.DB) *SafeConn {
|
||||
return &SafeConn{db: db}
|
||||
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
||||
return &SafeConn{db: db, logger: logger}
|
||||
}
|
||||
|
||||
// Extends sql.Tx for use with SafeConn
|
||||
@@ -28,27 +30,35 @@ type SafeTX struct {
|
||||
}
|
||||
|
||||
func (conn *SafeConn) acquireGlobalLock() bool {
|
||||
if atomic.LoadInt32(&conn.readLockCount) > 0 || atomic.LoadInt32(&conn.globalLockStatus) == 1 {
|
||||
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
||||
return false
|
||||
}
|
||||
atomic.StoreInt32(&conn.globalLockStatus, 1)
|
||||
conn.globalLockStatus = 1
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *SafeConn) releaseGlobalLock() {
|
||||
atomic.StoreInt32(&conn.globalLockStatus, 0)
|
||||
conn.globalLockStatus = 0
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock released")
|
||||
}
|
||||
|
||||
func (conn *SafeConn) acquireReadLock() bool {
|
||||
if atomic.LoadInt32(&conn.globalLockStatus) == 1 {
|
||||
if conn.globalLockStatus == 1 {
|
||||
return false
|
||||
}
|
||||
atomic.AddInt32(&conn.readLockCount, 1)
|
||||
conn.readLockCount += 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *SafeConn) releaseReadLock() {
|
||||
atomic.AddInt32(&conn.readLockCount, -1)
|
||||
conn.readLockCount -= 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock released")
|
||||
}
|
||||
|
||||
// Starts a new transaction based on the current context. Will cancel if
|
||||
@@ -134,32 +144,40 @@ func (stx *SafeTX) Rollback() error {
|
||||
func (conn *SafeConn) Pause() {
|
||||
for !conn.acquireGlobalLock() {
|
||||
// TODO: add a timeout?
|
||||
// TODO: failed to acquire lock: print info with readLockCount
|
||||
// every second, or update it dynamically
|
||||
}
|
||||
fmt.Println("Global database lock acquired")
|
||||
// force logger to log to Stdout
|
||||
log := conn.logger.With().Logger().Output(os.Stdout)
|
||||
log.Info().Msg("Global database lock acquired")
|
||||
}
|
||||
|
||||
// Resume allows transactions to proceed.
|
||||
func (conn *SafeConn) Resume() {
|
||||
conn.releaseGlobalLock()
|
||||
fmt.Println("Global database lock released")
|
||||
// force logger to log to Stdout
|
||||
log := conn.logger.With().Logger().Output(os.Stdout)
|
||||
log.Info().Msg("Global database lock released")
|
||||
}
|
||||
|
||||
// Close the database connection
|
||||
func (conn *SafeConn) Close() error {
|
||||
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
||||
conn.acquireGlobalLock()
|
||||
defer conn.releaseGlobalLock()
|
||||
conn.logger.Debug().Msg("Closing database connection")
|
||||
return conn.db.Close()
|
||||
}
|
||||
|
||||
// Returns a database connection handle for the DB
|
||||
func ConnectToDatabase(dbName string) (*SafeConn, error) {
|
||||
func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) {
|
||||
file := fmt.Sprintf("file:%s.db", dbName)
|
||||
db, err := sql.Open("sqlite", file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
|
||||
conn := MakeSafe(db)
|
||||
conn := MakeSafe(db, logger)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -23,7 +23,9 @@ require (
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.61.13 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
12
go.sum
12
go.sum
@@ -16,6 +16,11 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
@@ -42,8 +47,8 @@ golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtD
|
||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -51,8 +56,9 @@ golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
||||
|
||||
33
main.go
33
main.go
@@ -23,20 +23,21 @@ import (
|
||||
"projectreshoot/server"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
var embeddedStatic embed.FS
|
||||
|
||||
// Gets the static files
|
||||
func getStaticFiles() (http.FileSystem, error) {
|
||||
func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) {
|
||||
if _, err := os.Stat("static"); err == nil {
|
||||
// Use actual filesystem in development
|
||||
fmt.Println("Using filesystem for static files")
|
||||
logger.Debug().Msg("Using filesystem for static files")
|
||||
return http.Dir("static"), nil
|
||||
} else {
|
||||
// Use embedded filesystem in production
|
||||
fmt.Println("Using embedded static files")
|
||||
logger.Debug().Msg("Using embedded static files")
|
||||
subFS, err := fs.Sub(embeddedStatic, "static")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fs.Sub")
|
||||
@@ -47,7 +48,7 @@ func getStaticFiles() (http.FileSystem, error) {
|
||||
|
||||
var maint uint32 // atomic: 1 if in maintenance mode
|
||||
|
||||
func handleMaintSignals(conn *db.SafeConn, srv *http.Server) {
|
||||
func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Logger) {
|
||||
ch := make(chan os.Signal, 1)
|
||||
srv.RegisterOnShutdown(func() {
|
||||
close(ch)
|
||||
@@ -58,14 +59,16 @@ func handleMaintSignals(conn *db.SafeConn, srv *http.Server) {
|
||||
case syscall.SIGUSR1:
|
||||
if atomic.LoadUint32(&maint) != 1 {
|
||||
atomic.StoreUint32(&maint, 1)
|
||||
fmt.Println("Signal received: Starting maintenance")
|
||||
fmt.Println("Attempting to acquire database lock")
|
||||
log := logger.With().Logger().Output(os.Stdout)
|
||||
log.Info().Msg("Signal received: Starting maintenance")
|
||||
log.Info().Msg("Attempting to acquire database lock")
|
||||
conn.Pause()
|
||||
}
|
||||
case syscall.SIGUSR2:
|
||||
if atomic.LoadUint32(&maint) != 0 {
|
||||
fmt.Println("Signal received: Maintenance over")
|
||||
fmt.Println("Releasing database lock")
|
||||
log := logger.With().Logger().Output(os.Stdout)
|
||||
log.Info().Msg("Signal received: Maintenance over")
|
||||
log.Info().Msg("Releasing database lock")
|
||||
conn.Resume()
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
}
|
||||
@@ -109,13 +112,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
return errors.Wrap(err, "logging.GetLogger")
|
||||
}
|
||||
|
||||
conn, err := db.ConnectToDatabase(config.DBName)
|
||||
conn, err := db.ConnectToDatabase(config.DBName, logger)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
staticFS, err := getStaticFiles()
|
||||
staticFS, err := getStaticFiles(logger)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "getStaticFiles")
|
||||
}
|
||||
@@ -136,13 +139,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
}
|
||||
|
||||
// Setups a channel to listen for os.Signal
|
||||
handleMaintSignals(conn, httpServer)
|
||||
handleMaintSignals(conn, httpServer, logger)
|
||||
|
||||
// Runs the http server
|
||||
go func() {
|
||||
fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr)
|
||||
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
|
||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
fmt.Fprintf(os.Stderr, "Error listening and serving: %s\n", err)
|
||||
logger.Error().Err(err).Msg("Error listening and serving")
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -156,11 +159,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
|
||||
defer cancel()
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error shutting down http server: %s\n", err)
|
||||
logger.Error().Err(err).Msg("Error shutting down server")
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
fmt.Fprintln(w, "Shutting down")
|
||||
logger.Info().Msg("Shutting down")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -115,8 +115,8 @@ func Authentication(
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
// Failed to start transaction, warn the user they cant login right now
|
||||
logger.Warn().Err(err).Msg("Request failed to start a transaction")
|
||||
// Failed to start transaction, send 503 code to client
|
||||
logger.Warn().Err(err).Msg("Skipping Auth - unable to start a transaction")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"projectreshoot/contexts"
|
||||
@@ -16,15 +17,15 @@ import (
|
||||
)
|
||||
|
||||
func TestAuthenticationMiddleware(t *testing.T) {
|
||||
logger := tests.NilLogger()
|
||||
// Basic setup
|
||||
conn, err := tests.SetupTestDB()
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn)
|
||||
sconn := db.MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.DebugLogger(t)
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -38,9 +39,10 @@ func TestAuthenticationMiddleware(t *testing.T) {
|
||||
w.Write([]byte(strconv.Itoa(user.ID)))
|
||||
}
|
||||
})
|
||||
|
||||
var maint uint32
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
// Add the middleware and create the server
|
||||
authHandler := Authentication(logger, cfg, sconn, testHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, testHandler, &maint)
|
||||
require.NoError(t, err)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"projectreshoot/db"
|
||||
@@ -13,24 +14,26 @@ import (
|
||||
)
|
||||
|
||||
func TestPageLoginRequired(t *testing.T) {
|
||||
logger := tests.NilLogger()
|
||||
// Basic setup
|
||||
conn, err := tests.SetupTestDB()
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn)
|
||||
sconn := db.MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.DebugLogger(t)
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
var maint uint32
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
// Add the middleware and create the server
|
||||
loginRequiredHandler := RequiresLogin(testHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"projectreshoot/db"
|
||||
@@ -13,25 +14,27 @@ import (
|
||||
)
|
||||
|
||||
func TestReauthRequired(t *testing.T) {
|
||||
logger := tests.NilLogger()
|
||||
// Basic setup
|
||||
conn, err := tests.SetupTestDB()
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn)
|
||||
sconn := db.MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.DebugLogger(t)
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
var maint uint32
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
// Add the middleware and create the server
|
||||
reauthRequiredHandler := RequiresFresh(testHandler)
|
||||
loginRequiredHandler := RequiresLogin(reauthRequiredHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user