Made auth middleware turn skip timeout if maintenance mode is on
This commit is contained in:
@@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A helper function to create a transaction with a cancellable context.
|
|
||||||
func WithTransaction(
|
func WithTransaction(
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
@@ -24,8 +23,7 @@ func WithTransaction(
|
|||||||
r *http.Request,
|
r *http.Request,
|
||||||
),
|
),
|
||||||
) {
|
) {
|
||||||
// Create a cancellable context from the request context
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
@@ -41,6 +39,5 @@ func WithTransaction(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass the context and transaction to the handler
|
|
||||||
handler(ctx, tx, w, r)
|
handler(ctx, tx, w, r)
|
||||||
}
|
}
|
||||||
|
|||||||
2
main.go
2
main.go
@@ -120,7 +120,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|||||||
return errors.Wrap(err, "getStaticFiles")
|
return errors.Wrap(err, "getStaticFiles")
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := server.NewServer(config, logger, conn, &staticFS)
|
srv := server.NewServer(config, logger, conn, &staticFS, &maint)
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
Addr: net.JoinHostPort(config.Host, config.Port),
|
||||||
Handler: srv,
|
Handler: srv,
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/config"
|
||||||
"projectreshoot/contexts"
|
"projectreshoot/contexts"
|
||||||
"projectreshoot/cookies"
|
"projectreshoot/cookies"
|
||||||
"projectreshoot/db"
|
"projectreshoot/db"
|
||||||
"projectreshoot/handlers"
|
|
||||||
"projectreshoot/jwt"
|
"projectreshoot/jwt"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -98,6 +98,7 @@ func Authentication(
|
|||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *db.SafeConn,
|
conn *db.SafeConn,
|
||||||
next http.Handler,
|
next http.Handler,
|
||||||
|
maint *uint32,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/static/css/output.css" ||
|
if r.URL.Path == "/static/css/output.css" ||
|
||||||
@@ -105,26 +106,37 @@ func Authentication(
|
|||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
handlers.WithTransaction(w, r, logger, conn,
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
defer cancel()
|
||||||
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
|
if atomic.LoadUint32(maint) == 1 {
|
||||||
if err != nil {
|
cancel()
|
||||||
tx.Rollback()
|
}
|
||||||
// User auth failed, delete the cookies to avoid repeat requests
|
|
||||||
cookies.DeleteCookie(w, "access", "/")
|
// Start the transaction
|
||||||
cookies.DeleteCookie(w, "refresh", "/")
|
tx, err := conn.Begin(ctx)
|
||||||
logger.Debug().
|
if err != nil {
|
||||||
Str("remote_addr", r.RemoteAddr).
|
// Failed to start transaction, warn the user they cant login right now
|
||||||
Err(err).
|
logger.Warn().Err(err).Msg("Request failed to start a transaction")
|
||||||
Msg("Failed to authenticate user")
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
|
||||||
uctx := contexts.SetUser(r.Context(), user)
|
if err != nil {
|
||||||
newReq := r.WithContext(uctx)
|
tx.Rollback()
|
||||||
next.ServeHTTP(w, newReq)
|
// User auth failed, delete the cookies to avoid repeat requests
|
||||||
},
|
cookies.DeleteCookie(w, "access", "/")
|
||||||
)
|
cookies.DeleteCookie(w, "refresh", "/")
|
||||||
|
logger.Debug().
|
||||||
|
Str("remote_addr", r.RemoteAddr).
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to authenticate user")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
uctx := contexts.SetUser(r.Context(), user)
|
||||||
|
newReq := r.WithContext(uctx)
|
||||||
|
next.ServeHTTP(w, newReq)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ func NewServer(
|
|||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
conn *db.SafeConn,
|
conn *db.SafeConn,
|
||||||
staticFS *http.FileSystem,
|
staticFS *http.FileSystem,
|
||||||
|
maint *uint32,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
addRoutes(
|
addRoutes(
|
||||||
@@ -29,7 +30,7 @@ func NewServer(
|
|||||||
// Add middleware here, must be added in reverse order of execution
|
// Add middleware here, must be added in reverse order of execution
|
||||||
// i.e. First in list will get executed last during the request handling
|
// i.e. First in list will get executed last during the request handling
|
||||||
handler = middleware.Logging(logger, handler)
|
handler = middleware.Logging(logger, handler)
|
||||||
handler = middleware.Authentication(logger, config, conn, handler)
|
handler = middleware.Authentication(logger, config, conn, handler, maint)
|
||||||
|
|
||||||
// Gzip
|
// Gzip
|
||||||
handler = middleware.Gzip(handler, config.GZIP)
|
handler = middleware.Gzip(handler, config.GZIP)
|
||||||
|
|||||||
Reference in New Issue
Block a user