diff --git a/.gitignore b/.gitignore index 37839ec..5b22f92 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ query.sql *.db .logs/ +server.log tmp/ projectreshoot static/css/output.css diff --git a/db/connection.go b/db/connection.go index 199417f..2a88505 100644 --- a/db/connection.go +++ b/db/connection.go @@ -4,21 +4,23 @@ import ( "context" "database/sql" "fmt" - "sync/atomic" + "os" "github.com/pkg/errors" + "github.com/rs/zerolog" _ "modernc.org/sqlite" ) type SafeConn struct { db *sql.DB - readLockCount int32 - globalLockStatus int32 + readLockCount uint32 + globalLockStatus uint32 + logger *zerolog.Logger } -func MakeSafe(db *sql.DB) *SafeConn { - return &SafeConn{db: db} +func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn { + return &SafeConn{db: db, logger: logger} } // Extends sql.Tx for use with SafeConn @@ -28,27 +30,35 @@ type SafeTX struct { } func (conn *SafeConn) acquireGlobalLock() bool { - if atomic.LoadInt32(&conn.readLockCount) > 0 || atomic.LoadInt32(&conn.globalLockStatus) == 1 { + if conn.readLockCount > 0 || conn.globalLockStatus == 1 { return false } - atomic.StoreInt32(&conn.globalLockStatus, 1) + conn.globalLockStatus = 1 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock acquired") return true } func (conn *SafeConn) releaseGlobalLock() { - atomic.StoreInt32(&conn.globalLockStatus, 0) + conn.globalLockStatus = 0 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock released") } func (conn *SafeConn) acquireReadLock() bool { - if atomic.LoadInt32(&conn.globalLockStatus) == 1 { + if conn.globalLockStatus == 1 { return false } - atomic.AddInt32(&conn.readLockCount, 1) + conn.readLockCount += 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock acquired") return true } func (conn *SafeConn) releaseReadLock() { - atomic.AddInt32(&conn.readLockCount, -1) + conn.readLockCount -= 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock released") } // Starts a new transaction based on the current context. Will cancel if @@ -134,32 +144,40 @@ func (stx *SafeTX) Rollback() error { func (conn *SafeConn) Pause() { for !conn.acquireGlobalLock() { // TODO: add a timeout? + // TODO: failed to acquire lock: print info with readLockCount + // every second, or update it dynamically } - fmt.Println("Global database lock acquired") + // force logger to log to Stdout + log := conn.logger.With().Logger().Output(os.Stdout) + log.Info().Msg("Global database lock acquired") } // Resume allows transactions to proceed. func (conn *SafeConn) Resume() { conn.releaseGlobalLock() - fmt.Println("Global database lock released") + // force logger to log to Stdout + log := conn.logger.With().Logger().Output(os.Stdout) + log.Info().Msg("Global database lock released") } // Close the database connection func (conn *SafeConn) Close() error { + conn.logger.Debug().Msg("Acquiring global lock for connection close") conn.acquireGlobalLock() defer conn.releaseGlobalLock() + conn.logger.Debug().Msg("Closing database connection") return conn.db.Close() } // Returns a database connection handle for the DB -func ConnectToDatabase(dbName string) (*SafeConn, error) { +func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) db, err := sql.Open("sqlite", file) if err != nil { return nil, errors.Wrap(err, "sql.Open") } - conn := MakeSafe(db) + conn := MakeSafe(db, logger) return conn, nil } diff --git a/go.mod b/go.mod index 98e11af..0852a78 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,9 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect + golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.61.13 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index baa5432..9206c36 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,11 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -42,8 +47,8 @@ golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtD golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -51,8 +56,9 @@ golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= diff --git a/main.go b/main.go index 64b8af2..baa7350 100644 --- a/main.go +++ b/main.go @@ -23,20 +23,21 @@ import ( "projectreshoot/server" "github.com/pkg/errors" + "github.com/rs/zerolog" ) //go:embed static/* var embeddedStatic embed.FS // Gets the static files -func getStaticFiles() (http.FileSystem, error) { +func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) { if _, err := os.Stat("static"); err == nil { // Use actual filesystem in development - fmt.Println("Using filesystem for static files") + logger.Debug().Msg("Using filesystem for static files") return http.Dir("static"), nil } else { // Use embedded filesystem in production - fmt.Println("Using embedded static files") + logger.Debug().Msg("Using embedded static files") subFS, err := fs.Sub(embeddedStatic, "static") if err != nil { return nil, errors.Wrap(err, "fs.Sub") @@ -47,7 +48,7 @@ func getStaticFiles() (http.FileSystem, error) { var maint uint32 // atomic: 1 if in maintenance mode -func handleMaintSignals(conn *db.SafeConn, srv *http.Server) { +func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Logger) { ch := make(chan os.Signal, 1) srv.RegisterOnShutdown(func() { close(ch) @@ -58,14 +59,16 @@ func handleMaintSignals(conn *db.SafeConn, srv *http.Server) { case syscall.SIGUSR1: if atomic.LoadUint32(&maint) != 1 { atomic.StoreUint32(&maint, 1) - fmt.Println("Signal received: Starting maintenance") - fmt.Println("Attempting to acquire database lock") + log := logger.With().Logger().Output(os.Stdout) + log.Info().Msg("Signal received: Starting maintenance") + log.Info().Msg("Attempting to acquire database lock") conn.Pause() } case syscall.SIGUSR2: if atomic.LoadUint32(&maint) != 0 { - fmt.Println("Signal received: Maintenance over") - fmt.Println("Releasing database lock") + log := logger.With().Logger().Output(os.Stdout) + log.Info().Msg("Signal received: Maintenance over") + log.Info().Msg("Releasing database lock") conn.Resume() atomic.StoreUint32(&maint, 0) } @@ -109,13 +112,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "logging.GetLogger") } - conn, err := db.ConnectToDatabase(config.DBName) + conn, err := db.ConnectToDatabase(config.DBName, logger) if err != nil { return errors.Wrap(err, "db.ConnectToDatabase") } defer conn.Close() - staticFS, err := getStaticFiles() + staticFS, err := getStaticFiles(logger) if err != nil { return errors.Wrap(err, "getStaticFiles") } @@ -136,13 +139,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } // Setups a channel to listen for os.Signal - handleMaintSignals(conn, httpServer) + handleMaintSignals(conn, httpServer, logger) // Runs the http server go func() { - fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr) + logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests") if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - fmt.Fprintf(os.Stderr, "Error listening and serving: %s\n", err) + logger.Error().Err(err).Msg("Error listening and serving") } }() @@ -156,11 +159,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second) defer cancel() if err := httpServer.Shutdown(shutdownCtx); err != nil { - fmt.Fprintf(os.Stderr, "Error shutting down http server: %s\n", err) + logger.Error().Err(err).Msg("Error shutting down server") } }() wg.Wait() - fmt.Fprintln(w, "Shutting down") + logger.Info().Msg("Shutting down") return nil } diff --git a/middleware/authentication.go b/middleware/authentication.go index 354a936..92837e5 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -115,8 +115,8 @@ func Authentication( // Start the transaction tx, err := conn.Begin(ctx) if err != nil { - // Failed to start transaction, warn the user they cant login right now - logger.Warn().Err(err).Msg("Request failed to start a transaction") + // Failed to start transaction, send 503 code to client + logger.Warn().Err(err).Msg("Skipping Auth - unable to start a transaction") w.WriteHeader(http.StatusServiceUnavailable) next.ServeHTTP(w, r) return diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index 95583af..6ce807c 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "strconv" + "sync/atomic" "testing" "projectreshoot/contexts" @@ -16,15 +17,15 @@ import ( ) func TestAuthenticationMiddleware(t *testing.T) { + logger := tests.NilLogger() // Basic setup conn, err := tests.SetupTestDB() require.NoError(t, err) - sconn := db.MakeSafe(conn) + sconn := db.MakeSafe(conn, logger) defer sconn.Close() cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -38,9 +39,10 @@ func TestAuthenticationMiddleware(t *testing.T) { w.Write([]byte(strconv.Itoa(user.ID))) } }) - + var maint uint32 + atomic.StoreUint32(&maint, 0) // Add the middleware and create the server - authHandler := Authentication(logger, cfg, sconn, testHandler) + authHandler := Authentication(logger, cfg, sconn, testHandler, &maint) require.NoError(t, err) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go index 80b0a15..c6efcba 100644 --- a/middleware/pageprotection_test.go +++ b/middleware/pageprotection_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "sync/atomic" "testing" "projectreshoot/db" @@ -13,24 +14,26 @@ import ( ) func TestPageLoginRequired(t *testing.T) { + logger := tests.NilLogger() // Basic setup conn, err := tests.SetupTestDB() require.NoError(t, err) - sconn := db.MakeSafe(conn) + sconn := db.MakeSafe(conn, logger) defer sconn.Close() cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + var maint uint32 + atomic.StoreUint32(&maint, 0) // Add the middleware and create the server loginRequiredHandler := RequiresLogin(testHandler) - authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go index 595e4e7..a1f2083 100644 --- a/middleware/reauthentication_test.go +++ b/middleware/reauthentication_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "sync/atomic" "testing" "projectreshoot/db" @@ -13,25 +14,27 @@ import ( ) func TestReauthRequired(t *testing.T) { + logger := tests.NilLogger() // Basic setup conn, err := tests.SetupTestDB() require.NoError(t, err) - sconn := db.MakeSafe(conn) + sconn := db.MakeSafe(conn, logger) defer sconn.Close() cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + var maint uint32 + atomic.StoreUint32(&maint, 0) // Add the middleware and create the server reauthRequiredHandler := RequiresFresh(testHandler) loginRequiredHandler := RequiresLogin(reauthRequiredHandler) - authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint) server := httptest.NewServer(authHandler) defer server.Close()