Added proper debug logging to safeconn methods

This commit is contained in:
2025-02-18 10:02:42 +11:00
parent b4b57c14cb
commit c2de8d254a
9 changed files with 83 additions and 45 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
query.sql query.sql
*.db *.db
.logs/ .logs/
server.log
tmp/ tmp/
projectreshoot projectreshoot
static/css/output.css static/css/output.css

View File

@@ -4,21 +4,23 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"sync/atomic" "os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
type SafeConn struct { type SafeConn struct {
db *sql.DB db *sql.DB
readLockCount int32 readLockCount uint32
globalLockStatus int32 globalLockStatus uint32
logger *zerolog.Logger
} }
func MakeSafe(db *sql.DB) *SafeConn { func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
return &SafeConn{db: db} return &SafeConn{db: db, logger: logger}
} }
// Extends sql.Tx for use with SafeConn // Extends sql.Tx for use with SafeConn
@@ -28,27 +30,35 @@ type SafeTX struct {
} }
func (conn *SafeConn) acquireGlobalLock() bool { func (conn *SafeConn) acquireGlobalLock() bool {
if atomic.LoadInt32(&conn.readLockCount) > 0 || atomic.LoadInt32(&conn.globalLockStatus) == 1 { if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
return false return false
} }
atomic.StoreInt32(&conn.globalLockStatus, 1) conn.globalLockStatus = 1
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock acquired")
return true return true
} }
func (conn *SafeConn) releaseGlobalLock() { func (conn *SafeConn) releaseGlobalLock() {
atomic.StoreInt32(&conn.globalLockStatus, 0) conn.globalLockStatus = 0
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock released")
} }
func (conn *SafeConn) acquireReadLock() bool { func (conn *SafeConn) acquireReadLock() bool {
if atomic.LoadInt32(&conn.globalLockStatus) == 1 { if conn.globalLockStatus == 1 {
return false return false
} }
atomic.AddInt32(&conn.readLockCount, 1) conn.readLockCount += 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock acquired")
return true return true
} }
func (conn *SafeConn) releaseReadLock() { func (conn *SafeConn) releaseReadLock() {
atomic.AddInt32(&conn.readLockCount, -1) conn.readLockCount -= 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock released")
} }
// Starts a new transaction based on the current context. Will cancel if // Starts a new transaction based on the current context. Will cancel if
@@ -134,32 +144,40 @@ func (stx *SafeTX) Rollback() error {
func (conn *SafeConn) Pause() { func (conn *SafeConn) Pause() {
for !conn.acquireGlobalLock() { for !conn.acquireGlobalLock() {
// TODO: add a timeout? // TODO: add a timeout?
// TODO: failed to acquire lock: print info with readLockCount
// every second, or update it dynamically
} }
fmt.Println("Global database lock acquired") // force logger to log to Stdout
log := conn.logger.With().Logger().Output(os.Stdout)
log.Info().Msg("Global database lock acquired")
} }
// Resume allows transactions to proceed. // Resume allows transactions to proceed.
func (conn *SafeConn) Resume() { func (conn *SafeConn) Resume() {
conn.releaseGlobalLock() conn.releaseGlobalLock()
fmt.Println("Global database lock released") // force logger to log to Stdout
log := conn.logger.With().Logger().Output(os.Stdout)
log.Info().Msg("Global database lock released")
} }
// Close the database connection // Close the database connection
func (conn *SafeConn) Close() error { func (conn *SafeConn) Close() error {
conn.logger.Debug().Msg("Acquiring global lock for connection close")
conn.acquireGlobalLock() conn.acquireGlobalLock()
defer conn.releaseGlobalLock() defer conn.releaseGlobalLock()
conn.logger.Debug().Msg("Closing database connection")
return conn.db.Close() return conn.db.Close()
} }
// Returns a database connection handle for the DB // Returns a database connection handle for the DB
func ConnectToDatabase(dbName string) (*SafeConn, error) { func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) {
file := fmt.Sprintf("file:%s.db", dbName) file := fmt.Sprintf("file:%s.db", dbName)
db, err := sql.Open("sqlite", file) db, err := sql.Open("sqlite", file)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sql.Open") return nil, errors.Wrap(err, "sql.Open")
} }
conn := MakeSafe(db) conn := MakeSafe(db, logger)
return conn, nil return conn, nil
} }

2
go.mod
View File

@@ -23,7 +23,9 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect golang.org/x/sys v0.30.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.61.13 // indirect modernc.org/libc v1.61.13 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect

12
go.sum
View File

@@ -16,6 +16,11 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@@ -42,8 +47,8 @@ golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtD
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -51,8 +56,9 @@ golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=

33
main.go
View File

@@ -23,20 +23,21 @@ import (
"projectreshoot/server" "projectreshoot/server"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
) )
//go:embed static/* //go:embed static/*
var embeddedStatic embed.FS var embeddedStatic embed.FS
// Gets the static files // Gets the static files
func getStaticFiles() (http.FileSystem, error) { func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) {
if _, err := os.Stat("static"); err == nil { if _, err := os.Stat("static"); err == nil {
// Use actual filesystem in development // Use actual filesystem in development
fmt.Println("Using filesystem for static files") logger.Debug().Msg("Using filesystem for static files")
return http.Dir("static"), nil return http.Dir("static"), nil
} else { } else {
// Use embedded filesystem in production // Use embedded filesystem in production
fmt.Println("Using embedded static files") logger.Debug().Msg("Using embedded static files")
subFS, err := fs.Sub(embeddedStatic, "static") subFS, err := fs.Sub(embeddedStatic, "static")
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fs.Sub") return nil, errors.Wrap(err, "fs.Sub")
@@ -47,7 +48,7 @@ func getStaticFiles() (http.FileSystem, error) {
var maint uint32 // atomic: 1 if in maintenance mode var maint uint32 // atomic: 1 if in maintenance mode
func handleMaintSignals(conn *db.SafeConn, srv *http.Server) { func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Logger) {
ch := make(chan os.Signal, 1) ch := make(chan os.Signal, 1)
srv.RegisterOnShutdown(func() { srv.RegisterOnShutdown(func() {
close(ch) close(ch)
@@ -58,14 +59,16 @@ func handleMaintSignals(conn *db.SafeConn, srv *http.Server) {
case syscall.SIGUSR1: case syscall.SIGUSR1:
if atomic.LoadUint32(&maint) != 1 { if atomic.LoadUint32(&maint) != 1 {
atomic.StoreUint32(&maint, 1) atomic.StoreUint32(&maint, 1)
fmt.Println("Signal received: Starting maintenance") log := logger.With().Logger().Output(os.Stdout)
fmt.Println("Attempting to acquire database lock") log.Info().Msg("Signal received: Starting maintenance")
log.Info().Msg("Attempting to acquire database lock")
conn.Pause() conn.Pause()
} }
case syscall.SIGUSR2: case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 { if atomic.LoadUint32(&maint) != 0 {
fmt.Println("Signal received: Maintenance over") log := logger.With().Logger().Output(os.Stdout)
fmt.Println("Releasing database lock") log.Info().Msg("Signal received: Maintenance over")
log.Info().Msg("Releasing database lock")
conn.Resume() conn.Resume()
atomic.StoreUint32(&maint, 0) atomic.StoreUint32(&maint, 0)
} }
@@ -109,13 +112,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return errors.Wrap(err, "logging.GetLogger") return errors.Wrap(err, "logging.GetLogger")
} }
conn, err := db.ConnectToDatabase(config.DBName) conn, err := db.ConnectToDatabase(config.DBName, logger)
if err != nil { if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase") return errors.Wrap(err, "db.ConnectToDatabase")
} }
defer conn.Close() defer conn.Close()
staticFS, err := getStaticFiles() staticFS, err := getStaticFiles(logger)
if err != nil { if err != nil {
return errors.Wrap(err, "getStaticFiles") return errors.Wrap(err, "getStaticFiles")
} }
@@ -136,13 +139,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
} }
// Setups a channel to listen for os.Signal // Setups a channel to listen for os.Signal
handleMaintSignals(conn, httpServer) handleMaintSignals(conn, httpServer, logger)
// Runs the http server // Runs the http server
go func() { go func() {
fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr) logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Fprintf(os.Stderr, "Error listening and serving: %s\n", err) logger.Error().Err(err).Msg("Error listening and serving")
} }
}() }()
@@ -156,11 +159,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second) shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel() defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil { if err := httpServer.Shutdown(shutdownCtx); err != nil {
fmt.Fprintf(os.Stderr, "Error shutting down http server: %s\n", err) logger.Error().Err(err).Msg("Error shutting down server")
} }
}() }()
wg.Wait() wg.Wait()
fmt.Fprintln(w, "Shutting down") logger.Info().Msg("Shutting down")
return nil return nil
} }

View File

@@ -115,8 +115,8 @@ func Authentication(
// Start the transaction // Start the transaction
tx, err := conn.Begin(ctx) tx, err := conn.Begin(ctx)
if err != nil { if err != nil {
// Failed to start transaction, warn the user they cant login right now // Failed to start transaction, send 503 code to client
logger.Warn().Err(err).Msg("Request failed to start a transaction") logger.Warn().Err(err).Msg("Skipping Auth - unable to start a transaction")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv" "strconv"
"sync/atomic"
"testing" "testing"
"projectreshoot/contexts" "projectreshoot/contexts"
@@ -16,15 +17,15 @@ import (
) )
func TestAuthenticationMiddleware(t *testing.T) { func TestAuthenticationMiddleware(t *testing.T) {
logger := tests.NilLogger()
// Basic setup // Basic setup
conn, err := tests.SetupTestDB() conn, err := tests.SetupTestDB()
require.NoError(t, err) require.NoError(t, err)
sconn := db.MakeSafe(conn) sconn := db.MakeSafe(conn, logger)
defer sconn.Close() defer sconn.Close()
cfg, err := tests.TestConfig() cfg, err := tests.TestConfig()
require.NoError(t, err) require.NoError(t, err)
logger := tests.DebugLogger(t)
// Handler to check outcome of Authentication middleware // Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -38,9 +39,10 @@ func TestAuthenticationMiddleware(t *testing.T) {
w.Write([]byte(strconv.Itoa(user.ID))) w.Write([]byte(strconv.Itoa(user.ID)))
} }
}) })
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server // Add the middleware and create the server
authHandler := Authentication(logger, cfg, sconn, testHandler) authHandler := Authentication(logger, cfg, sconn, testHandler, &maint)
require.NoError(t, err) require.NoError(t, err)
server := httptest.NewServer(authHandler) server := httptest.NewServer(authHandler)
defer server.Close() defer server.Close()

View File

@@ -3,6 +3,7 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync/atomic"
"testing" "testing"
"projectreshoot/db" "projectreshoot/db"
@@ -13,24 +14,26 @@ import (
) )
func TestPageLoginRequired(t *testing.T) { func TestPageLoginRequired(t *testing.T) {
logger := tests.NilLogger()
// Basic setup // Basic setup
conn, err := tests.SetupTestDB() conn, err := tests.SetupTestDB()
require.NoError(t, err) require.NoError(t, err)
sconn := db.MakeSafe(conn) sconn := db.MakeSafe(conn, logger)
defer sconn.Close() defer sconn.Close()
cfg, err := tests.TestConfig() cfg, err := tests.TestConfig()
require.NoError(t, err) require.NoError(t, err)
logger := tests.DebugLogger(t)
// Handler to check outcome of Authentication middleware // Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server // Add the middleware and create the server
loginRequiredHandler := RequiresLogin(testHandler) loginRequiredHandler := RequiresLogin(testHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler) server := httptest.NewServer(authHandler)
defer server.Close() defer server.Close()

View File

@@ -3,6 +3,7 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync/atomic"
"testing" "testing"
"projectreshoot/db" "projectreshoot/db"
@@ -13,25 +14,27 @@ import (
) )
func TestReauthRequired(t *testing.T) { func TestReauthRequired(t *testing.T) {
logger := tests.NilLogger()
// Basic setup // Basic setup
conn, err := tests.SetupTestDB() conn, err := tests.SetupTestDB()
require.NoError(t, err) require.NoError(t, err)
sconn := db.MakeSafe(conn) sconn := db.MakeSafe(conn, logger)
defer sconn.Close() defer sconn.Close()
cfg, err := tests.TestConfig() cfg, err := tests.TestConfig()
require.NoError(t, err) require.NoError(t, err)
logger := tests.DebugLogger(t)
// Handler to check outcome of Authentication middleware // Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server // Add the middleware and create the server
reauthRequiredHandler := RequiresFresh(testHandler) reauthRequiredHandler := RequiresFresh(testHandler)
loginRequiredHandler := RequiresLogin(reauthRequiredHandler) loginRequiredHandler := RequiresLogin(reauthRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler) server := httptest.NewServer(authHandler)
defer server.Close() defer server.Close()