From 1253c6499d0186034f36c1119bac0644b01cb54f Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Wed, 12 Feb 2025 12:48:10 +1100 Subject: [PATCH] Added config options for http request timeouts --- config/config.go | 6 ++++++ config/environment.go | 33 +++++++++++++++++++++++++++++++++ contexts/keys.go | 1 + contexts/request_timer.go | 21 +++++++++++++++++++++ main.go | 6 +++--- middleware/logging.go | 7 ++++++- middleware/start.go | 18 ++++++++++++++++++ server/server.go | 3 +++ 8 files changed, 91 insertions(+), 4 deletions(-) create mode 100644 contexts/request_timer.go create mode 100644 middleware/start.go diff --git a/config/config.go b/config/config.go index 1209fb6..e7a0a8e 100644 --- a/config/config.go +++ b/config/config.go @@ -16,6 +16,9 @@ type Config struct { Port string // Port to listen on TrustedHost string // Domain/Hostname to accept as trusted SSL bool // Flag for SSL Mode + ReadHeaderTimeout int // Timeout for reading request headers in seconds + WriteTimeout int // Timeout for writing requests in seconds + IdleTimeout int // Timeout for idle connections in seconds TursoDBName string // DB Name for Turso DB/Branch TursoToken string // Bearer token for Turso DB/Branch SecretKey string // Secret key for signing tokens @@ -81,6 +84,9 @@ func GetConfig(args map[string]string) (*Config, error) { Port: port, TrustedHost: os.Getenv("TRUSTED_HOST"), SSL: GetEnvBool("SSL_MODE", false), + ReadHeaderTimeout: GetEnvInt("READ_HEADER_TIMEOUT", 2), + WriteTimeout: GetEnvInt("WRITE_TIMEOUT", 10), + IdleTimeout: GetEnvInt("IDLE_TIMEOUT", 120), TursoDBName: os.Getenv("TURSO_DB_NAME"), TursoToken: os.Getenv("TURSO_AUTH_TOKEN"), SecretKey: os.Getenv("SECRET_KEY"), diff --git a/config/environment.go b/config/environment.go index a00cdc4..2875b7b 100644 --- a/config/environment.go +++ b/config/environment.go @@ -4,6 +4,7 @@ import ( "os" "strconv" "strings" + "time" ) // Get an environment variable, specifying a default value if its not set @@ -15,6 +16,38 @@ func GetEnvDefault(key string, defaultValue string) string { return val } +// Get an environment variable as a time.Duration, specifying a default value if its +// not set or can't be parsed properly +func GetEnvDur(key string, defaultValue time.Duration) time.Duration { + val, exists := os.LookupEnv(key) + if !exists { + return defaultValue + } + + intVal, err := strconv.Atoi(val) + if err != nil { + return defaultValue + } + return time.Duration(intVal) + +} + +// Get an environment variable as an int, specifying a default value if its +// not set or can't be parsed properly into an int +func GetEnvInt(key string, defaultValue int) int { + val, exists := os.LookupEnv(key) + if !exists { + return defaultValue + } + + intVal, err := strconv.Atoi(val) + if err != nil { + return defaultValue + } + return intVal + +} + // Get an environment variable as an int64, specifying a default value if its // not set or can't be parsed properly into an int64 func GetEnvInt64(key string, defaultValue int64) int64 { diff --git a/contexts/keys.go b/contexts/keys.go index e5a08df..0875c3c 100644 --- a/contexts/keys.go +++ b/contexts/keys.go @@ -8,4 +8,5 @@ func (c contextKey) String() string { var ( contextKeyAuthorizedUser = contextKey("auth-user") + contextKeyRequestTime = contextKey("req-time") ) diff --git a/contexts/request_timer.go b/contexts/request_timer.go new file mode 100644 index 0000000..d5be372 --- /dev/null +++ b/contexts/request_timer.go @@ -0,0 +1,21 @@ +package contexts + +import ( + "context" + "errors" + "time" +) + +// Set the start time of the request +func SetStart(ctx context.Context, time time.Time) context.Context { + return context.WithValue(ctx, contextKeyRequestTime, time) +} + +// Get the start time of the request +func GetStartTime(ctx context.Context) (time.Time, error) { + start, ok := ctx.Value(contextKeyRequestTime).(time.Time) + if !ok { + return time.Time{}, errors.New("Failed to get start time of request") + } + return start, nil +} diff --git a/main.go b/main.go index 2d00ad7..3e7f1b7 100644 --- a/main.go +++ b/main.go @@ -66,9 +66,9 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, - ReadHeaderTimeout: 2 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, + ReadHeaderTimeout: config.ReadHeaderTimeout * time.Second, + WriteTimeout: config.WriteTimeout * time.Second, + IdleTimeout: config.IdleTimeout * time.Second, } // Runs function for testing in dev if --test flag true diff --git a/middleware/logging.go b/middleware/logging.go index 878acf8..abae797 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "projectreshoot/contexts" "time" "github.com/rs/zerolog" @@ -22,7 +23,11 @@ func (w *wrappedWriter) WriteHeader(statusCode int) { // Middleware to add logs to console with details of the request func Logging(logger *zerolog.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() + start, err := contexts.GetStartTime(r.Context()) + if err != nil { + // Handle failure here. internal server error maybe + return + } wrapped := &wrappedWriter{ ResponseWriter: w, statusCode: http.StatusOK, diff --git a/middleware/start.go b/middleware/start.go new file mode 100644 index 0000000..fb9a66c --- /dev/null +++ b/middleware/start.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "net/http" + "projectreshoot/contexts" + "time" +) + +func StartTimer(next http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + ctx := contexts.SetStart(r.Context(), start) + newReq := r.WithContext(ctx) + next.ServeHTTP(w, newReq) + }, + ) +} diff --git a/server/server.go b/server/server.go index 9979389..c43b225 100644 --- a/server/server.go +++ b/server/server.go @@ -32,5 +32,8 @@ func NewServer( // Serve the favicon and exluded files before any middleware is added handler = middleware.ExcludedFiles(handler) handler = middleware.Favicon(handler) + + // Start the timer for the request chain so logger can have accurate info + handler = middleware.StartTimer(handler) return handler }