Compare commits

...

14 Commits

55 changed files with 1766 additions and 872 deletions

173
AGENTS.md Normal file
View File

@@ -0,0 +1,173 @@
# AGENTS.md - Coding Agent Guidelines for golib
## Project Overview
This is a Go library repository containing multiple independent packages:
- **cookies**: HTTP cookie utilities
- **env**: Environment variable helpers
- **ezconf**: Configuration loader with ENV parsing
- **hlog**: Logging with zerolog
- **hws**: HTTP web server
- **hwsauth**: Authentication middleware for hws
- **jwt**: JWT token generation and validation
- **tmdb**: The Movie Database API client
Each package has its own `go.mod` and can be used independently.
## Interactive Questions
All questions in plan mode should use the opencode interactive question prompter for user interaction.
## Build, Test, and Lint Commands
### Running Tests
```bash
# Test all packages from repo root
go test ./...
# Test a specific package
cd <package> && go test
# Run a single test function
cd <package> && go test -run TestFunctionName
# Run tests with verbose output
cd <package> && go test -v
# Run tests matching a pattern
cd <package> && go test -run "TestName.*"
```
### Building
```bash
# Each package is a library - no build needed
# Verify code compiles:
go build ./...
# Or for specific package:
cd <package> && go build
```
### Linting
```bash
# Use standard go tools
go vet ./...
go fmt ./...
# Check formatting without changing files
gofmt -l .
```
## Code Style Guidelines
### Package Structure
- Each package must have its own `go.mod` with module path: `git.haelnorr.com/h/golib/<package>`
- Go version should be current (1.23.4+)
- Each package should have a `doc.go` file with package documentation
### Imports
- Use standard library imports first
- Then third-party imports
- Then local imports from this repo (e.g., `git.haelnorr.com/h/golib/hlog`)
- Group imports with blank lines between groups
- Example:
```go
import (
"context"
"net/http"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"git.haelnorr.com/h/golib/hlog"
)
```
### Formatting
- Use `gofmt` standard formatting
- No tabs for alignment, use spaces inside structs
- Line length: no hard limit, but prefer readability
### Types
- Use explicit types for struct fields
- Config structs must have ENV comments (see below)
- Prefer named return values for complex functions
- Use generics where appropriate (see `hwsauth.Authenticator[T Model, TX DBTransaction]`)
### Naming Conventions
- Packages: lowercase, single word (e.g., `cookies`, `ezconf`)
- Exported functions: PascalCase (e.g., `NewServer`, `ConfigFromEnv`)
- Unexported functions: camelCase (e.g., `isValidHostname`, `waitUntilReady`)
- Test functions: `Test<FunctionName>` or `Test<FunctionName>_<Case>` (underscore for sub-cases)
- Variables: camelCase, descriptive names
- Constants: PascalCase or UPPER_CASE depending on scope
### Error Handling
- Use `github.com/pkg/errors` for error wrapping
- Wrap errors with context: `errors.Wrap(err, "context message")`
- Return errors, don't panic (except in truly exceptional cases)
- Validate inputs and return descriptive errors
- Example:
```go
if config == nil {
return nil, errors.New("Config cannot be nil")
}
```
### Configuration Pattern
- Each package with config should have a `Config` struct
- Provide `ConfigFromEnv() (*Config, error)` function
- ENV comment format for Config struct fields:
```go
type Config struct {
Host string // ENV HWS_HOST: Host to listen on (default: 127.0.0.1)
Port uint64 // ENV HWS_PORT: Port to listen on (default: 3000)
SSL bool // ENV HWS_SSL: Enable SSL (required when using production)
}
```
- Format: `// ENV ENV_NAME: Description (required <condition>) (default: <value>)`
- Include "required" only if no default
- Include "default" only if one exists
### Testing
- Use `testing` package from standard library
- Use `github.com/stretchr/testify` for assertions (`require`, `assert`)
- Table-driven tests for multiple cases:
```go
tests := []struct {
name string
input string
wantErr bool
}{
{"valid case", "input", false},
{"error case", "", true},
}
```
- Test files use `<package>_test` for black-box tests or `<package>` for white-box
- Helper functions should use `t.Helper()`
### Documentation
- All exported functions, types, and constants must have godoc comments
- Comments should start with the name being documented
- Example: `// NewServer returns a new hws.Server with the specified configuration.`
- Keep doc.go files up to date with package overview
- Follow RULES.md for README and wiki documentation
## Version Control (from RULES.md)
- Do NOT make changes to master branch
- Checkout a branch for new features
- Version numbers use git tags - do NOT change manually
- When updating docs, append branch name to version
- Changes to golib-wiki repo should use same branch name
## Testing Requirements (from RULES.md)
- All features MUST have tests
- Update existing tests when modifying features
- New features require new tests
## Documentation Requirements (from RULES.md)
- Document via: docstrings, README.md, doc.go, wiki
- README structure: Title+version, Features (NO EMOTICONS), Installation, Quick Start, Docs links, Additional info, License, Contributing, Related projects
- Wiki location: `~/projects/golib-wiki`
- Docstrings must conform to godoc standards
## License
- All modules use MIT License

View File

@@ -45,3 +45,6 @@ Do not make any changes to master. Checkout a branch to work on new features
Version numbers are specified using git tags. Version numbers are specified using git tags.
Do not change version numbers. When updating documentation, append the branch name to the version number. Do not change version numbers. When updating documentation, append the branch name to the version number.
Changes made to the golib-wiki repo should be made under the same branch name as the changes made in this repo Changes made to the golib-wiki repo should be made under the same branch name as the changes made in this repo
4. Licencing
All modules should have an MIT License

21
cookies/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 haelnorr
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

61
cookies/README.md Normal file
View File

@@ -0,0 +1,61 @@
# cookies v1.0.0
HTTP cookie utilities for Go web applications with security best practices.
## Features
- Secure cookie setting with HttpOnly flag
- Cookie deletion with proper expiration
- Pagefrom tracking for post-login redirects
- Host validation for referer-based redirects
- Full test coverage
## Installation
```bash
go get git.haelnorr.com/h/golib/cookies
```
## Quick Start
```go
package main
import (
"net/http"
"git.haelnorr.com/h/golib/cookies"
)
func handler(w http.ResponseWriter, r *http.Request) {
// Set a secure cookie
cookies.SetCookie(w, "session", "/", "abc123", 3600)
// Delete a cookie
cookies.DeleteCookie(w, "old_session", "/")
// Handle pagefrom for redirects
if r.URL.Path == "/login" {
cookies.SetPageFrom(w, r, "example.com")
}
// Check pagefrom after login
redirectTo := cookies.CheckPageFrom(w, r)
http.Redirect(w, r, redirectTo, http.StatusFound)
}
```
## Documentation
See the [wiki documentation](../golib/wiki/cookies.md) for detailed usage information and examples.
## License
MIT License
## Contributing
Please see the main golib repository for contributing guidelines.
## Related Projects
This package is part of the golib collection of utilities for Go applications and integrates well with other golib packages.

405
cookies/cookies_test.go Normal file
View File

@@ -0,0 +1,405 @@
package cookies
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestSetCookie(t *testing.T) {
tests := []struct {
name string
cookie string
path string
value string
maxAge int
expected string
}{
{
name: "basic cookie",
cookie: "test",
path: "/",
value: "value",
maxAge: 3600,
expected: "test=value; Path=/; Max-Age=3600; HttpOnly",
},
{
name: "zero max age",
cookie: "session",
path: "/api",
value: "abc123",
maxAge: 0,
expected: "session=abc123; Path=/api; HttpOnly",
},
{
name: "negative max age",
cookie: "temp",
path: "/",
value: "temp",
maxAge: -1,
expected: "temp=temp; Path=/; Max-Age=0; HttpOnly",
},
{
name: "empty value",
cookie: "empty",
path: "/",
value: "",
maxAge: 3600,
expected: "empty=; Path=/; Max-Age=3600; HttpOnly",
},
{
name: "special characters in value",
cookie: "data",
path: "/",
value: "test@123!#$%",
maxAge: 7200,
expected: "data=test@123!#$%; Path=/; Max-Age=7200; HttpOnly",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
SetCookie(w, tt.cookie, tt.path, tt.value, tt.maxAge)
headers := w.Header()["Set-Cookie"]
if len(headers) != 1 {
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers))
return
}
// Parse the cookie header to check individual components
cookieHeader := headers[0]
// Check that all expected components are present
if !strings.Contains(cookieHeader, tt.cookie+"="+tt.value) {
t.Errorf("Expected cookie name/value not found in: %s", cookieHeader)
}
if !strings.Contains(cookieHeader, "Path="+tt.path) {
t.Errorf("Expected path not found in: %s", cookieHeader)
}
if !strings.Contains(cookieHeader, "HttpOnly") {
t.Errorf("Expected HttpOnly not found in: %s", cookieHeader)
}
if tt.maxAge != 0 {
expectedMaxAge := fmt.Sprintf("Max-Age=%d", tt.maxAge)
if tt.maxAge < 0 {
expectedMaxAge = "Max-Age=0" // Go normalizes negative Max-Age to 0
}
if !strings.Contains(cookieHeader, expectedMaxAge) {
t.Errorf("Expected Max-Age not found in: %s", cookieHeader)
}
}
})
}
}
func TestDeleteCookie(t *testing.T) {
tests := []struct {
name string
cookie string
path string
expected string
}{
{
name: "basic deletion",
cookie: "test",
path: "/",
expected: "test=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; HttpOnly",
},
{
name: "delete with specific path",
cookie: "session",
path: "/api",
expected: "session=; Path=/api; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; HttpOnly",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
DeleteCookie(w, tt.cookie, tt.path)
headers := w.Header()["Set-Cookie"]
if len(headers) != 1 {
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers))
return
}
cookieHeader := headers[0]
// Check deletion-specific components
if !strings.Contains(cookieHeader, tt.cookie+"=") {
t.Errorf("Expected cookie name not found in: %s", cookieHeader)
}
if !strings.Contains(cookieHeader, "Path="+tt.path) {
t.Errorf("Expected path not found in: %s", cookieHeader)
}
if !strings.Contains(cookieHeader, "Max-Age=0") {
t.Errorf("Expected Max-Age=0 not found in: %s", cookieHeader)
}
if !strings.Contains(cookieHeader, "Expires=") {
t.Errorf("Expected Expires not found in: %s", cookieHeader)
}
if !strings.Contains(cookieHeader, "HttpOnly") {
t.Errorf("Expected HttpOnly not found in: %s", cookieHeader)
}
})
}
}
func TestCheckPageFrom(t *testing.T) {
tests := []struct {
name string
cookieValue string
cookiePath string
expectedResult string
shouldSet bool
}{
{
name: "valid pagefrom cookie",
cookieValue: "/dashboard",
cookiePath: "/",
expectedResult: "/dashboard",
shouldSet: true,
},
{
name: "no pagefrom cookie",
cookieValue: "",
cookiePath: "",
expectedResult: "/",
shouldSet: false,
},
{
name: "empty pagefrom cookie",
cookieValue: "",
cookiePath: "/",
expectedResult: "",
shouldSet: true,
},
{
name: "pagefrom with query params",
cookieValue: "/search?q=test",
cookiePath: "/",
expectedResult: "/search?q=test",
shouldSet: true,
},
{
name: "pagefrom with special path",
cookieValue: "/api/v1/users",
cookiePath: "/api",
expectedResult: "/api/v1/users",
shouldSet: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
r := &http.Request{
Header: make(http.Header),
}
if tt.shouldSet {
cookie := &http.Cookie{
Name: "pagefrom",
Value: tt.cookieValue,
Path: tt.cookiePath,
}
r.AddCookie(cookie)
}
result := CheckPageFrom(w, r)
if result != tt.expectedResult {
t.Errorf("CheckPageFrom() = %v, want %v", result, tt.expectedResult)
}
// Verify that the cookie was deleted
if tt.shouldSet {
headers := w.Header()["Set-Cookie"]
if len(headers) != 1 {
t.Errorf("Expected 1 Set-Cookie header for deletion, got %d", len(headers))
return
}
cookieHeader := headers[0]
if !strings.Contains(cookieHeader, "pagefrom=") {
t.Errorf("Expected pagefrom cookie deletion not found in: %s", cookieHeader)
}
if !strings.Contains(cookieHeader, "Max-Age=0") {
t.Errorf("Expected Max-Age=0 for deletion not found in: %s", cookieHeader)
}
}
})
}
}
func TestSetPageFrom(t *testing.T) {
tests := []struct {
name string
referer string
trustedHost string
expectedSet bool
expectedValue string
}{
{
name: "valid trusted host referer",
referer: "http://example.com/dashboard",
trustedHost: "example.com",
expectedSet: true,
expectedValue: "/dashboard",
},
{
name: "valid trusted host with https",
referer: "https://example.com/profile",
trustedHost: "example.com",
expectedSet: true,
expectedValue: "/profile",
},
{
name: "untrusted host",
referer: "http://evil.com/dashboard",
trustedHost: "example.com",
expectedSet: true,
expectedValue: "/",
},
{
name: "empty path",
referer: "http://example.com",
trustedHost: "example.com",
expectedSet: true,
expectedValue: "/",
},
{
name: "login path - should not set",
referer: "http://example.com/login",
trustedHost: "example.com",
expectedSet: false,
expectedValue: "",
},
{
name: "register path - should not set",
referer: "http://example.com/register",
trustedHost: "example.com",
expectedSet: false,
expectedValue: "",
},
{
name: "invalid referer URL",
referer: "not-a-url",
trustedHost: "example.com",
expectedSet: true,
expectedValue: "/",
},
{
name: "empty referer",
referer: "",
trustedHost: "example.com",
expectedSet: true,
expectedValue: "/",
},
{
name: "root path",
referer: "http://example.com/",
trustedHost: "example.com",
expectedSet: true,
expectedValue: "/",
},
{
name: "path with query string",
referer: "http://example.com/search?q=test",
trustedHost: "example.com",
expectedSet: true,
expectedValue: "/search",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
r := &http.Request{
Header: make(http.Header),
}
if tt.referer != "" {
r.Header.Set("Referer", tt.referer)
}
SetPageFrom(w, r, tt.trustedHost)
headers := w.Header()["Set-Cookie"]
if tt.expectedSet {
if len(headers) != 1 {
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers))
return
}
cookieHeader := headers[0]
if !strings.Contains(cookieHeader, "pagefrom="+tt.expectedValue) {
t.Errorf("Expected pagefrom=%s not found in: %s", tt.expectedValue, cookieHeader)
}
} else {
if len(headers) != 0 {
t.Errorf("Expected no Set-Cookie header, got %d", len(headers))
}
}
})
}
}
func TestIntegration(t *testing.T) {
// Test the complete flow: SetPageFrom -> CheckPageFrom
t.Run("complete flow", func(t *testing.T) {
// Step 1: Set pagefrom cookie
w1 := httptest.NewRecorder()
r1 := &http.Request{
Header: make(http.Header),
}
r1.Header.Set("Referer", "http://example.com/dashboard")
SetPageFrom(w1, r1, "example.com")
// Extract the cookie from the response
headers1 := w1.Header()["Set-Cookie"]
if len(headers1) != 1 {
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers1))
return
}
// Verify the cookie was set correctly
cookieHeader := headers1[0]
if !strings.Contains(cookieHeader, "pagefrom=/dashboard") {
t.Errorf("Expected pagefrom=/dashboard not found in: %s", cookieHeader)
}
// Step 2: Check pagefrom cookie (should delete it)
w2 := httptest.NewRecorder()
r2 := &http.Request{
Header: make(http.Header),
}
r2.AddCookie(&http.Cookie{
Name: "pagefrom",
Value: "/dashboard",
Path: "/",
})
result := CheckPageFrom(w2, r2)
if result != "/dashboard" {
t.Errorf("Expected result /dashboard, got %s", result)
}
// Verify the cookie was deleted
headers2 := w2.Header()["Set-Cookie"]
if len(headers2) != 1 {
t.Errorf("Expected 1 Set-Cookie header for deletion, got %d", len(headers2))
return
}
cookieHeader2 := headers2[0]
// Check for deletion indicators (Max-Age=0 with Expires in the past)
if !(strings.Contains(cookieHeader2, "Max-Age=0") && strings.Contains(cookieHeader2, "Expires=Thu, 01 Jan 1970")) {
t.Errorf("Expected cookie deletion, got: %s", cookieHeader2)
}
})
}

26
cookies/doc.go Normal file
View File

@@ -0,0 +1,26 @@
// Package cookies provides utilities for handling HTTP cookies in Go web applications.
// It includes functions for setting secure cookies, deleting cookies, and managing
// pagefrom tracking for post-login redirects.
//
// The package follows security best practices by setting the HttpOnly flag on all
// cookies to prevent XSS attacks. The SetCookie function allows you to specify the
// name, path, value, and max-age for cookies.
//
// The pagefrom functionality helps with user experience by remembering where a user
// was before being redirected to login/register pages, then redirecting them back
// after successful authentication.
//
// Example usage:
//
// // Set a session cookie
// cookies.SetCookie(w, "session", "/", "abc123", 3600)
//
// // Delete a cookie
// cookies.DeleteCookie(w, "old_session", "/")
//
// // Handle pagefrom tracking
// cookies.SetPageFrom(w, r, "example.com")
// redirectTo := cookies.CheckPageFrom(w, r)
//
// All functions are designed to be safe and handle edge cases gracefully.
package cookies

21
env/LICENSE vendored Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 haelnorr
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

67
env/README.md vendored Normal file
View File

@@ -0,0 +1,67 @@
# env v1.0.0
Environment variable utilities for Go applications with type safety and default values.
## Features
- Type-safe environment variable parsing
- Support for all basic Go types (string, int variants, uint variants, bool, time.Duration)
- Graceful fallback to default values
- Comprehensive boolean parsing with multiple truthy/falsy values
- Full test coverage
## Installation
```bash
go get git.haelnorr.com/h/golib/env
```
## Quick Start
```go
package main
import (
"fmt"
"time"
"git.haelnorr.com/h/golib/env"
)
func main() {
// String values
host := env.String("HOST", "localhost")
// Integer values (all sizes supported)
port := env.Int("PORT", 8080)
timeout := env.Int64("TIMEOUT_SECONDS", 30)
// Unsigned integer values
maxConnections := env.UInt("MAX_CONNECTIONS", 100)
// Boolean values (supports many formats)
debug := env.Bool("DEBUG", false)
// Duration values
requestTimeout := env.Duration("REQUEST_TIMEOUT", 30*time.Second)
fmt.Printf("Server: %s:%d\n", host, port)
fmt.Printf("Debug: %v\n", debug)
fmt.Printf("Timeout: %v\n", requestTimeout)
}
```
## Documentation
See the [wiki documentation](../golib/wiki/env.md) for detailed usage information and examples.
## License
MIT License
## Contributing
Please see the main golib repository for contributing guidelines.
## Related Projects
This package is part of the golib collection of utilities for Go applications.

18
env/doc.go vendored Normal file
View File

@@ -0,0 +1,18 @@
// Package env provides utilities for reading environment variables with type safety
// and default values. It supports common Go types including strings, integers (all sizes),
// unsigned integers (all sizes), booleans, and time.Duration values.
//
// The package follows a simple pattern where each function takes a key name and a
// default value, returning the parsed environment variable or the default if the
// variable is not set or cannot be parsed.
//
// Example usage:
//
// port := env.Int("PORT", 8080)
// debug := env.Bool("DEBUG", false)
// timeout := env.Duration("TIMEOUT", 30*time.Second)
//
// All functions gracefully handle missing environment variables by returning the
// provided default value. They also handle parsing errors by falling back to the
// default value.
package env

View File

@@ -3,7 +3,7 @@
// //
// ezconf allows you to: // ezconf allows you to:
// - Load configurations from multiple packages using their ConfigFromEnv functions // - Load configurations from multiple packages using their ConfigFromEnv functions
// - Parse package source code to extract environment variable documentation // - Parse config struct tags to extract environment variable documentation
// - Generate and update .env files with all required environment variables // - Generate and update .env files with all required environment variables
// - Print environment variable lists with descriptions and current values // - Print environment variable lists with descriptions and current values
// - Track additional custom environment variables // - Track additional custom environment variables
@@ -40,16 +40,16 @@
// // Use configuration... // // Use configuration...
// } // }
// //
// Alternatively, you can manually register packages: // Alternatively, you can manually register config structs:
// //
// loader := ezconf.New() // loader := ezconf.New()
// //
// // Add package paths to parse for ENV comments // // Add config struct for tag parsing
// loader.AddPackagePath("/path/to/golib/hlog") // loader.AddConfigStruct(&mypackage.Config{}, "MyPackage")
// //
// // Add configuration loaders // // Add configuration loaders
// loader.AddConfigFunc("hlog", func() (interface{}, error) { // loader.AddConfigFunc("mypackage", func() (interface{}, error) {
// return hlog.ConfigFromEnv() // return mypackage.ConfigFromEnv()
// }) // })
// //
// loader.Load() // loader.Load()
@@ -94,27 +94,34 @@
// Default: "postgres://localhost/mydb", // Default: "postgres://localhost/mydb",
// }) // })
// //
// # ENV Comment Format // # Struct Tag Format
// //
// ezconf parses struct field comments in the following format: // ezconf uses struct tags to define environment variable metadata:
// //
// type Config struct { // type Config struct {
// // ENV LOG_LEVEL: Log level for the application (default: info) // LogLevel string `ezconf:"LOG_LEVEL,description:Log level for the application,default:info"`
// LogLevel string // DatabaseURL string `ezconf:"DATABASE_URL,description:Database connection string,required"`
// // LogDir string `ezconf:"LOG_DIR,description:Directory for log files,required:when LOG_OUTPUT is file"`
// // ENV DATABASE_URL: Database connection string (required)
// DatabaseURL string
// } // }
// //
// The format is: // Tag components (comma-separated):
// - ENV ENV_VAR_NAME: Description (optional modifiers) // - First value: environment variable name (required)
// - (required) or (required if condition) - marks variable as required // - description:...: Description of the variable
// - (default: value) - specifies default value // - default:...: Default value
// - required: Marks the variable as required
// - required:condition: Marks as required with a condition description
// //
// # Integration // # Integration
// //
// Packages can implement the Integration interface to provide automatic
// registration with ezconf. The interface requires:
// - Name() string: Registration key for the config
// - ConfigPointer() any: Pointer to config struct for tag parsing
// - ConfigFunc() func() (any, error): Function to load config from env
// - GroupName() string: Display name for grouping env vars
//
// ezconf integrates with: // ezconf integrates with:
// - All golib packages that follow the ConfigFromEnv pattern // - All golib packages that follow the ConfigFromEnv pattern
// - Any custom configuration structs with ENV comments // - Any custom configuration structs with ezconf struct tags
// - Standard .env file format // - Standard .env file format
package ezconf package ezconf

View File

@@ -16,11 +16,16 @@ type EnvVar struct {
Group string // Group name for organizing variables (e.g., "Database", "Logging") Group string // Group name for organizing variables (e.g., "Database", "Logging")
} }
// configStruct holds a config struct pointer and its group name for parsing
type configStruct struct {
configPtr any
groupName string
}
// ConfigLoader manages configuration loading from multiple sources // ConfigLoader manages configuration loading from multiple sources
type ConfigLoader struct { type ConfigLoader struct {
configFuncs map[string]ConfigFunc // Map of config names to ConfigFromEnv functions configFuncs map[string]ConfigFunc // Map of config names to ConfigFromEnv functions
packagePaths []string // Paths to packages to parse for ENV comments configStructs []configStruct // Config struct pointers for tag parsing
groupNames map[string]string // Map of package paths to group names
extraEnvVars []EnvVar // Additional environment variables to track extraEnvVars []EnvVar // Additional environment variables to track
envVars []EnvVar // All extracted environment variables envVars []EnvVar // All extracted environment variables
configs map[string]any // Loaded configurations configs map[string]any // Loaded configurations
@@ -33,8 +38,7 @@ type ConfigFunc func() (any, error)
func New() *ConfigLoader { func New() *ConfigLoader {
return &ConfigLoader{ return &ConfigLoader{
configFuncs: make(map[string]ConfigFunc), configFuncs: make(map[string]ConfigFunc),
packagePaths: make([]string, 0), configStructs: make([]configStruct, 0),
groupNames: make(map[string]string),
extraEnvVars: make([]EnvVar, 0), extraEnvVars: make([]EnvVar, 0),
envVars: make([]EnvVar, 0), envVars: make([]EnvVar, 0),
configs: make(map[string]any), configs: make(map[string]any),
@@ -54,16 +58,20 @@ func (cl *ConfigLoader) AddConfigFunc(name string, fn ConfigFunc) error {
return nil return nil
} }
// AddPackagePath adds a package directory path to parse for ENV comments // AddConfigStruct adds a config struct pointer for parsing ezconf tags.
func (cl *ConfigLoader) AddPackagePath(path string) error { // The configPtr must be a pointer to a struct with ezconf struct tags.
if path == "" { // The groupName is used for organizing environment variables in output.
return errors.New("package path cannot be empty") func (cl *ConfigLoader) AddConfigStruct(configPtr any, groupName string) error {
if configPtr == nil {
return errors.New("config pointer cannot be nil")
} }
// Check if path exists if groupName == "" {
if _, err := os.Stat(path); os.IsNotExist(err) { groupName = "Other"
return errors.Errorf("package path does not exist: %s", path)
} }
cl.packagePaths = append(cl.packagePaths, path) cl.configStructs = append(cl.configStructs, configStruct{
configPtr: configPtr,
groupName: groupName,
})
return nil return nil
} }
@@ -72,27 +80,22 @@ func (cl *ConfigLoader) AddEnvVar(envVar EnvVar) {
cl.extraEnvVars = append(cl.extraEnvVars, envVar) cl.extraEnvVars = append(cl.extraEnvVars, envVar)
} }
// ParseEnvVars extracts environment variables from packages and extra vars // ParseEnvVars extracts environment variables from config struct tags and extra vars.
// This can be called without having actual environment variables set // This can be called without having actual environment variables set.
func (cl *ConfigLoader) ParseEnvVars() error { func (cl *ConfigLoader) ParseEnvVars() error {
// Clear existing env vars to prevent duplicates // Clear existing env vars to prevent duplicates
cl.envVars = make([]EnvVar, 0) cl.envVars = make([]EnvVar, 0)
// Parse packages for ENV comments // Parse config structs for ezconf tags
for _, pkgPath := range cl.packagePaths { for _, cs := range cl.configStructs {
envVars, err := ParseConfigPackage(pkgPath) envVars, err := ParseConfigStruct(cs.configPtr)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to parse package: %s", pkgPath) return errors.Wrap(err, "failed to parse config struct")
}
// Set group name for these variables from stored mapping
groupName := cl.groupNames[pkgPath]
if groupName == "" {
groupName = "Other"
} }
// Set group name for these variables
for i := range envVars { for i := range envVars {
envVars[i].Group = groupName envVars[i].Group = cs.groupName
} }
cl.envVars = append(cl.envVars, envVars...) cl.envVars = append(cl.envVars, envVars...)
@@ -109,8 +112,8 @@ func (cl *ConfigLoader) ParseEnvVars() error {
return nil return nil
} }
// LoadConfigs executes the config functions to load actual configurations // LoadConfigs executes the config functions to load actual configurations.
// This should be called after environment variables are properly set // This should be called after environment variables are properly set.
func (cl *ConfigLoader) LoadConfigs() error { func (cl *ConfigLoader) LoadConfigs() error {
// Load configurations // Load configurations
for name, fn := range cl.configFuncs { for name, fn := range cl.configFuncs {

View File

@@ -2,11 +2,17 @@ package ezconf
import ( import (
"os" "os"
"path/filepath"
"strings" "strings"
"testing" "testing"
) )
// testConfig is a Config struct used by multiple tests
type testConfig struct {
LogLevel string `ezconf:"LOG_LEVEL,description:Log level for the application,default:info"`
LogOutput string `ezconf:"LOG_OUTPUT,description:Output destination,default:console"`
DatabaseURL string `ezconf:"DATABASE_URL,description:Database connection string,required"`
}
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
loader := New() loader := New()
if loader == nil { if loader == nil {
@@ -16,8 +22,8 @@ func TestNew(t *testing.T) {
if loader.configFuncs == nil { if loader.configFuncs == nil {
t.Error("configFuncs map is nil") t.Error("configFuncs map is nil")
} }
if loader.packagePaths == nil { if loader.configStructs == nil {
t.Error("packagePaths slice is nil") t.Error("configStructs slice is nil")
} }
if loader.extraEnvVars == nil { if loader.extraEnvVars == nil {
t.Error("extraEnvVars slice is nil") t.Error("extraEnvVars slice is nil")
@@ -66,35 +72,39 @@ func TestAddConfigFunc_EmptyName(t *testing.T) {
} }
} }
func TestAddPackagePath(t *testing.T) { func TestAddConfigStruct(t *testing.T) {
loader := New() loader := New()
// Use current directory as test path err := loader.AddConfigStruct(&testConfig{}, "Test")
err := loader.AddPackagePath(".")
if err != nil { if err != nil {
t.Errorf("AddPackagePath failed: %v", err) t.Errorf("AddConfigStruct failed: %v", err)
} }
if len(loader.packagePaths) != 1 { if len(loader.configStructs) != 1 {
t.Errorf("expected 1 package path, got %d", len(loader.packagePaths)) t.Errorf("expected 1 config struct, got %d", len(loader.configStructs))
} }
} }
func TestAddPackagePath_InvalidPath(t *testing.T) { func TestAddConfigStruct_NilPointer(t *testing.T) {
loader := New() loader := New()
err := loader.AddPackagePath("/nonexistent/path") err := loader.AddConfigStruct(nil, "Test")
if err == nil { if err == nil {
t.Error("expected error for nonexistent path") t.Error("expected error for nil pointer")
} }
} }
func TestAddPackagePath_EmptyPath(t *testing.T) { func TestAddConfigStruct_EmptyGroupName(t *testing.T) {
loader := New() loader := New()
err := loader.AddPackagePath("") err := loader.AddConfigStruct(&testConfig{}, "")
if err == nil { if err != nil {
t.Error("expected error for empty path") t.Errorf("AddConfigStruct failed: %v", err)
}
// Should default to "Other"
if loader.configStructs[0].groupName != "Other" {
t.Errorf("expected group name 'Other', got %s", loader.configStructs[0].groupName)
} }
} }
@@ -131,8 +141,8 @@ func TestLoad(t *testing.T) {
return testCfg, nil return testCfg, nil
}) })
// Add current package path // Add config struct for tag parsing
loader.AddPackagePath(".") loader.AddConfigStruct(&testConfig{}, "Test")
// Add an extra env var // Add an extra env var
loader.AddEnvVar(EnvVar{ loader.AddEnvVar(EnvVar{
@@ -249,8 +259,8 @@ func TestParseEnvVars(t *testing.T) {
return "test config", nil return "test config", nil
}) })
// Add current package path // Add config struct for tag parsing
loader.AddPackagePath(".") loader.AddConfigStruct(&testConfig{}, "Test")
// Add an extra env var // Add an extra env var
loader.AddEnvVar(EnvVar{ loader.AddEnvVar(EnvVar{
@@ -353,8 +363,8 @@ func TestParseEnvVars_Then_LoadConfigs(t *testing.T) {
return testCfg, nil return testCfg, nil
}) })
// Add current package path // Add config struct for tag parsing
loader.AddPackagePath(".") loader.AddConfigStruct(&testConfig{}, "Test")
// Add an extra env var // Add an extra env var
loader.AddEnvVar(EnvVar{ loader.AddEnvVar(EnvVar{
@@ -398,63 +408,68 @@ func TestParseEnvVars_Then_LoadConfigs(t *testing.T) {
} }
} }
func TestLoad_Integration(t *testing.T) { func TestParseEnvVars_GroupName(t *testing.T) {
// Integration test with real hlog package
hlogPath := filepath.Join("..", "hlog")
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
t.Skip("hlog package not found, skipping integration test")
}
loader := New() loader := New()
// Add hlog package loader.AddConfigStruct(&testConfig{}, "MyGroup")
if err := loader.AddPackagePath(hlogPath); err != nil {
t.Fatalf("failed to add hlog package: %v", err)
}
// Load without config function (just parse) err := loader.ParseEnvVars()
if err := loader.Load(); err != nil { if err != nil {
t.Fatalf("Load failed: %v", err) t.Fatalf("ParseEnvVars failed: %v", err)
} }
envVars := loader.GetEnvVars() envVars := loader.GetEnvVars()
if len(envVars) == 0 {
t.Error("expected env vars from hlog package")
}
t.Logf("Found %d environment variables from hlog", len(envVars))
for _, ev := range envVars { for _, ev := range envVars {
t.Logf(" %s: %s (default: %s, required: %t)", ev.Name, ev.Description, ev.Default, ev.Required) if ev.Group != "MyGroup" {
t.Errorf("expected group 'MyGroup', got '%s' for var %s", ev.Group, ev.Name)
}
} }
} }
func TestParseEnvVars_GenerateEnvFile_Integration(t *testing.T) { func TestParseEnvVars_CurrentValues(t *testing.T) {
// Test the new separated ParseEnvVars functionality
hlogPath := filepath.Join("..", "hlog")
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
t.Skip("hlog package not found, skipping integration test")
}
loader := New() loader := New()
// Add hlog package loader.AddConfigStruct(&testConfig{}, "Test")
if err := loader.AddPackagePath(hlogPath); err != nil {
t.Fatalf("failed to add hlog package: %v", err) // Set an env var
t.Setenv("LOG_LEVEL", "debug")
err := loader.ParseEnvVars()
if err != nil {
t.Fatalf("ParseEnvVars failed: %v", err)
} }
// Parse env vars without loading configs (this should work even if required env vars are missing) envVars := loader.GetEnvVars()
for _, ev := range envVars {
if ev.Name == "LOG_LEVEL" {
if ev.CurrentValue != "debug" {
t.Errorf("expected CurrentValue 'debug', got '%s'", ev.CurrentValue)
}
return
}
}
t.Error("LOG_LEVEL not found in env vars")
}
func TestParseEnvVars_GenerateEnvFile_Integration(t *testing.T) {
loader := New()
// Add config struct for tag parsing
loader.AddConfigStruct(&testConfig{}, "Test")
// Parse env vars
if err := loader.ParseEnvVars(); err != nil { if err := loader.ParseEnvVars(); err != nil {
t.Fatalf("ParseEnvVars failed: %v", err) t.Fatalf("ParseEnvVars failed: %v", err)
} }
envVars := loader.GetEnvVars() envVars := loader.GetEnvVars()
if len(envVars) == 0 { if len(envVars) == 0 {
t.Error("expected env vars from hlog package") t.Error("expected env vars from config struct")
} }
// Now test that we can generate an env file without calling Load() // Now test that we can generate an env file without calling Load()
tempDir := t.TempDir() tempDir := t.TempDir()
envFile := filepath.Join(tempDir, "test-generated.env") envFile := tempDir + "/test-generated.env"
err := loader.GenerateEnvFile(envFile, false) err := loader.GenerateEnvFile(envFile, false)
if err != nil { if err != nil {
@@ -472,16 +487,16 @@ func TestParseEnvVars_GenerateEnvFile_Integration(t *testing.T) {
t.Error("expected header in generated file") t.Error("expected header in generated file")
} }
// Should contain environment variables from hlog // Should contain environment variables from config struct
foundHlogVar := false foundVar := false
for _, ev := range envVars { for _, ev := range envVars {
if strings.Contains(output, ev.Name) { if strings.Contains(output, ev.Name) {
foundHlogVar = true foundVar = true
break break
} }
} }
if !foundHlogVar { if !foundVar {
t.Error("expected to find at least one hlog environment variable in generated file") t.Error("expected to find at least one environment variable in generated file")
} }
t.Logf("Successfully generated env file with %d variables", len(envVars)) t.Logf("Successfully generated env file with %d variables", len(envVars))

View File

@@ -1,31 +1,70 @@
package ezconf package ezconf
// Integration is an interface that packages can implement to provide type Integration struct {
Name string
ConfigPointer any
ConfigFunc func() (any, error)
GroupName string
}
func NewIntegration(name, groupname string, cfgptr any, cfgfunc func() (any, error)) *Integration {
return &Integration{
name,
cfgptr,
cfgfunc,
groupname,
}
}
// IntegrationDepr is an interface that packages can implement to provide
// easy integration with ezconf // easy integration with ezconf
type Integration interface { type IntegrationDepr interface {
// Name returns the name to use when registering the config // Name returns the name to use when registering the config
Name() string Name() string
// PackagePath returns the path to the package for source parsing // ConfigPointer returns a pointer to the config struct for tag parsing
PackagePath() string ConfigPointer() any
// ConfigFunc returns the ConfigFromEnv function // ConfigFunc returns the ConfigFromEnv function
ConfigFunc() func() (interface{}, error) ConfigFunc() func() (any, error)
// GroupName returns the display name for grouping environment variables // GroupName returns the display name for grouping environment variables
GroupName() string GroupName() string
} }
// RegisterIntegration registers a package that implements the Integration interface // AddIntegration registers a package using an Integration object returned by another package
func (cl *ConfigLoader) RegisterIntegration(integration Integration) error { func (cl *ConfigLoader) AddIntegration(integration *Integration) error {
// Add package path // Add config struct for tag parsing
pkgPath := integration.PackagePath() configPtr := integration.ConfigPointer
if err := cl.AddPackagePath(pkgPath); err != nil { if err := cl.AddConfigStruct(configPtr, integration.GroupName); err != nil {
return err return err
} }
// Store group name for this package // Add config function
cl.groupNames[pkgPath] = integration.GroupName() if err := cl.AddConfigFunc(integration.Name, integration.ConfigFunc); err != nil {
return err
}
return nil
}
// AddIntegrations registers multiple integrations at once
func (cl *ConfigLoader) AddIntegrations(integrations ...*Integration) error {
for _, integration := range integrations {
if err := cl.AddIntegration(integration); err != nil {
return err
}
}
return nil
}
// RegisterIntegration registers a package that implements the Integration interface
func (cl *ConfigLoader) RegisterIntegration(integration IntegrationDepr) error {
// Add config struct for tag parsing
configPtr := integration.ConfigPointer()
if err := cl.AddConfigStruct(configPtr, integration.GroupName()); err != nil {
return err
}
// Add config function // Add config function
if err := cl.AddConfigFunc(integration.Name(), integration.ConfigFunc()); err != nil { if err := cl.AddConfigFunc(integration.Name(), integration.ConfigFunc()); err != nil {
@@ -36,7 +75,7 @@ func (cl *ConfigLoader) RegisterIntegration(integration Integration) error {
} }
// RegisterIntegrations registers multiple integrations at once // RegisterIntegrations registers multiple integrations at once
func (cl *ConfigLoader) RegisterIntegrations(integrations ...Integration) error { func (cl *ConfigLoader) RegisterIntegrations(integrations ...IntegrationDepr) error {
for _, integration := range integrations { for _, integration := range integrations {
if err := cl.RegisterIntegration(integration); err != nil { if err := cl.RegisterIntegration(integration); err != nil {
return err return err

View File

@@ -1,24 +1,34 @@
package ezconf package ezconf
import ( import (
"os"
"path/filepath"
"testing" "testing"
) )
// Mock integration for testing // mockConfig is a test config struct with ezconf tags
type mockConfig struct {
Host string `ezconf:"MOCK_HOST,description:Host to connect to,default:localhost"`
Port int `ezconf:"MOCK_PORT,description:Port to connect to,default:8080"`
}
// mockConfig2 is a second test config struct
type mockConfig2 struct {
Token string `ezconf:"MOCK_TOKEN,description:API token,required"`
}
// mockIntegration implements the Integration interface for testing
type mockIntegration struct { type mockIntegration struct {
name string name string
packagePath string configPtr any
configFunc func() (interface{}, error) configFunc func() (interface{}, error)
groupName string
} }
func (m mockIntegration) Name() string { func (m mockIntegration) Name() string {
return m.name return m.name
} }
func (m mockIntegration) PackagePath() string { func (m mockIntegration) ConfigPointer() any {
return m.packagePath return m.configPtr
} }
func (m mockIntegration) ConfigFunc() func() (interface{}, error) { func (m mockIntegration) ConfigFunc() func() (interface{}, error) {
@@ -26,7 +36,10 @@ func (m mockIntegration) ConfigFunc() func() (interface{}, error) {
} }
func (m mockIntegration) GroupName() string { func (m mockIntegration) GroupName() string {
if m.groupName == "" {
return "Test Group" return "Test Group"
}
return m.groupName
} }
func TestRegisterIntegration(t *testing.T) { func TestRegisterIntegration(t *testing.T) {
@@ -34,7 +47,7 @@ func TestRegisterIntegration(t *testing.T) {
integration := mockIntegration{ integration := mockIntegration{
name: "test", name: "test",
packagePath: ".", configPtr: &mockConfig{},
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "test config", nil return "test config", nil
}, },
@@ -45,9 +58,9 @@ func TestRegisterIntegration(t *testing.T) {
t.Fatalf("RegisterIntegration failed: %v", err) t.Fatalf("RegisterIntegration failed: %v", err)
} }
// Verify package path was added // Verify config struct was added
if len(loader.packagePaths) != 1 { if len(loader.configStructs) != 1 {
t.Errorf("expected 1 package path, got %d", len(loader.packagePaths)) t.Errorf("expected 1 config struct, got %d", len(loader.configStructs))
} }
// Verify config func was added // Verify config func was added
@@ -68,14 +81,46 @@ func TestRegisterIntegration(t *testing.T) {
if cfg != "test config" { if cfg != "test config" {
t.Errorf("expected 'test config', got %v", cfg) t.Errorf("expected 'test config', got %v", cfg)
} }
// Verify env vars were parsed from struct tags
envVars := loader.GetEnvVars()
if len(envVars) != 2 {
t.Errorf("expected 2 env vars, got %d", len(envVars))
}
foundHost := false
foundPort := false
for _, ev := range envVars {
if ev.Name == "MOCK_HOST" {
foundHost = true
if ev.Default != "localhost" {
t.Errorf("expected default 'localhost', got '%s'", ev.Default)
}
if ev.Group != "Test Group" {
t.Errorf("expected group 'Test Group', got '%s'", ev.Group)
}
}
if ev.Name == "MOCK_PORT" {
foundPort = true
if ev.Default != "8080" {
t.Errorf("expected default '8080', got '%s'", ev.Default)
}
}
}
if !foundHost {
t.Error("MOCK_HOST not found in env vars")
}
if !foundPort {
t.Error("MOCK_PORT not found in env vars")
}
} }
func TestRegisterIntegration_InvalidPath(t *testing.T) { func TestRegisterIntegration_NilConfigPointer(t *testing.T) {
loader := New() loader := New()
integration := mockIntegration{ integration := mockIntegration{
name: "test", name: "test",
packagePath: "/nonexistent/path", configPtr: nil,
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "test config", nil return "test config", nil
}, },
@@ -83,7 +128,7 @@ func TestRegisterIntegration_InvalidPath(t *testing.T) {
err := loader.RegisterIntegration(integration) err := loader.RegisterIntegration(integration)
if err == nil { if err == nil {
t.Error("expected error for invalid package path") t.Error("expected error for nil config pointer")
} }
} }
@@ -92,7 +137,7 @@ func TestRegisterIntegrations(t *testing.T) {
integration1 := mockIntegration{ integration1 := mockIntegration{
name: "test1", name: "test1",
packagePath: ".", configPtr: &mockConfig{},
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "config1", nil return "config1", nil
}, },
@@ -100,7 +145,7 @@ func TestRegisterIntegrations(t *testing.T) {
integration2 := mockIntegration{ integration2 := mockIntegration{
name: "test2", name: "test2",
packagePath: ".", configPtr: &mockConfig2{},
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "config2", nil return "config2", nil
}, },
@@ -130,6 +175,12 @@ func TestRegisterIntegrations(t *testing.T) {
if cfg1 != "config1" || cfg2 != "config2" { if cfg1 != "config1" || cfg2 != "config2" {
t.Error("config values mismatch") t.Error("config values mismatch")
} }
// Should have env vars from both structs
envVars := loader.GetEnvVars()
if len(envVars) != 3 {
t.Errorf("expected 3 env vars (2 from mockConfig + 1 from mockConfig2), got %d", len(envVars))
}
} }
func TestRegisterIntegrations_PartialFailure(t *testing.T) { func TestRegisterIntegrations_PartialFailure(t *testing.T) {
@@ -137,7 +188,7 @@ func TestRegisterIntegrations_PartialFailure(t *testing.T) {
integration1 := mockIntegration{ integration1 := mockIntegration{
name: "test1", name: "test1",
packagePath: ".", configPtr: &mockConfig{},
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "config1", nil return "config1", nil
}, },
@@ -145,7 +196,7 @@ func TestRegisterIntegrations_PartialFailure(t *testing.T) {
integration2 := mockIntegration{ integration2 := mockIntegration{
name: "test2", name: "test2",
packagePath: "/nonexistent", configPtr: nil, // This should cause failure
configFunc: func() (interface{}, error) { configFunc: func() (interface{}, error) {
return "config2", nil return "config2", nil
}, },
@@ -159,54 +210,5 @@ func TestRegisterIntegrations_PartialFailure(t *testing.T) {
func TestIntegration_Interface(t *testing.T) { func TestIntegration_Interface(t *testing.T) {
// Verify that mockIntegration implements Integration interface // Verify that mockIntegration implements Integration interface
var _ Integration = (*mockIntegration)(nil) var _ IntegrationDepr = (*mockIntegration)(nil)
}
func TestRegisterIntegration_RealPackage(t *testing.T) {
// Integration test with real hlog package if available
hlogPath := filepath.Join("..", "hlog")
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
t.Skip("hlog package not found, skipping integration test")
}
loader := New()
// Create a simple integration for testing
integration := mockIntegration{
name: "hlog",
packagePath: hlogPath,
configFunc: func() (interface{}, error) {
// Return a mock config instead of calling real ConfigFromEnv
return struct{ LogLevel string }{LogLevel: "info"}, nil
},
}
err := loader.RegisterIntegration(integration)
if err != nil {
t.Fatalf("RegisterIntegration with real package failed: %v", err)
}
if err := loader.Load(); err != nil {
t.Fatalf("Load failed: %v", err)
}
// Should have parsed env vars from hlog
envVars := loader.GetEnvVars()
if len(envVars) == 0 {
t.Error("expected env vars from hlog package")
}
// Check for known hlog variables
foundLogLevel := false
for _, ev := range envVars {
if ev.Name == "LOG_LEVEL" {
foundLogLevel = true
t.Logf("Found LOG_LEVEL: %s", ev.Description)
break
}
}
if !foundLogLevel {
t.Error("expected to find LOG_LEVEL from hlog")
}
} }

View File

@@ -1,146 +1,102 @@
package ezconf package ezconf
import ( import (
"go/ast" "reflect"
"go/parser"
"go/token"
"os"
"path/filepath"
"regexp"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// ParseConfigFile parses a Go source file and extracts ENV comments from struct fields // ParseConfigStruct extracts environment variable metadata from a config
func ParseConfigFile(filename string) ([]EnvVar, error) { // struct's ezconf struct tags using reflection.
content, err := os.ReadFile(filename) //
if err != nil { // The configPtr parameter must be a pointer to a struct. Each field with an
return nil, errors.Wrap(err, "failed to read file") // ezconf tag will be parsed to extract environment variable information.
//
// Tag format: `ezconf:"VAR_NAME,description:Description text,default:value,required"`
//
// Components:
// - First value: environment variable name (required)
// - description:...: Description of the variable
// - default:...: Default value
// - required: Marks the variable as required (optionally required:condition)
func ParseConfigStruct(configPtr any) ([]EnvVar, error) {
if configPtr == nil {
return nil, errors.New("config pointer cannot be nil")
} }
fset := token.NewFileSet() v := reflect.ValueOf(configPtr)
file, err := parser.ParseFile(fset, filename, content, parser.ParseComments) if v.Kind() != reflect.Ptr {
if err != nil { return nil, errors.New("config must be a pointer to a struct")
return nil, errors.Wrap(err, "failed to parse file")
} }
v = v.Elem()
if v.Kind() != reflect.Struct {
return nil, errors.New("config must be a pointer to a struct")
}
t := v.Type()
envVars := make([]EnvVar, 0) envVars := make([]EnvVar, 0)
// Walk through the AST for i := 0; i < t.NumField(); i++ {
ast.Inspect(file, func(n ast.Node) bool { field := t.Field(i)
// Look for struct type declarations tag := field.Tag.Get("ezconf")
typeSpec, ok := n.(*ast.TypeSpec) if tag == "" {
if !ok { continue
return true
} }
structType, ok := typeSpec.Type.(*ast.StructType) envVar, err := parseEzconfTag(tag)
if !ok { if err != nil {
return true return nil, errors.Wrapf(err, "failed to parse ezconf tag on field %s", field.Name)
} }
// Iterate through struct fields
for _, field := range structType.Fields.List {
var comment string
// Try to get from doc comment (comment before field)
if field.Doc != nil && len(field.Doc.List) > 0 {
comment = field.Doc.List[0].Text
comment = strings.TrimPrefix(comment, "//")
comment = strings.TrimSpace(comment)
}
// Try to get from inline comment (comment after field)
if comment == "" && field.Comment != nil && len(field.Comment.List) > 0 {
comment = field.Comment.List[0].Text
comment = strings.TrimPrefix(comment, "//")
comment = strings.TrimSpace(comment)
}
// Parse ENV comment
if strings.HasPrefix(comment, "ENV ") {
envVar, err := parseEnvComment(comment)
if err == nil {
envVars = append(envVars, *envVar) envVars = append(envVars, *envVar)
} }
}
}
return true
})
return envVars, nil return envVars, nil
} }
// ParseConfigPackage parses all Go files in a package directory and extracts ENV comments // parseEzconfTag parses an ezconf struct tag value to extract environment
func ParseConfigPackage(packagePath string) ([]EnvVar, error) { // variable information.
// Find all .go files in the package //
files, err := filepath.Glob(filepath.Join(packagePath, "*.go")) // Expected format: "VAR_NAME,description:Description text,default:value,required"
if err != nil { func parseEzconfTag(tag string) (*EnvVar, error) {
return nil, errors.Wrap(err, "failed to glob package files") if tag == "" {
return nil, errors.New("tag cannot be empty")
} }
allEnvVars := make([]EnvVar, 0) parts := strings.Split(tag, ",")
if len(parts) == 0 {
for _, file := range files { return nil, errors.New("tag cannot be empty")
// Skip test files
if strings.HasSuffix(file, "_test.go") {
continue
}
envVars, err := ParseConfigFile(file)
if err != nil {
// Log error but continue with other files
continue
}
allEnvVars = append(allEnvVars, envVars...)
}
return allEnvVars, nil
}
// parseEnvComment parses a field comment to extract environment variable information.
// Expected format: ENV ENV_NAME: Description (required <condition>) (default: <value>)
func parseEnvComment(comment string) (*EnvVar, error) {
// Check if comment starts with ENV
if !strings.HasPrefix(comment, "ENV ") {
return nil, errors.New("comment does not start with 'ENV '")
}
// Remove "ENV " prefix
comment = strings.TrimPrefix(comment, "ENV ")
// Extract env var name (everything before the first colon)
colonIdx := strings.Index(comment, ":")
if colonIdx == -1 {
return nil, errors.New("missing colon separator")
} }
envVar := &EnvVar{ envVar := &EnvVar{
Name: strings.TrimSpace(comment[:colonIdx]), Name: strings.TrimSpace(parts[0]),
} }
// Extract description and optional parts if envVar.Name == "" {
remainder := strings.TrimSpace(comment[colonIdx+1:]) return nil, errors.New("environment variable name cannot be empty")
}
// Check for (required ...) pattern for _, part := range parts[1:] {
requiredPattern := regexp.MustCompile(`\(required[^)]*\)`) part = strings.TrimSpace(part)
if requiredPattern.MatchString(remainder) {
switch {
case strings.HasPrefix(part, "description:"):
envVar.Description = strings.TrimSpace(strings.TrimPrefix(part, "description:"))
case strings.HasPrefix(part, "default:"):
envVar.Default = strings.TrimSpace(strings.TrimPrefix(part, "default:"))
case part == "required":
envVar.Required = true envVar.Required = true
remainder = requiredPattern.ReplaceAllString(remainder, "") case strings.HasPrefix(part, "required:"):
envVar.Required = true
// Store the condition in the description if it adds context
condition := strings.TrimSpace(strings.TrimPrefix(part, "required:"))
if condition != "" && envVar.Description != "" {
envVar.Description = envVar.Description + " (required " + condition + ")"
}
} }
// Check for (default: ...) pattern
defaultPattern := regexp.MustCompile(`\(default:\s*([^)]*)\)`)
if matches := defaultPattern.FindStringSubmatch(remainder); len(matches) > 1 {
envVar.Default = strings.TrimSpace(matches[1])
remainder = defaultPattern.ReplaceAllString(remainder, "")
} }
// What remains is the description
envVar.Description = strings.TrimSpace(remainder)
return envVar, nil return envVar, nil
} }

View File

@@ -1,21 +1,19 @@
package ezconf package ezconf
import ( import (
"os"
"path/filepath"
"testing" "testing"
) )
func TestParseEnvComment(t *testing.T) { func TestParseEzconfTag(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
comment string tag string
wantEnvVar *EnvVar wantEnvVar *EnvVar
expectError bool expectError bool
}{ }{
{ {
name: "simple env variable", name: "simple env variable",
comment: "ENV LOG_LEVEL: Log level for the application", tag: "LOG_LEVEL,description:Log level for the application",
wantEnvVar: &EnvVar{ wantEnvVar: &EnvVar{
Name: "LOG_LEVEL", Name: "LOG_LEVEL",
Description: "Log level for the application", Description: "Log level for the application",
@@ -26,7 +24,7 @@ func TestParseEnvComment(t *testing.T) {
}, },
{ {
name: "env variable with default", name: "env variable with default",
comment: "ENV LOG_LEVEL: Log level for the application (default: info)", tag: "LOG_LEVEL,description:Log level for the application,default:info",
wantEnvVar: &EnvVar{ wantEnvVar: &EnvVar{
Name: "LOG_LEVEL", Name: "LOG_LEVEL",
Description: "Log level for the application", Description: "Log level for the application",
@@ -37,7 +35,7 @@ func TestParseEnvComment(t *testing.T) {
}, },
{ {
name: "required env variable", name: "required env variable",
comment: "ENV DATABASE_URL: Database connection string (required)", tag: "DATABASE_URL,description:Database connection string,required",
wantEnvVar: &EnvVar{ wantEnvVar: &EnvVar{
Name: "DATABASE_URL", Name: "DATABASE_URL",
Description: "Database connection string", Description: "Database connection string",
@@ -48,24 +46,35 @@ func TestParseEnvComment(t *testing.T) {
}, },
{ {
name: "required with condition and default", name: "required with condition and default",
comment: "ENV LOG_DIR: Directory for log files (required when LOG_OUTPUT is file) (default: /var/log)", tag: "LOG_DIR,description:Directory for log files,required:when LOG_OUTPUT is file,default:/var/log",
wantEnvVar: &EnvVar{ wantEnvVar: &EnvVar{
Name: "LOG_DIR", Name: "LOG_DIR",
Description: "Directory for log files", Description: "Directory for log files (required when LOG_OUTPUT is file)",
Required: true, Required: true,
Default: "/var/log", Default: "/var/log",
}, },
expectError: false, expectError: false,
}, },
{ {
name: "missing colon", name: "name only",
comment: "ENV LOG_LEVEL Log level", tag: "SIMPLE_VAR",
wantEnvVar: &EnvVar{
Name: "SIMPLE_VAR",
Description: "",
Required: false,
Default: "",
},
expectError: false,
},
{
name: "empty tag",
tag: "",
wantEnvVar: nil, wantEnvVar: nil,
expectError: true, expectError: true,
}, },
{ {
name: "not an ENV comment", name: "empty name",
comment: "This is a regular comment", tag: ",description:some desc",
wantEnvVar: nil, wantEnvVar: nil,
expectError: true, expectError: true,
}, },
@@ -73,7 +82,7 @@ func TestParseEnvComment(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
envVar, err := parseEnvComment(tt.comment) envVar, err := parseEzconfTag(tt.tag)
if tt.expectError { if tt.expectError {
if err == nil { if err == nil {
@@ -103,32 +112,17 @@ func TestParseEnvComment(t *testing.T) {
} }
} }
func TestParseConfigFile(t *testing.T) { func TestParseConfigStruct(t *testing.T) {
// Create a temporary test file type TestConfig struct {
tempDir := t.TempDir() LogLevel string `ezconf:"LOG_LEVEL,description:Log level for the application,default:info"`
testFile := filepath.Join(tempDir, "config.go") LogOutput string `ezconf:"LOG_OUTPUT,description:Output destination,default:console"`
DatabaseURL string `ezconf:"DATABASE_URL,description:Database connection string,required"`
content := `package testpkg NoTag string
type Config struct {
// ENV LOG_LEVEL: Log level for the application (default: info)
LogLevel string
// ENV LOG_OUTPUT: Output destination (default: console)
LogOutput string
// ENV DATABASE_URL: Database connection string (required)
DatabaseURL string
}
`
if err := os.WriteFile(testFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create test file: %v", err)
} }
envVars, err := ParseConfigFile(testFile) envVars, err := ParseConfigStruct(&TestConfig{})
if err != nil { if err != nil {
t.Fatalf("ParseConfigFile failed: %v", err) t.Fatalf("ParseConfigStruct failed: %v", err)
} }
if len(envVars) != 3 { if len(envVars) != 3 {
@@ -152,51 +146,70 @@ type Config struct {
} }
} }
func TestParseConfigPackage(t *testing.T) { func TestParseConfigStruct_NilPointer(t *testing.T) {
// Test with actual hlog package _, err := ParseConfigStruct(nil)
hlogPath := filepath.Join("..", "hlog")
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
t.Skip("hlog package not found, skipping integration test")
}
envVars, err := ParseConfigPackage(hlogPath)
if err != nil {
t.Fatalf("ParseConfigPackage failed: %v", err)
}
if len(envVars) == 0 {
t.Error("expected at least one env var from hlog package")
}
// Check for known hlog variables
foundLogLevel := false
for _, envVar := range envVars {
if envVar.Name == "LOG_LEVEL" {
foundLogLevel = true
t.Logf("Found LOG_LEVEL: %s", envVar.Description)
}
}
if !foundLogLevel {
t.Error("expected to find LOG_LEVEL in hlog package")
}
}
func TestParseConfigFile_InvalidFile(t *testing.T) {
_, err := ParseConfigFile("/nonexistent/file.go")
if err == nil { if err == nil {
t.Error("expected error for nonexistent file") t.Error("expected error for nil pointer")
} }
} }
func TestParseConfigPackage_InvalidPath(t *testing.T) { func TestParseConfigStruct_NotPointer(t *testing.T) {
envVars, err := ParseConfigPackage("/nonexistent/package") type TestConfig struct {
Foo string `ezconf:"FOO,description:test"`
}
_, err := ParseConfigStruct(TestConfig{})
if err == nil {
t.Error("expected error for non-pointer")
}
}
func TestParseConfigStruct_NotStruct(t *testing.T) {
str := "not a struct"
_, err := ParseConfigStruct(&str)
if err == nil {
t.Error("expected error for non-struct pointer")
}
}
func TestParseConfigStruct_NoTags(t *testing.T) {
type EmptyConfig struct {
Foo string
Bar int
}
envVars, err := ParseConfigStruct(&EmptyConfig{})
if err != nil { if err != nil {
t.Fatalf("ParseConfigPackage should not error on invalid path: %v", err) t.Fatalf("ParseConfigStruct failed: %v", err)
} }
// Should return empty slice for invalid path
if len(envVars) != 0 { if len(envVars) != 0 {
t.Errorf("expected 0 env vars for invalid path, got %d", len(envVars)) t.Errorf("expected 0 env vars for struct with no tags, got %d", len(envVars))
}
}
func TestParseConfigStruct_UnexportedFields(t *testing.T) {
type TestConfig struct {
exported string `ezconf:"EXPORTED,description:An exported field"`
unexported string `ezconf:"UNEXPORTED,description:An unexported field"`
}
envVars, err := ParseConfigStruct(&TestConfig{})
if err != nil {
t.Fatalf("ParseConfigStruct failed: %v", err)
}
if len(envVars) != 2 {
t.Errorf("expected 2 env vars (both exported and unexported), got %d", len(envVars))
}
}
func TestParseConfigStruct_InvalidTag(t *testing.T) {
type TestConfig struct {
Bad string `ezconf:",description:missing name"`
}
_, err := ParseConfigStruct(&TestConfig{})
if err == nil {
t.Error("expected error for invalid tag")
} }
} }

View File

@@ -9,11 +9,11 @@ import (
// It can be populated from environment variables using ConfigFromEnv // It can be populated from environment variables using ConfigFromEnv
// or created programmatically. // or created programmatically.
type Config struct { type Config struct {
LogLevel Level // ENV LOG_LEVEL: Log level for the logger - trace, debug, info, warn, error, fatal, panic (default: info) LogLevel Level `ezconf:"LOG_LEVEL,description:Log level for the logger - trace debug info warn error fatal panic,default:info"`
LogOutput string // ENV LOG_OUTPUT: Output destination for logs - console, file, or both (default: console) LogOutput string `ezconf:"LOG_OUTPUT,description:Output destination for logs - console file or both,default:console"`
LogDir string // ENV LOG_DIR: Directory path for log files (required when LOG_OUTPUT is "file" or "both") LogDir string `ezconf:"LOG_DIR,description:Directory path for log files,required:when LOG_OUTPUT is file or both"`
LogFileName string // ENV LOG_FILE_NAME: Name of the log file (required when LOG_OUTPUT is "file" or "both") LogFileName string `ezconf:"LOG_FILE_NAME,description:Name of the log file,required:when LOG_OUTPUT is file or both"`
LogAppend bool // ENV LOG_APPEND: Append to existing log file or overwrite (default: true) LogAppend bool `ezconf:"LOG_APPEND,description:Append to existing log file or overwrite,default:true"`
} }
// ConfigFromEnv loads logger configuration from environment variables. // ConfigFromEnv loads logger configuration from environment variables.

View File

@@ -1,35 +1,9 @@
package hlog package hlog
import "runtime" import "git.haelnorr.com/h/golib/ezconf"
// EZConfIntegration provides integration with ezconf for automatic configuration // NewEZConfIntegration creates a new EZConf integration
type EZConfIntegration struct{} func NewEZConfIntegration() *ezconf.Integration {
return ezconf.NewIntegration("hlog", "HLog",
// PackagePath returns the path to the hlog package for source parsing &Config{}, func() (any, error) { return ConfigFromEnv() })
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
return ConfigFromEnv()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return "hlog"
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return "HLog"
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{}
} }

View File

@@ -7,6 +7,8 @@ require (
github.com/rs/zerolog v1.34.0 github.com/rs/zerolog v1.34.0
) )
require git.haelnorr.com/h/golib/ezconf v0.2.1
require ( require (
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect

View File

@@ -1,5 +1,7 @@
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY= git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=

View File

@@ -7,13 +7,13 @@ import (
) )
type Config struct { type Config struct {
Host string // ENV HWS_HOST: Host to listen on (default: 127.0.0.1) Host string `ezconf:"HWS_HOST,description:Host to listen on,default:127.0.0.1"`
Port uint64 // ENV HWS_PORT: Port to listen on (default: 3000) Port uint64 `ezconf:"HWS_PORT,description:Port to listen on,default:3000"`
GZIP bool // ENV HWS_GZIP: Flag for GZIP compression on requests (default: false) GZIP bool `ezconf:"HWS_GZIP,description:Flag for GZIP compression on requests,default:false"`
ReadHeaderTimeout time.Duration // ENV HWS_READ_HEADER_TIMEOUT: Timeout for reading request headers in seconds (default: 2) ReadHeaderTimeout time.Duration `ezconf:"HWS_READ_HEADER_TIMEOUT,description:Timeout for reading request headers in seconds,default:2"`
WriteTimeout time.Duration // ENV HWS_WRITE_TIMEOUT: Timeout for writing requests in seconds (default: 10) WriteTimeout time.Duration `ezconf:"HWS_WRITE_TIMEOUT,description:Timeout for writing requests in seconds,default:10"`
IdleTimeout time.Duration // ENV HWS_IDLE_TIMEOUT: Timeout for idle connections in seconds (default: 120) IdleTimeout time.Duration `ezconf:"HWS_IDLE_TIMEOUT,description:Timeout for idle connections in seconds,default:120"`
ShutdownDelay time.Duration // ENV HWS_SHUTDOWN_DELAY: Delay in seconds before server shutsdown when Shutdown is called (default: 5) ShutdownDelay time.Duration `ezconf:"HWS_SHUTDOWN_DELAY,description:Delay in seconds before server shuts down when Shutdown is called,default:5"`
} }
// ConfigFromEnv returns a Config struct loaded from the environment variables // ConfigFromEnv returns a Config struct loaded from the environment variables

View File

@@ -13,12 +13,12 @@ import (
func Test_ConfigFromEnv(t *testing.T) { func Test_ConfigFromEnv(t *testing.T) {
t.Run("Default values when no env vars set", func(t *testing.T) { t.Run("Default values when no env vars set", func(t *testing.T) {
// Clear any existing env vars // Clear any existing env vars
os.Unsetenv("HWS_HOST") _ = os.Unsetenv("HWS_HOST")
os.Unsetenv("HWS_PORT") _ = os.Unsetenv("HWS_PORT")
os.Unsetenv("HWS_GZIP") _ = os.Unsetenv("HWS_GZIP")
os.Unsetenv("HWS_READ_HEADER_TIMEOUT") _ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
os.Unsetenv("HWS_WRITE_TIMEOUT") _ = os.Unsetenv("HWS_WRITE_TIMEOUT")
os.Unsetenv("HWS_IDLE_TIMEOUT") _ = os.Unsetenv("HWS_IDLE_TIMEOUT")
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -33,8 +33,10 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("Custom host", func(t *testing.T) { t.Run("Custom host", func(t *testing.T) {
os.Setenv("HWS_HOST", "192.168.1.1") _ = os.Setenv("HWS_HOST", "192.168.1.1")
defer os.Unsetenv("HWS_HOST") defer func() {
_ = os.Unsetenv("HWS_HOST")
}()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -42,8 +44,10 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("Custom port", func(t *testing.T) { t.Run("Custom port", func(t *testing.T) {
os.Setenv("HWS_PORT", "8080") _ = os.Setenv("HWS_PORT", "8080")
defer os.Unsetenv("HWS_PORT") defer func() {
_ = os.Unsetenv("HWS_PORT")
}()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -51,8 +55,10 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("GZIP enabled", func(t *testing.T) { t.Run("GZIP enabled", func(t *testing.T) {
os.Setenv("HWS_GZIP", "true") _ = os.Setenv("HWS_GZIP", "true")
defer os.Unsetenv("HWS_GZIP") defer func() {
_ = os.Unsetenv("HWS_GZIP")
}()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -60,12 +66,14 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("Custom timeouts", func(t *testing.T) { t.Run("Custom timeouts", func(t *testing.T) {
os.Setenv("HWS_READ_HEADER_TIMEOUT", "5") _ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
os.Setenv("HWS_WRITE_TIMEOUT", "30") _ = os.Setenv("HWS_WRITE_TIMEOUT", "30")
os.Setenv("HWS_IDLE_TIMEOUT", "300") _ = os.Setenv("HWS_IDLE_TIMEOUT", "300")
defer os.Unsetenv("HWS_READ_HEADER_TIMEOUT") defer func() {
defer os.Unsetenv("HWS_WRITE_TIMEOUT") _ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
defer os.Unsetenv("HWS_IDLE_TIMEOUT") _ = os.Unsetenv("HWS_WRITE_TIMEOUT")
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
}()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()
require.NoError(t, err) require.NoError(t, err)
@@ -75,19 +83,19 @@ func Test_ConfigFromEnv(t *testing.T) {
}) })
t.Run("All custom values", func(t *testing.T) { t.Run("All custom values", func(t *testing.T) {
os.Setenv("HWS_HOST", "0.0.0.0") _ = os.Setenv("HWS_HOST", "0.0.0.0")
os.Setenv("HWS_PORT", "9000") _ = os.Setenv("HWS_PORT", "9000")
os.Setenv("HWS_GZIP", "true") _ = os.Setenv("HWS_GZIP", "true")
os.Setenv("HWS_READ_HEADER_TIMEOUT", "3") _ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
os.Setenv("HWS_WRITE_TIMEOUT", "15") _ = os.Setenv("HWS_WRITE_TIMEOUT", "15")
os.Setenv("HWS_IDLE_TIMEOUT", "180") _ = os.Setenv("HWS_IDLE_TIMEOUT", "180")
defer func() { defer func() {
os.Unsetenv("HWS_HOST") _ = os.Unsetenv("HWS_HOST")
os.Unsetenv("HWS_PORT") _ = os.Unsetenv("HWS_PORT")
os.Unsetenv("HWS_GZIP") _ = os.Unsetenv("HWS_GZIP")
os.Unsetenv("HWS_READ_HEADER_TIMEOUT") _ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
os.Unsetenv("HWS_WRITE_TIMEOUT") _ = os.Unsetenv("HWS_WRITE_TIMEOUT")
os.Unsetenv("HWS_IDLE_TIMEOUT") _ = os.Unsetenv("HWS_IDLE_TIMEOUT")
}() }()
config, err := hws.ConfigFromEnv() config, err := hws.ConfigFromEnv()

View File

@@ -9,7 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// Error to use with Server.ThrowError // HWSError wraps an error with other information for use with HWS features
type HWSError struct { type HWSError struct {
StatusCode int // HTTP Status code StatusCode int // HTTP Status code
Message string // Error message Message string // Error message
@@ -41,7 +41,7 @@ type ErrorPage interface {
} }
// AddErrorPage registers a handler that returns an ErrorPage // AddErrorPage registers a handler that returns an ErrorPage
func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error { func (s *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError}) page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError})
@@ -56,7 +56,7 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
return errors.New("Render method of the error page did not write anything to the response writer") return errors.New("Render method of the error page did not write anything to the response writer")
} }
server.errorPage = pageFunc s.errorPage = pageFunc
return nil return nil
} }
@@ -64,7 +64,19 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
// the error with the level specified by the HWSError. // the error with the level specified by the HWSError.
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter // If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
// and the request chain should be terminated. // and the request chain should be terminated.
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) error { func (s *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) {
err := s.throwError(w, r, error)
if err != nil {
s.LogError(error)
s.LogError(HWSError{
Message: "Error occured during throwError",
Error: errors.Wrap(err, "s.throwError"),
Level: ErrorERROR,
})
}
}
func (s *Server) throwError(w http.ResponseWriter, r *http.Request, error HWSError) error {
if error.StatusCode <= 0 { if error.StatusCode <= 0 {
return errors.New("HWSError.StatusCode cannot be 0.") return errors.New("HWSError.StatusCode cannot be 0.")
} }
@@ -77,32 +89,27 @@ func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error H
if r == nil { if r == nil {
return errors.New("Request cannot be nil") return errors.New("Request cannot be nil")
} }
if !server.IsReady() { if !s.IsReady() {
return errors.New("ThrowError called before server started") return errors.New("ThrowError called before server started")
} }
w.WriteHeader(error.StatusCode) w.WriteHeader(error.StatusCode)
server.LogError(error) s.LogError(error)
if server.errorPage == nil { if s.errorPage == nil {
server.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG}) s.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
return nil return nil
} }
if error.RenderErrorPage { if error.RenderErrorPage {
server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG}) s.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
errPage, err := server.errorPage(error) errPage, err := s.errorPage(error)
if err != nil { if err != nil {
server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err}) s.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
} }
err = errPage.Render(r.Context(), w) err = errPage.Render(r.Context(), w)
if err != nil { if err != nil {
server.LogError(HWSError{Message: "Failed to render error page", Error: err}) s.LogError(HWSError{Message: "Failed to render error page", Error: err})
} }
} else { } else {
server.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG}) s.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
} }
return nil return nil
} }
func (server *Server) ThrowFatal(w http.ResponseWriter, err error) {
w.WriteHeader(http.StatusInternalServerError)
server.LogFatal(err)
}

View File

@@ -14,22 +14,26 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type goodPage struct{} type (
type badPage struct{} goodPage struct{}
badPage struct{}
)
func goodRender(error hws.HWSError) (hws.ErrorPage, error) { func goodRender(error hws.HWSError) (hws.ErrorPage, error) {
return goodPage{}, nil return goodPage{}, nil
} }
func badRender1(error hws.HWSError) (hws.ErrorPage, error) { func badRender1(error hws.HWSError) (hws.ErrorPage, error) {
return badPage{}, nil return badPage{}, nil
} }
func badRender2(error hws.HWSError) (hws.ErrorPage, error) { func badRender2(error hws.HWSError) (hws.ErrorPage, error) {
return nil, errors.New("I'm an error") return nil, errors.New("I'm an error")
} }
func (g goodPage) Render(ctx context.Context, w io.Writer) error { func (g goodPage) Render(ctx context.Context, w io.Writer) error {
w.Write([]byte("Test write to ResponseWriter")) _, err := w.Write([]byte("Test write to ResponseWriter"))
return nil return err
} }
func (b badPage) Render(ctx context.Context, w io.Writer) error { func (b badPage) Render(ctx context.Context, w io.Writer) error {
@@ -85,40 +89,42 @@ func Test_ThrowError(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
t.Run("Server not started", func(t *testing.T) { t.Run("Server not started", func(t *testing.T) {
err := server.ThrowError(rr, req, hws.HWSError{ buf.Reset()
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "Error", Message: "Error",
Error: errors.New("Error"), Error: errors.New("Error"),
}) })
assert.Error(t, err) // ThrowError logs errors internally when validation fails
output := buf.String()
assert.Contains(t, output, "ThrowError called before server started")
}) })
startTestServer(t, server) startTestServer(t, server)
defer server.Shutdown(t.Context())
tests := []struct { tests := []struct {
name string name string
request *http.Request request *http.Request
error hws.HWSError error hws.HWSError
valid bool expectLogItem string
}{ }{
{ {
name: "No HWSError.Status code", name: "No HWSError.Status code",
request: nil, request: nil,
error: hws.HWSError{}, error: hws.HWSError{},
valid: false, expectLogItem: "HWSError.StatusCode cannot be 0",
}, },
{ {
name: "Negative HWSError.Status code", name: "Negative HWSError.Status code",
request: nil, request: nil,
error: hws.HWSError{StatusCode: -1}, error: hws.HWSError{StatusCode: -1},
valid: false, expectLogItem: "HWSError.StatusCode cannot be 0",
}, },
{ {
name: "No HWSError.Message", name: "No HWSError.Message",
request: nil, request: nil,
error: hws.HWSError{StatusCode: http.StatusInternalServerError}, error: hws.HWSError{StatusCode: http.StatusInternalServerError},
valid: false, expectLogItem: "HWSError.Message cannot be empty",
}, },
{ {
name: "No HWSError.Error", name: "No HWSError.Error",
@@ -127,7 +133,7 @@ func Test_ThrowError(t *testing.T) {
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
}, },
valid: false, expectLogItem: "HWSError.Error cannot be nil",
}, },
{ {
name: "No request provided", name: "No request provided",
@@ -137,7 +143,7 @@ func Test_ThrowError(t *testing.T) {
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
}, },
valid: false, expectLogItem: "Request cannot be nil",
}, },
{ {
name: "Valid", name: "Valid",
@@ -147,106 +153,92 @@ func Test_ThrowError(t *testing.T) {
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
}, },
valid: true, expectLogItem: "An error occured",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
buf.Reset()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
err := server.ThrowError(rr, tt.request, tt.error) server.ThrowError(rr, tt.request, tt.error)
if tt.valid { // ThrowError no longer returns errors; check logs instead
assert.NoError(t, err) output := buf.String()
} else { assert.Contains(t, output, tt.expectLogItem)
t.Log(err)
assert.Error(t, err)
}
}) })
} }
t.Run("Log level set correctly", func(t *testing.T) { t.Run("Log level set correctly", func(t *testing.T) {
buf.Reset() buf.Reset()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{ server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
Level: hws.ErrorWARN, Level: hws.ErrorWARN,
}) })
assert.NoError(t, err) _, err := buf.ReadString([]byte(" ")[0])
_, err = buf.ReadString([]byte(" ")[0]) require.NoError(t, err)
loglvl, err := buf.ReadString([]byte(" ")[0]) loglvl, err := buf.ReadString([]byte(" ")[0])
assert.NoError(t, err) require.NoError(t, err)
if loglvl != "\x1b[33mWRN\x1b[0m " { assert.Equal(t, "\x1b[33mWRN\x1b[0m ", loglvl, "Log level should be WRN for ErrorWARN")
err = errors.New("Log level not set correctly")
}
assert.NoError(t, err)
buf.Reset() buf.Reset()
err = server.ThrowError(rr, req, hws.HWSError{ server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
}) })
assert.NoError(t, err)
_, err = buf.ReadString([]byte(" ")[0]) _, err = buf.ReadString([]byte(" ")[0])
require.NoError(t, err)
loglvl, err = buf.ReadString([]byte(" ")[0]) loglvl, err = buf.ReadString([]byte(" ")[0])
assert.NoError(t, err) require.NoError(t, err)
if loglvl != "\x1b[31mERR\x1b[0m " { assert.Equal(t, "\x1b[31mERR\x1b[0m ", loglvl, "Log level should be ERR when no level specified")
err = errors.New("Log level not set correctly")
}
assert.NoError(t, err)
}) })
t.Run("Error page doesnt render if no error page set", func(t *testing.T) { t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
// Must be run before adding the error page to the test server // Must be run before adding the error page to the test server
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{ server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
RenderErrorPage: true, RenderErrorPage: true,
}) })
assert.NoError(t, err)
body := rr.Body.String() body := rr.Body.String()
if body != "" { assert.Empty(t, body, "Error page should not render when no error page is set")
assert.Error(t, nil)
}
}) })
t.Run("Error page renders", func(t *testing.T) { t.Run("Error page renders", func(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
// Adding the error page will carry over to all future tests and cant be undone // Adding the error page will carry over to all future tests and cant be undone
server.AddErrorPage(goodRender) err := server.AddErrorPage(goodRender)
err := server.ThrowError(rr, req, hws.HWSError{ require.NoError(t, err)
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
RenderErrorPage: true, RenderErrorPage: true,
}) })
assert.NoError(t, err)
body := rr.Body.String() body := rr.Body.String()
if body == "" { assert.NotEmpty(t, body, "Error page should render when RenderErrorPage is true")
assert.Error(t, nil)
}
}) })
t.Run("Error page doesnt render if no told to render", func(t *testing.T) { t.Run("Error page doesnt render if not told to render", func(t *testing.T) {
// Error page already added to server // Error page already added to server
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
err := server.ThrowError(rr, req, hws.HWSError{ server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
}) })
assert.NoError(t, err)
body := rr.Body.String() body := rr.Body.String()
if body != "" { assert.Empty(t, body, "Error page should not render when RenderErrorPage is false")
assert.Error(t, nil)
}
}) })
server.Shutdown(t.Context()) err := server.Shutdown(t.Context())
require.NoError(t, err)
t.Run("Doesn't error if no logger added to server", func(t *testing.T) { t.Run("Doesn't panic if no logger added to server", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{ server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: randomPort(), Port: randomPort(),
@@ -261,13 +253,18 @@ func Test_ThrowError(t *testing.T) {
err = server.Start(t.Context()) err = server.Start(t.Context())
require.NoError(t, err) require.NoError(t, err)
<-server.Ready() <-server.Ready()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
err = server.ThrowError(rr, req, hws.HWSError{ // Should not panic when no logger is present
assert.NotPanics(t, func() {
server.ThrowError(rr, req, hws.HWSError{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Message: "An error occured", Message: "An error occured",
Error: errors.New("Error"), Error: errors.New("Error"),
}) })
assert.NoError(t, err) }, "ThrowError should not panic when no logger is present")
err = server.Shutdown(t.Context())
require.NoError(t, err)
}) })
} }

View File

@@ -1,35 +1,9 @@
package hws package hws
import "runtime" import "git.haelnorr.com/h/golib/ezconf"
// EZConfIntegration provides integration with ezconf for automatic configuration // NewEZConfIntegration creates a new EZConf integration
type EZConfIntegration struct{} func NewEZConfIntegration() *ezconf.Integration {
return ezconf.NewIntegration("hws", "HWS",
// PackagePath returns the path to the hws package for source parsing &Config{}, func() (any, error) { return ConfigFromEnv() })
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
return ConfigFromEnv()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return "hws"
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return "HWS"
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{}
} }

View File

@@ -4,21 +4,24 @@ go 1.25.5
require ( require (
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/hlog v0.9.0 git.haelnorr.com/h/golib/hlog v0.11.0
git.haelnorr.com/h/golib/notify v0.1.0 git.haelnorr.com/h/golib/notify v0.1.0
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
k8s.io/apimachinery v0.35.0 k8s.io/apimachinery v0.35.0
) )
require git.haelnorr.com/h/golib/ezconf v0.2.1
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/gobwas/glob v0.2.3
github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rs/zerolog v1.34.0 // indirect github.com/rs/zerolog v1.34.0 // indirect
golang.org/x/sys v0.12.0 // indirect golang.org/x/sys v0.41.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect

View File

@@ -1,7 +1,9 @@
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY= git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/hlog v0.9.0 h1:ib8n2MdmiRK2TF067p220kXmhDe9aAnlcsgpuv+QpvE= git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
git.haelnorr.com/h/golib/hlog v0.9.0/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk= git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
git.haelnorr.com/h/golib/hlog v0.11.0 h1:tCT8HWs51Nbin58sCTLcq5re6CZqo5/IHCzk3G+S3vQ=
git.haelnorr.com/h/golib/hlog v0.11.0/go.mod h1:HjhXS5G3A0BwOZq7nu2qpNBtvOFiCa1GbAuBRxAkYqs=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10= git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc= git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
@@ -9,12 +11,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -26,8 +32,9 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -17,6 +17,10 @@ import (
func Test_GZIP_Compression(t *testing.T) { func Test_GZIP_Compression(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
dbg, _ := hlog.LogLevel("debug")
logcfg := &hlog.Config{
LogLevel: dbg,
}
t.Run("GZIP enabled compresses response", func(t *testing.T) { t.Run("GZIP enabled compresses response", func(t *testing.T) {
server, err := hws.NewServer(&hws.Config{ server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
@@ -25,7 +29,7 @@ func Test_GZIP_Compression(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "") logger, err := hlog.NewLogger(logcfg, &buf)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)
@@ -80,7 +84,7 @@ func Test_GZIP_Compression(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "") logger, err := hlog.NewLogger(logcfg, &buf)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)
@@ -131,7 +135,7 @@ func Test_GZIP_Compression(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "") logger, err := hlog.NewLogger(logcfg, &buf)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)

View File

@@ -6,68 +6,58 @@ import (
"net/url" "net/url"
"git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/hlog"
"github.com/gobwas/glob"
) )
type logger struct { type logger struct {
logger *hlog.Logger logger *hlog.Logger
ignoredPaths []string ignoredPaths []glob.Glob
} }
// TODO: add tests to make sure all the fields are correctly set // LogError uses the attached logger to log a HWSError
func (s *Server) LogError(err HWSError) { func (s *Server) LogError(err HWSError) {
if s.logger == nil { if s.logger == nil {
return return
} }
switch err.Level { switch err.Level {
case ErrorDEBUG: case ErrorDEBUG:
s.logger.logger.Debug().Msg(err.Message) s.logger.logger.Debug().Err(err.Error).Msg(err.Message)
return return
case ErrorINFO: case ErrorINFO:
s.logger.logger.Info().Msg(err.Message) s.logger.logger.Info().Err(err.Error).Msg(err.Message)
return return
case ErrorWARN: case ErrorWARN:
s.logger.logger.Warn().Err(err.Error).Msg(err.Message) s.logger.logger.Warn().Err(err.Error).Msg(err.Message)
return return
case ErrorERROR: case ErrorERROR:
s.logger.logger.Error().Str("stacktrace", fmt.Sprintf("%+v", err)).Err(err.Error).Msg(err.Message) s.logger.logger.Error().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
return return
case ErrorFATAL: case ErrorFATAL:
s.logger.logger.Fatal().Str("stacktrace", fmt.Sprintf("%+v", err)).Err(err.Error).Msg(err.Message) s.logger.logger.Fatal().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
return return
case ErrorPANIC: case ErrorPANIC:
s.logger.logger.Panic().Str("stacktrace", fmt.Sprintf("%+v", err)).Err(err.Error).Msg(err.Message) s.logger.logger.Panic().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
return return
default: default:
s.logger.logger.Error().Str("stacktrace", fmt.Sprintf("%+v", err)).Err(err.Error).Msg(err.Message) s.logger.logger.Error().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
} }
} }
func (server *Server) LogFatal(err error) { // AddLogger adds a logger to the server to use for request logging.
if err == nil { func (s *Server) AddLogger(hlogger *hlog.Logger) error {
err = errors.New("LogFatal was called with a nil error")
}
if server.logger == nil {
fmt.Printf("FATAL - %s: %s", "A fatal error has occured", err.Error())
return
}
server.logger.logger.Fatal().Err(err).Msg("A fatal error has occured")
}
// Server.AddLogger adds a logger to the server to use for request logging.
func (server *Server) AddLogger(hlogger *hlog.Logger) error {
if hlogger == nil { if hlogger == nil {
return errors.New("Unable to add logger, no logger provided") return errors.New("unable to add logger, no logger provided")
} }
server.logger = &logger{ s.logger = &logger{
logger: hlogger, logger: hlogger,
} }
return nil return nil
} }
// Server.LoggerIgnorePaths sets a list of URL paths to ignore logging for. // LoggerIgnorePaths sets a list of URL paths to ignore logging for.
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL // Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
// Useful for ignoring requests to CSS files or favicons // Useful for ignoring requests to CSS files or favicons
func (server *Server) LoggerIgnorePaths(paths ...string) error { func (s *Server) LoggerIgnorePaths(paths ...string) error {
for _, path := range paths { for _, path := range paths {
u, err := url.Parse(path) u, err := url.Parse(path)
valid := err == nil && valid := err == nil &&
@@ -76,9 +66,22 @@ func (server *Server) LoggerIgnorePaths(paths ...string) error {
u.RawQuery == "" && u.RawQuery == "" &&
u.Fragment == "" u.Fragment == ""
if !valid { if !valid {
return fmt.Errorf("Invalid path: '%s'", path) return fmt.Errorf("invalid path: '%s'", path)
} }
} }
server.logger.ignoredPaths = paths s.logger.ignoredPaths = prepareGlobs(paths)
return nil return nil
} }
func prepareGlobs(paths []string) []glob.Glob {
compiledGlobs := make([]glob.Glob, 0, len(paths))
for _, pattern := range paths {
g, err := glob.Compile(pattern)
if err != nil {
// If pattern fails to compile, skip it
continue
}
compiledGlobs = append(compiledGlobs, g)
}
return compiledGlobs
}

View File

@@ -25,6 +25,10 @@ func Test_AddLogger(t *testing.T) {
} }
func Test_LogError_AllLevels(t *testing.T) { func Test_LogError_AllLevels(t *testing.T) {
dbg, _ := hlog.LogLevel("debug")
logcfg := &hlog.Config{
LogLevel: dbg,
}
t.Run("DEBUG level", func(t *testing.T) { t.Run("DEBUG level", func(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
// Create server with logger explicitly set to Debug level // Create server with logger explicitly set to Debug level
@@ -34,7 +38,7 @@ func Test_LogError_AllLevels(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
logger, err := hlog.NewLogger(hlog.LogLevel("debug"), &buf, nil, "") logger, err := hlog.NewLogger(logcfg, &buf)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)
@@ -197,7 +201,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("http://example.com/path") err := server.LoggerIgnorePaths("http://example.com/path")
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path") assert.Contains(t, err.Error(), "invalid path")
}) })
t.Run("Invalid path with host", func(t *testing.T) { t.Run("Invalid path with host", func(t *testing.T) {
@@ -207,7 +211,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("//example.com/path") err := server.LoggerIgnorePaths("//example.com/path")
assert.Error(t, err) assert.Error(t, err)
if err != nil { if err != nil {
assert.Contains(t, err.Error(), "Invalid path") assert.Contains(t, err.Error(), "invalid path")
} }
}) })
@@ -217,7 +221,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("/path?query=value") err := server.LoggerIgnorePaths("/path?query=value")
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path") assert.Contains(t, err.Error(), "invalid path")
}) })
t.Run("Invalid path with fragment", func(t *testing.T) { t.Run("Invalid path with fragment", func(t *testing.T) {
@@ -226,7 +230,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
err := server.LoggerIgnorePaths("/path#fragment") err := server.LoggerIgnorePaths("/path#fragment")
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid path") assert.Contains(t, err.Error(), "invalid path")
}) })
t.Run("Valid paths", func(t *testing.T) { t.Run("Valid paths", func(t *testing.T) {

View File

@@ -5,35 +5,37 @@ import (
"net/http" "net/http"
) )
type Middleware func(h http.Handler) http.Handler type (
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError) Middleware func(h http.Handler) http.Handler
MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
)
// Server.AddMiddleware registers all the middleware. // AddMiddleware registers all the middleware.
// Middleware will be run in the order that they are provided. // Middleware will be run in the order that they are provided.
// Can only be called once // Can only be called once
func (server *Server) AddMiddleware(middleware ...Middleware) error { func (s *Server) AddMiddleware(middleware ...Middleware) error {
if !server.routes { if !s.routes {
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware") return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
} }
if server.middleware { if s.middleware {
return errors.New("Server.AddMiddleware already called") return errors.New("Server.AddMiddleware already called")
} }
// RUN LOGGING MIDDLEWARE FIRST // RUN LOGGING MIDDLEWARE FIRST
server.server.Handler = logging(server.server.Handler, server.logger) s.server.Handler = logging(s.server.Handler, s.logger)
// LOOP PROVIDED MIDDLEWARE IN REVERSE order // LOOP PROVIDED MIDDLEWARE IN REVERSE order
for i := len(middleware); i > 0; i-- { for i := len(middleware); i > 0; i-- {
server.server.Handler = middleware[i-1](server.server.Handler) s.server.Handler = middleware[i-1](s.server.Handler)
} }
// RUN GZIP // RUN GZIP
if server.GZIP { if s.GZIP {
server.server.Handler = addgzip(server.server.Handler) s.server.Handler = addgzip(s.server.Handler)
} }
// RUN TIMER MIDDLEWARE LAST // RUN TIMER MIDDLEWARE LAST
server.server.Handler = startTimer(server.server.Handler) s.server.Handler = startTimer(s.server.Handler)
server.middleware = true s.middleware = true
return nil return nil
} }
@@ -43,14 +45,14 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error {
// and returns a new request and optional HWSError. // and returns a new request and optional HWSError.
// If a HWSError is returned, server.ThrowError will be called. // If a HWSError is returned, server.ThrowError will be called.
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered // If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
func (server *Server) NewMiddleware( func (s *Server) NewMiddleware(
middlewareFunc MiddlewareFunc, middlewareFunc MiddlewareFunc,
) Middleware { ) Middleware {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
newReq, herr := middlewareFunc(w, r) newReq, herr := middlewareFunc(w, r)
if herr != nil { if herr != nil {
server.ThrowError(w, r, *herr) s.ThrowError(w, r, *herr)
if herr.RenderErrorPage { if herr.RenderErrorPage {
return return
} }

View File

@@ -2,8 +2,9 @@ package hws
import ( import (
"net/http" "net/http"
"slices"
"time" "time"
"github.com/gobwas/glob"
) )
// Middleware to add logs to console with details of the request // Middleware to add logs to console with details of the request
@@ -13,7 +14,7 @@ func logging(next http.Handler, logger *logger) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
if slices.Contains(logger.ignoredPaths, r.URL.Path) { if globTest(r.URL.Path, logger.ignoredPaths) {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
@@ -36,3 +37,12 @@ func logging(next http.Handler, logger *logger) http.Handler {
Msg("Served") Msg("Served")
}) })
} }
func globTest(testPath string, globs []glob.Glob) bool {
for _, g := range globs {
if g.Match(testPath) {
return true
}
}
return false
}

View File

@@ -18,16 +18,24 @@ func startTimer(next http.Handler) http.Handler {
) )
} }
type contextKey string
func (c contextKey) String() string {
return "hws context key " + string(c)
}
var requestTimerCtxKey = contextKey("request-timer")
// Set the start time of the request // Set the start time of the request
func setStart(ctx context.Context, time time.Time) context.Context { func setStart(ctx context.Context, time time.Time) context.Context {
return context.WithValue(ctx, "hws context key request-timer", time) return context.WithValue(ctx, requestTimerCtxKey, time)
} }
// Get the start time of the request // Get the start time of the request
func getStartTime(ctx context.Context) (time.Time, error) { func getStartTime(ctx context.Context) (time.Time, error) {
start, ok := ctx.Value("hws context key request-timer").(time.Time) start, ok := ctx.Value(requestTimerCtxKey).(time.Time)
if !ok { if !ok {
return time.Time{}, errors.New("Failed to get start time of request") return time.Time{}, errors.New("failed to get start time of request")
} }
return start, nil return start, nil
} }

View File

@@ -100,7 +100,7 @@ func (s *Server) NotifySub(nt notify.Notification) {
} }
_, exists := s.notifier.clients.getClient(nt.Target) _, exists := s.notifier.clients.getClient(nt.Target)
if !exists { if !exists {
err := fmt.Errorf("Tried to notify subscriber that doesn't exist - subID: %s", nt.Target) err := fmt.Errorf("tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err}) s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
return return
} }
@@ -119,7 +119,7 @@ func (s *Server) NotifyID(nt notify.Notification, altID string) {
clients, exists := s.notifier.clients.clientsIDMap[altID] clients, exists := s.notifier.clients.clientsIDMap[altID]
s.notifier.clients.lock.RUnlock() s.notifier.clients.lock.RUnlock()
if !exists { if !exists {
err := fmt.Errorf("Tried to notify client group that doesn't exist - altID: %s", altID) err := fmt.Errorf("tried to notify client group that doesn't exist - altID: %s", altID)
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err}) s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
return return
} }

View File

@@ -17,6 +17,7 @@ func newTestServerWithNotifier(t *testing.T) *Server {
cfg := &Config{ cfg := &Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 0, Port: 0,
ShutdownDelay: 0, // No delay for tests
} }
server, err := NewServer(cfg) server, err := NewServer(cfg)
@@ -359,7 +360,7 @@ func Test_ActiveClientStaysAlive(t *testing.T) {
done := make(chan bool) done := make(chan bool)
go func() { go func() {
for i := 0; i < 3; i++ { for range 3 {
<-ticker.C <-ticker.C
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
@@ -460,7 +461,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
defer close(stop) defer close(stop)
// Send 10 notifications quickly (buffer is 10) // Send 10 notifications quickly (buffer is 10)
for i := 0; i < 10; i++ { for range 10 {
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
Message: "Burst message", Message: "Burst message",
@@ -468,7 +469,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
} }
// Client should receive all 10 // Client should receive all 10
for i := 0; i < 10; i++ { for i := range 10 {
select { select {
case <-notifications: case <-notifications:
// Received // Received
@@ -487,7 +488,7 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
defer close(stop) defer close(stop)
// Fill buffer completely (buffer is 10) // Fill buffer completely (buffer is 10)
for i := 0; i < 10; i++ { for range 10 {
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
Message: "Fill buffer", Message: "Fill buffer",
@@ -500,15 +501,15 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
Message: "Timeout message", Message: "Timeout message",
}) })
// Wait for timeout // Wait for timeout (5s timeout + small buffer)
time.Sleep(6 * time.Second) time.Sleep(5100 * time.Millisecond)
// Check failure count (should be 1) // Check failure count (should be 1)
fails := atomic.LoadInt32(&client.consecutiveFails) fails := atomic.LoadInt32(&client.consecutiveFails)
require.Equal(t, int32(1), fails, "Should have 1 timeout") require.Equal(t, int32(1), fails, "Should have 1 timeout")
// Now read all buffered messages // Now read all buffered messages
for i := 0; i < 10; i++ { for range 10 {
<-notifications <-notifications
} }
@@ -538,7 +539,7 @@ func Test_ConsecutiveFailureDisconnect(t *testing.T) {
defer close(stop) defer close(stop)
// Fill buffer and never read to cause 5 consecutive timeouts // Fill buffer and never read to cause 5 consecutive timeouts
for i := 0; i < 20; i++ { for range 20 {
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
Message: "Timeout message", Message: "Timeout message",
@@ -684,7 +685,7 @@ func Test_ConcurrentSubscriptions(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
clients := make([]*Client, 100) clients := make([]*Client, 100)
for i := 0; i < 100; i++ { for i := range 100 {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int) {
defer wg.Done() defer wg.Done()
@@ -716,7 +717,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
messageCount := 50 messageCount := 50
// Send from multiple goroutines // Send from multiple goroutines
for i := 0; i < messageCount; i++ { for i := range messageCount {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int) {
defer wg.Done() defer wg.Done()
@@ -733,7 +734,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
// This is expected behavior - we're testing thread safety, not guaranteed delivery // This is expected behavior - we're testing thread safety, not guaranteed delivery
// Just verify we receive at least some messages without panicking or deadlocking // Just verify we receive at least some messages without panicking or deadlocking
received := 0 received := 0
timeout := time.After(2 * time.Second) timeout := time.After(500 * time.Millisecond)
for received < messageCount { for received < messageCount {
select { select {
case <-notifications: case <-notifications:
@@ -751,7 +752,7 @@ func Test_ConcurrentCleanup(t *testing.T) {
server := newTestServerWithNotifier(t) server := newTestServerWithNotifier(t)
// Create some clients // Create some clients
for i := 0; i < 10; i++ { for i := range 10 {
client, _ := server.GetClient("", "") client, _ := server.GetClient("", "")
// Set some to be old // Set some to be old
if i%2 == 0 { if i%2 == 0 {
@@ -790,16 +791,14 @@ func Test_NoRaceConditions(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
// Create a few clients and read from them // Create a few clients and read from them
for i := 0; i < 5; i++ { for range 5 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
client, _ := server.GetClient("", "") client, _ := server.GetClient("", "")
notifications, stop := client.Listen() notifications, stop := client.Listen()
defer close(stop) defer close(stop)
// Actively read messages // Actively read messages
timeout := time.After(2 * time.Second) timeout := time.After(200 * time.Millisecond)
for { for {
select { select {
case <-notifications: case <-notifications:
@@ -808,21 +807,18 @@ func Test_NoRaceConditions(t *testing.T) {
return return
} }
} }
}() })
} }
// Send a few notifications // Send a few notifications
wg.Add(1) wg.Go(func() {
go func() { for range 10 {
defer wg.Done()
for j := 0; j < 20; j++ {
server.NotifyAll(notify.Notification{ server.NotifyAll(notify.Notification{
Message: "Stress test", Message: "Stress test",
}) })
time.Sleep(50 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}() })
wg.Wait() wg.Wait()
} }
@@ -948,7 +944,7 @@ func Test_ListenSignature(t *testing.T) {
require.NotNil(t, stop) require.NotNil(t, stop)
// notifications should be receive-only // notifications should be receive-only
_, ok := interface{}(notifications).(<-chan notify.Notification) _, ok := any(notifications).(<-chan notify.Notification)
require.True(t, ok, "notifications should be receive-only channel") require.True(t, ok, "notifications should be receive-only channel")
// stop should be closeable // stop should be closeable
@@ -964,7 +960,7 @@ func Test_BufferSize(t *testing.T) {
defer close(stop) defer close(stop)
// Send 10 messages without reading (buffer size is 10) // Send 10 messages without reading (buffer size is 10)
for i := 0; i < 10; i++ { for range 10 {
server.NotifySub(notify.Notification{ server.NotifySub(notify.Notification{
Target: client.sub.ID, Target: client.sub.ID,
Message: "Buffered", Message: "Buffered",
@@ -975,7 +971,7 @@ func Test_BufferSize(t *testing.T) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
// Read all 10 // Read all 10
for i := 0; i < 10; i++ { for i := range 10 {
select { select {
case <-notifications: case <-notifications:
// Success // Success

View File

@@ -30,13 +30,13 @@ const (
MethodPATCH Method = "PATCH" MethodPATCH Method = "PATCH"
) )
// Server.AddRoutes registers the page handlers for the server. // AddRoutes registers the page handlers for the server.
// At least one route must be provided. // At least one route must be provided.
// If any route patterns (path + method) are defined multiple times, the first // If any route patterns (path + method) are defined multiple times, the first
// instance will be added and any additional conflicts will be discarded. // instance will be added and any additional conflicts will be discarded.
func (server *Server) AddRoutes(routes ...Route) error { func (s *Server) AddRoutes(routes ...Route) error {
if len(routes) == 0 { if len(routes) == 0 {
return errors.New("No routes provided") return errors.New("no routes provided")
} }
patterns := []string{} patterns := []string{}
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -47,10 +47,10 @@ func (server *Server) AddRoutes(routes ...Route) error {
} }
for _, method := range route.Methods { for _, method := range route.Methods {
if !validMethod(method) { if !validMethod(method) {
return fmt.Errorf("Invalid method %s for path %s", method, route.Path) return fmt.Errorf("invalid method %s for path %s", method, route.Path)
} }
if route.Handler == nil { if route.Handler == nil {
return fmt.Errorf("No handler provided for %s %s", method, route.Path) return fmt.Errorf("no handler provided for %s %s", method, route.Path)
} }
pattern := fmt.Sprintf("%s %s", method, route.Path) pattern := fmt.Sprintf("%s %s", method, route.Path)
if slices.Contains(patterns, pattern) { if slices.Contains(patterns, pattern) {
@@ -61,8 +61,8 @@ func (server *Server) AddRoutes(routes ...Route) error {
} }
} }
server.server.Handler = mux s.server.Handler = mux
server.routes = true s.routes = true
return nil return nil
} }

View File

@@ -18,7 +18,7 @@ func Test_AddRoutes(t *testing.T) {
server := createTestServer(t, &buf) server := createTestServer(t, &buf)
err := server.AddRoutes() err := server.AddRoutes()
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "No routes provided") assert.Contains(t, err.Error(), "no routes provided")
}) })
t.Run("Single valid route", func(t *testing.T) { t.Run("Single valid route", func(t *testing.T) {
@@ -58,7 +58,7 @@ func Test_AddRoutes(t *testing.T) {
Handler: handler, Handler: handler,
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid method") assert.Contains(t, err.Error(), "invalid method")
}) })
t.Run("No handler provided", func(t *testing.T) { t.Run("No handler provided", func(t *testing.T) {
@@ -69,7 +69,7 @@ func Test_AddRoutes(t *testing.T) {
Handler: nil, Handler: nil,
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "No handler provided") assert.Contains(t, err.Error(), "no handler provided")
}) })
t.Run("All HTTP methods are valid", func(t *testing.T) { t.Run("All HTTP methods are valid", func(t *testing.T) {
@@ -203,7 +203,7 @@ func Test_AddRoutes_MultipleMethods(t *testing.T) {
Handler: handler, Handler: handler,
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid method") assert.Contains(t, err.Error(), "invalid method")
}) })
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) { t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {

View File

@@ -26,14 +26,14 @@ type Server struct {
} }
// Ready returns a channel that is closed when the server is started // Ready returns a channel that is closed when the server is started
func (server *Server) Ready() <-chan struct{} { func (s *Server) Ready() <-chan struct{} {
return server.ready return s.ready
} }
// IsReady checks if the server is running // IsReady checks if the server is running
func (server *Server) IsReady() bool { func (s *Server) IsReady() bool {
select { select {
case <-server.ready: case <-s.ready:
return true return true
default: default:
return false return false
@@ -41,13 +41,13 @@ func (server *Server) IsReady() bool {
} }
// Addr returns the server's network address // Addr returns the server's network address
func (server *Server) Addr() string { func (s *Server) Addr() string {
return server.server.Addr return s.server.Addr
} }
// Handler returns the server's HTTP handler for testing purposes // Handler returns the server's HTTP handler for testing purposes
func (server *Server) Handler() http.Handler { func (s *Server) Handler() http.Handler {
return server.server.Handler return s.server.Handler
} }
// NewServer returns a new hws.Server with the specified configuration. // NewServer returns a new hws.Server with the specified configuration.
@@ -75,7 +75,7 @@ func NewServer(config *Config) (*Server, error) {
valid := isValidHostname(config.Host) valid := isValidHostname(config.Host)
if !valid { if !valid {
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host) return nil, fmt.Errorf("hostname '%s' is not valid", config.Host)
} }
httpServer := &http.Server{ httpServer := &http.Server{
@@ -95,62 +95,64 @@ func NewServer(config *Config) (*Server, error) {
return server, nil return server, nil
} }
func (server *Server) Start(ctx context.Context) error { func (s *Server) Start(ctx context.Context) error {
if ctx == nil { if ctx == nil {
return errors.New("Context cannot be nil") return errors.New("Context cannot be nil")
} }
if !server.routes { if !s.routes {
return errors.New("Server.AddRoutes must be run before starting the server") return errors.New("Server.AddRoutes must be run before starting the server")
} }
if !server.middleware { if !s.middleware {
err := server.AddMiddleware() err := s.AddMiddleware()
if err != nil { if err != nil {
return errors.Wrap(err, "server.AddMiddleware") return errors.Wrap(err, "server.AddMiddleware")
} }
} }
server.startNotifier() s.startNotifier()
go func() { go func() {
if server.logger == nil { if s.logger == nil {
fmt.Printf("Listening for requests on %s", server.server.Addr) fmt.Printf("Listening for requests on %s", s.server.Addr)
} else { } else {
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests") s.logger.logger.Info().Str("address", s.server.Addr).Msg("Listening for requests")
} }
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if server.logger == nil { if s.logger == nil {
fmt.Printf("Server encountered a fatal error: %s", err.Error()) fmt.Printf("Server encountered a fatal error: %s", err.Error())
} else { } else {
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"}) s.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
} }
} }
}() }()
server.waitUntilReady(ctx) s.waitUntilReady(ctx)
return nil return nil
} }
func (server *Server) Shutdown(ctx context.Context) error { func (s *Server) Shutdown(ctx context.Context) error {
server.logger.logger.Debug().Dur("shutdown_delay", server.shutdowndelay).Msg("HWS Server shutting down") if s.logger != nil {
server.NotifyAll(notify.Notification{ s.logger.logger.Debug().Dur("shutdown_delay", s.shutdowndelay).Msg("HWS Server shutting down")
}
s.NotifyAll(notify.Notification{
Title: "Shutting down", Title: "Shutting down",
Message: fmt.Sprintf("Server is shutting down in %v", server.shutdowndelay), Message: fmt.Sprintf("Server is shutting down in %v", s.shutdowndelay),
Level: LevelShutdown, Level: LevelShutdown,
}) })
<-time.NewTimer(server.shutdowndelay).C <-time.NewTimer(s.shutdowndelay).C
if !server.IsReady() { if !s.IsReady() {
return errors.New("Server isn't running") return errors.New("Server isn't running")
} }
if ctx == nil { if ctx == nil {
return errors.New("Context cannot be nil") return errors.New("Context cannot be nil")
} }
err := server.server.Shutdown(ctx) err := s.server.Shutdown(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to shutdown the server gracefully") return errors.Wrap(err, "Failed to shutdown the server gracefully")
} }
server.closeNotifier() s.closeNotifier()
server.ready = make(chan struct{}) s.ready = make(chan struct{})
return nil return nil
} }
@@ -168,7 +170,7 @@ func isValidHostname(host string) bool {
return false return false
} }
func (server *Server) waitUntilReady(ctx context.Context) error { func (s *Server) waitUntilReady(ctx context.Context) error {
ticker := time.NewTicker(50 * time.Millisecond) ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
@@ -180,14 +182,14 @@ func (server *Server) waitUntilReady(ctx context.Context) error {
return ctx.Err() return ctx.Err()
case <-ticker.C: case <-ticker.C:
resp, err := http.Get("http://" + server.server.Addr + "/healthz") resp, err := http.Get("http://" + s.server.Addr + "/healthz")
if err != nil { if err != nil {
continue // not accepting yet continue // not accepting yet
} }
resp.Body.Close() resp.Body.Close()
if resp.StatusCode == http.StatusOK { if resp.StatusCode == http.StatusOK {
closeOnce.Do(func() { close(server.ready) }) closeOnce.Do(func() { close(s.ready) })
return nil return nil
} }
} }

View File

@@ -28,10 +28,15 @@ func createTestServer(t *testing.T, w io.Writer) *hws.Server {
server, err := hws.NewServer(&hws.Config{ server, err := hws.NewServer(&hws.Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: randomPort(), Port: randomPort(),
ShutdownDelay: 0, // No delay for tests
}) })
require.NoError(t, err) require.NoError(t, err)
dbg, _ := hlog.LogLevel("debug")
logcfg := &hlog.Config{
LogLevel: dbg,
}
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), w, nil, "") logger, err := hlog.NewLogger(logcfg, w)
require.NoError(t, err) require.NoError(t, err)
err = server.AddLogger(logger) err = server.AddLogger(logger)
@@ -227,5 +232,4 @@ func Test_NewServer(t *testing.T) {
} }
}) })
} }
} }

View File

@@ -9,6 +9,7 @@ import (
"git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/jwt" "git.haelnorr.com/h/golib/jwt"
"github.com/gobwas/glob"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -16,7 +17,7 @@ type Authenticator[T Model, TX DBTransaction] struct {
tokenGenerator *jwt.TokenGenerator tokenGenerator *jwt.TokenGenerator
load LoadFunc[T, TX] load LoadFunc[T, TX]
beginTx BeginTX beginTx BeginTX
ignoredPaths []string ignoredPaths []glob.Glob
logger *hlog.Logger logger *hlog.Logger
server *hws.Server server *hws.Server
errorPage hws.ErrorPageFunc errorPage hws.ErrorPageFunc

View File

@@ -9,16 +9,16 @@ import (
// Config holds the configuration settings for the authenticator. // Config holds the configuration settings for the authenticator.
// All time-based settings are in minutes. // All time-based settings are in minutes.
type Config struct { type Config struct {
SSL bool // ENV HWSAUTH_SSL: Enable SSL secure cookies (default: false) SSL bool `ezconf:"HWSAUTH_SSL,description:Enable SSL secure cookies,default:false"`
TrustedHost string // ENV HWSAUTH_TRUSTED_HOST: Full server address for SSL (required if SSL is true) TrustedHost string `ezconf:"HWSAUTH_TRUSTED_HOST,description:Full server address for SSL,required:if SSL is true"`
SecretKey string // ENV HWSAUTH_SECRET_KEY: Secret key for signing JWT tokens (required) SecretKey string `ezconf:"HWSAUTH_SECRET_KEY,description:Secret key for signing JWT tokens,required"`
AccessTokenExpiry int64 // ENV HWSAUTH_ACCESS_TOKEN_EXPIRY: Access token expiry in minutes (default: 5) AccessTokenExpiry int64 `ezconf:"HWSAUTH_ACCESS_TOKEN_EXPIRY,description:Access token expiry in minutes,default:5"`
RefreshTokenExpiry int64 // ENV HWSAUTH_REFRESH_TOKEN_EXPIRY: Refresh token expiry in minutes (default: 1440) RefreshTokenExpiry int64 `ezconf:"HWSAUTH_REFRESH_TOKEN_EXPIRY,description:Refresh token expiry in minutes,default:1440"`
TokenFreshTime int64 // ENV HWSAUTH_TOKEN_FRESH_TIME: Token fresh time in minutes (default: 5) TokenFreshTime int64 `ezconf:"HWSAUTH_TOKEN_FRESH_TIME,description:Token fresh time in minutes,default:5"`
LandingPage string // ENV HWSAUTH_LANDING_PAGE: Redirect destination for authenticated users (default: "/profile") LandingPage string `ezconf:"HWSAUTH_LANDING_PAGE,description:Redirect destination for authenticated users,default:/profile"`
DatabaseType string // ENV HWSAUTH_DATABASE_TYPE: Database type (postgres, mysql, sqlite, mariadb) (default: "postgres") DatabaseType string `ezconf:"HWSAUTH_DATABASE_TYPE,description:Database type (postgres mysql sqlite mariadb),default:postgres"`
DatabaseVersion string // ENV HWSAUTH_DATABASE_VERSION: Database version string (default: "15") DatabaseVersion string `ezconf:"HWSAUTH_DATABASE_VERSION,description:Database version string,default:15"`
JWTTableName string // ENV HWSAUTH_JWT_TABLE_NAME: Custom JWT blacklist table name (default: "jwtblacklist") JWTTableName string `ezconf:"HWSAUTH_JWT_TABLE_NAME,description:Custom JWT blacklist table name,default:jwtblacklist"`
} }
// ConfigFromEnv loads configuration from environment variables. // ConfigFromEnv loads configuration from environment variables.

View File

@@ -1,35 +1,9 @@
package hwsauth package hwsauth
import "runtime" import "git.haelnorr.com/h/golib/ezconf"
// EZConfIntegration provides integration with ezconf for automatic configuration
type EZConfIntegration struct{}
// PackagePath returns the path to the hwsauth package for source parsing
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
return ConfigFromEnv()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return "hwsauth"
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return "HWSAuth"
}
// NewEZConfIntegration creates a new EZConf integration helper // NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration { func NewEZConfIntegration() *ezconf.Integration {
return EZConfIntegration{} return ezconf.NewIntegration("hwsauth", "HWSAuth", &Config{},
func() (any, error) { return ConfigFromEnv() })
} }

View File

@@ -5,24 +5,28 @@ go 1.25.5
require ( require (
git.haelnorr.com/h/golib/cookies v0.9.0 git.haelnorr.com/h/golib/cookies v0.9.0
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/hlog v0.10.4 git.haelnorr.com/h/golib/ezconf v0.2.1
git.haelnorr.com/h/golib/hws v0.3.0 git.haelnorr.com/h/golib/hlog v0.11.0
git.haelnorr.com/h/golib/hws v0.5.0
git.haelnorr.com/h/golib/jwt v0.10.1 git.haelnorr.com/h/golib/jwt v0.10.1
github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
) )
require git.haelnorr.com/h/golib/notify v0.1.0 // indirect
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/gobwas/glob v0.2.3
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rs/zerolog v1.34.0 // indirect github.com/rs/zerolog v1.34.0 // indirect
golang.org/x/sys v0.40.0 // indirect golang.org/x/sys v0.41.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/apimachinery v0.35.0 // indirect k8s.io/apimachinery v0.35.0 // indirect
k8s.io/klog/v2 v2.130.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect

View File

@@ -2,12 +2,16 @@ git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDO
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo= git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY= git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo= git.haelnorr.com/h/golib/hlog v0.11.0 h1:tCT8HWs51Nbin58sCTLcq5re6CZqo5/IHCzk3G+S3vQ=
git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= git.haelnorr.com/h/golib/hlog v0.11.0/go.mod h1:HjhXS5G3A0BwOZq7nu2qpNBtvOFiCa1GbAuBRxAkYqs=
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
@@ -15,6 +19,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
@@ -40,8 +46,8 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -23,8 +23,7 @@ func (tm TestModel) GetID() int {
return tm.ID return tm.ID
} }
type TestTransaction struct { type TestTransaction struct{}
}
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) { func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
return nil, nil return nil, nil
@@ -137,8 +136,10 @@ func TestCurrentModel(t *testing.T) {
func TestConfigFromEnv_MissingSecretKey(t *testing.T) { func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
// Clear environment variables // Clear environment variables
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY") originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
os.Setenv("HWSAUTH_SECRET_KEY", "") _ = os.Setenv("HWSAUTH_SECRET_KEY", "")
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret) defer func() {
_ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
}()
_, err := ConfigFromEnv() _, err := ConfigFromEnv()
assert.Error(t, err) assert.Error(t, err)
@@ -327,7 +328,9 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
db, _, err := createMockDB() db, _, err := createMockDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator( auth, err := NewAuthenticator(
cfg, cfg,
@@ -409,7 +412,9 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
db, _, err := createMockDB() db, _, err := createMockDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator( auth, err := NewAuthenticator(
cfg, cfg,
@@ -454,7 +459,9 @@ func TestLogin_BasicFunctionality(t *testing.T) {
db, _, err := createMockDB() db, _, err := createMockDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer func() {
_ = db.Close()
}()
auth, err := NewAuthenticator( auth, err := NewAuthenticator(
cfg, cfg,
@@ -476,6 +483,7 @@ func TestLogin_BasicFunctionality(t *testing.T) {
// This test mainly checks that the function doesn't panic and has right call signature // This test mainly checks that the function doesn't panic and has right call signature
// The actual JWT functionality is tested in jwt package itself // The actual JWT functionality is tested in jwt package itself
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
auth.Login(w, r, user, rememberMe) err := auth.Login(w, r, user, rememberMe)
require.NoError(t, err)
}) })
} }

View File

@@ -3,6 +3,8 @@ package hwsauth
import ( import (
"fmt" "fmt"
"net/url" "net/url"
"github.com/gobwas/glob"
) )
// IgnorePaths excludes specified paths from authentication middleware. // IgnorePaths excludes specified paths from authentication middleware.
@@ -22,9 +24,22 @@ func (auth *Authenticator[T, TX]) IgnorePaths(paths ...string) error {
u.RawQuery == "" && u.RawQuery == "" &&
u.Fragment == "" u.Fragment == ""
if !valid { if !valid {
return fmt.Errorf("Invalid path: '%s'", path) return fmt.Errorf("invalid path: '%s'", path)
} }
} }
auth.ignoredPaths = paths auth.ignoredPaths = prepareGlobs(paths)
return nil return nil
} }
func prepareGlobs(paths []string) []glob.Glob {
compiledGlobs := make([]glob.Glob, 0, len(paths))
for _, pattern := range paths {
g, err := glob.Compile(pattern)
if err != nil {
// If pattern fails to compile, skip it
continue
}
compiledGlobs = append(compiledGlobs, g)
}
return compiledGlobs
}

View File

@@ -33,14 +33,18 @@ func (auth *Authenticator[T, TX]) Logout(tx TX, w http.ResponseWriter, r *http.R
if err != nil { if err != nil {
return errors.Wrap(err, "auth.getTokens") return errors.Wrap(err, "auth.getTokens")
} }
if aT != nil {
err = aT.Revoke(jwt.DBTransaction(tx)) err = aT.Revoke(jwt.DBTransaction(tx))
if err != nil { if err != nil {
return errors.Wrap(err, "aT.Revoke") return errors.Wrap(err, "aT.Revoke")
} }
}
if rT != nil {
err = rT.Revoke(jwt.DBTransaction(tx)) err = rT.Revoke(jwt.DBTransaction(tx))
if err != nil { if err != nil {
return errors.Wrap(err, "rT.Revoke") return errors.Wrap(err, "rT.Revoke")
} }
}
cookies.DeleteCookie(w, "access", "/") cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/") cookies.DeleteCookie(w, "refresh", "/")
return nil return nil

View File

@@ -2,10 +2,12 @@ package hwsauth
import ( import (
"context" "context"
"git.haelnorr.com/h/golib/hws"
"net/http" "net/http"
"slices"
"time" "time"
"git.haelnorr.com/h/golib/hws"
"github.com/gobwas/glob"
"github.com/pkg/errors"
) )
// Authenticate returns the main authentication middleware. // Authenticate returns the main authentication middleware.
@@ -14,14 +16,22 @@ import (
// //
// Example: // Example:
// //
// server.AddMiddleware(auth.Authenticate()) // server.AddMiddleware(auth.Authenticate(nil))
func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware { //
return auth.server.NewMiddleware(auth.authenticate()) // If extraCheck is provided, it will run just before the user is added to the context,
// and the return will determine if the user will be added, or the request passed on
// without the user.
func (auth *Authenticator[T, TX]) Authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.Middleware {
return auth.server.NewMiddleware(auth.authenticate(extraCheck))
} }
func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc { func (auth *Authenticator[T, TX]) authenticate(
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
) hws.MiddlewareFunc {
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) { return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
if slices.Contains(auth.ignoredPaths, r.URL.Path) { if globTest(r.URL.Path, auth.ignoredPaths) {
return r, nil return r, nil
} }
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
@@ -30,25 +40,70 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
// Start the transaction // Start the transaction
tx, err := auth.beginTx(ctx) tx, err := auth.beginTx(ctx)
if err != nil { if err != nil {
return nil, &hws.HWSError{Message: "Unable to start transaction", StatusCode: http.StatusServiceUnavailable, Error: err} return nil, &hws.HWSError{
Message: "Unable to start transaction",
StatusCode: http.StatusServiceUnavailable,
Error: errors.Wrap(err, "auth.beginTx"),
} }
}
defer func() {
_ = tx.Rollback()
}()
// Type assert to TX - safe because user's beginTx should return their TX type // Type assert to TX - safe because user's beginTx should return their TX type
txTyped, ok := tx.(TX) txTyped, ok := tx.(TX)
if !ok { if !ok {
return nil, &hws.HWSError{Message: "Transaction type mismatch", StatusCode: http.StatusInternalServerError, Error: err} return nil, &hws.HWSError{
Message: "Transaction type mismatch",
StatusCode: http.StatusInternalServerError,
Error: errors.Wrap(err, "TX type not ok"),
}
} }
model, err := auth.getAuthenticatedUser(txTyped, w, r) model, err := auth.getAuthenticatedUser(txTyped, w, r)
if err != nil { if err != nil {
tx.Rollback() rberr := tx.Rollback()
if rberr != nil {
return nil, &hws.HWSError{
Message: "Failed rolling back after error",
StatusCode: http.StatusInternalServerError,
Error: errors.Wrap(err, "tx.Rollback"),
}
}
auth.logger.Debug(). auth.logger.Debug().
Str("remote_addr", r.RemoteAddr). Str("remote_addr", r.RemoteAddr).
Err(err). Err(err).
Msg("Failed to authenticate user") Msg("Failed to authenticate user")
return r, nil return r, nil
} }
tx.Commit() var check bool
if extraCheck != nil {
var err *hws.HWSError
check, err = extraCheck(ctx, model.model, txTyped, w, r)
if err != nil {
return nil, err
}
}
err = tx.Commit()
if err != nil {
return nil, &hws.HWSError{
Message: "Failed to commit transaction",
StatusCode: http.StatusInternalServerError,
Error: errors.Wrap(err, "tx.Commit"),
}
}
authContext := setAuthenticatedModel(r.Context(), model) authContext := setAuthenticatedModel(r.Context(), model)
newReq := r.WithContext(authContext) newReq := r.WithContext(authContext)
if extraCheck == nil || check {
return newReq, nil return newReq, nil
} }
return r, nil
}
}
func globTest(testPath string, globs []glob.Glob) bool {
for _, g := range globs {
if g.Match(testPath) {
return true
}
}
return false
} }

View File

@@ -39,9 +39,17 @@ type ContextLoader[T Model] func(ctx context.Context) T
// } // }
type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error) type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error)
type contextKey string
func (c contextKey) String() string {
return "hwsauth context key" + string(c)
}
var authenticatedModelContextKey = contextKey("authenticated-model")
// Return a new context with the user added in // Return a new context with the user added in
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context { func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
return context.WithValue(ctx, "hwsauth context key authenticated-model", m) return context.WithValue(ctx, authenticatedModelContextKey, m)
} }
// Retrieve a user from the given context. Returns nil if not set // Retrieve a user from the given context. Returns nil if not set
@@ -53,7 +61,7 @@ func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[
model = authenticatedModel[T]{} model = authenticatedModel[T]{}
} }
}() }()
model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T]) model, cok := ctx.Value(authenticatedModelContextKey).(authenticatedModel[T])
if !cok { if !cok {
return authenticatedModel[T]{}, false return authenticatedModel[T]{}, false
} }

View File

@@ -19,15 +19,12 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := getAuthorizedModel[T](r.Context()) _, ok := getAuthorizedModel[T](r.Context())
if !ok { if !ok {
err := auth.server.ThrowError(w, r, hws.HWSError{ auth.server.ThrowError(w, r, hws.HWSError{
Error: errors.New("Login required"), Error: errors.New("Login required"),
Message: "Please login to view this page", Message: "Please login to view this page",
StatusCode: http.StatusUnauthorized, StatusCode: http.StatusUnauthorized,
RenderErrorPage: true, RenderErrorPage: true,
}) })
if err != nil {
auth.server.ThrowFatal(w, err)
}
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@@ -66,15 +63,12 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model, ok := getAuthorizedModel[T](r.Context()) model, ok := getAuthorizedModel[T](r.Context())
if !ok { if !ok {
err := auth.server.ThrowError(w, r, hws.HWSError{ auth.server.ThrowError(w, r, hws.HWSError{
Error: errors.New("Login required"), Error: errors.New("Login required"),
Message: "Please login to view this page", Message: "Please login to view this page",
StatusCode: http.StatusUnauthorized, StatusCode: http.StatusUnauthorized,
RenderErrorPage: true, RenderErrorPage: true,
}) })
if err != nil {
auth.server.ThrowFatal(w, err)
}
return return
} }
isFresh := time.Now().Before(time.Unix(model.fresh, 0)) isFresh := time.Now().Before(time.Unix(model.fresh, 0))

View File

@@ -34,7 +34,7 @@ func (auth *Authenticator[T, TX]) RefreshAuthTokens(tx TX, w http.ResponseWriter
rememberMe := map[string]bool{ rememberMe := map[string]bool{
"session": false, "session": false,
"exp": true, "exp": true,
}[aT.TTL] }[rT.TTL]
// issue new tokens for the user // issue new tokens for the user
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL) err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
if err != nil { if err != nil {
@@ -55,14 +55,21 @@ func (auth *Authenticator[T, TX]) getTokens(
) (*jwt.AccessToken, *jwt.RefreshToken, error) { ) (*jwt.AccessToken, *jwt.RefreshToken, error) {
// get the existing tokens from the cookies // get the existing tokens from the cookies
atStr, rtStr := jwt.GetTokenCookies(r) atStr, rtStr := jwt.GetTokenCookies(r)
aT, err := auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr) var aT *jwt.AccessToken
var rT *jwt.RefreshToken
var err error
if atStr != "" {
aT, err = auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess") return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
} }
rT, err := auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr) }
if rtStr != "" {
rT, err = auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh") return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
} }
}
return aT, rT, nil return aT, rT, nil
} }
@@ -72,13 +79,17 @@ func revokeTokenPair(
aT *jwt.AccessToken, aT *jwt.AccessToken,
rT *jwt.RefreshToken, rT *jwt.RefreshToken,
) error { ) error {
if aT != nil {
err := aT.Revoke(tx) err := aT.Revoke(tx)
if err != nil { if err != nil {
return errors.Wrap(err, "aT.Revoke") return errors.Wrap(err, "aT.Revoke")
} }
err = rT.Revoke(tx) }
if rT != nil {
err := rT.Revoke(tx)
if err != nil { if err != nil {
return errors.Wrap(err, "rT.Revoke") return errors.Wrap(err, "rT.Revoke")
} }
}
return nil return nil
} }

View File

@@ -7,7 +7,7 @@ import (
type API struct { type API struct {
*Config *Config
token string // ENV TMDB_TOKEN: API token for TMDB (required) token string `ezconf:"TMDB_TOKEN,description:API token for TMDB,required"`
} }
func NewAPIConnection() (*API, error) { func NewAPIConnection() (*API, error) {

View File

@@ -1,36 +1,9 @@
package tmdb package tmdb
import "runtime" import "git.haelnorr.com/h/golib/ezconf"
// EZConfIntegration provides integration with ezconf for automatic configuration // NewEZConfIntegration creates a new EZConf integration
type EZConfIntegration struct{} func NewEZConfIntegration() *ezconf.Integration {
return ezconf.NewIntegration("tmdb", "TMDB", &Config{},
// PackagePath returns the path to the tmdb package for source parsing func() (any, error) { return NewAPIConnection() })
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the NewAPIConnection function for ezconf
// Note: tmdb uses NewAPIConnection instead of ConfigFromEnv
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
return func() (interface{}, error) {
return NewAPIConnection()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return "tmdb"
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return "TMDB"
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{}
} }

View File

@@ -6,3 +6,5 @@ require (
git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/env v0.9.1
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
) )
require git.haelnorr.com/h/golib/ezconf v0.2.1

View File

@@ -1,4 +1,6 @@
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY= git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg= git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=