Refactor database interface to use *sql.DB directly
Simplified the database layer by removing custom interface wrappers and using standard library *sql.DB and *sql.Tx types directly. Changes: - Removed DBConnection and DBTransaction interfaces from database.go - Removed NewDBConnection() wrapper function - Updated TokenGenerator to use *sql.DB instead of DBConnection - Updated all validation and revocation methods to accept *sql.Tx - Updated TableManager to work with *sql.DB directly - Updated all tests to use db.Begin() instead of custom wrappers - Fixed GeneratorConfig.DB field (was DBConn) - Updated documentation in doc.go with correct API usage Benefits: - Simpler API with fewer abstractions - Works directly with database/sql standard library - Compatible with GORM (via gormDB.DB()) and Bun (share same *sql.DB) - Easier to understand and maintain - No unnecessary wrapper layers Breaking changes: - GeneratorConfig.DBConn renamed to GeneratorConfig.DB - Removed NewDBConnection() function - pass *sql.DB directly - ValidateAccess/ValidateRefresh now accept *sql.Tx instead of DBTransaction - Token.Revoke/CheckNotRevoked now accept *sql.Tx instead of DBTransaction 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
153
hws/server.go
153
hws/server.go
@@ -3,48 +3,98 @@ package hws
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/validation"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
GZIP bool
|
||||
server *http.Server
|
||||
logger *logger
|
||||
routes bool
|
||||
middleware bool
|
||||
gzip bool
|
||||
errorPage ErrorPage
|
||||
errorPage ErrorPageFunc
|
||||
ready chan struct{}
|
||||
}
|
||||
|
||||
// NewServer returns a new hws.Server with the specified parameters.
|
||||
// The timeout options are specified in seconds
|
||||
func NewServer(
|
||||
host string,
|
||||
port string,
|
||||
readHeaderTimeout time.Duration,
|
||||
writeTimeout time.Duration,
|
||||
idleTimeout time.Duration,
|
||||
gzip bool,
|
||||
) (*Server, error) {
|
||||
// TODO: test that host and port are valid values
|
||||
httpServer := &http.Server{
|
||||
Addr: net.JoinHostPort(host, port),
|
||||
ReadHeaderTimeout: readHeaderTimeout * time.Second,
|
||||
WriteTimeout: writeTimeout * time.Second,
|
||||
IdleTimeout: idleTimeout * time.Second,
|
||||
// Ready returns a channel that is closed when the server is started
|
||||
func (server *Server) Ready() <-chan struct{} {
|
||||
return server.ready
|
||||
}
|
||||
|
||||
// IsReady checks if the server is running
|
||||
func (server *Server) IsReady() bool {
|
||||
select {
|
||||
case <-server.ready:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Addr returns the server's network address
|
||||
func (server *Server) Addr() string {
|
||||
return server.server.Addr
|
||||
}
|
||||
|
||||
// Handler returns the server's HTTP handler for testing purposes
|
||||
func (server *Server) Handler() http.Handler {
|
||||
return server.server.Handler
|
||||
}
|
||||
|
||||
// NewServer returns a new hws.Server with the specified configuration.
|
||||
func NewServer(config *Config) (*Server, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("Config cannot be nil")
|
||||
}
|
||||
|
||||
// Apply defaults for undefined fields
|
||||
if config.Host == "" {
|
||||
config.Host = "127.0.0.1"
|
||||
}
|
||||
if config.Port == 0 {
|
||||
config.Port = 3000
|
||||
}
|
||||
if config.ReadHeaderTimeout == 0 {
|
||||
config.ReadHeaderTimeout = 2 * time.Second
|
||||
}
|
||||
if config.WriteTimeout == 0 {
|
||||
config.WriteTimeout = 10 * time.Second
|
||||
}
|
||||
if config.IdleTimeout == 0 {
|
||||
config.IdleTimeout = 120 * time.Second
|
||||
}
|
||||
|
||||
valid := isValidHostname(config.Host)
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%v", config.Host, config.Port),
|
||||
ReadHeaderTimeout: config.ReadHeaderTimeout,
|
||||
WriteTimeout: config.WriteTimeout,
|
||||
IdleTimeout: config.IdleTimeout,
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
server: httpServer,
|
||||
routes: false,
|
||||
gzip: gzip,
|
||||
GZIP: config.GZIP,
|
||||
ready: make(chan struct{}),
|
||||
}
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (server *Server) Start() error {
|
||||
func (server *Server) Start(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("Context cannot be nil")
|
||||
}
|
||||
if !server.routes {
|
||||
return errors.New("Server.AddRoutes must be run before starting the server")
|
||||
}
|
||||
@@ -65,20 +115,67 @@ func (server *Server) Start() error {
|
||||
if server.logger == nil {
|
||||
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
||||
} else {
|
||||
server.logger.logger.Error().Err(err).Msg("Server encountered a fatal error")
|
||||
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
server.waitUntilReady(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) Shutdown(ctx context.Context) {
|
||||
if err := server.server.Shutdown(ctx); err != nil {
|
||||
if server.logger == nil {
|
||||
fmt.Printf("Failed to gracefully shutdown the server: %s", err.Error())
|
||||
} else {
|
||||
server.logger.logger.Error().Err(err).Msg("Failed to gracefully shutdown the server")
|
||||
func (server *Server) Shutdown(ctx context.Context) error {
|
||||
if !server.IsReady() {
|
||||
return errors.New("Server isn't running")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("Context cannot be nil")
|
||||
}
|
||||
err := server.server.Shutdown(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to shutdown the server gracefully")
|
||||
}
|
||||
server.ready = make(chan struct{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func isValidHostname(host string) bool {
|
||||
// Validate as IP or hostname
|
||||
if errs := validation.IsDNS1123Subdomain(host); len(errs) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check IPv4 / IPv6
|
||||
if errs := validation.IsValidIP(nil, host); len(errs) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (server *Server) waitUntilReady(ctx context.Context) error {
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
closeOnce := sync.Once{}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
||||
case <-ticker.C:
|
||||
resp, err := http.Get("http://" + server.server.Addr + "/healthz")
|
||||
if err != nil {
|
||||
continue // not accepting yet
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
closeOnce.Do(func() { close(server.ready) })
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user