package hws import ( "context" "fmt" "net/http" "sync" "time" "k8s.io/apimachinery/pkg/util/validation" "github.com/pkg/errors" ) type Server struct { GZIP bool server *http.Server logger *logger routes bool middleware bool errorPage ErrorPageFunc ready chan struct{} } // Ready returns a channel that is closed when the server is started func (server *Server) Ready() <-chan struct{} { return server.ready } // IsReady checks if the server is running func (server *Server) IsReady() bool { select { case <-server.ready: return true default: return false } } // Addr returns the server's network address func (server *Server) Addr() string { return server.server.Addr } // Handler returns the server's HTTP handler for testing purposes func (server *Server) Handler() http.Handler { return server.server.Handler } // NewServer returns a new hws.Server with the specified configuration. func NewServer(config *Config) (*Server, error) { if config == nil { return nil, errors.New("Config cannot be nil") } // Apply defaults for undefined fields if config.Host == "" { config.Host = "127.0.0.1" } if config.Port == 0 { config.Port = 3000 } if config.ReadHeaderTimeout == 0 { config.ReadHeaderTimeout = 2 * time.Second } if config.WriteTimeout == 0 { config.WriteTimeout = 10 * time.Second } if config.IdleTimeout == 0 { config.IdleTimeout = 120 * time.Second } valid := isValidHostname(config.Host) if !valid { return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host) } httpServer := &http.Server{ Addr: fmt.Sprintf("%s:%v", config.Host, config.Port), ReadHeaderTimeout: config.ReadHeaderTimeout, WriteTimeout: config.WriteTimeout, IdleTimeout: config.IdleTimeout, } server := &Server{ server: httpServer, routes: false, GZIP: config.GZIP, ready: make(chan struct{}), } return server, nil } func (server *Server) Start(ctx context.Context) error { if ctx == nil { return errors.New("Context cannot be nil") } if !server.routes { return errors.New("Server.AddRoutes must be run before starting the server") } if !server.middleware { err := server.AddMiddleware() if err != nil { return errors.Wrap(err, "server.AddMiddleware") } } go func() { if server.logger == nil { fmt.Printf("Listening for requests on %s", server.server.Addr) } else { server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests") } if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if server.logger == nil { fmt.Printf("Server encountered a fatal error: %s", err.Error()) } else { server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"}) } } }() server.waitUntilReady(ctx) return nil } func (server *Server) Shutdown(ctx context.Context) error { if !server.IsReady() { return errors.New("Server isn't running") } if ctx == nil { return errors.New("Context cannot be nil") } err := server.server.Shutdown(ctx) if err != nil { return errors.Wrap(err, "Failed to shutdown the server gracefully") } server.ready = make(chan struct{}) return nil } func isValidHostname(host string) bool { // Validate as IP or hostname if errs := validation.IsDNS1123Subdomain(host); len(errs) == 0 { return true } // Check IPv4 / IPv6 if errs := validation.IsValidIP(nil, host); len(errs) == 0 { return true } return false } func (server *Server) waitUntilReady(ctx context.Context) error { ticker := time.NewTicker(50 * time.Millisecond) defer ticker.Stop() closeOnce := sync.Once{} for { select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: resp, err := http.Get("http://" + server.server.Addr + "/healthz") if err != nil { continue // not accepting yet } resp.Body.Close() if resp.StatusCode == http.StatusOK { closeOnce.Do(func() { close(server.ready) }) return nil } } } }