Files
golib/hws/server.go
Haelnorr 1b25e2f0a5 Refactor database interface to use *sql.DB directly
Simplified the database layer by removing custom interface wrappers
and using standard library *sql.DB and *sql.Tx types directly.

Changes:
- Removed DBConnection and DBTransaction interfaces from database.go
- Removed NewDBConnection() wrapper function
- Updated TokenGenerator to use *sql.DB instead of DBConnection
- Updated all validation and revocation methods to accept *sql.Tx
- Updated TableManager to work with *sql.DB directly
- Updated all tests to use db.Begin() instead of custom wrappers
- Fixed GeneratorConfig.DB field (was DBConn)
- Updated documentation in doc.go with correct API usage

Benefits:
- Simpler API with fewer abstractions
- Works directly with database/sql standard library
- Compatible with GORM (via gormDB.DB()) and Bun (share same *sql.DB)
- Easier to understand and maintain
- No unnecessary wrapper layers

Breaking changes:
- GeneratorConfig.DBConn renamed to GeneratorConfig.DB
- Removed NewDBConnection() function - pass *sql.DB directly
- ValidateAccess/ValidateRefresh now accept *sql.Tx instead of DBTransaction
- Token.Revoke/CheckNotRevoked now accept *sql.Tx instead of DBTransaction

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-11 17:39:30 +11:00

182 lines
3.9 KiB
Go

package hws
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"k8s.io/apimachinery/pkg/util/validation"
"github.com/pkg/errors"
)
type Server struct {
GZIP bool
server *http.Server
logger *logger
routes bool
middleware bool
errorPage ErrorPageFunc
ready chan struct{}
}
// Ready returns a channel that is closed when the server is started
func (server *Server) Ready() <-chan struct{} {
return server.ready
}
// IsReady checks if the server is running
func (server *Server) IsReady() bool {
select {
case <-server.ready:
return true
default:
return false
}
}
// Addr returns the server's network address
func (server *Server) Addr() string {
return server.server.Addr
}
// Handler returns the server's HTTP handler for testing purposes
func (server *Server) Handler() http.Handler {
return server.server.Handler
}
// NewServer returns a new hws.Server with the specified configuration.
func NewServer(config *Config) (*Server, error) {
if config == nil {
return nil, errors.New("Config cannot be nil")
}
// Apply defaults for undefined fields
if config.Host == "" {
config.Host = "127.0.0.1"
}
if config.Port == 0 {
config.Port = 3000
}
if config.ReadHeaderTimeout == 0 {
config.ReadHeaderTimeout = 2 * time.Second
}
if config.WriteTimeout == 0 {
config.WriteTimeout = 10 * time.Second
}
if config.IdleTimeout == 0 {
config.IdleTimeout = 120 * time.Second
}
valid := isValidHostname(config.Host)
if !valid {
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host)
}
httpServer := &http.Server{
Addr: fmt.Sprintf("%s:%v", config.Host, config.Port),
ReadHeaderTimeout: config.ReadHeaderTimeout,
WriteTimeout: config.WriteTimeout,
IdleTimeout: config.IdleTimeout,
}
server := &Server{
server: httpServer,
routes: false,
GZIP: config.GZIP,
ready: make(chan struct{}),
}
return server, nil
}
func (server *Server) Start(ctx context.Context) error {
if ctx == nil {
return errors.New("Context cannot be nil")
}
if !server.routes {
return errors.New("Server.AddRoutes must be run before starting the server")
}
if !server.middleware {
err := server.AddMiddleware()
if err != nil {
return errors.Wrap(err, "server.AddMiddleware")
}
}
go func() {
if server.logger == nil {
fmt.Printf("Listening for requests on %s", server.server.Addr)
} else {
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
}
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if server.logger == nil {
fmt.Printf("Server encountered a fatal error: %s", err.Error())
} else {
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
}
}
}()
server.waitUntilReady(ctx)
return nil
}
func (server *Server) Shutdown(ctx context.Context) error {
if !server.IsReady() {
return errors.New("Server isn't running")
}
if ctx == nil {
return errors.New("Context cannot be nil")
}
err := server.server.Shutdown(ctx)
if err != nil {
return errors.Wrap(err, "Failed to shutdown the server gracefully")
}
server.ready = make(chan struct{})
return nil
}
func isValidHostname(host string) bool {
// Validate as IP or hostname
if errs := validation.IsDNS1123Subdomain(host); len(errs) == 0 {
return true
}
// Check IPv4 / IPv6
if errs := validation.IsValidIP(nil, host); len(errs) == 0 {
return true
}
return false
}
func (server *Server) waitUntilReady(ctx context.Context) error {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
closeOnce := sync.Once{}
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
resp, err := http.Get("http://" + server.server.Addr + "/healthz")
if err != nil {
continue // not accepting yet
}
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
closeOnce.Do(func() { close(server.ready) })
return nil
}
}
}
}