From 9ea58b096112754eba800bbf7fbf561030aaae74 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 23:31:09 +1100 Subject: [PATCH] Made auth middleware turn skip timeout if maintenance mode is on --- handlers/withtransaction.go | 5 +--- main.go | 2 +- middleware/authentication.go | 56 ++++++++++++++++++++++-------------- server/server.go | 3 +- 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/handlers/withtransaction.go b/handlers/withtransaction.go index db963fc..a097719 100644 --- a/handlers/withtransaction.go +++ b/handlers/withtransaction.go @@ -11,7 +11,6 @@ import ( "github.com/rs/zerolog" ) -// A helper function to create a transaction with a cancellable context. func WithTransaction( w http.ResponseWriter, r *http.Request, @@ -24,8 +23,7 @@ func WithTransaction( r *http.Request, ), ) { - // Create a cancellable context from the request context - ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) defer cancel() // Start the transaction @@ -41,6 +39,5 @@ func WithTransaction( return } - // Pass the context and transaction to the handler handler(ctx, tx, w, r) } diff --git a/main.go b/main.go index 03406c6..64b8af2 100644 --- a/main.go +++ b/main.go @@ -120,7 +120,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "getStaticFiles") } - srv := server.NewServer(config, logger, conn, &staticFS) + srv := server.NewServer(config, logger, conn, &staticFS, &maint) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, diff --git a/middleware/authentication.go b/middleware/authentication.go index e444da8..354a936 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -3,13 +3,13 @@ package middleware import ( "context" "net/http" + "sync/atomic" "time" "projectreshoot/config" "projectreshoot/contexts" "projectreshoot/cookies" "projectreshoot/db" - "projectreshoot/handlers" "projectreshoot/jwt" "github.com/pkg/errors" @@ -98,6 +98,7 @@ func Authentication( config *config.Config, conn *db.SafeConn, next http.Handler, + maint *uint32, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/static/css/output.css" || @@ -105,26 +106,37 @@ func Authentication( next.ServeHTTP(w, r) return } - handlers.WithTransaction(w, r, logger, conn, - func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - user, err := getAuthenticatedUser(config, ctx, tx, w, r) - if err != nil { - tx.Rollback() - // User auth failed, delete the cookies to avoid repeat requests - cookies.DeleteCookie(w, "access", "/") - cookies.DeleteCookie(w, "refresh", "/") - logger.Debug(). - Str("remote_addr", r.RemoteAddr). - Err(err). - Msg("Failed to authenticate user") - next.ServeHTTP(w, r) - return - } - tx.Commit() - uctx := contexts.SetUser(r.Context(), user) - newReq := r.WithContext(uctx) - next.ServeHTTP(w, newReq) - }, - ) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + if atomic.LoadUint32(maint) == 1 { + cancel() + } + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + // Failed to start transaction, warn the user they cant login right now + logger.Warn().Err(err).Msg("Request failed to start a transaction") + w.WriteHeader(http.StatusServiceUnavailable) + next.ServeHTTP(w, r) + return + } + user, err := getAuthenticatedUser(config, ctx, tx, w, r) + if err != nil { + tx.Rollback() + // User auth failed, delete the cookies to avoid repeat requests + cookies.DeleteCookie(w, "access", "/") + cookies.DeleteCookie(w, "refresh", "/") + logger.Debug(). + Str("remote_addr", r.RemoteAddr). + Err(err). + Msg("Failed to authenticate user") + next.ServeHTTP(w, r) + return + } + tx.Commit() + uctx := contexts.SetUser(r.Context(), user) + newReq := r.WithContext(uctx) + next.ServeHTTP(w, newReq) }) } diff --git a/server/server.go b/server/server.go index fa75b0a..8f189ad 100644 --- a/server/server.go +++ b/server/server.go @@ -16,6 +16,7 @@ func NewServer( logger *zerolog.Logger, conn *db.SafeConn, staticFS *http.FileSystem, + maint *uint32, ) http.Handler { mux := http.NewServeMux() addRoutes( @@ -29,7 +30,7 @@ func NewServer( // Add middleware here, must be added in reverse order of execution // i.e. First in list will get executed last during the request handling handler = middleware.Logging(logger, handler) - handler = middleware.Authentication(logger, config, conn, handler) + handler = middleware.Authentication(logger, config, conn, handler, maint) // Gzip handler = middleware.Gzip(handler, config.GZIP)