package hws import ( "errors" "net/http" ) type ( Middleware func(h http.Handler) http.Handler MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError) ) // AddMiddleware registers all the middleware. // Middleware will be run in the order that they are provided. // Can only be called once func (s *Server) AddMiddleware(middleware ...Middleware) error { if !s.routes { return errors.New("Server.AddRoutes must be called before Server.AddMiddleware") } if s.middleware { return errors.New("Server.AddMiddleware already called") } // RUN LOGGING MIDDLEWARE FIRST s.server.Handler = logging(s.server.Handler, s.logger) // LOOP PROVIDED MIDDLEWARE IN REVERSE order for i := len(middleware); i > 0; i-- { s.server.Handler = middleware[i-1](s.server.Handler) } // RUN GZIP if s.GZIP { s.server.Handler = addgzip(s.server.Handler) } // RUN TIMER MIDDLEWARE LAST s.server.Handler = startTimer(s.server.Handler) s.middleware = true return nil } // NewMiddleware returns a new Middleware for the server. // A MiddlewareFunc is a function that takes in a http.ResponseWriter and http.Request, // and returns a new request and optional HWSError. // If a HWSError is returned, server.ThrowError will be called. // If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered func (s *Server) NewMiddleware( middlewareFunc MiddlewareFunc, ) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { newReq, herr := middlewareFunc(w, r) if herr != nil { s.ThrowError(w, r, *herr) if herr.RenderErrorPage { return } next.ServeHTTP(w, r) return } next.ServeHTTP(w, newReq) }) } }