package hws import ( "errors" "net/http" ) type Middleware func(h http.Handler) http.Handler type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError) // Server.AddMiddleware registers all the middleware. // Middleware will be run in the order that they are provided. // Can only be called once func (server *Server) AddMiddleware(middleware ...Middleware) error { if !server.routes { return errors.New("Server.AddRoutes must be called before Server.AddMiddleware") } if server.middleware { return errors.New("Server.AddMiddleware already called") } // RUN LOGGING MIDDLEWARE FIRST server.server.Handler = logging(server.server.Handler, server.logger) // LOOP PROVIDED MIDDLEWARE IN REVERSE order for i := len(middleware); i > 0; i-- { server.server.Handler = middleware[i-1](server.server.Handler) } // RUN GZIP if server.GZIP { server.server.Handler = addgzip(server.server.Handler) } // RUN TIMER MIDDLEWARE LAST server.server.Handler = startTimer(server.server.Handler) server.middleware = true return nil } // NewMiddleware returns a new Middleware for the server. // A MiddlewareFunc is a function that takes in a http.ResponseWriter and http.Request, // and returns a new request and optional HWSError. // If a HWSError is returned, server.ThrowError will be called. // If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered func (server *Server) NewMiddleware( middlewareFunc MiddlewareFunc, ) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { newReq, herr := middlewareFunc(w, r) if herr != nil { server.ThrowError(w, r, *herr) if herr.RenderErrorPage { return } next.ServeHTTP(w, r) return } next.ServeHTTP(w, newReq) }) } }