Compare commits

...

11 Commits

Author SHA1 Message Date
bb6820f269 updated to use new ezconf 2026-02-25 22:23:59 +11:00
380e366891 updated to use new modules 2026-02-25 22:20:09 +11:00
e8ffec6b7e updated hws to use new hlog and ezconf 2026-02-25 22:17:25 +11:00
1745458a95 updated hlog to use new ezconf 2026-02-25 22:06:27 +11:00
f3d6a01105 added new way to integrate with ezconf 2026-02-25 22:01:25 +11:00
9179736c90 updated ezconf 2026-02-25 21:52:57 +11:00
05be28d7f3 fixed fatal bug after access token expires 2026-02-07 17:58:02 +11:00
8f7c87cef2 added extracheck to hwsauth 2026-02-07 16:42:08 +11:00
525b3b1396 updated to use new hws version 2026-02-03 19:11:59 +11:00
563908bbb4 updated hws.ThrowError to not return an error and log it to console instead
fixed errors_test

fixed tests
2026-02-03 18:43:31 +11:00
95a17597cf added glob matching to auth middleware 2026-02-01 19:55:04 +11:00
45 changed files with 917 additions and 855 deletions

View File

@@ -3,7 +3,7 @@
// //
// ezconf allows you to: // ezconf allows you to:
// - Load configurations from multiple packages using their ConfigFromEnv functions // - Load configurations from multiple packages using their ConfigFromEnv functions
// - Parse package source code to extract environment variable documentation // - Parse config struct tags to extract environment variable documentation
// - Generate and update .env files with all required environment variables // - Generate and update .env files with all required environment variables
// - Print environment variable lists with descriptions and current values // - Print environment variable lists with descriptions and current values
// - Track additional custom environment variables // - Track additional custom environment variables
@@ -40,16 +40,16 @@
// // Use configuration... // // Use configuration...
// } // }
// //
// Alternatively, you can manually register packages: // Alternatively, you can manually register config structs:
// //
// loader := ezconf.New() // loader := ezconf.New()
// //
// // Add package paths to parse for ENV comments // // Add config struct for tag parsing
// loader.AddPackagePath("/path/to/golib/hlog") // loader.AddConfigStruct(&mypackage.Config{}, "MyPackage")
// //
// // Add configuration loaders // // Add configuration loaders
// loader.AddConfigFunc("hlog", func() (interface{}, error) { // loader.AddConfigFunc("mypackage", func() (interface{}, error) {
// return hlog.ConfigFromEnv() // return mypackage.ConfigFromEnv()
// }) // })
// //
// loader.Load() // loader.Load()
@@ -94,27 +94,34 @@
// Default: "postgres://localhost/mydb", // Default: "postgres://localhost/mydb",
// }) // })
// //
// # ENV Comment Format // # Struct Tag Format
// //
// ezconf parses struct field comments in the following format: // ezconf uses struct tags to define environment variable metadata:
// //
// type Config struct { // type Config struct {
// // ENV LOG_LEVEL: Log level for the application (default: info) // LogLevel string `ezconf:"LOG_LEVEL,description:Log level for the application,default:info"`
// LogLevel string // DatabaseURL string `ezconf:"DATABASE_URL,description:Database connection string,required"`
// // LogDir string `ezconf:"LOG_DIR,description:Directory for log files,required:when LOG_OUTPUT is file"`
// // ENV DATABASE_URL: Database connection string (required)
// DatabaseURL string
// } // }
// //
// The format is: // Tag components (comma-separated):
// - ENV ENV_VAR_NAME: Description (optional modifiers) // - First value: environment variable name (required)
// - (required) or (required if condition) - marks variable as required // - description:...: Description of the variable
// - (default: value) - specifies default value // - default:...: Default value
// - required: Marks the variable as required
// - required:condition: Marks as required with a condition description
// //
// # Integration // # Integration
// //
// Packages can implement the Integration interface to provide automatic
// registration with ezconf. The interface requires:
// - Name() string: Registration key for the config
// - ConfigPointer() any: Pointer to config struct for tag parsing
// - ConfigFunc() func() (any, error): Function to load config from env
// - GroupName() string: Display name for grouping env vars
//
// ezconf integrates with: // ezconf integrates with:
// - All golib packages that follow the ConfigFromEnv pattern // - All golib packages that follow the ConfigFromEnv pattern
// - Any custom configuration structs with ENV comments // - Any custom configuration structs with ezconf struct tags
// - Standard .env file format // - Standard .env file format
package ezconf package ezconf

View File

@@ -16,14 +16,19 @@ type EnvVar struct {
Group string // Group name for organizing variables (e.g., "Database", "Logging") Group string // Group name for organizing variables (e.g., "Database", "Logging")
} }
// configStruct holds a config struct pointer and its group name for parsing
type configStruct struct {
configPtr any
groupName string
}
// ConfigLoader manages configuration loading from multiple sources // ConfigLoader manages configuration loading from multiple sources
type ConfigLoader struct { type ConfigLoader struct {
configFuncs map[string]ConfigFunc // Map of config names to ConfigFromEnv functions configFuncs map[string]ConfigFunc // Map of config names to ConfigFromEnv functions
packagePaths []string // Paths to packages to parse for ENV comments configStructs []configStruct // Config struct pointers for tag parsing
groupNames map[string]string // Map of package paths to group names extraEnvVars []EnvVar // Additional environment variables to track
extraEnvVars []EnvVar // Additional environment variables to track envVars []EnvVar // All extracted environment variables
envVars []EnvVar // All extracted environment variables configs map[string]any // Loaded configurations
configs map[string]any // Loaded configurations
} }
// ConfigFunc is a function that loads configuration from environment variables // ConfigFunc is a function that loads configuration from environment variables
@@ -32,12 +37,11 @@ type ConfigFunc func() (any, error)
// New creates a new ConfigLoader // New creates a new ConfigLoader
func New() *ConfigLoader { func New() *ConfigLoader {
return &ConfigLoader{ return &ConfigLoader{
configFuncs: make(map[string]ConfigFunc), configFuncs: make(map[string]ConfigFunc),
packagePaths: make([]string, 0), configStructs: make([]configStruct, 0),
groupNames: make(map[string]string), extraEnvVars: make([]EnvVar, 0),
extraEnvVars: make([]EnvVar, 0), envVars: make([]EnvVar, 0),
envVars: make([]EnvVar, 0), configs: make(map[string]any),
configs: make(map[string]any),
} }
} }
@@ -54,16 +58,20 @@ func (cl *ConfigLoader) AddConfigFunc(name string, fn ConfigFunc) error {
return nil return nil
} }
// AddPackagePath adds a package directory path to parse for ENV comments // AddConfigStruct adds a config struct pointer for parsing ezconf tags.
func (cl *ConfigLoader) AddPackagePath(path string) error { // The configPtr must be a pointer to a struct with ezconf struct tags.
if path == "" { // The groupName is used for organizing environment variables in output.
return errors.New("package path cannot be empty") func (cl *ConfigLoader) AddConfigStruct(configPtr any, groupName string) error {
if configPtr == nil {
return errors.New("config pointer cannot be nil")
} }
// Check if path exists if groupName == "" {
if _, err := os.Stat(path); os.IsNotExist(err) { groupName = "Other"
return errors.Errorf("package path does not exist: %s", path)
} }
cl.packagePaths = append(cl.packagePaths, path) cl.configStructs = append(cl.configStructs, configStruct{
configPtr: configPtr,
groupName: groupName,
})
return nil return nil
} }
@@ -72,27 +80,22 @@ func (cl *ConfigLoader) AddEnvVar(envVar EnvVar) {
cl.extraEnvVars = append(cl.extraEnvVars, envVar) cl.extraEnvVars = append(cl.extraEnvVars, envVar)
} }
// ParseEnvVars extracts environment variables from packages and extra vars // ParseEnvVars extracts environment variables from config struct tags and extra vars.
// This can be called without having actual environment variables set // This can be called without having actual environment variables set.
func (cl *ConfigLoader) ParseEnvVars() error { func (cl *ConfigLoader) ParseEnvVars() error {
// Clear existing env vars to prevent duplicates // Clear existing env vars to prevent duplicates
cl.envVars = make([]EnvVar, 0) cl.envVars = make([]EnvVar, 0)
// Parse packages for ENV comments // Parse config structs for ezconf tags
for _, pkgPath := range cl.packagePaths { for _, cs := range cl.configStructs {
envVars, err := ParseConfigPackage(pkgPath) envVars, err := ParseConfigStruct(cs.configPtr)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to parse package: %s", pkgPath) return errors.Wrap(err, "failed to parse config struct")
}
// Set group name for these variables from stored mapping
groupName := cl.groupNames[pkgPath]
if groupName == "" {
groupName = "Other"
} }
// Set group name for these variables
for i := range envVars { for i := range envVars {
envVars[i].Group = groupName envVars[i].Group = cs.groupName
} }
cl.envVars = append(cl.envVars, envVars...) cl.envVars = append(cl.envVars, envVars...)
@@ -109,8 +112,8 @@ func (cl *ConfigLoader) ParseEnvVars() error {
return nil return nil
} }
// LoadConfigs executes the config functions to load actual configurations // LoadConfigs executes the config functions to load actual configurations.
// This should be called after environment variables are properly set // This should be called after environment variables are properly set.
func (cl *ConfigLoader) LoadConfigs() error { func (cl *ConfigLoader) LoadConfigs() error {
// Load configurations // Load configurations
for name, fn := range cl.configFuncs { for name, fn := range cl.configFuncs {

View File

@@ -2,11 +2,17 @@ package ezconf
import ( import (
"os" "os"
"path/filepath"
"strings" "strings"
"testing" "testing"
) )
// testConfig is a Config struct used by multiple tests
type testConfig struct {
LogLevel string `ezconf:"LOG_LEVEL,description:Log level for the application,default:info"`
LogOutput string `ezconf:"LOG_OUTPUT,description:Output destination,default:console"`
DatabaseURL string `ezconf:"DATABASE_URL,description:Database connection string,required"`
}
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
loader := New() loader := New()
if loader == nil { if loader == nil {
@@ -16,8 +22,8 @@ func TestNew(t *testing.T) {
if loader.configFuncs == nil { if loader.configFuncs == nil {
t.Error("configFuncs map is nil") t.Error("configFuncs map is nil")
} }
if loader.packagePaths == nil { if loader.configStructs == nil {
t.Error("packagePaths slice is nil") t.Error("configStructs slice is nil")
} }
if loader.extraEnvVars == nil { if loader.extraEnvVars == nil {
t.Error("extraEnvVars slice is nil") t.Error("extraEnvVars slice is nil")
@@ -66,35 +72,39 @@ func TestAddConfigFunc_EmptyName(t *testing.T) {
} }
} }
func TestAddPackagePath(t *testing.T) { func TestAddConfigStruct(t *testing.T) {
loader := New() loader := New()
// Use current directory as test path err := loader.AddConfigStruct(&testConfig{}, "Test")
err := loader.AddPackagePath(".")
if err != nil { if err != nil {
t.Errorf("AddPackagePath failed: %v", err) t.Errorf("AddConfigStruct failed: %v", err)
} }
if len(loader.packagePaths) != 1 { if len(loader.configStructs) != 1 {
t.Errorf("expected 1 package path, got %d", len(loader.packagePaths)) t.Errorf("expected 1 config struct, got %d", len(loader.configStructs))
} }
} }
func TestAddPackagePath_InvalidPath(t *testing.T) { func TestAddConfigStruct_NilPointer(t *testing.T) {
loader := New() loader := New()
err := loader.AddPackagePath("/nonexistent/path") err := loader.AddConfigStruct(nil, "Test")
if err == nil { if err == nil {
t.Error("expected error for nonexistent path") t.Error("expected error for nil pointer")
} }
} }
func TestAddPackagePath_EmptyPath(t *testing.T) { func TestAddConfigStruct_EmptyGroupName(t *testing.T) {
loader := New() loader := New()
err := loader.AddPackagePath("") err := loader.AddConfigStruct(&testConfig{}, "")
if err == nil { if err != nil {
t.Error("expected error for empty path") t.Errorf("AddConfigStruct failed: %v", err)
}
// Should default to "Other"
if loader.configStructs[0].groupName != "Other" {
t.Errorf("expected group name 'Other', got %s", loader.configStructs[0].groupName)
} }
} }
@@ -131,8 +141,8 @@ func TestLoad(t *testing.T) {
return testCfg, nil return testCfg, nil
}) })
// Add current package path // Add config struct for tag parsing
loader.AddPackagePath(".") loader.AddConfigStruct(&testConfig{}, "Test")
// Add an extra env var // Add an extra env var
loader.AddEnvVar(EnvVar{ loader.AddEnvVar(EnvVar{
@@ -249,8 +259,8 @@ func TestParseEnvVars(t *testing.T) {
return "test config", nil return "test config", nil
}) })
// Add current package path // Add config struct for tag parsing
loader.AddPackagePath(".") loader.AddConfigStruct(&testConfig{}, "Test")
// Add an extra env var // Add an extra env var
loader.AddEnvVar(EnvVar{ loader.AddEnvVar(EnvVar{
@@ -353,8 +363,8 @@ func TestParseEnvVars_Then_LoadConfigs(t *testing.T) {
return testCfg, nil return testCfg, nil
}) })
// Add current package path // Add config struct for tag parsing
loader.AddPackagePath(".") loader.AddConfigStruct(&testConfig{}, "Test")
// Add an extra env var // Add an extra env var
loader.AddEnvVar(EnvVar{ loader.AddEnvVar(EnvVar{
@@ -398,63 +408,68 @@ func TestParseEnvVars_Then_LoadConfigs(t *testing.T) {
} }
} }
func TestLoad_Integration(t *testing.T) { func TestParseEnvVars_GroupName(t *testing.T) {
// Integration test with real hlog package
hlogPath := filepath.Join("..", "hlog")
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
t.Skip("hlog package not found, skipping integration test")
}
loader := New() loader := New()
// Add hlog package loader.AddConfigStruct(&testConfig{}, "MyGroup")
if err := loader.AddPackagePath(hlogPath); err != nil {
t.Fatalf("failed to add hlog package: %v", err)
}
// Load without config function (just parse) err := loader.ParseEnvVars()
if err := loader.Load(); err != nil { if err != nil {
t.Fatalf("Load failed: %v", err) t.Fatalf("ParseEnvVars failed: %v", err)
} }
envVars := loader.GetEnvVars() envVars := loader.GetEnvVars()
if len(envVars) == 0 {
t.Error("expected env vars from hlog package")
}
t.Logf("Found %d environment variables from hlog", len(envVars))
for _, ev := range envVars { for _, ev := range envVars {
t.Logf(" %s: %s (default: %s, required: %t)", ev.Name, ev.Description, ev.Default, ev.Required) if ev.Group != "MyGroup" {
t.Errorf("expected group 'MyGroup', got '%s' for var %s", ev.Group, ev.Name)
}
} }
} }
func TestParseEnvVars_GenerateEnvFile_Integration(t *testing.T) { func TestParseEnvVars_CurrentValues(t *testing.T) {
// Test the new separated ParseEnvVars functionality
hlogPath := filepath.Join("..", "hlog")
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
t.Skip("hlog package not found, skipping integration test")
}
loader := New() loader := New()
// Add hlog package loader.AddConfigStruct(&testConfig{}, "Test")
if err := loader.AddPackagePath(hlogPath); err != nil {
t.Fatalf("failed to add hlog package: %v", err) // Set an env var
t.Setenv("LOG_LEVEL", "debug")
err := loader.ParseEnvVars()
if err != nil {
t.Fatalf("ParseEnvVars failed: %v", err)
} }
// Parse env vars without loading configs (this should work even if required env vars are missing) envVars := loader.GetEnvVars()
for _, ev := range envVars {
if ev.Name == "LOG_LEVEL" {
if ev.CurrentValue != "debug" {
t.Errorf("expected CurrentValue 'debug', got '%s'", ev.CurrentValue)
}
return
}
}
t.Error("LOG_LEVEL not found in env vars")
}
func TestParseEnvVars_GenerateEnvFile_Integration(t *testing.T) {
loader := New()
// Add config struct for tag parsing
loader.AddConfigStruct(&testConfig{}, "Test")
// Parse env vars
if err := loader.ParseEnvVars(); err != nil { if err := loader.ParseEnvVars(); err != nil {
t.Fatalf("ParseEnvVars failed: %v", err) t.Fatalf("ParseEnvVars failed: %v", err)
} }
envVars := loader.GetEnvVars() envVars := loader.GetEnvVars()
if len(envVars) == 0 { if len(envVars) == 0 {
t.Error("expected env vars from hlog package") t.Error("expected env vars from config struct")
} }
// Now test that we can generate an env file without calling Load() // Now test that we can generate an env file without calling Load()
tempDir := t.TempDir() tempDir := t.TempDir()
envFile := filepath.Join(tempDir, "test-generated.env") envFile := tempDir + "/test-generated.env"
err := loader.GenerateEnvFile(envFile, false) err := loader.GenerateEnvFile(envFile, false)
if err != nil { if err != nil {
@@ -472,16 +487,16 @@ func TestParseEnvVars_GenerateEnvFile_Integration(t *testing.T) {
t.Error("expected header in generated file") t.Error("expected header in generated file")
} }
// Should contain environment variables from hlog // Should contain environment variables from config struct
foundHlogVar := false foundVar := false
for _, ev := range envVars { for _, ev := range envVars {
if strings.Contains(output, ev.Name) { if strings.Contains(output, ev.Name) {
foundHlogVar = true foundVar = true
break break
} }
} }
if !foundHlogVar { if !foundVar {
t.Error("expected to find at least one hlog environment variable in generated file") t.Error("expected to find at least one environment variable in generated file")
} }
t.Logf("Successfully generated env file with %d variables", len(envVars)) t.Logf("Successfully generated env file with %d variables", len(envVars))

View File

@@ -1,31 +1,70 @@
package ezconf package ezconf
// Integration is an interface that packages can implement to provide type Integration struct {
Name string
ConfigPointer any
ConfigFunc func() (any, error)
GroupName string
}
func NewIntegration(name, groupname string, cfgptr any, cfgfunc func() (any, error)) *Integration {
return &Integration{
name,
cfgptr,
cfgfunc,
groupname,
}
}
// IntegrationDepr is an interface that packages can implement to provide
// easy integration with ezconf // easy integration with ezconf
type Integration interface { type IntegrationDepr interface {
// Name returns the name to use when registering the config // Name returns the name to use when registering the config
Name() string Name() string
// PackagePath returns the path to the package for source parsing // ConfigPointer returns a pointer to the config struct for tag parsing
PackagePath() string ConfigPointer() any
// ConfigFunc returns the ConfigFromEnv function // ConfigFunc returns the ConfigFromEnv function
ConfigFunc() func() (interface{}, error) ConfigFunc() func() (any, error)
// GroupName returns the display name for grouping environment variables // GroupName returns the display name for grouping environment variables
GroupName() string GroupName() string
} }
// RegisterIntegration registers a package that implements the Integration interface // AddIntegration registers a package using an Integration object returned by another package
func (cl *ConfigLoader) RegisterIntegration(integration Integration) error { func (cl *ConfigLoader) AddIntegration(integration *Integration) error {
// Add package path // Add config struct for tag parsing
pkgPath := integration.PackagePath() configPtr := integration.ConfigPointer
if err := cl.AddPackagePath(pkgPath); err != nil { if err := cl.AddConfigStruct(configPtr, integration.GroupName); err != nil {
return err return err
} }
// Store group name for this package // Add config function
cl.groupNames[pkgPath] = integration.GroupName() if err := cl.AddConfigFunc(integration.Name, integration.ConfigFunc); err != nil {
return err
}
return nil
}
// AddIntegrations registers multiple integrations at once
func (cl *ConfigLoader) AddIntegrations(integrations ...*Integration) error {
for _, integration := range integrations {
if err := cl.AddIntegration(integration); err != nil {
return err
}
}
return nil
}
// RegisterIntegration registers a package that implements the Integration interface
func (cl *ConfigLoader) RegisterIntegration(integration IntegrationDepr) error {
// Add config struct for tag parsing
configPtr := integration.ConfigPointer()
if err := cl.AddConfigStruct(configPtr, integration.GroupName()); err != nil {
return err
}
// Add config function // Add config function
if err := cl.AddConfigFunc(integration.Name(), integration.ConfigFunc()); err != nil { if err := cl.AddConfigFunc(integration.Name(), integration.ConfigFunc()); err != nil {
@@ -36,7 +75,7 @@ func (cl *ConfigLoader) RegisterIntegration(integration Integration) error {
} }
// RegisterIntegrations registers multiple integrations at once // RegisterIntegrations registers multiple integrations at once
func (cl *ConfigLoader) RegisterIntegrations(integrations ...Integration) error { func (cl *ConfigLoader) RegisterIntegrations(integrations ...IntegrationDepr) error {
for _, integration := range integrations { for _, integration := range integrations {
if err := cl.RegisterIntegration(integration); err != nil { if err := cl.RegisterIntegration(integration); err != nil {
return err return err

View File

@@ -1,24 +1,34 @@
package ezconf package ezconf
import ( import (
"os"
"path/filepath"
"testing" "testing"
) )
// Mock integration for testing // mockConfig is a test config struct with ezconf tags
type mockConfig struct {
Host string `ezconf:"MOCK_HOST,description:Host to connect to,default:localhost"`
Port int `ezconf:"MOCK_PORT,description:Port to connect to,default:8080"`
}
// mockConfig2 is a second test config struct
type mockConfig2 struct {
Token string `ezconf:"MOCK_TOKEN,description:API token,required"`
}
// mockIntegration implements the Integration interface for testing
type mockIntegration struct { type mockIntegration struct {
name string name string
packagePath string configPtr any
configFunc func() (interface{}, error) configFunc func() (interface{}, error)
groupName string
} }
func (m mockIntegration) Name() string { func (m mockIntegration) Name() string {
return m.name return m.name
} }
func (m mockIntegration) PackagePath() string { func (m mockIntegration) ConfigPointer() any {
return m.packagePath return m.configPtr
} }
func (m mockIntegration) ConfigFunc() func() (interface{}, error) { func (m mockIntegration) ConfigFunc() func() (interface{}, error) {
@@ -26,15 +36,18 @@ func (m mockIntegration) ConfigFunc() func() (interface{}, error) {
} }
func (m mockIntegration) GroupName() string { func (m mockIntegration) GroupName() string {
return "Test Group" if m.groupName == "" {
return "Test Group"
}
return m.groupName
} }
func TestRegisterIntegration(t *testing.T) { func TestRegisterIntegration(t *testing.T) {
loader := New() loader := New()
integration := mockIntegration{ integration := mockIntegration{
name: "test", name: "test",
packagePath: ".", configPtr: &mockConfig{},
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "test config", nil return "test config", nil
}, },
@@ -45,9 +58,9 @@ func TestRegisterIntegration(t *testing.T) {
t.Fatalf("RegisterIntegration failed: %v", err) t.Fatalf("RegisterIntegration failed: %v", err)
} }
// Verify package path was added // Verify config struct was added
if len(loader.packagePaths) != 1 { if len(loader.configStructs) != 1 {
t.Errorf("expected 1 package path, got %d", len(loader.packagePaths)) t.Errorf("expected 1 config struct, got %d", len(loader.configStructs))
} }
// Verify config func was added // Verify config func was added
@@ -68,14 +81,46 @@ func TestRegisterIntegration(t *testing.T) {
if cfg != "test config" { if cfg != "test config" {
t.Errorf("expected 'test config', got %v", cfg) t.Errorf("expected 'test config', got %v", cfg)
} }
// Verify env vars were parsed from struct tags
envVars := loader.GetEnvVars()
if len(envVars) != 2 {
t.Errorf("expected 2 env vars, got %d", len(envVars))
}
foundHost := false
foundPort := false
for _, ev := range envVars {
if ev.Name == "MOCK_HOST" {
foundHost = true
if ev.Default != "localhost" {
t.Errorf("expected default 'localhost', got '%s'", ev.Default)
}
if ev.Group != "Test Group" {
t.Errorf("expected group 'Test Group', got '%s'", ev.Group)
}
}
if ev.Name == "MOCK_PORT" {
foundPort = true
if ev.Default != "8080" {
t.Errorf("expected default '8080', got '%s'", ev.Default)
}
}
}
if !foundHost {
t.Error("MOCK_HOST not found in env vars")
}
if !foundPort {
t.Error("MOCK_PORT not found in env vars")
}
} }
func TestRegisterIntegration_InvalidPath(t *testing.T) { func TestRegisterIntegration_NilConfigPointer(t *testing.T) {
loader := New() loader := New()
integration := mockIntegration{ integration := mockIntegration{
name: "test", name: "test",
packagePath: "/nonexistent/path", configPtr: nil,
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "test config", nil return "test config", nil
}, },
@@ -83,7 +128,7 @@ func TestRegisterIntegration_InvalidPath(t *testing.T) {
err := loader.RegisterIntegration(integration) err := loader.RegisterIntegration(integration)
if err == nil { if err == nil {
t.Error("expected error for invalid package path") t.Error("expected error for nil config pointer")
} }
} }
@@ -91,16 +136,16 @@ func TestRegisterIntegrations(t *testing.T) {
loader := New() loader := New()
integration1 := mockIntegration{ integration1 := mockIntegration{
name: "test1", name: "test1",
packagePath: ".", configPtr: &mockConfig{},
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "config1", nil return "config1", nil
}, },
} }
integration2 := mockIntegration{ integration2 := mockIntegration{
name: "test2", name: "test2",
packagePath: ".", configPtr: &mockConfig2{},
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "config2", nil return "config2", nil
}, },
@@ -130,22 +175,28 @@ func TestRegisterIntegrations(t *testing.T) {
if cfg1 != "config1" || cfg2 != "config2" { if cfg1 != "config1" || cfg2 != "config2" {
t.Error("config values mismatch") t.Error("config values mismatch")
} }
// Should have env vars from both structs
envVars := loader.GetEnvVars()
if len(envVars) != 3 {
t.Errorf("expected 3 env vars (2 from mockConfig + 1 from mockConfig2), got %d", len(envVars))
}
} }
func TestRegisterIntegrations_PartialFailure(t *testing.T) { func TestRegisterIntegrations_PartialFailure(t *testing.T) {
loader := New() loader := New()
integration1 := mockIntegration{ integration1 := mockIntegration{
name: "test1", name: "test1",
packagePath: ".", configPtr: &mockConfig{},
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "config1", nil return "config1", nil
}, },
} }
integration2 := mockIntegration{ integration2 := mockIntegration{
name: "test2", name: "test2",
packagePath: "/nonexistent", configPtr: nil, // This should cause failure
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "config2", nil return "config2", nil
}, },
@@ -159,54 +210,5 @@ func TestRegisterIntegrations_PartialFailure(t *testing.T) {
func TestIntegration_Interface(t *testing.T) { func TestIntegration_Interface(t *testing.T) {
// Verify that mockIntegration implements Integration interface // Verify that mockIntegration implements Integration interface
var _ Integration = (*mockIntegration)(nil) var _ IntegrationDepr = (*mockIntegration)(nil)
}
func TestRegisterIntegration_RealPackage(t *testing.T) {
// Integration test with real hlog package if available
hlogPath := filepath.Join("..", "hlog")
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
t.Skip("hlog package not found, skipping integration test")
}
loader := New()
// Create a simple integration for testing
integration := mockIntegration{
name: "hlog",
packagePath: hlogPath,
configFunc: func() (interface{}, error) {
// Return a mock config instead of calling real ConfigFromEnv
return struct{ LogLevel string }{LogLevel: "info"}, nil
},
}
err := loader.RegisterIntegration(integration)
if err != nil {
t.Fatalf("RegisterIntegration with real package failed: %v", err)
}
if err := loader.Load(); err != nil {
t.Fatalf("Load failed: %v", err)
}
// Should have parsed env vars from hlog
envVars := loader.GetEnvVars()
if len(envVars) == 0 {
t.Error("expected env vars from hlog package")
}
// Check for known hlog variables
foundLogLevel := false
for _, ev := range envVars {
if ev.Name == "LOG_LEVEL" {
foundLogLevel = true
t.Logf("Found LOG_LEVEL: %s", ev.Description)
break
}
}
if !foundLogLevel {
t.Error("expected to find LOG_LEVEL from hlog")
}
} }

View File

@@ -1,146 +1,102 @@
package ezconf package ezconf
import ( import (
"go/ast" "reflect"
"go/parser"
"go/token"
"os"
"path/filepath"
"regexp"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// ParseConfigFile parses a Go source file and extracts ENV comments from struct fields // ParseConfigStruct extracts environment variable metadata from a config
func ParseConfigFile(filename string) ([]EnvVar, error) { // struct's ezconf struct tags using reflection.
content, err := os.ReadFile(filename) //
if err != nil { // The configPtr parameter must be a pointer to a struct. Each field with an
return nil, errors.Wrap(err, "failed to read file") // ezconf tag will be parsed to extract environment variable information.
//
// Tag format: `ezconf:"VAR_NAME,description:Description text,default:value,required"`
//
// Components:
// - First value: environment variable name (required)
// - description:...: Description of the variable
// - default:...: Default value
// - required: Marks the variable as required (optionally required:condition)
func ParseConfigStruct(configPtr any) ([]EnvVar, error) {
if configPtr == nil {
return nil, errors.New("config pointer cannot be nil")
} }
fset := token.NewFileSet() v := reflect.ValueOf(configPtr)
file, err := parser.ParseFile(fset, filename, content, parser.ParseComments) if v.Kind() != reflect.Ptr {
if err != nil { return nil, errors.New("config must be a pointer to a struct")
return nil, errors.Wrap(err, "failed to parse file")
} }
v = v.Elem()
if v.Kind() != reflect.Struct {
return nil, errors.New("config must be a pointer to a struct")
}
t := v.Type()
envVars := make([]EnvVar, 0) envVars := make([]EnvVar, 0)
// Walk through the AST for i := 0; i < t.NumField(); i++ {
ast.Inspect(file, func(n ast.Node) bool { field := t.Field(i)
// Look for struct type declarations tag := field.Tag.Get("ezconf")
typeSpec, ok := n.(*ast.TypeSpec) if tag == "" {
if !ok { continue
return true
} }
structType, ok := typeSpec.Type.(*ast.StructType) envVar, err := parseEzconfTag(tag)
if !ok { if err != nil {
return true return nil, errors.Wrapf(err, "failed to parse ezconf tag on field %s", field.Name)
} }
// Iterate through struct fields envVars = append(envVars, *envVar)
for _, field := range structType.Fields.List { }
var comment string
// Try to get from doc comment (comment before field)
if field.Doc != nil && len(field.Doc.List) > 0 {
comment = field.Doc.List[0].Text
comment = strings.TrimPrefix(comment, "//")
comment = strings.TrimSpace(comment)
}
// Try to get from inline comment (comment after field)
if comment == "" && field.Comment != nil && len(field.Comment.List) > 0 {
comment = field.Comment.List[0].Text
comment = strings.TrimPrefix(comment, "//")
comment = strings.TrimSpace(comment)
}
// Parse ENV comment
if strings.HasPrefix(comment, "ENV ") {
envVar, err := parseEnvComment(comment)
if err == nil {
envVars = append(envVars, *envVar)
}
}
}
return true
})
return envVars, nil return envVars, nil
} }
// ParseConfigPackage parses all Go files in a package directory and extracts ENV comments // parseEzconfTag parses an ezconf struct tag value to extract environment
func ParseConfigPackage(packagePath string) ([]EnvVar, error) { // variable information.
// Find all .go files in the package //
files, err := filepath.Glob(filepath.Join(packagePath, "*.go")) // Expected format: "VAR_NAME,description:Description text,default:value,required"
if err != nil { func parseEzconfTag(tag string) (*EnvVar, error) {
return nil, errors.Wrap(err, "failed to glob package files") if tag == "" {
return nil, errors.New("tag cannot be empty")
} }
allEnvVars := make([]EnvVar, 0) parts := strings.Split(tag, ",")
if len(parts) == 0 {
for _, file := range files { return nil, errors.New("tag cannot be empty")
// Skip test files
if strings.HasSuffix(file, "_test.go") {
continue
}
envVars, err := ParseConfigFile(file)
if err != nil {
// Log error but continue with other files
continue
}
allEnvVars = append(allEnvVars, envVars...)
}
return allEnvVars, nil
}
// parseEnvComment parses a field comment to extract environment variable information.
// Expected format: ENV ENV_NAME: Description (required <condition>) (default: <value>)
func parseEnvComment(comment string) (*EnvVar, error) {
// Check if comment starts with ENV
if !strings.HasPrefix(comment, "ENV ") {
return nil, errors.New("comment does not start with 'ENV '")
}
// Remove "ENV " prefix
comment = strings.TrimPrefix(comment, "ENV ")
// Extract env var name (everything before the first colon)
colonIdx := strings.Index(comment, ":")
if colonIdx == -1 {
return nil, errors.New("missing colon separator")
} }
envVar := &EnvVar{ envVar := &EnvVar{
Name: strings.TrimSpace(comment[:colonIdx]), Name: strings.TrimSpace(parts[0]),
} }
// Extract description and optional parts if envVar.Name == "" {
remainder := strings.TrimSpace(comment[colonIdx+1:]) return nil, errors.New("environment variable name cannot be empty")
// Check for (required ...) pattern
requiredPattern := regexp.MustCompile(`\(required[^)]*\)`)
if requiredPattern.MatchString(remainder) {
envVar.Required = true
remainder = requiredPattern.ReplaceAllString(remainder, "")
} }
// Check for (default: ...) pattern for _, part := range parts[1:] {
defaultPattern := regexp.MustCompile(`\(default:\s*([^)]*)\)`) part = strings.TrimSpace(part)
if matches := defaultPattern.FindStringSubmatch(remainder); len(matches) > 1 {
envVar.Default = strings.TrimSpace(matches[1])
remainder = defaultPattern.ReplaceAllString(remainder, "")
}
// What remains is the description switch {
envVar.Description = strings.TrimSpace(remainder) case strings.HasPrefix(part, "description:"):
envVar.Description = strings.TrimSpace(strings.TrimPrefix(part, "description:"))
case strings.HasPrefix(part, "default:"):
envVar.Default = strings.TrimSpace(strings.TrimPrefix(part, "default:"))
case part == "required":
envVar.Required = true
case strings.HasPrefix(part, "required:"):
envVar.Required = true
// Store the condition in the description if it adds context
condition := strings.TrimSpace(strings.TrimPrefix(part, "required:"))
if condition != "" && envVar.Description != "" {
envVar.Description = envVar.Description + " (required " + condition + ")"
}
}
}
return envVar, nil return envVar, nil
} }

View File

@@ -1,21 +1,19 @@
package ezconf package ezconf
import ( import (
"os"
"path/filepath"
"testing" "testing"
) )
func TestParseEnvComment(t *testing.T) { func TestParseEzconfTag(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
comment string tag string
wantEnvVar *EnvVar wantEnvVar *EnvVar
expectError bool expectError bool
}{ }{
{ {
name: "simple env variable", name: "simple env variable",
comment: "ENV LOG_LEVEL: Log level for the application", tag: "LOG_LEVEL,description:Log level for the application",
wantEnvVar: &EnvVar{ wantEnvVar: &EnvVar{
Name: "LOG_LEVEL", Name: "LOG_LEVEL",
Description: "Log level for the application", Description: "Log level for the application",
@@ -25,8 +23,8 @@ func TestParseEnvComment(t *testing.T) {
expectError: false, expectError: false,
}, },
{ {
name: "env variable with default", name: "env variable with default",
comment: "ENV LOG_LEVEL: Log level for the application (default: info)", tag: "LOG_LEVEL,description:Log level for the application,default:info",
wantEnvVar: &EnvVar{ wantEnvVar: &EnvVar{
Name: "LOG_LEVEL", Name: "LOG_LEVEL",
Description: "Log level for the application", Description: "Log level for the application",
@@ -36,8 +34,8 @@ func TestParseEnvComment(t *testing.T) {
expectError: false, expectError: false,
}, },
{ {
name: "required env variable", name: "required env variable",
comment: "ENV DATABASE_URL: Database connection string (required)", tag: "DATABASE_URL,description:Database connection string,required",
wantEnvVar: &EnvVar{ wantEnvVar: &EnvVar{
Name: "DATABASE_URL", Name: "DATABASE_URL",
Description: "Database connection string", Description: "Database connection string",
@@ -47,25 +45,36 @@ func TestParseEnvComment(t *testing.T) {
expectError: false, expectError: false,
}, },
{ {
name: "required with condition and default", name: "required with condition and default",
comment: "ENV LOG_DIR: Directory for log files (required when LOG_OUTPUT is file) (default: /var/log)", tag: "LOG_DIR,description:Directory for log files,required:when LOG_OUTPUT is file,default:/var/log",
wantEnvVar: &EnvVar{ wantEnvVar: &EnvVar{
Name: "LOG_DIR", Name: "LOG_DIR",
Description: "Directory for log files", Description: "Directory for log files (required when LOG_OUTPUT is file)",
Required: true, Required: true,
Default: "/var/log", Default: "/var/log",
}, },
expectError: false, expectError: false,
}, },
{ {
name: "missing colon", name: "name only",
comment: "ENV LOG_LEVEL Log level", tag: "SIMPLE_VAR",
wantEnvVar: &EnvVar{
Name: "SIMPLE_VAR",
Description: "",
Required: false,
Default: "",
},
expectError: false,
},
{
name: "empty tag",
tag: "",
wantEnvVar: nil, wantEnvVar: nil,
expectError: true, expectError: true,
}, },
{ {
name: "not an ENV comment", name: "empty name",
comment: "This is a regular comment", tag: ",description:some desc",
wantEnvVar: nil, wantEnvVar: nil,
expectError: true, expectError: true,
}, },
@@ -73,7 +82,7 @@ func TestParseEnvComment(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
envVar, err := parseEnvComment(tt.comment) envVar, err := parseEzconfTag(tt.tag)
if tt.expectError { if tt.expectError {
if err == nil { if err == nil {
@@ -103,32 +112,17 @@ func TestParseEnvComment(t *testing.T) {
} }
} }
func TestParseConfigFile(t *testing.T) { func TestParseConfigStruct(t *testing.T) {
// Create a temporary test file type TestConfig struct {
tempDir := t.TempDir() LogLevel string `ezconf:"LOG_LEVEL,description:Log level for the application,default:info"`
testFile := filepath.Join(tempDir, "config.go") LogOutput string `ezconf:"LOG_OUTPUT,description:Output destination,default:console"`
DatabaseURL string `ezconf:"DATABASE_URL,description:Database connection string,required"`
content := `package testpkg NoTag string
type Config struct {
// ENV LOG_LEVEL: Log level for the application (default: info)
LogLevel string
// ENV LOG_OUTPUT: Output destination (default: console)
LogOutput string
// ENV DATABASE_URL: Database connection string (required)
DatabaseURL string
}
`
if err := os.WriteFile(testFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create test file: %v", err)
} }
envVars, err := ParseConfigFile(testFile) envVars, err := ParseConfigStruct(&TestConfig{})
if err != nil { if err != nil {
t.Fatalf("ParseConfigFile failed: %v", err) t.Fatalf("ParseConfigStruct failed: %v", err)
} }
if len(envVars) != 3 { if len(envVars) != 3 {
@@ -152,51 +146,70 @@ type Config struct {
} }
} }
func TestParseConfigPackage(t *testing.T) { func TestParseConfigStruct_NilPointer(t *testing.T) {
// Test with actual hlog package _, err := ParseConfigStruct(nil)
hlogPath := filepath.Join("..", "hlog")
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
t.Skip("hlog package not found, skipping integration test")
}
envVars, err := ParseConfigPackage(hlogPath)
if err != nil {
t.Fatalf("ParseConfigPackage failed: %v", err)
}
if len(envVars) == 0 {
t.Error("expected at least one env var from hlog package")
}
// Check for known hlog variables
foundLogLevel := false
for _, envVar := range envVars {
if envVar.Name == "LOG_LEVEL" {
foundLogLevel = true
t.Logf("Found LOG_LEVEL: %s", envVar.Description)
}
}
if !foundLogLevel {
t.Error("expected to find LOG_LEVEL in hlog package")
}
}
func TestParseConfigFile_InvalidFile(t *testing.T) {
_, err := ParseConfigFile("/nonexistent/file.go")
if err == nil { if err == nil {
t.Error("expected error for nonexistent file") t.Error("expected error for nil pointer")
} }
} }
func TestParseConfigPackage_InvalidPath(t *testing.T) { func TestParseConfigStruct_NotPointer(t *testing.T) {
envVars, err := ParseConfigPackage("/nonexistent/package") type TestConfig struct {
Foo string `ezconf:"FOO,description:test"`
}
_, err := ParseConfigStruct(TestConfig{})
if err == nil {
t.Error("expected error for non-pointer")
}
}
func TestParseConfigStruct_NotStruct(t *testing.T) {
str := "not a struct"
_, err := ParseConfigStruct(&str)
if err == nil {
t.Error("expected error for non-struct pointer")
}
}
func TestParseConfigStruct_NoTags(t *testing.T) {
type EmptyConfig struct {
Foo string
Bar int
}
envVars, err := ParseConfigStruct(&EmptyConfig{})
if err != nil { if err != nil {
t.Fatalf("ParseConfigPackage should not error on invalid path: %v", err) t.Fatalf("ParseConfigStruct failed: %v", err)
} }
// Should return empty slice for invalid path
if len(envVars) != 0 { if len(envVars) != 0 {
t.Errorf("expected 0 env vars for invalid path, got %d", len(envVars)) t.Errorf("expected 0 env vars for struct with no tags, got %d", len(envVars))
}
}
func TestParseConfigStruct_UnexportedFields(t *testing.T) {
type TestConfig struct {
exported string `ezconf:"EXPORTED,description:An exported field"`
unexported string `ezconf:"UNEXPORTED,description:An unexported field"`
}
envVars, err := ParseConfigStruct(&TestConfig{})
if err != nil {
t.Fatalf("ParseConfigStruct failed: %v", err)
}
if len(envVars) != 2 {
t.Errorf("expected 2 env vars (both exported and unexported), got %d", len(envVars))
}
}
func TestParseConfigStruct_InvalidTag(t *testing.T) {
type TestConfig struct {
Bad string `ezconf:",description:missing name"`
}
_, err := ParseConfigStruct(&TestConfig{})
if err == nil {
t.Error("expected error for invalid tag")
} }
} }

View File

@@ -9,11 +9,11 @@ import (
// It can be populated from environment variables using ConfigFromEnv // It can be populated from environment variables using ConfigFromEnv
// or created programmatically. // or created programmatically.
type Config struct { type Config struct {
LogLevel Level // ENV LOG_LEVEL: Log level for the logger - trace, debug, info, warn, error, fatal, panic (default: info) LogLevel Level `ezconf:"LOG_LEVEL,description:Log level for the logger - trace debug info warn error fatal panic,default:info"`
LogOutput string // ENV LOG_OUTPUT: Output destination for logs - console, file, or both (default: console) LogOutput string `ezconf:"LOG_OUTPUT,description:Output destination for logs - console file or both,default:console"`
LogDir string // ENV LOG_DIR: Directory path for log files (required when LOG_OUTPUT is "file" or "both") LogDir string `ezconf:"LOG_DIR,description:Directory path for log files,required:when LOG_OUTPUT is file or both"`
LogFileName string // ENV LOG_FILE_NAME: Name of the log file (required when LOG_OUTPUT is "file" or "both") LogFileName string `ezconf:"LOG_FILE_NAME,description:Name of the log file,required:when LOG_OUTPUT is file or both"`
LogAppend bool // ENV LOG_APPEND: Append to existing log file or overwrite (default: true) LogAppend bool `ezconf:"LOG_APPEND,description:Append to existing log file or overwrite,default:true"`
} }
// ConfigFromEnv loads logger configuration from environment variables. // ConfigFromEnv loads logger configuration from environment variables.

View File

@@ -1,35 +1,9 @@
package hlog package hlog
import "runtime" import "git.haelnorr.com/h/golib/ezconf"
// EZConfIntegration provides integration with ezconf for automatic configuration // NewEZConfIntegration creates a new EZConf integration
type EZConfIntegration struct{} func NewEZConfIntegration() *ezconf.Integration {
return ezconf.NewIntegration("hlog", "HLog",
// PackagePath returns the path to the hlog package for source parsing &Config{}, func() (any, error) { return ConfigFromEnv() })
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
return ConfigFromEnv()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return "hlog"
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return "HLog"
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{}
} }

View File

@@ -7,6 +7,8 @@ require (
github.com/rs/zerolog v1.34.0 github.com/rs/zerolog v1.34.0
) )
require git.haelnorr.com/h/golib/ezconf v0.2.1
require ( require (
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect

View File

@@ -1,5 +1,7 @@
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY= git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=

View File

@@ -7,13 +7,13 @@ import (
) )
type Config struct { type Config struct {
Host string // ENV HWS_HOST: Host to listen on (default: 127.0.0.1) Host string `ezconf:"HWS_HOST,description:Host to listen on,default:127.0.0.1"`
Port uint64 // ENV HWS_PORT: Port to listen on (default: 3000) Port uint64 `ezconf:"HWS_PORT,description:Port to listen on,default:3000"`
GZIP bool // ENV HWS_GZIP: Flag for GZIP compression on requests (default: false) GZIP bool `ezconf:"HWS_GZIP,description:Flag for GZIP compression on requests,default:false"`
ReadHeaderTimeout time.Duration // ENV HWS_READ_HEADER_TIMEOUT: Timeout for reading request headers in seconds (default: 2) ReadHeaderTimeout time.Duration `ezconf:"HWS_READ_HEADER_TIMEOUT,description:Timeout for reading request headers in seconds,default:2"`
WriteTimeout time.Duration // ENV HWS_WRITE_TIMEOUT: Timeout for writing requests in seconds (default: 10) WriteTimeout time.Duration `ezconf:"HWS_WRITE_TIMEOUT,description:Timeout for writing requests in seconds,default:10"`
IdleTimeout time.Duration // ENV HWS_IDLE_TIMEOUT: Timeout for idle connections in seconds (default: 120) IdleTimeout time.Duration `ezconf:"HWS_IDLE_TIMEOUT,description:Timeout for idle connections in seconds,default:120"`
ShutdownDelay time.Duration // ENV HWS_SHUTDOWN_DELAY: Delay in seconds before server shutsdown when Shutdown is called (default: 5) ShutdownDelay time.Duration `ezconf:"HWS_SHUTDOWN_DELAY,description:Delay in seconds before server shuts down when Shutdown is called,default:5"`
} }
// ConfigFromEnv returns a Config struct loaded from the environment variables // ConfigFromEnv returns a Config struct loaded from the environment variables

View File

@@ -13,12 +13,12 @@ import (
func Test_ConfigFromEnv(t *testing.T) { func Test_ConfigFromEnv(t *testing.T) {
t.Run("Default values when no env vars set", func(t *testing.T) { t.Run("Default values when no env vars set", func(t *testing.T) {
// Clear any existing env vars // Clear any existing env vars
os.Unsetenv("HWS_HOST") _ = os.Unsetenv("HWS_HOST")
os.Unsetenv("HWS_PORT") _ = os.Unsetenv("HWS_PORT")
os.Unsetenv("HWS_GZIP") _ = os.Unsetenv("HWS_GZIP")
os.Unsetenv("HWS_READ_HEADER_TIMEOUT") _ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
os.Unsetenv("HWS_WRITE_TIMEOUT") _ = os.Unsetenv("HWS_WRITE_TIMEOUT")
os.Unsetenv("HWS_IDLE_TIMEOUT") _ = os.Unsetenv("HWS_IDLE_TIMEOUT")
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -33,8 +33,10 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("Custom host", func(t *testing.T) { t.Run("Custom host", func(t *testing.T) {
os.Setenv("HWS_HOST", "192.168.1.1") _ = os.Setenv("HWS_HOST", "192.168.1.1")
defer os.Unsetenv("HWS_HOST") defer func() {
_ = os.Unsetenv("HWS_HOST")
}()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -42,8 +44,10 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("Custom port", func(t *testing.T) { t.Run("Custom port", func(t *testing.T) {
os.Setenv("HWS_PORT", "8080") _ = os.Setenv("HWS_PORT", "8080")
defer os.Unsetenv("HWS_PORT") defer func() {
_ = os.Unsetenv("HWS_PORT")
}()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -51,8 +55,10 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("GZIP enabled", func(t *testing.T) { t.Run("GZIP enabled", func(t *testing.T) {
os.Setenv("HWS_GZIP", "true") _ = os.Setenv("HWS_GZIP", "true")
defer os.Unsetenv("HWS_GZIP") defer func() {
_ = os.Unsetenv("HWS_GZIP")
}()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -60,12 +66,14 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("Custom timeouts", func(t *testing.T) { t.Run("Custom timeouts", func(t *testing.T) {
os.Setenv("HWS_READ_HEADER_TIMEOUT", "5") _ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
os.Setenv("HWS_WRITE_TIMEOUT", "30") _ = os.Setenv("HWS_WRITE_TIMEOUT", "30")
os.Setenv("HWS_IDLE_TIMEOUT", "300") _ = os.Setenv("HWS_IDLE_TIMEOUT", "300")
defer os.Unsetenv("HWS_READ_HEADER_TIMEOUT") defer func() {
defer os.Unsetenv("HWS_WRITE_TIMEOUT") _ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
defer os.Unsetenv("HWS_IDLE_TIMEOUT") _ = os.Unsetenv("HWS_WRITE_TIMEOUT")
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
}()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -75,19 +83,19 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("All custom values", func(t *testing.T) { t.Run("All custom values", func(t *testing.T) {
os.Setenv("HWS_HOST", "0.0.0.0") _ = os.Setenv("HWS_HOST", "0.0.0.0")
os.Setenv("HWS_PORT", "9000") _ = os.Setenv("HWS_PORT", "9000")
os.Setenv("HWS_GZIP", "true") _ = os.Setenv("HWS_GZIP", "true")
os.Setenv("HWS_READ_HEADER_TIMEOUT", "3") _ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
os.Setenv("HWS_WRITE_TIMEOUT", "15") _ = os.Setenv("HWS_WRITE_TIMEOUT", "15")
os.Setenv("HWS_IDLE_TIMEOUT", "180") _ = os.Setenv("HWS_IDLE_TIMEOUT", "180")
defer func() { defer func() {
os.Unsetenv("HWS_HOST") _ = os.Unsetenv("HWS_HOST")
os.Unsetenv("HWS_PORT") _ = os.Unsetenv("HWS_PORT")
os.Unsetenv("HWS_GZIP") _ = os.Unsetenv("HWS_GZIP")
os.Unsetenv("HWS_READ_HEADER_TIMEOUT") _ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
os.Unsetenv("HWS_WRITE_TIMEOUT") _ = os.Unsetenv("HWS_WRITE_TIMEOUT")
os.Unsetenv("HWS_IDLE_TIMEOUT") _ = os.Unsetenv("HWS_IDLE_TIMEOUT")
}() }()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()

View File

@@ -9,7 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// Error to use with Server.ThrowError // HWSError wraps an error with other information for use with HWS features
type HWSError struct { type HWSError struct {
StatusCode int // HTTP Status code StatusCode int // HTTP Status code
Message string // Error message Message string // Error message
@@ -41,7 +41,7 @@ type ErrorPage interface {
} }
// AddErrorPage registers a handler that returns an ErrorPage // AddErrorPage registers a handler that returns an ErrorPage
func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error { func (s *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError}) page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError})
@@ -56,7 +56,7 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
return errors.New("Render method of the error page did not write anything to the response writer") return errors.New("Render method of the error page did not write anything to the response writer")
} }
server.errorPage = pageFunc s.errorPage = pageFunc
return nil return nil
} }
@@ -64,7 +64,19 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
// the error with the level specified by the HWSError. // the error with the level specified by the HWSError.
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter // If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
// and the request chain should be terminated. // and the request chain should be terminated.
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) error { func (s *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) {
err := s.throwError(w, r, error)
if err != nil {
s.LogError(error)
s.LogError(HWSError{
Message: "Error occured during throwError",
Error: errors.Wrap(err, "s.throwError"),
Level: ErrorERROR,
})
}
}
func (s *Server) throwError(w http.ResponseWriter, r *http.Request, error HWSError) error {
if error.StatusCode <= 0 { if error.StatusCode <= 0 {
return errors.New("HWSError.StatusCode cannot be 0.") return errors.New("HWSError.StatusCode cannot be 0.")
} }
@@ -77,32 +89,27 @@ func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error H
if r == nil { if r == nil {
return errors.New("Request cannot be nil") return errors.New("Request cannot be nil")
} }
if !server.IsReady() { if !s.IsReady() {
return errors.New("ThrowError called before server started") return errors.New("ThrowError called before server started")
} }
w.WriteHeader(error.StatusCode) w.WriteHeader(error.StatusCode)
server.LogError(error) s.LogError(error)
if server.errorPage == nil { if s.errorPage == nil {
server.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG}) s.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
return nil return nil
} }
if error.RenderErrorPage { if error.RenderErrorPage {
server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG}) s.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
errPage, err := server.errorPage(error) errPage, err := s.errorPage(error)
if err != nil { if err != nil {
server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err}) s.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
} }
err = errPage.Render(r.Context(), w) err = errPage.Render(r.Context(), w)
if err != nil { if err != nil {
server.LogError(HWSError{Message: "Failed to render error page", Error: err}) s.LogError(HWSError{Message: "Failed to render error page", Error: err})
} }
} else { } else {
server.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG}) s.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
} }
return nil return nil
} }
func (server *Server) ThrowFatal(w http.ResponseWriter, err error) {
w.WriteHeader(http.StatusInternalServerError)
server.LogFatal(err)
}

View File

@@ -14,22 +14,26 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type goodPage struct{} type (
type badPage struct{} goodPage struct{}
badPage struct{}
)
func goodRender(error hws.HWSError) (hws.ErrorPage, error) { func goodRender(error hws.HWSError) (hws.ErrorPage, error) {
return goodPage{}, nil return goodPage{}, nil
} }
func badRender1(error hws.HWSError) (hws.ErrorPage, error) { func badRender1(error hws.HWSError) (hws.ErrorPage, error) {
return badPage{}, nil return badPage{}, nil
} }
func badRender2(error hws.HWSError) (hws.ErrorPage, error) { func badRender2(error hws.HWSError) (hws.ErrorPage, error) {
return nil, errors.New("I'm an error") return nil, errors.New("I'm an error")
} }
func (g goodPage) Render(ctx context.Context, w io.Writer) error { func (g goodPage) Render(ctx context.Context, w io.Writer) error {
w.Write([]byte("Test write to ResponseWriter")) _, err := w.Write([]byte("Test write to ResponseWriter"))
return nil return err
} }
func (b badPage) Render(ctx context.Context, w io.Writer) error { func (b badPage) Render(ctx context.Context, w io.Writer) error {
@@ -85,40 +89,42 @@ func Test_ThrowError(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
t.Run("Server not started", func(t *testing.T) { t.Run("Server not started", func(t *testing.T) {
err := server.ThrowError(rr, req, hws.HWSError{ buf.Reset()
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "Error", Message: "Error",
Error: errors.New("Error"), Error: errors.New("Error"),
}) })
assert.Error(t, err) // ThrowError logs errors internally when validation fails
output := buf.String()
assert.Contains(t, output, "ThrowError called before server started")
}) })
startTestServer(t, server) startTestServer(t, server)
defer server.Shutdown(t.Context())
tests := []struct { tests := []struct {
name string name string
request *http.Request request *http.Request
error hws.HWSError error hws.HWSError
valid bool expectLogItem string
}{ }{
{ {
name: "No HWSError.Status code", name: "No HWSError.Status code",
request: nil, request: nil,
error: hws.HWSError{}, error: hws.HWSError{},
valid: false, expectLogItem: "HWSError.StatusCode cannot be 0",
}, },
{ {
name: "Negative HWSError.Status code", name: "Negative HWSError.Status code",
request: nil, request: nil,
error: hws.HWSError{StatusCode: -1}, error: hws.HWSError{StatusCode: -1},
valid: false, expectLogItem: "HWSError.StatusCode cannot be 0",
}, },
{ {
name: "No HWSError.Message", name: "No HWSError.Message",
request: nil, request: nil,
error: hws.HWSError{StatusCode: http.StatusInternalServerError}, error: hws.HWSError{StatusCode: http.StatusInternalServerError},
valid: false, expectLogItem: "HWSError.Message cannot be empty",
}, },
{ {
name: "No HWSError.Error", name: "No HWSError.Error",
@@ -127,7 +133,7 @@ func Test_ThrowError(t *testing.T) {
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
}, },
valid: false, expectLogItem: "HWSError.Error cannot be nil",
}, },
{ {
name: "No request provided", name: "No request provided",
@@ -137,7 +143,7 @@ func Test_ThrowError(t *testing.T) {
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
}, },
valid: false, expectLogItem: "Request cannot be nil",
}, },
{ {
name: "Valid", name: "Valid",
@@ -147,106 +153,92 @@ func Test_ThrowError(t *testing.T) {
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
}, },
valid: true, expectLogItem: "An error occured",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
buf.Reset()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
err := server.ThrowError(rr, tt.request, tt.error) server.ThrowError(rr, tt.request, tt.error)
if tt.valid { // ThrowError no longer returns errors; check logs instead
assert.NoError(t, err) output := buf.String()
} else { assert.Contains(t, output, tt.expectLogItem)
t.Log(err)
assert.Error(t, err)
}
}) })
} }
t.Run("Log level set correctly", func(t *testing.T) { t.Run("Log level set correctly", func(t *testing.T) {
buf.Reset() buf.Reset()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{ server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
Level: hws.ErrorWARN, Level: hws.ErrorWARN,
}) })
assert.NoError(t, err) _, err := buf.ReadString([]byte(" ")[0])
_, err = buf.ReadString([]byte(" ")[0]) require.NoError(t, err)
loglvl, err := buf.ReadString([]byte(" ")[0]) loglvl, err := buf.ReadString([]byte(" ")[0])
assert.NoError(t, err) require.NoError(t, err)
if loglvl != "\x1b[33mWRN\x1b[0m " { assert.Equal(t, "\x1b[33mWRN\x1b[0m ", loglvl, "Log level should be WRN for ErrorWARN")
err = errors.New("Log level not set correctly")
}
assert.NoError(t, err)
buf.Reset() buf.Reset()
err = server.ThrowError(rr, req, hws.HWSError{ server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
}) })
assert.NoError(t, err)
_, err = buf.ReadString([]byte(" ")[0]) _, err = buf.ReadString([]byte(" ")[0])
require.NoError(t, err)
loglvl, err = buf.ReadString([]byte(" ")[0]) loglvl, err = buf.ReadString([]byte(" ")[0])
assert.NoError(t, err) require.NoError(t, err)
if loglvl != "\x1b[31mERR\x1b[0m " { assert.Equal(t, "\x1b[31mERR\x1b[0m ", loglvl, "Log level should be ERR when no level specified")
err = errors.New("Log level not set correctly")
}
assert.NoError(t, err)
}) })
t.Run("Error page doesnt render if no error page set", func(t *testing.T) { t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
// Must be run before adding the error page to the test server // Must be run before adding the error page to the test server
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{ server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
RenderErrorPage: true, RenderErrorPage: true,
}) })
assert.NoError(t, err)
body := rr.Body.String() body := rr.Body.String()
if body != "" { assert.Empty(t, body, "Error page should not render when no error page is set")
assert.Error(t, nil)
}
}) })
t.Run("Error page renders", func(t *testing.T) { t.Run("Error page renders", func(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
// Adding the error page will carry over to all future tests and cant be undone // Adding the error page will carry over to all future tests and cant be undone
server.AddErrorPage(goodRender) err := server.AddErrorPage(goodRender)
err := server.ThrowError(rr, req, hws.HWSError{ require.NoError(t, err)
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
RenderErrorPage: true, RenderErrorPage: true,
}) })
assert.NoError(t, err)
body := rr.Body.String() body := rr.Body.String()
if body == "" { assert.NotEmpty(t, body, "Error page should render when RenderErrorPage is true")
assert.Error(t, nil)
}
}) })
t.Run("Error page doesnt render if no told to render", func(t *testing.T) { t.Run("Error page doesnt render if not told to render", func(t *testing.T) {
// Error page already added to server // Error page already added to server
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{ server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
}) })
assert.NoError(t, err)
body := rr.Body.String() body := rr.Body.String()
if body != "" { assert.Empty(t, body, "Error page should not render when RenderErrorPage is false")
assert.Error(t, nil)
}
}) })
server.Shutdown(t.Context()) err := server.Shutdown(t.Context())
require.NoError(t, err)
t.Run("Doesn't error if no logger added to server", func(t *testing.T) { t.Run("Doesn't panic if no logger added to server", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{ server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: randomPort(), Port: randomPort(),
@@ -261,13 +253,18 @@ func Test_ThrowError(t *testing.T) {
err = server.Start(t.Context()) err = server.Start(t.Context())
require.NoError(t, err) require.NoError(t, err)
<-server.Ready() <-server.Ready()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
err = server.ThrowError(rr, req, hws.HWSError{ // Should not panic when no logger is present
StatusCode: http.StatusInternalServerError, assert.NotPanics(t, func() {
Message: "An error occured", server.ThrowError(rr, req, hws.HWSError{
Error: errors.New("Error"), StatusCode: http.StatusInternalServerError,
}) Message: "An error occured",
assert.NoError(t, err) Error: errors.New("Error"),
})
}, "ThrowError should not panic when no logger is present")
err = server.Shutdown(t.Context())
require.NoError(t, err)
}) })
} }

View File

@@ -1,35 +1,9 @@
package hws package hws
import "runtime" import "git.haelnorr.com/h/golib/ezconf"
// EZConfIntegration provides integration with ezconf for automatic configuration // NewEZConfIntegration creates a new EZConf integration
type EZConfIntegration struct{} func NewEZConfIntegration() *ezconf.Integration {
return ezconf.NewIntegration("hws", "HWS",
// PackagePath returns the path to the hws package for source parsing &Config{}, func() (any, error) { return ConfigFromEnv() })
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
return ConfigFromEnv()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return "hws"
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return "HWS"
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{}
} }

View File

@@ -4,22 +4,24 @@ go 1.25.5
require ( require (
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/hlog v0.9.0 git.haelnorr.com/h/golib/hlog v0.11.0
git.haelnorr.com/h/golib/notify v0.1.0 git.haelnorr.com/h/golib/notify v0.1.0
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
k8s.io/apimachinery v0.35.0 k8s.io/apimachinery v0.35.0
) )
require git.haelnorr.com/h/golib/ezconf v0.2.1
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/gobwas/glob v0.2.3 github.com/gobwas/glob v0.2.3
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rs/zerolog v1.34.0 // indirect github.com/rs/zerolog v1.34.0 // indirect
golang.org/x/sys v0.12.0 // indirect golang.org/x/sys v0.41.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect

View File

@@ -1,7 +1,9 @@
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY= git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/hlog v0.9.0 h1:ib8n2MdmiRK2TF067p220kXmhDe9aAnlcsgpuv+QpvE= git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
git.haelnorr.com/h/golib/hlog v0.9.0/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk= git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
git.haelnorr.com/h/golib/hlog v0.11.0 h1:tCT8HWs51Nbin58sCTLcq5re6CZqo5/IHCzk3G+S3vQ=
git.haelnorr.com/h/golib/hlog v0.11.0/go.mod h1:HjhXS5G3A0BwOZq7nu2qpNBtvOFiCa1GbAuBRxAkYqs=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10= git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc= git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
@@ -12,11 +14,13 @@ github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -28,8 +32,9 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -17,6 +17,10 @@ import (
func Test_GZIP_Compression(t *testing.T) { func Test_GZIP_Compression(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
dbg, _ := hlog.LogLevel("debug")
logcfg := &hlog.Config{
LogLevel: dbg,
}
t.Run("GZIP enabled compresses response", func(t *testing.T) { t.Run("GZIP enabled compresses response", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{ server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
@@ -25,7 +29,7 @@ func Test_GZIP_Compression(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "") logger, err := hlog.NewLogger(logcfg, &buf)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)
@@ -80,7 +84,7 @@ func Test_GZIP_Compression(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "") logger, err := hlog.NewLogger(logcfg, &buf)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)
@@ -131,7 +135,7 @@ func Test_GZIP_Compression(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "") logger, err := hlog.NewLogger(logcfg, &buf)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)
@@ -179,20 +183,20 @@ func Test_GzipResponseWriter(t *testing.T) {
t.Run("Can write through gzip writer", func(t *testing.T) { t.Run("Can write through gzip writer", func(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
gzWriter := gzip.NewWriter(&buf) gzWriter := gzip.NewWriter(&buf)
testData := []byte("Test data to compress") testData := []byte("Test data to compress")
n, err := gzWriter.Write(testData) n, err := gzWriter.Write(testData)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(testData), n) assert.Equal(t, len(testData), n)
err = gzWriter.Close() err = gzWriter.Close()
require.NoError(t, err) require.NoError(t, err)
// Decompress and verify // Decompress and verify
gzReader, err := gzip.NewReader(&buf) gzReader, err := gzip.NewReader(&buf)
require.NoError(t, err) require.NoError(t, err)
defer gzReader.Close() defer gzReader.Close()
decompressed, err := io.ReadAll(gzReader) decompressed, err := io.ReadAll(gzReader)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, testData, decompressed) assert.Equal(t, testData, decompressed)
@@ -215,9 +219,9 @@ func Test_GzipResponseWriter(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req) wrapped.ServeHTTP(rr, req)
// Note: This is a simplified test // Note: This is a simplified test
}) })
} }

View File

@@ -43,23 +43,12 @@ func (s *Server) LogError(err HWSError) {
} }
} }
func (server *Server) LogFatal(err error) {
if err == nil {
err = errors.New("LogFatal was called with a nil error")
}
if server.logger == nil {
fmt.Printf("FATAL - %s: %s", "A fatal error has occured", err.Error())
return
}
server.logger.logger.Fatal().Err(err).Msg("A fatal error has occured")
}
// AddLogger adds a logger to the server to use for request logging. // AddLogger adds a logger to the server to use for request logging.
func (server *Server) AddLogger(hlogger *hlog.Logger) error { func (s *Server) AddLogger(hlogger *hlog.Logger) error {
if hlogger == nil { if hlogger == nil {
return errors.New("unable to add logger, no logger provided") return errors.New("unable to add logger, no logger provided")
} }
server.logger = &logger{ s.logger = &logger{
logger: hlogger, logger: hlogger,
} }
return nil return nil
@@ -68,7 +57,7 @@ func (server *Server) AddLogger(hlogger *hlog.Logger) error {
// LoggerIgnorePaths sets a list of URL paths to ignore logging for. // LoggerIgnorePaths sets a list of URL paths to ignore logging for.
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL // Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
// Useful for ignoring requests to CSS files or favicons // Useful for ignoring requests to CSS files or favicons
func (server *Server) LoggerIgnorePaths(paths ...string) error { func (s *Server) LoggerIgnorePaths(paths ...string) error {
for _, path := range paths { for _, path := range paths {
u, err := url.Parse(path) u, err := url.Parse(path)
valid := err == nil && valid := err == nil &&
@@ -80,7 +69,7 @@ func (server *Server) LoggerIgnorePaths(paths ...string) error {
return fmt.Errorf("invalid path: '%s'", path) return fmt.Errorf("invalid path: '%s'", path)
} }
} }
server.logger.ignoredPaths = prepareGlobs(paths) s.logger.ignoredPaths = prepareGlobs(paths)
return nil return nil
} }

View File

@@ -25,6 +25,10 @@ func Test_AddLogger(t *testing.T) {
} }
func Test_LogError_AllLevels(t *testing.T) { func Test_LogError_AllLevels(t *testing.T) {
dbg, _ := hlog.LogLevel("debug")
logcfg := &hlog.Config{
LogLevel: dbg,
}
t.Run("DEBUG level", func(t *testing.T) { t.Run("DEBUG level", func(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
// Create server with logger explicitly set to Debug level // Create server with logger explicitly set to Debug level
@@ -34,7 +38,7 @@ func Test_LogError_AllLevels(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("debug"), &buf, nil, "") logger, err := hlog.NewLogger(logcfg, &buf)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)
@@ -169,7 +173,7 @@ func Test_LogFatal(t *testing.T) {
// Note: We cannot actually test Fatal() as it calls os.Exit() // Note: We cannot actually test Fatal() as it calls os.Exit()
// Testing this would require subprocess testing which is overly complex // Testing this would require subprocess testing which is overly complex
// These tests document the expected behavior and verify the function signatures exist // These tests document the expected behavior and verify the function signatures exist
t.Run("LogFatal with nil logger prints to stdout", func(t *testing.T) { t.Run("LogFatal with nil logger prints to stdout", func(t *testing.T) {
_, err := hws.NewServer(&hws.Config{ _, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
@@ -197,7 +201,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("http://example.com/path") err := server.LoggerIgnorePaths("http://example.com/path")
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path") assert.Contains(t, err.Error(), "invalid path")
}) })
t.Run("Invalid path with host", func(t *testing.T) { t.Run("Invalid path with host", func(t *testing.T) {
@@ -207,7 +211,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("//example.com/path") err := server.LoggerIgnorePaths("//example.com/path")
assert.Error(t, err) assert.Error(t, err)
if err != nil { if err != nil {
assert.Contains(t, err.Error(), "Invalid path") assert.Contains(t, err.Error(), "invalid path")
} }
}) })
@@ -217,7 +221,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("/path?query=value") err := server.LoggerIgnorePaths("/path?query=value")
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path") assert.Contains(t, err.Error(), "invalid path")
}) })
t.Run("Invalid path with fragment", func(t *testing.T) { t.Run("Invalid path with fragment", func(t *testing.T) {
@@ -226,7 +230,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("/path#fragment") err := server.LoggerIgnorePaths("/path#fragment")
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path") assert.Contains(t, err.Error(), "invalid path")
}) })
t.Run("Valid paths", func(t *testing.T) { t.Run("Valid paths", func(t *testing.T) {

View File

@@ -5,35 +5,37 @@ import (
"net/http" "net/http"
) )
type Middleware func(h http.Handler) http.Handler type (
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError) Middleware func(h http.Handler) http.Handler
MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
)
// Server.AddMiddleware registers all the middleware. // AddMiddleware registers all the middleware.
// Middleware will be run in the order that they are provided. // Middleware will be run in the order that they are provided.
// Can only be called once // Can only be called once
func (server *Server) AddMiddleware(middleware ...Middleware) error { func (s *Server) AddMiddleware(middleware ...Middleware) error {
if !server.routes { if !s.routes {
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware") return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
} }
if server.middleware { if s.middleware {
return errors.New("Server.AddMiddleware already called") return errors.New("Server.AddMiddleware already called")
} }
// RUN LOGGING MIDDLEWARE FIRST // RUN LOGGING MIDDLEWARE FIRST
server.server.Handler = logging(server.server.Handler, server.logger) s.server.Handler = logging(s.server.Handler, s.logger)
// LOOP PROVIDED MIDDLEWARE IN REVERSE order // LOOP PROVIDED MIDDLEWARE IN REVERSE order
for i := len(middleware); i > 0; i-- { for i := len(middleware); i > 0; i-- {
server.server.Handler = middleware[i-1](server.server.Handler) s.server.Handler = middleware[i-1](s.server.Handler)
} }
// RUN GZIP // RUN GZIP
if server.GZIP { if s.GZIP {
server.server.Handler = addgzip(server.server.Handler) s.server.Handler = addgzip(s.server.Handler)
} }
// RUN TIMER MIDDLEWARE LAST // RUN TIMER MIDDLEWARE LAST
server.server.Handler = startTimer(server.server.Handler) s.server.Handler = startTimer(s.server.Handler)
server.middleware = true s.middleware = true
return nil return nil
} }
@@ -43,14 +45,14 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error {
// and returns a new request and optional HWSError. // and returns a new request and optional HWSError.
// If a HWSError is returned, server.ThrowError will be called. // If a HWSError is returned, server.ThrowError will be called.
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered // If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
func (server *Server) NewMiddleware( func (s *Server) NewMiddleware(
middlewareFunc MiddlewareFunc, middlewareFunc MiddlewareFunc,
) Middleware { ) Middleware {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
newReq, herr := middlewareFunc(w, r) newReq, herr := middlewareFunc(w, r)
if herr != nil { if herr != nil {
server.ThrowError(w, r, *herr) s.ThrowError(w, r, *herr)
if herr.RenderErrorPage { if herr.RenderErrorPage {
return return
} }

View File

@@ -18,16 +18,24 @@ func startTimer(next http.Handler) http.Handler {
) )
} }
type contextKey string
func (c contextKey) String() string {
return "hws context key " + string(c)
}
var requestTimerCtxKey = contextKey("request-timer")
// Set the start time of the request // Set the start time of the request
func setStart(ctx context.Context, time time.Time) context.Context { func setStart(ctx context.Context, time time.Time) context.Context {
return context.WithValue(ctx, "hws context key request-timer", time) return context.WithValue(ctx, requestTimerCtxKey, time)
} }
// Get the start time of the request // Get the start time of the request
func getStartTime(ctx context.Context) (time.Time, error) { func getStartTime(ctx context.Context) (time.Time, error) {
start, ok := ctx.Value("hws context key request-timer").(time.Time) start, ok := ctx.Value(requestTimerCtxKey).(time.Time)
if !ok { if !ok {
return time.Time{}, errors.New("Failed to get start time of request") return time.Time{}, errors.New("failed to get start time of request")
} }
return start, nil return start, nil
} }

View File

@@ -100,7 +100,7 @@ func (s *Server) NotifySub(nt notify.Notification) {
} }
_, exists := s.notifier.clients.getClient(nt.Target) _, exists := s.notifier.clients.getClient(nt.Target)
if !exists { if !exists {
err := fmt.Errorf("Tried to notify subscriber that doesn't exist - subID: %s", nt.Target) err := fmt.Errorf("tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err}) s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
return return
} }
@@ -119,7 +119,7 @@ func (s *Server) NotifyID(nt notify.Notification, altID string) {
clients, exists := s.notifier.clients.clientsIDMap[altID] clients, exists := s.notifier.clients.clientsIDMap[altID]
s.notifier.clients.lock.RUnlock() s.notifier.clients.lock.RUnlock()
if !exists { if !exists {
err := fmt.Errorf("Tried to notify client group that doesn't exist - altID: %s", altID) err := fmt.Errorf("tried to notify client group that doesn't exist - altID: %s", altID)
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err}) s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
return return
} }

View File

@@ -15,8 +15,9 @@ func newTestServerWithNotifier(t *testing.T) *Server {
t.Helper() t.Helper()
cfg := &Config{ cfg := &Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 0, Port: 0,
ShutdownDelay: 0, // No delay for tests
} }
server, err := NewServer(cfg) server, err := NewServer(cfg)
@@ -359,7 +360,7 @@ func Test_ActiveClientStaysAlive(t *testing.T) {
done := make(chan bool) done := make(chan bool)
go func() { go func() {
for i := 0; i < 3; i++ { for range 3 {
<-ticker.C <-ticker.C
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
@@ -460,7 +461,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
defer close(stop) defer close(stop)
// Send 10 notifications quickly (buffer is 10) // Send 10 notifications quickly (buffer is 10)
for i := 0; i < 10; i++ { for range 10 {
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
Message: "Burst message", Message: "Burst message",
@@ -468,7 +469,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
} }
// Client should receive all 10 // Client should receive all 10
for i := 0; i < 10; i++ { for i := range 10 {
select { select {
case <-notifications: case <-notifications:
// Received // Received
@@ -487,7 +488,7 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
defer close(stop) defer close(stop)
// Fill buffer completely (buffer is 10) // Fill buffer completely (buffer is 10)
for i := 0; i < 10; i++ { for range 10 {
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
Message: "Fill buffer", Message: "Fill buffer",
@@ -500,15 +501,15 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
Message: "Timeout message", Message: "Timeout message",
}) })
// Wait for timeout // Wait for timeout (5s timeout + small buffer)
time.Sleep(6 * time.Second) time.Sleep(5100 * time.Millisecond)
// Check failure count (should be 1) // Check failure count (should be 1)
fails := atomic.LoadInt32(&client.consecutiveFails) fails := atomic.LoadInt32(&client.consecutiveFails)
require.Equal(t, int32(1), fails, "Should have 1 timeout") require.Equal(t, int32(1), fails, "Should have 1 timeout")
// Now read all buffered messages // Now read all buffered messages
for i := 0; i < 10; i++ { for range 10 {
<-notifications <-notifications
} }
@@ -538,7 +539,7 @@ func Test_ConsecutiveFailureDisconnect(t *testing.T) {
defer close(stop) defer close(stop)
// Fill buffer and never read to cause 5 consecutive timeouts // Fill buffer and never read to cause 5 consecutive timeouts
for i := 0; i < 20; i++ { for range 20 {
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
Message: "Timeout message", Message: "Timeout message",
@@ -684,7 +685,7 @@ func Test_ConcurrentSubscriptions(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
clients := make([]*Client, 100) clients := make([]*Client, 100)
for i := 0; i < 100; i++ { for i := range 100 {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int) {
defer wg.Done() defer wg.Done()
@@ -716,7 +717,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
messageCount := 50 messageCount := 50
// Send from multiple goroutines // Send from multiple goroutines
for i := 0; i < messageCount; i++ { for i := range messageCount {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int) {
defer wg.Done() defer wg.Done()
@@ -733,7 +734,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
// This is expected behavior - we're testing thread safety, not guaranteed delivery // This is expected behavior - we're testing thread safety, not guaranteed delivery
// Just verify we receive at least some messages without panicking or deadlocking // Just verify we receive at least some messages without panicking or deadlocking
received := 0 received := 0
timeout := time.After(2 * time.Second) timeout := time.After(500 * time.Millisecond)
for received < messageCount { for received < messageCount {
select { select {
case <-notifications: case <-notifications:
@@ -751,7 +752,7 @@ func Test_ConcurrentCleanup(t *testing.T) {
server := newTestServerWithNotifier(t) server := newTestServerWithNotifier(t)
// Create some clients // Create some clients
for i := 0; i < 10; i++ { for i := range 10 {
client, _ := server.GetClient("", "") client, _ := server.GetClient("", "")
// Set some to be old // Set some to be old
if i%2 == 0 { if i%2 == 0 {
@@ -790,39 +791,34 @@ func Test_NoRaceConditions(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
// Create a few clients and read from them // Create a few clients and read from them
for i := 0; i < 5; i++ { for range 5 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
client, _ := server.GetClient("", "") client, _ := server.GetClient("", "")
notifications, stop := client.Listen() notifications, stop := client.Listen()
defer close(stop) defer close(stop)
// Actively read messages // Actively read messages
timeout := time.After(2 * time.Second) timeout := time.After(200 * time.Millisecond)
for { for {
select { select {
case <-notifications: case <-notifications:
// Keep reading // Keep reading
case <-timeout: case <-timeout:
return return
} }
} }
}() })
} }
// Send a few notifications // Send a few notifications
wg.Add(1) wg.Go(func() {
go func() { for range 10 {
defer wg.Done()
for j := 0; j < 20; j++ {
server.NotifyAll(notify.Notification{ server.NotifyAll(notify.Notification{
Message: "Stress test", Message: "Stress test",
}) })
time.Sleep(50 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}() })
wg.Wait() wg.Wait()
} }
@@ -948,7 +944,7 @@ func Test_ListenSignature(t *testing.T) {
require.NotNil(t, stop) require.NotNil(t, stop)
// notifications should be receive-only // notifications should be receive-only
_, ok := interface{}(notifications).(<-chan notify.Notification) _, ok := any(notifications).(<-chan notify.Notification)
require.True(t, ok, "notifications should be receive-only channel") require.True(t, ok, "notifications should be receive-only channel")
// stop should be closeable // stop should be closeable
@@ -964,7 +960,7 @@ func Test_BufferSize(t *testing.T) {
defer close(stop) defer close(stop)
// Send 10 messages without reading (buffer size is 10) // Send 10 messages without reading (buffer size is 10)
for i := 0; i < 10; i++ { for range 10 {
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
Message: "Buffered", Message: "Buffered",
@@ -975,7 +971,7 @@ func Test_BufferSize(t *testing.T) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
// Read all 10 // Read all 10
for i := 0; i < 10; i++ { for i := range 10 {
select { select {
case <-notifications: case <-notifications:
// Success // Success

View File

@@ -30,13 +30,13 @@ const (
MethodPATCH Method = "PATCH" MethodPATCH Method = "PATCH"
) )
// Server.AddRoutes registers the page handlers for the server. // AddRoutes registers the page handlers for the server.
// At least one route must be provided. // At least one route must be provided.
// If any route patterns (path + method) are defined multiple times, the first // If any route patterns (path + method) are defined multiple times, the first
// instance will be added and any additional conflicts will be discarded. // instance will be added and any additional conflicts will be discarded.
func (server *Server) AddRoutes(routes ...Route) error { func (s *Server) AddRoutes(routes ...Route) error {
if len(routes) == 0 { if len(routes) == 0 {
return errors.New("No routes provided") return errors.New("no routes provided")
} }
patterns := []string{} patterns := []string{}
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -47,10 +47,10 @@ func (server *Server) AddRoutes(routes ...Route) error {
} }
for _, method := range route.Methods { for _, method := range route.Methods {
if !validMethod(method) { if !validMethod(method) {
return fmt.Errorf("Invalid method %s for path %s", method, route.Path) return fmt.Errorf("invalid method %s for path %s", method, route.Path)
} }
if route.Handler == nil { if route.Handler == nil {
return fmt.Errorf("No handler provided for %s %s", method, route.Path) return fmt.Errorf("no handler provided for %s %s", method, route.Path)
} }
pattern := fmt.Sprintf("%s %s", method, route.Path) pattern := fmt.Sprintf("%s %s", method, route.Path)
if slices.Contains(patterns, pattern) { if slices.Contains(patterns, pattern) {
@@ -61,8 +61,8 @@ func (server *Server) AddRoutes(routes ...Route) error {
} }
} }
server.server.Handler = mux s.server.Handler = mux
server.routes = true s.routes = true
return nil return nil
} }

View File

@@ -18,7 +18,7 @@ func Test_AddRoutes(t *testing.T) {
server := createTestServer(t, &buf) server := createTestServer(t, &buf)
err := server.AddRoutes() err := server.AddRoutes()
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "No routes provided") assert.Contains(t, err.Error(), "no routes provided")
}) })
t.Run("Single valid route", func(t *testing.T) { t.Run("Single valid route", func(t *testing.T) {
@@ -58,7 +58,7 @@ func Test_AddRoutes(t *testing.T) {
Handler: handler, Handler: handler,
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid method") assert.Contains(t, err.Error(), "invalid method")
}) })
t.Run("No handler provided", func(t *testing.T) { t.Run("No handler provided", func(t *testing.T) {
@@ -69,7 +69,7 @@ func Test_AddRoutes(t *testing.T) {
Handler: nil, Handler: nil,
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "No handler provided") assert.Contains(t, err.Error(), "no handler provided")
}) })
t.Run("All HTTP methods are valid", func(t *testing.T) { t.Run("All HTTP methods are valid", func(t *testing.T) {
@@ -203,7 +203,7 @@ func Test_AddRoutes_MultipleMethods(t *testing.T) {
Handler: handler, Handler: handler,
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid method") assert.Contains(t, err.Error(), "invalid method")
}) })
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) { t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {

View File

@@ -26,14 +26,14 @@ type Server struct {
} }
// Ready returns a channel that is closed when the server is started // Ready returns a channel that is closed when the server is started
func (server *Server) Ready() <-chan struct{} { func (s *Server) Ready() <-chan struct{} {
return server.ready return s.ready
} }
// IsReady checks if the server is running // IsReady checks if the server is running
func (server *Server) IsReady() bool { func (s *Server) IsReady() bool {
select { select {
case <-server.ready: case <-s.ready:
return true return true
default: default:
return false return false
@@ -41,13 +41,13 @@ func (server *Server) IsReady() bool {
} }
// Addr returns the server's network address // Addr returns the server's network address
func (server *Server) Addr() string { func (s *Server) Addr() string {
return server.server.Addr return s.server.Addr
} }
// Handler returns the server's HTTP handler for testing purposes // Handler returns the server's HTTP handler for testing purposes
func (server *Server) Handler() http.Handler { func (s *Server) Handler() http.Handler {
return server.server.Handler return s.server.Handler
} }
// NewServer returns a new hws.Server with the specified configuration. // NewServer returns a new hws.Server with the specified configuration.
@@ -75,7 +75,7 @@ func NewServer(config *Config) (*Server, error) {
valid := isValidHostname(config.Host) valid := isValidHostname(config.Host)
if !valid { if !valid {
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host) return nil, fmt.Errorf("hostname '%s' is not valid", config.Host)
} }
httpServer := &http.Server{ httpServer := &http.Server{
@@ -95,62 +95,64 @@ func NewServer(config *Config) (*Server, error) {
return server, nil return server, nil
} }
func (server *Server) Start(ctx context.Context) error { func (s *Server) Start(ctx context.Context) error {
if ctx == nil { if ctx == nil {
return errors.New("Context cannot be nil") return errors.New("Context cannot be nil")
} }
if !server.routes { if !s.routes {
return errors.New("Server.AddRoutes must be run before starting the server") return errors.New("Server.AddRoutes must be run before starting the server")
} }
if !server.middleware { if !s.middleware {
err := server.AddMiddleware() err := s.AddMiddleware()
if err != nil { if err != nil {
return errors.Wrap(err, "server.AddMiddleware") return errors.Wrap(err, "server.AddMiddleware")
} }
} }
server.startNotifier() s.startNotifier()
go func() { go func() {
if server.logger == nil { if s.logger == nil {
fmt.Printf("Listening for requests on %s", server.server.Addr) fmt.Printf("Listening for requests on %s", s.server.Addr)
} else { } else {
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests") s.logger.logger.Info().Str("address", s.server.Addr).Msg("Listening for requests")
} }
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if server.logger == nil { if s.logger == nil {
fmt.Printf("Server encountered a fatal error: %s", err.Error()) fmt.Printf("Server encountered a fatal error: %s", err.Error())
} else { } else {
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"}) s.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
} }
} }
}() }()
server.waitUntilReady(ctx) s.waitUntilReady(ctx)
return nil return nil
} }
func (server *Server) Shutdown(ctx context.Context) error { func (s *Server) Shutdown(ctx context.Context) error {
server.logger.logger.Debug().Dur("shutdown_delay", server.shutdowndelay).Msg("HWS Server shutting down") if s.logger != nil {
server.NotifyAll(notify.Notification{ s.logger.logger.Debug().Dur("shutdown_delay", s.shutdowndelay).Msg("HWS Server shutting down")
}
s.NotifyAll(notify.Notification{
Title: "Shutting down", Title: "Shutting down",
Message: fmt.Sprintf("Server is shutting down in %v", server.shutdowndelay), Message: fmt.Sprintf("Server is shutting down in %v", s.shutdowndelay),
Level: LevelShutdown, Level: LevelShutdown,
}) })
<-time.NewTimer(server.shutdowndelay).C <-time.NewTimer(s.shutdowndelay).C
if !server.IsReady() { if !s.IsReady() {
return errors.New("Server isn't running") return errors.New("Server isn't running")
} }
if ctx == nil { if ctx == nil {
return errors.New("Context cannot be nil") return errors.New("Context cannot be nil")
} }
err := server.server.Shutdown(ctx) err := s.server.Shutdown(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to shutdown the server gracefully") return errors.Wrap(err, "Failed to shutdown the server gracefully")
} }
server.closeNotifier() s.closeNotifier()
server.ready = make(chan struct{}) s.ready = make(chan struct{})
return nil return nil
} }
@@ -168,7 +170,7 @@ func isValidHostname(host string) bool {
return false return false
} }
func (server *Server) waitUntilReady(ctx context.Context) error { func (s *Server) waitUntilReady(ctx context.Context) error {
ticker := time.NewTicker(50 * time.Millisecond) ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
@@ -180,14 +182,14 @@ func (server *Server) waitUntilReady(ctx context.Context) error {
return ctx.Err() return ctx.Err()
case <-ticker.C: case <-ticker.C:
resp, err := http.Get("http://" + server.server.Addr + "/healthz") resp, err := http.Get("http://" + s.server.Addr + "/healthz")
if err != nil { if err != nil {
continue // not accepting yet continue // not accepting yet
} }
resp.Body.Close() resp.Body.Close()
if resp.StatusCode == http.StatusOK { if resp.StatusCode == http.StatusOK {
closeOnce.Do(func() { close(server.ready) }) closeOnce.Do(func() { close(s.ready) })
return nil return nil
} }
} }

View File

@@ -26,12 +26,17 @@ func randomPort() uint64 {
func createTestServer(t *testing.T, w io.Writer) *hws.Server { func createTestServer(t *testing.T, w io.Writer) *hws.Server {
server, err := hws.NewServer(&hws.Config{ server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: randomPort(), Port: randomPort(),
ShutdownDelay: 0, // No delay for tests
}) })
require.NoError(t, err) require.NoError(t, err)
dbg, _ := hlog.LogLevel("debug")
logcfg := &hlog.Config{
LogLevel: dbg,
}
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), w, nil, "") logger, err := hlog.NewLogger(logcfg, w)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)
@@ -227,5 +232,4 @@ func Test_NewServer(t *testing.T) {
} }
}) })
} }
} }

View File

@@ -9,6 +9,7 @@ import (
"git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/jwt" "git.haelnorr.com/h/golib/jwt"
"github.com/gobwas/glob"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -16,7 +17,7 @@ type Authenticator[T Model, TX DBTransaction] struct {
tokenGenerator *jwt.TokenGenerator tokenGenerator *jwt.TokenGenerator
load LoadFunc[T, TX] load LoadFunc[T, TX]
beginTx BeginTX beginTx BeginTX
ignoredPaths []string ignoredPaths []glob.Glob
logger *hlog.Logger logger *hlog.Logger
server *hws.Server server *hws.Server
errorPage hws.ErrorPageFunc errorPage hws.ErrorPageFunc

View File

@@ -9,16 +9,16 @@ import (
// Config holds the configuration settings for the authenticator. // Config holds the configuration settings for the authenticator.
// All time-based settings are in minutes. // All time-based settings are in minutes.
type Config struct { type Config struct {
SSL bool // ENV HWSAUTH_SSL: Enable SSL secure cookies (default: false) SSL bool `ezconf:"HWSAUTH_SSL,description:Enable SSL secure cookies,default:false"`
TrustedHost string // ENV HWSAUTH_TRUSTED_HOST: Full server address for SSL (required if SSL is true) TrustedHost string `ezconf:"HWSAUTH_TRUSTED_HOST,description:Full server address for SSL,required:if SSL is true"`
SecretKey string // ENV HWSAUTH_SECRET_KEY: Secret key for signing JWT tokens (required) SecretKey string `ezconf:"HWSAUTH_SECRET_KEY,description:Secret key for signing JWT tokens,required"`
AccessTokenExpiry int64 // ENV HWSAUTH_ACCESS_TOKEN_EXPIRY: Access token expiry in minutes (default: 5) AccessTokenExpiry int64 `ezconf:"HWSAUTH_ACCESS_TOKEN_EXPIRY,description:Access token expiry in minutes,default:5"`
RefreshTokenExpiry int64 // ENV HWSAUTH_REFRESH_TOKEN_EXPIRY: Refresh token expiry in minutes (default: 1440) RefreshTokenExpiry int64 `ezconf:"HWSAUTH_REFRESH_TOKEN_EXPIRY,description:Refresh token expiry in minutes,default:1440"`
TokenFreshTime int64 // ENV HWSAUTH_TOKEN_FRESH_TIME: Token fresh time in minutes (default: 5) TokenFreshTime int64 `ezconf:"HWSAUTH_TOKEN_FRESH_TIME,description:Token fresh time in minutes,default:5"`
LandingPage string // ENV HWSAUTH_LANDING_PAGE: Redirect destination for authenticated users (default: "/profile") LandingPage string `ezconf:"HWSAUTH_LANDING_PAGE,description:Redirect destination for authenticated users,default:/profile"`
DatabaseType string // ENV HWSAUTH_DATABASE_TYPE: Database type (postgres, mysql, sqlite, mariadb) (default: "postgres") DatabaseType string `ezconf:"HWSAUTH_DATABASE_TYPE,description:Database type (postgres mysql sqlite mariadb),default:postgres"`
DatabaseVersion string // ENV HWSAUTH_DATABASE_VERSION: Database version string (default: "15") DatabaseVersion string `ezconf:"HWSAUTH_DATABASE_VERSION,description:Database version string,default:15"`
JWTTableName string // ENV HWSAUTH_JWT_TABLE_NAME: Custom JWT blacklist table name (default: "jwtblacklist") JWTTableName string `ezconf:"HWSAUTH_JWT_TABLE_NAME,description:Custom JWT blacklist table name,default:jwtblacklist"`
} }
// ConfigFromEnv loads configuration from environment variables. // ConfigFromEnv loads configuration from environment variables.

View File

@@ -1,35 +1,9 @@
package hwsauth package hwsauth
import "runtime" import "git.haelnorr.com/h/golib/ezconf"
// EZConfIntegration provides integration with ezconf for automatic configuration
type EZConfIntegration struct{}
// PackagePath returns the path to the hwsauth package for source parsing
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
return ConfigFromEnv()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return "hwsauth"
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return "HWSAuth"
}
// NewEZConfIntegration creates a new EZConf integration helper // NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration { func NewEZConfIntegration() *ezconf.Integration {
return EZConfIntegration{} return ezconf.NewIntegration("hwsauth", "HWSAuth", &Config{},
func() (any, error) { return ConfigFromEnv() })
} }

View File

@@ -5,24 +5,28 @@ go 1.25.5
require ( require (
git.haelnorr.com/h/golib/cookies v0.9.0 git.haelnorr.com/h/golib/cookies v0.9.0
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/hlog v0.10.4 git.haelnorr.com/h/golib/ezconf v0.2.1
git.haelnorr.com/h/golib/hws v0.3.0 git.haelnorr.com/h/golib/hlog v0.11.0
git.haelnorr.com/h/golib/hws v0.5.0
git.haelnorr.com/h/golib/jwt v0.10.1 git.haelnorr.com/h/golib/jwt v0.10.1
github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
) )
require git.haelnorr.com/h/golib/notify v0.1.0 // indirect
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/gobwas/glob v0.2.3
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rs/zerolog v1.34.0 // indirect github.com/rs/zerolog v1.34.0 // indirect
golang.org/x/sys v0.40.0 // indirect golang.org/x/sys v0.41.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/apimachinery v0.35.0 // indirect k8s.io/apimachinery v0.35.0 // indirect
k8s.io/klog/v2 v2.130.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect

View File

@@ -2,12 +2,16 @@ git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDO
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo= git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY= git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo= git.haelnorr.com/h/golib/hlog v0.11.0 h1:tCT8HWs51Nbin58sCTLcq5re6CZqo5/IHCzk3G+S3vQ=
git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= git.haelnorr.com/h/golib/hlog v0.11.0/go.mod h1:HjhXS5G3A0BwOZq7nu2qpNBtvOFiCa1GbAuBRxAkYqs=
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
@@ -15,6 +19,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
@@ -40,8 +46,8 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -23,8 +23,7 @@ func (tm TestModel) GetID() int {
return tm.ID return tm.ID
} }
type TestTransaction struct { type TestTransaction struct{}
}
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) { func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
return nil, nil return nil, nil
@@ -137,8 +136,10 @@ func TestCurrentModel(t *testing.T) {
func TestConfigFromEnv_MissingSecretKey(t *testing.T) { func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
// Clear environment variables // Clear environment variables
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY") originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
os.Setenv("HWSAUTH_SECRET_KEY", "") _ = os.Setenv("HWSAUTH_SECRET_KEY", "")
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret) defer func() {
_ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
}()
_, err := ConfigFromEnv() _, err := ConfigFromEnv()
assert.Error(t, err) assert.Error(t, err)
@@ -327,7 +328,9 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
db, _, err := createMockDB() db, _, err := createMockDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator( auth, err := NewAuthenticator(
cfg, cfg,
@@ -409,7 +412,9 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
db, _, err := createMockDB() db, _, err := createMockDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator( auth, err := NewAuthenticator(
cfg, cfg,
@@ -454,7 +459,9 @@ func TestLogin_BasicFunctionality(t *testing.T) {
db, _, err := createMockDB() db, _, err := createMockDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator( auth, err := NewAuthenticator(
cfg, cfg,
@@ -476,6 +483,7 @@ func TestLogin_BasicFunctionality(t *testing.T) {
// This test mainly checks that the function doesn't panic and has right call signature // This test mainly checks that the function doesn't panic and has right call signature
// The actual JWT functionality is tested in jwt package itself // The actual JWT functionality is tested in jwt package itself
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
auth.Login(w, r, user, rememberMe) err := auth.Login(w, r, user, rememberMe)
require.NoError(t, err)
}) })
} }

View File

@@ -3,6 +3,8 @@ package hwsauth
import ( import (
"fmt" "fmt"
"net/url" "net/url"
"github.com/gobwas/glob"
) )
// IgnorePaths excludes specified paths from authentication middleware. // IgnorePaths excludes specified paths from authentication middleware.
@@ -22,9 +24,22 @@ func (auth *Authenticator[T, TX]) IgnorePaths(paths ...string) error {
u.RawQuery == "" && u.RawQuery == "" &&
u.Fragment == "" u.Fragment == ""
if !valid { if !valid {
return fmt.Errorf("Invalid path: '%s'", path) return fmt.Errorf("invalid path: '%s'", path)
} }
} }
auth.ignoredPaths = paths auth.ignoredPaths = prepareGlobs(paths)
return nil return nil
} }
func prepareGlobs(paths []string) []glob.Glob {
compiledGlobs := make([]glob.Glob, 0, len(paths))
for _, pattern := range paths {
g, err := glob.Compile(pattern)
if err != nil {
// If pattern fails to compile, skip it
continue
}
compiledGlobs = append(compiledGlobs, g)
}
return compiledGlobs
}

View File

@@ -33,13 +33,17 @@ func (auth *Authenticator[T, TX]) Logout(tx TX, w http.ResponseWriter, r *http.R
if err != nil { if err != nil {
return errors.Wrap(err, "auth.getTokens") return errors.Wrap(err, "auth.getTokens")
} }
err = aT.Revoke(jwt.DBTransaction(tx)) if aT != nil {
if err != nil { err = aT.Revoke(jwt.DBTransaction(tx))
return errors.Wrap(err, "aT.Revoke") if err != nil {
return errors.Wrap(err, "aT.Revoke")
}
} }
err = rT.Revoke(jwt.DBTransaction(tx)) if rT != nil {
if err != nil { err = rT.Revoke(jwt.DBTransaction(tx))
return errors.Wrap(err, "rT.Revoke") if err != nil {
return errors.Wrap(err, "rT.Revoke")
}
} }
cookies.DeleteCookie(w, "access", "/") cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/") cookies.DeleteCookie(w, "refresh", "/")

View File

@@ -3,10 +3,10 @@ package hwsauth
import ( import (
"context" "context"
"net/http" "net/http"
"slices"
"time" "time"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"github.com/gobwas/glob"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -16,14 +16,22 @@ import (
// //
// Example: // Example:
// //
// server.AddMiddleware(auth.Authenticate()) // server.AddMiddleware(auth.Authenticate(nil))
func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware { //
return auth.server.NewMiddleware(auth.authenticate()) // If extraCheck is provided, it will run just before the user is added to the context,
// and the return will determine if the user will be added, or the request passed on
// without the user.
func (auth *Authenticator[T, TX]) Authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.Middleware {
return auth.server.NewMiddleware(auth.authenticate(extraCheck))
} }
func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { func (auth *Authenticator[T, TX]) authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.MiddlewareFunc {
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
if slices.Contains(auth.ignoredPaths, r.URL.Path) { if globTest(r.URL.Path, auth.ignoredPaths) {
return r, nil return r, nil
} }
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
@@ -38,6 +46,9 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
Error: errors.Wrap(err, "auth.beginTx"), Error: errors.Wrap(err, "auth.beginTx"),
} }
} }
defer func() {
_ = tx.Rollback()
}()
// Type assert to TX - safe because user's beginTx should return their TX type // Type assert to TX - safe because user's beginTx should return their TX type
txTyped, ok := tx.(TX) txTyped, ok := tx.(TX)
if !ok { if !ok {
@@ -49,16 +60,50 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
} }
model, err := auth.getAuthenticatedUser(txTyped, w, r) model, err := auth.getAuthenticatedUser(txTyped, w, r)
if err != nil { if err != nil {
tx.Rollback() rberr := tx.Rollback()
if rberr != nil {
return nil, &hws.HWSError{
Message: "Failed rolling back after error",
StatusCode: http.StatusInternalServerError,
Error: errors.Wrap(err, "tx.Rollback"),
}
}
auth.logger.Debug(). auth.logger.Debug().
Str("remote_addr", r.RemoteAddr). Str("remote_addr", r.RemoteAddr).
Err(err). Err(err).
Msg("Failed to authenticate user") Msg("Failed to authenticate user")
return r, nil return r, nil
} }
tx.Commit() var check bool
if extraCheck != nil {
var err *hws.HWSError
check, err = extraCheck(ctx, model.model, txTyped, w, r)
if err != nil {
return nil, err
}
}
err = tx.Commit()
if err != nil {
return nil, &hws.HWSError{
Message: "Failed to commit transaction",
StatusCode: http.StatusInternalServerError,
Error: errors.Wrap(err, "tx.Commit"),
}
}
authContext := setAuthenticatedModel(r.Context(), model) authContext := setAuthenticatedModel(r.Context(), model)
newReq := r.WithContext(authContext) newReq := r.WithContext(authContext)
return newReq, nil if extraCheck == nil || check {
return newReq, nil
}
return r, nil
} }
} }
func globTest(testPath string, globs []glob.Glob) bool {
for _, g := range globs {
if g.Match(testPath) {
return true
}
}
return false
}

View File

@@ -39,9 +39,17 @@ type ContextLoader[T Model] func(ctx context.Context) T
// } // }
type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error) type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error)
type contextKey string
func (c contextKey) String() string {
return "hwsauth context key" + string(c)
}
var authenticatedModelContextKey = contextKey("authenticated-model")
// Return a new context with the user added in // Return a new context with the user added in
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context { func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
return context.WithValue(ctx, "hwsauth context key authenticated-model", m) return context.WithValue(ctx, authenticatedModelContextKey, m)
} }
// Retrieve a user from the given context. Returns nil if not set // Retrieve a user from the given context. Returns nil if not set
@@ -53,7 +61,7 @@ func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[
model = authenticatedModel[T]{} model = authenticatedModel[T]{}
} }
}() }()
model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T]) model, cok := ctx.Value(authenticatedModelContextKey).(authenticatedModel[T])
if !cok { if !cok {
return authenticatedModel[T]{}, false return authenticatedModel[T]{}, false
} }

View File

@@ -19,15 +19,12 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := getAuthorizedModel[T](r.Context()) _, ok := getAuthorizedModel[T](r.Context())
if !ok { if !ok {
err := auth.server.ThrowError(w, r, hws.HWSError{ auth.server.ThrowError(w, r, hws.HWSError{
Error: errors.New("Login required"), Error: errors.New("Login required"),
Message: "Please login to view this page", Message: "Please login to view this page",
StatusCode: http.StatusUnauthorized, StatusCode: http.StatusUnauthorized,
RenderErrorPage: true, RenderErrorPage: true,
}) })
if err != nil {
auth.server.ThrowFatal(w, err)
}
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@@ -66,15 +63,12 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model, ok := getAuthorizedModel[T](r.Context()) model, ok := getAuthorizedModel[T](r.Context())
if !ok { if !ok {
err := auth.server.ThrowError(w, r, hws.HWSError{ auth.server.ThrowError(w, r, hws.HWSError{
Error: errors.New("Login required"), Error: errors.New("Login required"),
Message: "Please login to view this page", Message: "Please login to view this page",
StatusCode: http.StatusUnauthorized, StatusCode: http.StatusUnauthorized,
RenderErrorPage: true, RenderErrorPage: true,
}) })
if err != nil {
auth.server.ThrowFatal(w, err)
}
return return
} }
isFresh := time.Now().Before(time.Unix(model.fresh, 0)) isFresh := time.Now().Before(time.Unix(model.fresh, 0))

View File

@@ -34,7 +34,7 @@ func (auth *Authenticator[T, TX]) RefreshAuthTokens(tx TX, w http.ResponseWriter
rememberMe := map[string]bool{ rememberMe := map[string]bool{
"session": false, "session": false,
"exp": true, "exp": true,
}[aT.TTL] }[rT.TTL]
// issue new tokens for the user // issue new tokens for the user
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL) err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
if err != nil { if err != nil {
@@ -55,13 +55,20 @@ func (auth *Authenticator[T, TX]) getTokens(
) (*jwt.AccessToken, *jwt.RefreshToken, error) { ) (*jwt.AccessToken, *jwt.RefreshToken, error) {
// get the existing tokens from the cookies // get the existing tokens from the cookies
atStr, rtStr := jwt.GetTokenCookies(r) atStr, rtStr := jwt.GetTokenCookies(r)
aT, err := auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr) var aT *jwt.AccessToken
if err != nil { var rT *jwt.RefreshToken
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess") var err error
if atStr != "" {
aT, err = auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
}
} }
rT, err := auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr) if rtStr != "" {
if err != nil { rT, err = auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh") if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
}
} }
return aT, rT, nil return aT, rT, nil
} }
@@ -72,13 +79,17 @@ func revokeTokenPair(
aT *jwt.AccessToken, aT *jwt.AccessToken,
rT *jwt.RefreshToken, rT *jwt.RefreshToken,
) error { ) error {
err := aT.Revoke(tx) if aT != nil {
if err != nil { err := aT.Revoke(tx)
return errors.Wrap(err, "aT.Revoke") if err != nil {
return errors.Wrap(err, "aT.Revoke")
}
} }
err = rT.Revoke(tx) if rT != nil {
if err != nil { err := rT.Revoke(tx)
return errors.Wrap(err, "rT.Revoke") if err != nil {
return errors.Wrap(err, "rT.Revoke")
}
} }
return nil return nil
} }

View File

@@ -7,7 +7,7 @@ import (
type API struct { type API struct {
*Config *Config
token string // ENV TMDB_TOKEN: API token for TMDB (required) token string `ezconf:"TMDB_TOKEN,description:API token for TMDB,required"`
} }
func NewAPIConnection() (*API, error) { func NewAPIConnection() (*API, error) {

View File

@@ -1,36 +1,9 @@
package tmdb package tmdb
import "runtime" import "git.haelnorr.com/h/golib/ezconf"
// EZConfIntegration provides integration with ezconf for automatic configuration // NewEZConfIntegration creates a new EZConf integration
type EZConfIntegration struct{} func NewEZConfIntegration() *ezconf.Integration {
return ezconf.NewIntegration("tmdb", "TMDB", &Config{},
// PackagePath returns the path to the tmdb package for source parsing func() (any, error) { return NewAPIConnection() })
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the NewAPIConnection function for ezconf
// Note: tmdb uses NewAPIConnection instead of ConfigFromEnv
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
return NewAPIConnection()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return "tmdb"
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return "TMDB"
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{}
} }

View File

@@ -6,3 +6,5 @@ require (
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
) )
require git.haelnorr.com/h/golib/ezconf v0.2.1

View File

@@ -1,4 +1,6 @@
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY= git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=