diff --git a/hws/errors.go b/hws/errors.go new file mode 100644 index 0000000..0e21c7e --- /dev/null +++ b/hws/errors.go @@ -0,0 +1,39 @@ +package hws + +import "net/http" + +type HWSError struct { + statusCode int // HTTP Status code + message string // Error message + error error // Error +} + +type ErrorPage func(statusCode int, w http.ResponseWriter, r *http.Request) error + +func NewError(statusCode int, msg string, err error) *HWSError { + return &HWSError{ + statusCode: statusCode, + message: msg, + error: err, + } +} + +func (server *Server) AddErrorPage(page ErrorPage) { + server.errorPage = page +} + +func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error *HWSError) { + w.WriteHeader(error.statusCode) + server.logger.logger.Error().Err(error.error).Msg(error.message) + if server.errorPage != nil { + err := server.errorPage(error.statusCode, w, r) + if err != nil { + server.logger.logger.Error().Err(err).Msg("Failed to render error page") + } + } +} + +func (server *Server) ThrowWarn(w http.ResponseWriter, error *HWSError) { + w.WriteHeader(error.statusCode) + server.logger.logger.Warn().Err(error.error).Msg(error.message) +} diff --git a/hws/go.mod b/hws/go.mod new file mode 100644 index 0000000..def5f55 --- /dev/null +++ b/hws/go.mod @@ -0,0 +1,14 @@ +module git.haelnorr.com/h/golib/hws + +go 1.25.5 + +require ( + github.com/pkg/errors v0.9.1 + github.com/rs/zerolog v1.34.0 +) + +require ( + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + golang.org/x/sys v0.12.0 // indirect +) diff --git a/hws/go.sum b/hws/go.sum new file mode 100644 index 0000000..1f7edd4 --- /dev/null +++ b/hws/go.sum @@ -0,0 +1,16 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/hws/gzip.go b/hws/gzip.go new file mode 100644 index 0000000..07293de --- /dev/null +++ b/hws/gzip.go @@ -0,0 +1,31 @@ +package hws + +import ( + "compress/gzip" + "io" + "net/http" + "strings" +) + +func addgzip(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + next.ServeHTTP(w, r) + return + } + w.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(w) + defer gz.Close() + gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w} + next.ServeHTTP(gzw, r) + }) +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter +} + +func (w gzipResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +} diff --git a/hws/logger.go b/hws/logger.go new file mode 100644 index 0000000..5b90c3a --- /dev/null +++ b/hws/logger.go @@ -0,0 +1,44 @@ +package hws + +import ( + "errors" + "fmt" + "net/url" + + "github.com/rs/zerolog" +) + +type logger struct { + logger *zerolog.Logger + ignoredPaths []string +} + +// Server.AddLogger adds a logger to the server to use for request logging. +func (server *Server) AddLogger(zlogger *zerolog.Logger) error { + if zlogger == nil { + return errors.New("Unable to add logger, no logger provided") + } + server.logger = &logger{ + logger: zlogger, + } + return nil +} + +// Server.LoggerIgnorePaths sets a list of URL paths to ignore logging for. +// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL +// Useful for ignoring requests to CSS files or favicons +func (server *Server) LoggerIgnorePaths(paths ...string) error { + for _, path := range paths { + u, err := url.Parse(path) + valid := err == nil && + u.Scheme == "" && + u.Host == "" && + u.RawQuery == "" && + u.Fragment == "" + if !valid { + return fmt.Errorf("Invalid path: '%s'", path) + } + } + server.logger.ignoredPaths = paths + return nil +} diff --git a/hws/middleware.go b/hws/middleware.go new file mode 100644 index 0000000..f4d7c9b --- /dev/null +++ b/hws/middleware.go @@ -0,0 +1,51 @@ +package hws + +import ( + "errors" + "net/http" +) + +type Middleware func(h http.Handler) http.Handler +type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError) + +// Server.AddMiddleware registers all the middleware. +// Middleware will be run in the order that they are provided. +func (server *Server) AddMiddleware(middleware ...Middleware) error { + if !server.routes { + return errors.New("Server.AddRoutes must be called before Server.AddMiddleware") + } + + // RUN LOGGING MIDDLEWARE FIRST + server.server.Handler = logging(server.server.Handler, server.logger) + + // LOOP PROVIDED MIDDLEWARE IN REVERSE order + for i := len(middleware); i > 0; i-- { + server.server.Handler = middleware[i-1](server.server.Handler) + } + + // RUN GZIP + if server.gzip { + server.server.Handler = addgzip(server.server.Handler) + } + // RUN TIMER MIDDLEWARE LAST + server.server.Handler = startTimer(server.server.Handler) + + server.middleware = true + + return nil +} + +func (server *Server) NewMiddleware( + middlewareFunc MiddlewareFunc, +) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + newReq, herr := middlewareFunc(w, r) + if herr != nil { + server.ThrowError(w, r, herr) + return + } + next.ServeHTTP(w, newReq) + }) + } +} diff --git a/hws/middleware_logging.go b/hws/middleware_logging.go new file mode 100644 index 0000000..b506622 --- /dev/null +++ b/hws/middleware_logging.go @@ -0,0 +1,38 @@ +package hws + +import ( + "net/http" + "slices" + "time" +) + +// Middleware to add logs to console with details of the request +func logging(next http.Handler, logger *logger) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if logger == nil { + next.ServeHTTP(w, r) + return + } + if slices.Contains(logger.ignoredPaths, r.URL.Path) { + next.ServeHTTP(w, r) + return + } + start, err := getStartTime(r.Context()) + if err != nil { + logger.logger.Error().Err(err) + return + } + wrapped := &wrappedWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + } + next.ServeHTTP(wrapped, r) + logger.logger.Info(). + Int("status", wrapped.statusCode). + Str("method", r.Method). + Str("resource", r.URL.Path). + Dur("time_elapsed", time.Since(start)). + Str("remote_addr", r.Header.Get("X-Forwarded-For")). + Msg("Served") + }) +} diff --git a/hws/middleware_timer.go b/hws/middleware_timer.go new file mode 100644 index 0000000..df0e690 --- /dev/null +++ b/hws/middleware_timer.go @@ -0,0 +1,33 @@ +package hws + +import ( + "context" + "errors" + "net/http" + "time" +) + +func startTimer(next http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + ctx := setStart(r.Context(), start) + newReq := r.WithContext(ctx) + next.ServeHTTP(w, newReq) + }, + ) +} + +// Set the start time of the request +func setStart(ctx context.Context, time time.Time) context.Context { + return context.WithValue(ctx, "hws context key request-timer", time) +} + +// Get the start time of the request +func getStartTime(ctx context.Context) (time.Time, error) { + start, ok := ctx.Value("hws context key request-timer").(time.Time) + if !ok { + return time.Time{}, errors.New("Failed to get start time of request") + } + return start, nil +} diff --git a/hws/responsewriter.go b/hws/responsewriter.go new file mode 100644 index 0000000..6f72a8b --- /dev/null +++ b/hws/responsewriter.go @@ -0,0 +1,15 @@ +package hws + +import "net/http" + +// Wraps the http.ResponseWriter, adding a statusCode field +type wrappedWriter struct { + http.ResponseWriter + statusCode int +} + +// Extends WriteHeader to the ResponseWriter to add the status code +func (w *wrappedWriter) WriteHeader(statusCode int) { + w.ResponseWriter.WriteHeader(statusCode) + w.statusCode = statusCode +} diff --git a/hws/routes.go b/hws/routes.go new file mode 100644 index 0000000..6805e0f --- /dev/null +++ b/hws/routes.go @@ -0,0 +1,62 @@ +package hws + +import ( + "errors" + "fmt" + "net/http" +) + +type Route struct { + Path string // Absolute path to the requested resource + Method Method // HTTP Method + Handler http.Handler // Handler to use for the request +} + +type Method string + +const ( + MethodGET Method = "GET" + MethodPOST Method = "POST" + MethodPUT Method = "PUT" + MethodHEAD Method = "HEAD" + MethodDELETE Method = "DELETE" + MethodCONNECT Method = "CONNECT" + MethodOPTIONS Method = "OPTIONS" + MethodTRACE Method = "TRACE" + MethodPATCH Method = "PATCH" +) + +// Server.AddRoutes registers the page handlers for the server. +// At least one route must be provided. +func (server *Server) AddRoutes(routes ...Route) error { + if len(routes) == 0 { + return errors.New("No routes provided") + } + mux := http.NewServeMux() + mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {}) + for _, route := range routes { + if !validMethod(route.Method) { + return fmt.Errorf("Invalid method %s for path %s", route.Method, route.Path) + } + if route.Handler == nil { + return fmt.Errorf("No handler provided for %s %s", route.Method, route.Path) + } + pattern := fmt.Sprintf("%s %s", route.Method, route.Path) + mux.Handle(pattern, route.Handler) + } + + server.server.Handler = mux + server.routes = true + + return nil +} + +func validMethod(m Method) bool { + switch m { + case MethodGET, MethodPOST, MethodPUT, MethodHEAD, + MethodDELETE, MethodCONNECT, MethodOPTIONS, MethodTRACE, MethodPATCH: + return true + default: + return false + } +} diff --git a/hws/safefileserver.go b/hws/safefileserver.go new file mode 100644 index 0000000..7435971 --- /dev/null +++ b/hws/safefileserver.go @@ -0,0 +1,52 @@ +package hws + +import ( + "net/http" + "os" + + "github.com/pkg/errors" +) + +// Wrapper for default FileSystem +type justFilesFilesystem struct { + fs http.FileSystem +} + +// Wrapper for default File +type neuteredReaddirFile struct { + http.File +} + +// Modifies the behavior of FileSystem.Open to return the neutered version of File +func (fs justFilesFilesystem) Open(name string) (http.File, error) { + f, err := fs.fs.Open(name) + if err != nil { + return nil, err + } + + // Check if the requested path is a directory + // and explicitly return an error to trigger a 404 + fileInfo, err := f.Stat() + if err != nil { + return nil, err + } + if fileInfo.IsDir() { + return nil, os.ErrNotExist + } + + return neuteredReaddirFile{f}, nil +} + +// Overrides the Readdir method of File to always return nil +func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) { + return nil, nil +} + +func SafeFileServer(fileSystem *http.FileSystem) (http.Handler, error) { + if fileSystem == nil { + return nil, errors.New("No file system provided") + } + nfs := justFilesFilesystem{*fileSystem} + fs := http.FileServer(nfs) + return fs, nil +} diff --git a/hws/server.go b/hws/server.go new file mode 100644 index 0000000..ca08f43 --- /dev/null +++ b/hws/server.go @@ -0,0 +1,84 @@ +package hws + +import ( + "context" + "fmt" + "net" + "net/http" + "time" + + "github.com/pkg/errors" +) + +type Server struct { + server *http.Server + logger *logger + routes bool + middleware bool + gzip bool + errorPage ErrorPage +} + +// NewServer returns a new hws.Server with the specified parameters. +// The timeout options are specified in seconds +func NewServer( + host string, + port string, + readHeaderTimeout time.Duration, + writeTimeout time.Duration, + idleTimeout time.Duration, + gzip bool, +) (*Server, error) { + // TODO: test that host and port are valid values + httpServer := &http.Server{ + Addr: net.JoinHostPort(host, port), + ReadHeaderTimeout: readHeaderTimeout * time.Second, + WriteTimeout: writeTimeout * time.Second, + IdleTimeout: idleTimeout * time.Second, + } + server := &Server{ + server: httpServer, + routes: false, + gzip: gzip, + } + return server, nil +} + +func (server *Server) Start() error { + if !server.routes { + return errors.New("Server.AddRoutes must be run before starting the server") + } + if !server.middleware { + err := server.AddMiddleware() + if err != nil { + return errors.Wrap(err, "server.AddMiddleware") + } + } + + go func() { + if server.logger == nil { + fmt.Printf("Listening for requests on %s", server.server.Addr) + } else { + server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests") + } + if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + if server.logger == nil { + fmt.Printf("Server encountered a fatal error: %s", err.Error()) + } else { + server.logger.logger.Error().Err(err).Msg("Server encountered a fatal error") + } + } + }() + + return nil +} + +func (server *Server) Shutdown(ctx context.Context) { + if err := server.server.Shutdown(ctx); err != nil { + if server.logger == nil { + fmt.Printf("Failed to gracefully shutdown the server: %s", err.Error()) + } else { + server.logger.logger.Error().Err(err).Msg("Failed to gracefully shutdown the server") + } + } +}