Simplified the database layer by removing custom interface wrappers and using standard library *sql.DB and *sql.Tx types directly. Changes: - Removed DBConnection and DBTransaction interfaces from database.go - Removed NewDBConnection() wrapper function - Updated TokenGenerator to use *sql.DB instead of DBConnection - Updated all validation and revocation methods to accept *sql.Tx - Updated TableManager to work with *sql.DB directly - Updated all tests to use db.Begin() instead of custom wrappers - Fixed GeneratorConfig.DB field (was DBConn) - Updated documentation in doc.go with correct API usage Benefits: - Simpler API with fewer abstractions - Works directly with database/sql standard library - Compatible with GORM (via gormDB.DB()) and Bun (share same *sql.DB) - Easier to understand and maintain - No unnecessary wrapper layers Breaking changes: - GeneratorConfig.DBConn renamed to GeneratorConfig.DB - Removed NewDBConnection() function - pass *sql.DB directly - ValidateAccess/ValidateRefresh now accept *sql.Tx instead of DBTransaction - Token.Revoke/CheckNotRevoked now accept *sql.Tx instead of DBTransaction 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
182 lines
3.9 KiB
Go
182 lines
3.9 KiB
Go
package hws
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"k8s.io/apimachinery/pkg/util/validation"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type Server struct {
|
|
GZIP bool
|
|
server *http.Server
|
|
logger *logger
|
|
routes bool
|
|
middleware bool
|
|
errorPage ErrorPageFunc
|
|
ready chan struct{}
|
|
}
|
|
|
|
// Ready returns a channel that is closed when the server is started
|
|
func (server *Server) Ready() <-chan struct{} {
|
|
return server.ready
|
|
}
|
|
|
|
// IsReady checks if the server is running
|
|
func (server *Server) IsReady() bool {
|
|
select {
|
|
case <-server.ready:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Addr returns the server's network address
|
|
func (server *Server) Addr() string {
|
|
return server.server.Addr
|
|
}
|
|
|
|
// Handler returns the server's HTTP handler for testing purposes
|
|
func (server *Server) Handler() http.Handler {
|
|
return server.server.Handler
|
|
}
|
|
|
|
// NewServer returns a new hws.Server with the specified configuration.
|
|
func NewServer(config *Config) (*Server, error) {
|
|
if config == nil {
|
|
return nil, errors.New("Config cannot be nil")
|
|
}
|
|
|
|
// Apply defaults for undefined fields
|
|
if config.Host == "" {
|
|
config.Host = "127.0.0.1"
|
|
}
|
|
if config.Port == 0 {
|
|
config.Port = 3000
|
|
}
|
|
if config.ReadHeaderTimeout == 0 {
|
|
config.ReadHeaderTimeout = 2 * time.Second
|
|
}
|
|
if config.WriteTimeout == 0 {
|
|
config.WriteTimeout = 10 * time.Second
|
|
}
|
|
if config.IdleTimeout == 0 {
|
|
config.IdleTimeout = 120 * time.Second
|
|
}
|
|
|
|
valid := isValidHostname(config.Host)
|
|
if !valid {
|
|
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host)
|
|
}
|
|
|
|
httpServer := &http.Server{
|
|
Addr: fmt.Sprintf("%s:%v", config.Host, config.Port),
|
|
ReadHeaderTimeout: config.ReadHeaderTimeout,
|
|
WriteTimeout: config.WriteTimeout,
|
|
IdleTimeout: config.IdleTimeout,
|
|
}
|
|
|
|
server := &Server{
|
|
server: httpServer,
|
|
routes: false,
|
|
GZIP: config.GZIP,
|
|
ready: make(chan struct{}),
|
|
}
|
|
return server, nil
|
|
}
|
|
|
|
func (server *Server) Start(ctx context.Context) error {
|
|
if ctx == nil {
|
|
return errors.New("Context cannot be nil")
|
|
}
|
|
if !server.routes {
|
|
return errors.New("Server.AddRoutes must be run before starting the server")
|
|
}
|
|
if !server.middleware {
|
|
err := server.AddMiddleware()
|
|
if err != nil {
|
|
return errors.Wrap(err, "server.AddMiddleware")
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
if server.logger == nil {
|
|
fmt.Printf("Listening for requests on %s", server.server.Addr)
|
|
} else {
|
|
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
|
|
}
|
|
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
if server.logger == nil {
|
|
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
|
} else {
|
|
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
|
}
|
|
}
|
|
}()
|
|
|
|
server.waitUntilReady(ctx)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (server *Server) Shutdown(ctx context.Context) error {
|
|
if !server.IsReady() {
|
|
return errors.New("Server isn't running")
|
|
}
|
|
if ctx == nil {
|
|
return errors.New("Context cannot be nil")
|
|
}
|
|
err := server.server.Shutdown(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Failed to shutdown the server gracefully")
|
|
}
|
|
server.ready = make(chan struct{})
|
|
return nil
|
|
}
|
|
|
|
func isValidHostname(host string) bool {
|
|
// Validate as IP or hostname
|
|
if errs := validation.IsDNS1123Subdomain(host); len(errs) == 0 {
|
|
return true
|
|
}
|
|
|
|
// Check IPv4 / IPv6
|
|
if errs := validation.IsValidIP(nil, host); len(errs) == 0 {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (server *Server) waitUntilReady(ctx context.Context) error {
|
|
ticker := time.NewTicker(50 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
closeOnce := sync.Once{}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
|
|
case <-ticker.C:
|
|
resp, err := http.Get("http://" + server.server.Addr + "/healthz")
|
|
if err != nil {
|
|
continue // not accepting yet
|
|
}
|
|
resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
closeOnce.Do(func() { close(server.ready) })
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|