Compare commits
14 Commits
hws/v0.4.1
...
hwsauth/v0
| Author | SHA1 | Date | |
|---|---|---|---|
| 380e366891 | |||
| e8ffec6b7e | |||
| 1745458a95 | |||
| f3d6a01105 | |||
| 9179736c90 | |||
| 05be28d7f3 | |||
| 8f7c87cef2 | |||
| 525b3b1396 | |||
| 563908bbb4 | |||
| 95a17597cf | |||
| cd29f11296 | |||
| 7ed40c7afe | |||
| 596a4c0529 | |||
| ed3bc4afb0 |
173
AGENTS.md
Normal file
173
AGENTS.md
Normal file
@@ -0,0 +1,173 @@
|
||||
# AGENTS.md - Coding Agent Guidelines for golib
|
||||
|
||||
## Project Overview
|
||||
This is a Go library repository containing multiple independent packages:
|
||||
- **cookies**: HTTP cookie utilities
|
||||
- **env**: Environment variable helpers
|
||||
- **ezconf**: Configuration loader with ENV parsing
|
||||
- **hlog**: Logging with zerolog
|
||||
- **hws**: HTTP web server
|
||||
- **hwsauth**: Authentication middleware for hws
|
||||
- **jwt**: JWT token generation and validation
|
||||
- **tmdb**: The Movie Database API client
|
||||
|
||||
Each package has its own `go.mod` and can be used independently.
|
||||
|
||||
## Interactive Questions
|
||||
All questions in plan mode should use the opencode interactive question prompter for user interaction.
|
||||
|
||||
## Build, Test, and Lint Commands
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
# Test all packages from repo root
|
||||
go test ./...
|
||||
|
||||
# Test a specific package
|
||||
cd <package> && go test
|
||||
|
||||
# Run a single test function
|
||||
cd <package> && go test -run TestFunctionName
|
||||
|
||||
# Run tests with verbose output
|
||||
cd <package> && go test -v
|
||||
|
||||
# Run tests matching a pattern
|
||||
cd <package> && go test -run "TestName.*"
|
||||
```
|
||||
|
||||
### Building
|
||||
```bash
|
||||
# Each package is a library - no build needed
|
||||
# Verify code compiles:
|
||||
go build ./...
|
||||
|
||||
# Or for specific package:
|
||||
cd <package> && go build
|
||||
```
|
||||
|
||||
### Linting
|
||||
```bash
|
||||
# Use standard go tools
|
||||
go vet ./...
|
||||
go fmt ./...
|
||||
|
||||
# Check formatting without changing files
|
||||
gofmt -l .
|
||||
```
|
||||
|
||||
## Code Style Guidelines
|
||||
|
||||
### Package Structure
|
||||
- Each package must have its own `go.mod` with module path: `git.haelnorr.com/h/golib/<package>`
|
||||
- Go version should be current (1.23.4+)
|
||||
- Each package should have a `doc.go` file with package documentation
|
||||
|
||||
### Imports
|
||||
- Use standard library imports first
|
||||
- Then third-party imports
|
||||
- Then local imports from this repo (e.g., `git.haelnorr.com/h/golib/hlog`)
|
||||
- Group imports with blank lines between groups
|
||||
- Example:
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
)
|
||||
```
|
||||
|
||||
### Formatting
|
||||
- Use `gofmt` standard formatting
|
||||
- No tabs for alignment, use spaces inside structs
|
||||
- Line length: no hard limit, but prefer readability
|
||||
|
||||
### Types
|
||||
- Use explicit types for struct fields
|
||||
- Config structs must have ENV comments (see below)
|
||||
- Prefer named return values for complex functions
|
||||
- Use generics where appropriate (see `hwsauth.Authenticator[T Model, TX DBTransaction]`)
|
||||
|
||||
### Naming Conventions
|
||||
- Packages: lowercase, single word (e.g., `cookies`, `ezconf`)
|
||||
- Exported functions: PascalCase (e.g., `NewServer`, `ConfigFromEnv`)
|
||||
- Unexported functions: camelCase (e.g., `isValidHostname`, `waitUntilReady`)
|
||||
- Test functions: `Test<FunctionName>` or `Test<FunctionName>_<Case>` (underscore for sub-cases)
|
||||
- Variables: camelCase, descriptive names
|
||||
- Constants: PascalCase or UPPER_CASE depending on scope
|
||||
|
||||
### Error Handling
|
||||
- Use `github.com/pkg/errors` for error wrapping
|
||||
- Wrap errors with context: `errors.Wrap(err, "context message")`
|
||||
- Return errors, don't panic (except in truly exceptional cases)
|
||||
- Validate inputs and return descriptive errors
|
||||
- Example:
|
||||
```go
|
||||
if config == nil {
|
||||
return nil, errors.New("Config cannot be nil")
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration Pattern
|
||||
- Each package with config should have a `Config` struct
|
||||
- Provide `ConfigFromEnv() (*Config, error)` function
|
||||
- ENV comment format for Config struct fields:
|
||||
```go
|
||||
type Config struct {
|
||||
Host string // ENV HWS_HOST: Host to listen on (default: 127.0.0.1)
|
||||
Port uint64 // ENV HWS_PORT: Port to listen on (default: 3000)
|
||||
SSL bool // ENV HWS_SSL: Enable SSL (required when using production)
|
||||
}
|
||||
```
|
||||
- Format: `// ENV ENV_NAME: Description (required <condition>) (default: <value>)`
|
||||
- Include "required" only if no default
|
||||
- Include "default" only if one exists
|
||||
|
||||
### Testing
|
||||
- Use `testing` package from standard library
|
||||
- Use `github.com/stretchr/testify` for assertions (`require`, `assert`)
|
||||
- Table-driven tests for multiple cases:
|
||||
```go
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid case", "input", false},
|
||||
{"error case", "", true},
|
||||
}
|
||||
```
|
||||
- Test files use `<package>_test` for black-box tests or `<package>` for white-box
|
||||
- Helper functions should use `t.Helper()`
|
||||
|
||||
### Documentation
|
||||
- All exported functions, types, and constants must have godoc comments
|
||||
- Comments should start with the name being documented
|
||||
- Example: `// NewServer returns a new hws.Server with the specified configuration.`
|
||||
- Keep doc.go files up to date with package overview
|
||||
- Follow RULES.md for README and wiki documentation
|
||||
|
||||
## Version Control (from RULES.md)
|
||||
- Do NOT make changes to master branch
|
||||
- Checkout a branch for new features
|
||||
- Version numbers use git tags - do NOT change manually
|
||||
- When updating docs, append branch name to version
|
||||
- Changes to golib-wiki repo should use same branch name
|
||||
|
||||
## Testing Requirements (from RULES.md)
|
||||
- All features MUST have tests
|
||||
- Update existing tests when modifying features
|
||||
- New features require new tests
|
||||
|
||||
## Documentation Requirements (from RULES.md)
|
||||
- Document via: docstrings, README.md, doc.go, wiki
|
||||
- README structure: Title+version, Features (NO EMOTICONS), Installation, Quick Start, Docs links, Additional info, License, Contributing, Related projects
|
||||
- Wiki location: `~/projects/golib-wiki`
|
||||
- Docstrings must conform to godoc standards
|
||||
|
||||
## License
|
||||
- All modules use MIT License
|
||||
3
RULES.md
3
RULES.md
@@ -45,3 +45,6 @@ Do not make any changes to master. Checkout a branch to work on new features
|
||||
Version numbers are specified using git tags.
|
||||
Do not change version numbers. When updating documentation, append the branch name to the version number.
|
||||
Changes made to the golib-wiki repo should be made under the same branch name as the changes made in this repo
|
||||
|
||||
4. Licencing
|
||||
All modules should have an MIT License
|
||||
|
||||
21
cookies/LICENSE
Normal file
21
cookies/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 haelnorr
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
61
cookies/README.md
Normal file
61
cookies/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# cookies v1.0.0
|
||||
|
||||
HTTP cookie utilities for Go web applications with security best practices.
|
||||
|
||||
## Features
|
||||
|
||||
- Secure cookie setting with HttpOnly flag
|
||||
- Cookie deletion with proper expiration
|
||||
- Pagefrom tracking for post-login redirects
|
||||
- Host validation for referer-based redirects
|
||||
- Full test coverage
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get git.haelnorr.com/h/golib/cookies
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
)
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
// Set a secure cookie
|
||||
cookies.SetCookie(w, "session", "/", "abc123", 3600)
|
||||
|
||||
// Delete a cookie
|
||||
cookies.DeleteCookie(w, "old_session", "/")
|
||||
|
||||
// Handle pagefrom for redirects
|
||||
if r.URL.Path == "/login" {
|
||||
cookies.SetPageFrom(w, r, "example.com")
|
||||
}
|
||||
|
||||
// Check pagefrom after login
|
||||
redirectTo := cookies.CheckPageFrom(w, r)
|
||||
http.Redirect(w, r, redirectTo, http.StatusFound)
|
||||
}
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
See the [wiki documentation](../golib/wiki/cookies.md) for detailed usage information and examples.
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
## Contributing
|
||||
|
||||
Please see the main golib repository for contributing guidelines.
|
||||
|
||||
## Related Projects
|
||||
|
||||
This package is part of the golib collection of utilities for Go applications and integrates well with other golib packages.
|
||||
405
cookies/cookies_test.go
Normal file
405
cookies/cookies_test.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetCookie(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cookie string
|
||||
path string
|
||||
value string
|
||||
maxAge int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic cookie",
|
||||
cookie: "test",
|
||||
path: "/",
|
||||
value: "value",
|
||||
maxAge: 3600,
|
||||
expected: "test=value; Path=/; Max-Age=3600; HttpOnly",
|
||||
},
|
||||
{
|
||||
name: "zero max age",
|
||||
cookie: "session",
|
||||
path: "/api",
|
||||
value: "abc123",
|
||||
maxAge: 0,
|
||||
expected: "session=abc123; Path=/api; HttpOnly",
|
||||
},
|
||||
{
|
||||
name: "negative max age",
|
||||
cookie: "temp",
|
||||
path: "/",
|
||||
value: "temp",
|
||||
maxAge: -1,
|
||||
expected: "temp=temp; Path=/; Max-Age=0; HttpOnly",
|
||||
},
|
||||
{
|
||||
name: "empty value",
|
||||
cookie: "empty",
|
||||
path: "/",
|
||||
value: "",
|
||||
maxAge: 3600,
|
||||
expected: "empty=; Path=/; Max-Age=3600; HttpOnly",
|
||||
},
|
||||
{
|
||||
name: "special characters in value",
|
||||
cookie: "data",
|
||||
path: "/",
|
||||
value: "test@123!#$%",
|
||||
maxAge: 7200,
|
||||
expected: "data=test@123!#$%; Path=/; Max-Age=7200; HttpOnly",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
SetCookie(w, tt.cookie, tt.path, tt.value, tt.maxAge)
|
||||
|
||||
headers := w.Header()["Set-Cookie"]
|
||||
if len(headers) != 1 {
|
||||
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers))
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the cookie header to check individual components
|
||||
cookieHeader := headers[0]
|
||||
|
||||
// Check that all expected components are present
|
||||
if !strings.Contains(cookieHeader, tt.cookie+"="+tt.value) {
|
||||
t.Errorf("Expected cookie name/value not found in: %s", cookieHeader)
|
||||
}
|
||||
if !strings.Contains(cookieHeader, "Path="+tt.path) {
|
||||
t.Errorf("Expected path not found in: %s", cookieHeader)
|
||||
}
|
||||
if !strings.Contains(cookieHeader, "HttpOnly") {
|
||||
t.Errorf("Expected HttpOnly not found in: %s", cookieHeader)
|
||||
}
|
||||
if tt.maxAge != 0 {
|
||||
expectedMaxAge := fmt.Sprintf("Max-Age=%d", tt.maxAge)
|
||||
if tt.maxAge < 0 {
|
||||
expectedMaxAge = "Max-Age=0" // Go normalizes negative Max-Age to 0
|
||||
}
|
||||
if !strings.Contains(cookieHeader, expectedMaxAge) {
|
||||
t.Errorf("Expected Max-Age not found in: %s", cookieHeader)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteCookie(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cookie string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic deletion",
|
||||
cookie: "test",
|
||||
path: "/",
|
||||
expected: "test=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; HttpOnly",
|
||||
},
|
||||
{
|
||||
name: "delete with specific path",
|
||||
cookie: "session",
|
||||
path: "/api",
|
||||
expected: "session=; Path=/api; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; HttpOnly",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
DeleteCookie(w, tt.cookie, tt.path)
|
||||
|
||||
headers := w.Header()["Set-Cookie"]
|
||||
if len(headers) != 1 {
|
||||
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers))
|
||||
return
|
||||
}
|
||||
|
||||
cookieHeader := headers[0]
|
||||
|
||||
// Check deletion-specific components
|
||||
if !strings.Contains(cookieHeader, tt.cookie+"=") {
|
||||
t.Errorf("Expected cookie name not found in: %s", cookieHeader)
|
||||
}
|
||||
if !strings.Contains(cookieHeader, "Path="+tt.path) {
|
||||
t.Errorf("Expected path not found in: %s", cookieHeader)
|
||||
}
|
||||
if !strings.Contains(cookieHeader, "Max-Age=0") {
|
||||
t.Errorf("Expected Max-Age=0 not found in: %s", cookieHeader)
|
||||
}
|
||||
if !strings.Contains(cookieHeader, "Expires=") {
|
||||
t.Errorf("Expected Expires not found in: %s", cookieHeader)
|
||||
}
|
||||
if !strings.Contains(cookieHeader, "HttpOnly") {
|
||||
t.Errorf("Expected HttpOnly not found in: %s", cookieHeader)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPageFrom(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cookieValue string
|
||||
cookiePath string
|
||||
expectedResult string
|
||||
shouldSet bool
|
||||
}{
|
||||
{
|
||||
name: "valid pagefrom cookie",
|
||||
cookieValue: "/dashboard",
|
||||
cookiePath: "/",
|
||||
expectedResult: "/dashboard",
|
||||
shouldSet: true,
|
||||
},
|
||||
{
|
||||
name: "no pagefrom cookie",
|
||||
cookieValue: "",
|
||||
cookiePath: "",
|
||||
expectedResult: "/",
|
||||
shouldSet: false,
|
||||
},
|
||||
{
|
||||
name: "empty pagefrom cookie",
|
||||
cookieValue: "",
|
||||
cookiePath: "/",
|
||||
expectedResult: "",
|
||||
shouldSet: true,
|
||||
},
|
||||
{
|
||||
name: "pagefrom with query params",
|
||||
cookieValue: "/search?q=test",
|
||||
cookiePath: "/",
|
||||
expectedResult: "/search?q=test",
|
||||
shouldSet: true,
|
||||
},
|
||||
{
|
||||
name: "pagefrom with special path",
|
||||
cookieValue: "/api/v1/users",
|
||||
cookiePath: "/api",
|
||||
expectedResult: "/api/v1/users",
|
||||
shouldSet: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := &http.Request{
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
if tt.shouldSet {
|
||||
cookie := &http.Cookie{
|
||||
Name: "pagefrom",
|
||||
Value: tt.cookieValue,
|
||||
Path: tt.cookiePath,
|
||||
}
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
|
||||
result := CheckPageFrom(w, r)
|
||||
|
||||
if result != tt.expectedResult {
|
||||
t.Errorf("CheckPageFrom() = %v, want %v", result, tt.expectedResult)
|
||||
}
|
||||
|
||||
// Verify that the cookie was deleted
|
||||
if tt.shouldSet {
|
||||
headers := w.Header()["Set-Cookie"]
|
||||
if len(headers) != 1 {
|
||||
t.Errorf("Expected 1 Set-Cookie header for deletion, got %d", len(headers))
|
||||
return
|
||||
}
|
||||
cookieHeader := headers[0]
|
||||
if !strings.Contains(cookieHeader, "pagefrom=") {
|
||||
t.Errorf("Expected pagefrom cookie deletion not found in: %s", cookieHeader)
|
||||
}
|
||||
if !strings.Contains(cookieHeader, "Max-Age=0") {
|
||||
t.Errorf("Expected Max-Age=0 for deletion not found in: %s", cookieHeader)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetPageFrom(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
referer string
|
||||
trustedHost string
|
||||
expectedSet bool
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "valid trusted host referer",
|
||||
referer: "http://example.com/dashboard",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: true,
|
||||
expectedValue: "/dashboard",
|
||||
},
|
||||
{
|
||||
name: "valid trusted host with https",
|
||||
referer: "https://example.com/profile",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: true,
|
||||
expectedValue: "/profile",
|
||||
},
|
||||
{
|
||||
name: "untrusted host",
|
||||
referer: "http://evil.com/dashboard",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: true,
|
||||
expectedValue: "/",
|
||||
},
|
||||
{
|
||||
name: "empty path",
|
||||
referer: "http://example.com",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: true,
|
||||
expectedValue: "/",
|
||||
},
|
||||
{
|
||||
name: "login path - should not set",
|
||||
referer: "http://example.com/login",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: false,
|
||||
expectedValue: "",
|
||||
},
|
||||
{
|
||||
name: "register path - should not set",
|
||||
referer: "http://example.com/register",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: false,
|
||||
expectedValue: "",
|
||||
},
|
||||
{
|
||||
name: "invalid referer URL",
|
||||
referer: "not-a-url",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: true,
|
||||
expectedValue: "/",
|
||||
},
|
||||
{
|
||||
name: "empty referer",
|
||||
referer: "",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: true,
|
||||
expectedValue: "/",
|
||||
},
|
||||
{
|
||||
name: "root path",
|
||||
referer: "http://example.com/",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: true,
|
||||
expectedValue: "/",
|
||||
},
|
||||
{
|
||||
name: "path with query string",
|
||||
referer: "http://example.com/search?q=test",
|
||||
trustedHost: "example.com",
|
||||
expectedSet: true,
|
||||
expectedValue: "/search",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := &http.Request{
|
||||
Header: make(http.Header),
|
||||
}
|
||||
if tt.referer != "" {
|
||||
r.Header.Set("Referer", tt.referer)
|
||||
}
|
||||
|
||||
SetPageFrom(w, r, tt.trustedHost)
|
||||
|
||||
headers := w.Header()["Set-Cookie"]
|
||||
if tt.expectedSet {
|
||||
if len(headers) != 1 {
|
||||
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers))
|
||||
return
|
||||
}
|
||||
cookieHeader := headers[0]
|
||||
if !strings.Contains(cookieHeader, "pagefrom="+tt.expectedValue) {
|
||||
t.Errorf("Expected pagefrom=%s not found in: %s", tt.expectedValue, cookieHeader)
|
||||
}
|
||||
} else {
|
||||
if len(headers) != 0 {
|
||||
t.Errorf("Expected no Set-Cookie header, got %d", len(headers))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
// Test the complete flow: SetPageFrom -> CheckPageFrom
|
||||
t.Run("complete flow", func(t *testing.T) {
|
||||
// Step 1: Set pagefrom cookie
|
||||
w1 := httptest.NewRecorder()
|
||||
r1 := &http.Request{
|
||||
Header: make(http.Header),
|
||||
}
|
||||
r1.Header.Set("Referer", "http://example.com/dashboard")
|
||||
|
||||
SetPageFrom(w1, r1, "example.com")
|
||||
|
||||
// Extract the cookie from the response
|
||||
headers1 := w1.Header()["Set-Cookie"]
|
||||
if len(headers1) != 1 {
|
||||
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers1))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the cookie was set correctly
|
||||
cookieHeader := headers1[0]
|
||||
if !strings.Contains(cookieHeader, "pagefrom=/dashboard") {
|
||||
t.Errorf("Expected pagefrom=/dashboard not found in: %s", cookieHeader)
|
||||
}
|
||||
|
||||
// Step 2: Check pagefrom cookie (should delete it)
|
||||
w2 := httptest.NewRecorder()
|
||||
r2 := &http.Request{
|
||||
Header: make(http.Header),
|
||||
}
|
||||
r2.AddCookie(&http.Cookie{
|
||||
Name: "pagefrom",
|
||||
Value: "/dashboard",
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
result := CheckPageFrom(w2, r2)
|
||||
|
||||
if result != "/dashboard" {
|
||||
t.Errorf("Expected result /dashboard, got %s", result)
|
||||
}
|
||||
|
||||
// Verify the cookie was deleted
|
||||
headers2 := w2.Header()["Set-Cookie"]
|
||||
if len(headers2) != 1 {
|
||||
t.Errorf("Expected 1 Set-Cookie header for deletion, got %d", len(headers2))
|
||||
return
|
||||
}
|
||||
|
||||
cookieHeader2 := headers2[0]
|
||||
// Check for deletion indicators (Max-Age=0 with Expires in the past)
|
||||
if !(strings.Contains(cookieHeader2, "Max-Age=0") && strings.Contains(cookieHeader2, "Expires=Thu, 01 Jan 1970")) {
|
||||
t.Errorf("Expected cookie deletion, got: %s", cookieHeader2)
|
||||
}
|
||||
})
|
||||
}
|
||||
26
cookies/doc.go
Normal file
26
cookies/doc.go
Normal file
@@ -0,0 +1,26 @@
|
||||
// Package cookies provides utilities for handling HTTP cookies in Go web applications.
|
||||
// It includes functions for setting secure cookies, deleting cookies, and managing
|
||||
// pagefrom tracking for post-login redirects.
|
||||
//
|
||||
// The package follows security best practices by setting the HttpOnly flag on all
|
||||
// cookies to prevent XSS attacks. The SetCookie function allows you to specify the
|
||||
// name, path, value, and max-age for cookies.
|
||||
//
|
||||
// The pagefrom functionality helps with user experience by remembering where a user
|
||||
// was before being redirected to login/register pages, then redirecting them back
|
||||
// after successful authentication.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// // Set a session cookie
|
||||
// cookies.SetCookie(w, "session", "/", "abc123", 3600)
|
||||
//
|
||||
// // Delete a cookie
|
||||
// cookies.DeleteCookie(w, "old_session", "/")
|
||||
//
|
||||
// // Handle pagefrom tracking
|
||||
// cookies.SetPageFrom(w, r, "example.com")
|
||||
// redirectTo := cookies.CheckPageFrom(w, r)
|
||||
//
|
||||
// All functions are designed to be safe and handle edge cases gracefully.
|
||||
package cookies
|
||||
21
env/LICENSE
vendored
Normal file
21
env/LICENSE
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 haelnorr
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
67
env/README.md
vendored
Normal file
67
env/README.md
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
# env v1.0.0
|
||||
|
||||
Environment variable utilities for Go applications with type safety and default values.
|
||||
|
||||
## Features
|
||||
|
||||
- Type-safe environment variable parsing
|
||||
- Support for all basic Go types (string, int variants, uint variants, bool, time.Duration)
|
||||
- Graceful fallback to default values
|
||||
- Comprehensive boolean parsing with multiple truthy/falsy values
|
||||
- Full test coverage
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get git.haelnorr.com/h/golib/env
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"git.haelnorr.com/h/golib/env"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// String values
|
||||
host := env.String("HOST", "localhost")
|
||||
|
||||
// Integer values (all sizes supported)
|
||||
port := env.Int("PORT", 8080)
|
||||
timeout := env.Int64("TIMEOUT_SECONDS", 30)
|
||||
|
||||
// Unsigned integer values
|
||||
maxConnections := env.UInt("MAX_CONNECTIONS", 100)
|
||||
|
||||
// Boolean values (supports many formats)
|
||||
debug := env.Bool("DEBUG", false)
|
||||
|
||||
// Duration values
|
||||
requestTimeout := env.Duration("REQUEST_TIMEOUT", 30*time.Second)
|
||||
|
||||
fmt.Printf("Server: %s:%d\n", host, port)
|
||||
fmt.Printf("Debug: %v\n", debug)
|
||||
fmt.Printf("Timeout: %v\n", requestTimeout)
|
||||
}
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
See the [wiki documentation](../golib/wiki/env.md) for detailed usage information and examples.
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
## Contributing
|
||||
|
||||
Please see the main golib repository for contributing guidelines.
|
||||
|
||||
## Related Projects
|
||||
|
||||
This package is part of the golib collection of utilities for Go applications.
|
||||
18
env/doc.go
vendored
Normal file
18
env/doc.go
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
// Package env provides utilities for reading environment variables with type safety
|
||||
// and default values. It supports common Go types including strings, integers (all sizes),
|
||||
// unsigned integers (all sizes), booleans, and time.Duration values.
|
||||
//
|
||||
// The package follows a simple pattern where each function takes a key name and a
|
||||
// default value, returning the parsed environment variable or the default if the
|
||||
// variable is not set or cannot be parsed.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// port := env.Int("PORT", 8080)
|
||||
// debug := env.Bool("DEBUG", false)
|
||||
// timeout := env.Duration("TIMEOUT", 30*time.Second)
|
||||
//
|
||||
// All functions gracefully handle missing environment variables by returning the
|
||||
// provided default value. They also handle parsing errors by falling back to the
|
||||
// default value.
|
||||
package env
|
||||
@@ -3,7 +3,7 @@
|
||||
//
|
||||
// ezconf allows you to:
|
||||
// - Load configurations from multiple packages using their ConfigFromEnv functions
|
||||
// - Parse package source code to extract environment variable documentation
|
||||
// - Parse config struct tags to extract environment variable documentation
|
||||
// - Generate and update .env files with all required environment variables
|
||||
// - Print environment variable lists with descriptions and current values
|
||||
// - Track additional custom environment variables
|
||||
@@ -40,16 +40,16 @@
|
||||
// // Use configuration...
|
||||
// }
|
||||
//
|
||||
// Alternatively, you can manually register packages:
|
||||
// Alternatively, you can manually register config structs:
|
||||
//
|
||||
// loader := ezconf.New()
|
||||
//
|
||||
// // Add package paths to parse for ENV comments
|
||||
// loader.AddPackagePath("/path/to/golib/hlog")
|
||||
// // Add config struct for tag parsing
|
||||
// loader.AddConfigStruct(&mypackage.Config{}, "MyPackage")
|
||||
//
|
||||
// // Add configuration loaders
|
||||
// loader.AddConfigFunc("hlog", func() (interface{}, error) {
|
||||
// return hlog.ConfigFromEnv()
|
||||
// loader.AddConfigFunc("mypackage", func() (interface{}, error) {
|
||||
// return mypackage.ConfigFromEnv()
|
||||
// })
|
||||
//
|
||||
// loader.Load()
|
||||
@@ -94,27 +94,34 @@
|
||||
// Default: "postgres://localhost/mydb",
|
||||
// })
|
||||
//
|
||||
// # ENV Comment Format
|
||||
// # Struct Tag Format
|
||||
//
|
||||
// ezconf parses struct field comments in the following format:
|
||||
// ezconf uses struct tags to define environment variable metadata:
|
||||
//
|
||||
// type Config struct {
|
||||
// // ENV LOG_LEVEL: Log level for the application (default: info)
|
||||
// LogLevel string
|
||||
//
|
||||
// // ENV DATABASE_URL: Database connection string (required)
|
||||
// DatabaseURL string
|
||||
// LogLevel string `ezconf:"LOG_LEVEL,description:Log level for the application,default:info"`
|
||||
// DatabaseURL string `ezconf:"DATABASE_URL,description:Database connection string,required"`
|
||||
// LogDir string `ezconf:"LOG_DIR,description:Directory for log files,required:when LOG_OUTPUT is file"`
|
||||
// }
|
||||
//
|
||||
// The format is:
|
||||
// - ENV ENV_VAR_NAME: Description (optional modifiers)
|
||||
// - (required) or (required if condition) - marks variable as required
|
||||
// - (default: value) - specifies default value
|
||||
// Tag components (comma-separated):
|
||||
// - First value: environment variable name (required)
|
||||
// - description:...: Description of the variable
|
||||
// - default:...: Default value
|
||||
// - required: Marks the variable as required
|
||||
// - required:condition: Marks as required with a condition description
|
||||
//
|
||||
// # Integration
|
||||
//
|
||||
// Packages can implement the Integration interface to provide automatic
|
||||
// registration with ezconf. The interface requires:
|
||||
// - Name() string: Registration key for the config
|
||||
// - ConfigPointer() any: Pointer to config struct for tag parsing
|
||||
// - ConfigFunc() func() (any, error): Function to load config from env
|
||||
// - GroupName() string: Display name for grouping env vars
|
||||
//
|
||||
// ezconf integrates with:
|
||||
// - All golib packages that follow the ConfigFromEnv pattern
|
||||
// - Any custom configuration structs with ENV comments
|
||||
// - Any custom configuration structs with ezconf struct tags
|
||||
// - Standard .env file format
|
||||
package ezconf
|
||||
|
||||
@@ -16,14 +16,19 @@ type EnvVar struct {
|
||||
Group string // Group name for organizing variables (e.g., "Database", "Logging")
|
||||
}
|
||||
|
||||
// configStruct holds a config struct pointer and its group name for parsing
|
||||
type configStruct struct {
|
||||
configPtr any
|
||||
groupName string
|
||||
}
|
||||
|
||||
// ConfigLoader manages configuration loading from multiple sources
|
||||
type ConfigLoader struct {
|
||||
configFuncs map[string]ConfigFunc // Map of config names to ConfigFromEnv functions
|
||||
packagePaths []string // Paths to packages to parse for ENV comments
|
||||
groupNames map[string]string // Map of package paths to group names
|
||||
extraEnvVars []EnvVar // Additional environment variables to track
|
||||
envVars []EnvVar // All extracted environment variables
|
||||
configs map[string]any // Loaded configurations
|
||||
configFuncs map[string]ConfigFunc // Map of config names to ConfigFromEnv functions
|
||||
configStructs []configStruct // Config struct pointers for tag parsing
|
||||
extraEnvVars []EnvVar // Additional environment variables to track
|
||||
envVars []EnvVar // All extracted environment variables
|
||||
configs map[string]any // Loaded configurations
|
||||
}
|
||||
|
||||
// ConfigFunc is a function that loads configuration from environment variables
|
||||
@@ -32,12 +37,11 @@ type ConfigFunc func() (any, error)
|
||||
// New creates a new ConfigLoader
|
||||
func New() *ConfigLoader {
|
||||
return &ConfigLoader{
|
||||
configFuncs: make(map[string]ConfigFunc),
|
||||
packagePaths: make([]string, 0),
|
||||
groupNames: make(map[string]string),
|
||||
extraEnvVars: make([]EnvVar, 0),
|
||||
envVars: make([]EnvVar, 0),
|
||||
configs: make(map[string]any),
|
||||
configFuncs: make(map[string]ConfigFunc),
|
||||
configStructs: make([]configStruct, 0),
|
||||
extraEnvVars: make([]EnvVar, 0),
|
||||
envVars: make([]EnvVar, 0),
|
||||
configs: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,16 +58,20 @@ func (cl *ConfigLoader) AddConfigFunc(name string, fn ConfigFunc) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPackagePath adds a package directory path to parse for ENV comments
|
||||
func (cl *ConfigLoader) AddPackagePath(path string) error {
|
||||
if path == "" {
|
||||
return errors.New("package path cannot be empty")
|
||||
// AddConfigStruct adds a config struct pointer for parsing ezconf tags.
|
||||
// The configPtr must be a pointer to a struct with ezconf struct tags.
|
||||
// The groupName is used for organizing environment variables in output.
|
||||
func (cl *ConfigLoader) AddConfigStruct(configPtr any, groupName string) error {
|
||||
if configPtr == nil {
|
||||
return errors.New("config pointer cannot be nil")
|
||||
}
|
||||
// Check if path exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return errors.Errorf("package path does not exist: %s", path)
|
||||
if groupName == "" {
|
||||
groupName = "Other"
|
||||
}
|
||||
cl.packagePaths = append(cl.packagePaths, path)
|
||||
cl.configStructs = append(cl.configStructs, configStruct{
|
||||
configPtr: configPtr,
|
||||
groupName: groupName,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -72,27 +80,22 @@ func (cl *ConfigLoader) AddEnvVar(envVar EnvVar) {
|
||||
cl.extraEnvVars = append(cl.extraEnvVars, envVar)
|
||||
}
|
||||
|
||||
// ParseEnvVars extracts environment variables from packages and extra vars
|
||||
// This can be called without having actual environment variables set
|
||||
// ParseEnvVars extracts environment variables from config struct tags and extra vars.
|
||||
// This can be called without having actual environment variables set.
|
||||
func (cl *ConfigLoader) ParseEnvVars() error {
|
||||
// Clear existing env vars to prevent duplicates
|
||||
cl.envVars = make([]EnvVar, 0)
|
||||
|
||||
// Parse packages for ENV comments
|
||||
for _, pkgPath := range cl.packagePaths {
|
||||
envVars, err := ParseConfigPackage(pkgPath)
|
||||
// Parse config structs for ezconf tags
|
||||
for _, cs := range cl.configStructs {
|
||||
envVars, err := ParseConfigStruct(cs.configPtr)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to parse package: %s", pkgPath)
|
||||
}
|
||||
|
||||
// Set group name for these variables from stored mapping
|
||||
groupName := cl.groupNames[pkgPath]
|
||||
if groupName == "" {
|
||||
groupName = "Other"
|
||||
return errors.Wrap(err, "failed to parse config struct")
|
||||
}
|
||||
|
||||
// Set group name for these variables
|
||||
for i := range envVars {
|
||||
envVars[i].Group = groupName
|
||||
envVars[i].Group = cs.groupName
|
||||
}
|
||||
|
||||
cl.envVars = append(cl.envVars, envVars...)
|
||||
@@ -109,8 +112,8 @@ func (cl *ConfigLoader) ParseEnvVars() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfigs executes the config functions to load actual configurations
|
||||
// This should be called after environment variables are properly set
|
||||
// LoadConfigs executes the config functions to load actual configurations.
|
||||
// This should be called after environment variables are properly set.
|
||||
func (cl *ConfigLoader) LoadConfigs() error {
|
||||
// Load configurations
|
||||
for name, fn := range cl.configFuncs {
|
||||
|
||||
@@ -2,11 +2,17 @@ package ezconf
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// testConfig is a Config struct used by multiple tests
|
||||
type testConfig struct {
|
||||
LogLevel string `ezconf:"LOG_LEVEL,description:Log level for the application,default:info"`
|
||||
LogOutput string `ezconf:"LOG_OUTPUT,description:Output destination,default:console"`
|
||||
DatabaseURL string `ezconf:"DATABASE_URL,description:Database connection string,required"`
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
loader := New()
|
||||
if loader == nil {
|
||||
@@ -16,8 +22,8 @@ func TestNew(t *testing.T) {
|
||||
if loader.configFuncs == nil {
|
||||
t.Error("configFuncs map is nil")
|
||||
}
|
||||
if loader.packagePaths == nil {
|
||||
t.Error("packagePaths slice is nil")
|
||||
if loader.configStructs == nil {
|
||||
t.Error("configStructs slice is nil")
|
||||
}
|
||||
if loader.extraEnvVars == nil {
|
||||
t.Error("extraEnvVars slice is nil")
|
||||
@@ -66,35 +72,39 @@ func TestAddConfigFunc_EmptyName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPackagePath(t *testing.T) {
|
||||
func TestAddConfigStruct(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
// Use current directory as test path
|
||||
err := loader.AddPackagePath(".")
|
||||
err := loader.AddConfigStruct(&testConfig{}, "Test")
|
||||
if err != nil {
|
||||
t.Errorf("AddPackagePath failed: %v", err)
|
||||
t.Errorf("AddConfigStruct failed: %v", err)
|
||||
}
|
||||
|
||||
if len(loader.packagePaths) != 1 {
|
||||
t.Errorf("expected 1 package path, got %d", len(loader.packagePaths))
|
||||
if len(loader.configStructs) != 1 {
|
||||
t.Errorf("expected 1 config struct, got %d", len(loader.configStructs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPackagePath_InvalidPath(t *testing.T) {
|
||||
func TestAddConfigStruct_NilPointer(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
err := loader.AddPackagePath("/nonexistent/path")
|
||||
err := loader.AddConfigStruct(nil, "Test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent path")
|
||||
t.Error("expected error for nil pointer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPackagePath_EmptyPath(t *testing.T) {
|
||||
func TestAddConfigStruct_EmptyGroupName(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
err := loader.AddPackagePath("")
|
||||
if err == nil {
|
||||
t.Error("expected error for empty path")
|
||||
err := loader.AddConfigStruct(&testConfig{}, "")
|
||||
if err != nil {
|
||||
t.Errorf("AddConfigStruct failed: %v", err)
|
||||
}
|
||||
|
||||
// Should default to "Other"
|
||||
if loader.configStructs[0].groupName != "Other" {
|
||||
t.Errorf("expected group name 'Other', got %s", loader.configStructs[0].groupName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,8 +141,8 @@ func TestLoad(t *testing.T) {
|
||||
return testCfg, nil
|
||||
})
|
||||
|
||||
// Add current package path
|
||||
loader.AddPackagePath(".")
|
||||
// Add config struct for tag parsing
|
||||
loader.AddConfigStruct(&testConfig{}, "Test")
|
||||
|
||||
// Add an extra env var
|
||||
loader.AddEnvVar(EnvVar{
|
||||
@@ -249,8 +259,8 @@ func TestParseEnvVars(t *testing.T) {
|
||||
return "test config", nil
|
||||
})
|
||||
|
||||
// Add current package path
|
||||
loader.AddPackagePath(".")
|
||||
// Add config struct for tag parsing
|
||||
loader.AddConfigStruct(&testConfig{}, "Test")
|
||||
|
||||
// Add an extra env var
|
||||
loader.AddEnvVar(EnvVar{
|
||||
@@ -353,8 +363,8 @@ func TestParseEnvVars_Then_LoadConfigs(t *testing.T) {
|
||||
return testCfg, nil
|
||||
})
|
||||
|
||||
// Add current package path
|
||||
loader.AddPackagePath(".")
|
||||
// Add config struct for tag parsing
|
||||
loader.AddConfigStruct(&testConfig{}, "Test")
|
||||
|
||||
// Add an extra env var
|
||||
loader.AddEnvVar(EnvVar{
|
||||
@@ -398,63 +408,68 @@ func TestParseEnvVars_Then_LoadConfigs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_Integration(t *testing.T) {
|
||||
// Integration test with real hlog package
|
||||
hlogPath := filepath.Join("..", "hlog")
|
||||
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
|
||||
t.Skip("hlog package not found, skipping integration test")
|
||||
}
|
||||
|
||||
func TestParseEnvVars_GroupName(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
// Add hlog package
|
||||
if err := loader.AddPackagePath(hlogPath); err != nil {
|
||||
t.Fatalf("failed to add hlog package: %v", err)
|
||||
}
|
||||
loader.AddConfigStruct(&testConfig{}, "MyGroup")
|
||||
|
||||
// Load without config function (just parse)
|
||||
if err := loader.Load(); err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
err := loader.ParseEnvVars()
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEnvVars failed: %v", err)
|
||||
}
|
||||
|
||||
envVars := loader.GetEnvVars()
|
||||
if len(envVars) == 0 {
|
||||
t.Error("expected env vars from hlog package")
|
||||
}
|
||||
|
||||
t.Logf("Found %d environment variables from hlog", len(envVars))
|
||||
for _, ev := range envVars {
|
||||
t.Logf(" %s: %s (default: %s, required: %t)", ev.Name, ev.Description, ev.Default, ev.Required)
|
||||
if ev.Group != "MyGroup" {
|
||||
t.Errorf("expected group 'MyGroup', got '%s' for var %s", ev.Group, ev.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEnvVars_GenerateEnvFile_Integration(t *testing.T) {
|
||||
// Test the new separated ParseEnvVars functionality
|
||||
hlogPath := filepath.Join("..", "hlog")
|
||||
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
|
||||
t.Skip("hlog package not found, skipping integration test")
|
||||
}
|
||||
|
||||
func TestParseEnvVars_CurrentValues(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
// Add hlog package
|
||||
if err := loader.AddPackagePath(hlogPath); err != nil {
|
||||
t.Fatalf("failed to add hlog package: %v", err)
|
||||
loader.AddConfigStruct(&testConfig{}, "Test")
|
||||
|
||||
// Set an env var
|
||||
t.Setenv("LOG_LEVEL", "debug")
|
||||
|
||||
err := loader.ParseEnvVars()
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEnvVars failed: %v", err)
|
||||
}
|
||||
|
||||
// Parse env vars without loading configs (this should work even if required env vars are missing)
|
||||
envVars := loader.GetEnvVars()
|
||||
for _, ev := range envVars {
|
||||
if ev.Name == "LOG_LEVEL" {
|
||||
if ev.CurrentValue != "debug" {
|
||||
t.Errorf("expected CurrentValue 'debug', got '%s'", ev.CurrentValue)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Error("LOG_LEVEL not found in env vars")
|
||||
}
|
||||
|
||||
func TestParseEnvVars_GenerateEnvFile_Integration(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
// Add config struct for tag parsing
|
||||
loader.AddConfigStruct(&testConfig{}, "Test")
|
||||
|
||||
// Parse env vars
|
||||
if err := loader.ParseEnvVars(); err != nil {
|
||||
t.Fatalf("ParseEnvVars failed: %v", err)
|
||||
}
|
||||
|
||||
envVars := loader.GetEnvVars()
|
||||
if len(envVars) == 0 {
|
||||
t.Error("expected env vars from hlog package")
|
||||
t.Error("expected env vars from config struct")
|
||||
}
|
||||
|
||||
// Now test that we can generate an env file without calling Load()
|
||||
tempDir := t.TempDir()
|
||||
envFile := filepath.Join(tempDir, "test-generated.env")
|
||||
envFile := tempDir + "/test-generated.env"
|
||||
|
||||
err := loader.GenerateEnvFile(envFile, false)
|
||||
if err != nil {
|
||||
@@ -472,16 +487,16 @@ func TestParseEnvVars_GenerateEnvFile_Integration(t *testing.T) {
|
||||
t.Error("expected header in generated file")
|
||||
}
|
||||
|
||||
// Should contain environment variables from hlog
|
||||
foundHlogVar := false
|
||||
// Should contain environment variables from config struct
|
||||
foundVar := false
|
||||
for _, ev := range envVars {
|
||||
if strings.Contains(output, ev.Name) {
|
||||
foundHlogVar = true
|
||||
foundVar = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundHlogVar {
|
||||
t.Error("expected to find at least one hlog environment variable in generated file")
|
||||
if !foundVar {
|
||||
t.Error("expected to find at least one environment variable in generated file")
|
||||
}
|
||||
|
||||
t.Logf("Successfully generated env file with %d variables", len(envVars))
|
||||
|
||||
@@ -1,31 +1,70 @@
|
||||
package ezconf
|
||||
|
||||
// Integration is an interface that packages can implement to provide
|
||||
type Integration struct {
|
||||
Name string
|
||||
ConfigPointer any
|
||||
ConfigFunc func() (any, error)
|
||||
GroupName string
|
||||
}
|
||||
|
||||
func NewIntegration(name, groupname string, cfgptr any, cfgfunc func() (any, error)) *Integration {
|
||||
return &Integration{
|
||||
name,
|
||||
cfgptr,
|
||||
cfgfunc,
|
||||
groupname,
|
||||
}
|
||||
}
|
||||
|
||||
// IntegrationDepr is an interface that packages can implement to provide
|
||||
// easy integration with ezconf
|
||||
type Integration interface {
|
||||
type IntegrationDepr interface {
|
||||
// Name returns the name to use when registering the config
|
||||
Name() string
|
||||
|
||||
// PackagePath returns the path to the package for source parsing
|
||||
PackagePath() string
|
||||
// ConfigPointer returns a pointer to the config struct for tag parsing
|
||||
ConfigPointer() any
|
||||
|
||||
// ConfigFunc returns the ConfigFromEnv function
|
||||
ConfigFunc() func() (interface{}, error)
|
||||
ConfigFunc() func() (any, error)
|
||||
|
||||
// GroupName returns the display name for grouping environment variables
|
||||
GroupName() string
|
||||
}
|
||||
|
||||
// RegisterIntegration registers a package that implements the Integration interface
|
||||
func (cl *ConfigLoader) RegisterIntegration(integration Integration) error {
|
||||
// Add package path
|
||||
pkgPath := integration.PackagePath()
|
||||
if err := cl.AddPackagePath(pkgPath); err != nil {
|
||||
// AddIntegration registers a package using an Integration object returned by another package
|
||||
func (cl *ConfigLoader) AddIntegration(integration *Integration) error {
|
||||
// Add config struct for tag parsing
|
||||
configPtr := integration.ConfigPointer
|
||||
if err := cl.AddConfigStruct(configPtr, integration.GroupName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store group name for this package
|
||||
cl.groupNames[pkgPath] = integration.GroupName()
|
||||
// Add config function
|
||||
if err := cl.AddConfigFunc(integration.Name, integration.ConfigFunc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddIntegrations registers multiple integrations at once
|
||||
func (cl *ConfigLoader) AddIntegrations(integrations ...*Integration) error {
|
||||
for _, integration := range integrations {
|
||||
if err := cl.AddIntegration(integration); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterIntegration registers a package that implements the Integration interface
|
||||
func (cl *ConfigLoader) RegisterIntegration(integration IntegrationDepr) error {
|
||||
// Add config struct for tag parsing
|
||||
configPtr := integration.ConfigPointer()
|
||||
if err := cl.AddConfigStruct(configPtr, integration.GroupName()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add config function
|
||||
if err := cl.AddConfigFunc(integration.Name(), integration.ConfigFunc()); err != nil {
|
||||
@@ -36,7 +75,7 @@ func (cl *ConfigLoader) RegisterIntegration(integration Integration) error {
|
||||
}
|
||||
|
||||
// RegisterIntegrations registers multiple integrations at once
|
||||
func (cl *ConfigLoader) RegisterIntegrations(integrations ...Integration) error {
|
||||
func (cl *ConfigLoader) RegisterIntegrations(integrations ...IntegrationDepr) error {
|
||||
for _, integration := range integrations {
|
||||
if err := cl.RegisterIntegration(integration); err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,24 +1,34 @@
|
||||
package ezconf
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Mock integration for testing
|
||||
// mockConfig is a test config struct with ezconf tags
|
||||
type mockConfig struct {
|
||||
Host string `ezconf:"MOCK_HOST,description:Host to connect to,default:localhost"`
|
||||
Port int `ezconf:"MOCK_PORT,description:Port to connect to,default:8080"`
|
||||
}
|
||||
|
||||
// mockConfig2 is a second test config struct
|
||||
type mockConfig2 struct {
|
||||
Token string `ezconf:"MOCK_TOKEN,description:API token,required"`
|
||||
}
|
||||
|
||||
// mockIntegration implements the Integration interface for testing
|
||||
type mockIntegration struct {
|
||||
name string
|
||||
packagePath string
|
||||
configFunc func() (interface{}, error)
|
||||
name string
|
||||
configPtr any
|
||||
configFunc func() (interface{}, error)
|
||||
groupName string
|
||||
}
|
||||
|
||||
func (m mockIntegration) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m mockIntegration) PackagePath() string {
|
||||
return m.packagePath
|
||||
func (m mockIntegration) ConfigPointer() any {
|
||||
return m.configPtr
|
||||
}
|
||||
|
||||
func (m mockIntegration) ConfigFunc() func() (interface{}, error) {
|
||||
@@ -26,15 +36,18 @@ func (m mockIntegration) ConfigFunc() func() (interface{}, error) {
|
||||
}
|
||||
|
||||
func (m mockIntegration) GroupName() string {
|
||||
return "Test Group"
|
||||
if m.groupName == "" {
|
||||
return "Test Group"
|
||||
}
|
||||
return m.groupName
|
||||
}
|
||||
|
||||
func TestRegisterIntegration(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
integration := mockIntegration{
|
||||
name: "test",
|
||||
packagePath: ".",
|
||||
name: "test",
|
||||
configPtr: &mockConfig{},
|
||||
configFunc: func() (interface{}, error) {
|
||||
return "test config", nil
|
||||
},
|
||||
@@ -45,9 +58,9 @@ func TestRegisterIntegration(t *testing.T) {
|
||||
t.Fatalf("RegisterIntegration failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify package path was added
|
||||
if len(loader.packagePaths) != 1 {
|
||||
t.Errorf("expected 1 package path, got %d", len(loader.packagePaths))
|
||||
// Verify config struct was added
|
||||
if len(loader.configStructs) != 1 {
|
||||
t.Errorf("expected 1 config struct, got %d", len(loader.configStructs))
|
||||
}
|
||||
|
||||
// Verify config func was added
|
||||
@@ -68,14 +81,46 @@ func TestRegisterIntegration(t *testing.T) {
|
||||
if cfg != "test config" {
|
||||
t.Errorf("expected 'test config', got %v", cfg)
|
||||
}
|
||||
|
||||
// Verify env vars were parsed from struct tags
|
||||
envVars := loader.GetEnvVars()
|
||||
if len(envVars) != 2 {
|
||||
t.Errorf("expected 2 env vars, got %d", len(envVars))
|
||||
}
|
||||
|
||||
foundHost := false
|
||||
foundPort := false
|
||||
for _, ev := range envVars {
|
||||
if ev.Name == "MOCK_HOST" {
|
||||
foundHost = true
|
||||
if ev.Default != "localhost" {
|
||||
t.Errorf("expected default 'localhost', got '%s'", ev.Default)
|
||||
}
|
||||
if ev.Group != "Test Group" {
|
||||
t.Errorf("expected group 'Test Group', got '%s'", ev.Group)
|
||||
}
|
||||
}
|
||||
if ev.Name == "MOCK_PORT" {
|
||||
foundPort = true
|
||||
if ev.Default != "8080" {
|
||||
t.Errorf("expected default '8080', got '%s'", ev.Default)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundHost {
|
||||
t.Error("MOCK_HOST not found in env vars")
|
||||
}
|
||||
if !foundPort {
|
||||
t.Error("MOCK_PORT not found in env vars")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterIntegration_InvalidPath(t *testing.T) {
|
||||
func TestRegisterIntegration_NilConfigPointer(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
integration := mockIntegration{
|
||||
name: "test",
|
||||
packagePath: "/nonexistent/path",
|
||||
name: "test",
|
||||
configPtr: nil,
|
||||
configFunc: func() (interface{}, error) {
|
||||
return "test config", nil
|
||||
},
|
||||
@@ -83,7 +128,7 @@ func TestRegisterIntegration_InvalidPath(t *testing.T) {
|
||||
|
||||
err := loader.RegisterIntegration(integration)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid package path")
|
||||
t.Error("expected error for nil config pointer")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,16 +136,16 @@ func TestRegisterIntegrations(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
integration1 := mockIntegration{
|
||||
name: "test1",
|
||||
packagePath: ".",
|
||||
name: "test1",
|
||||
configPtr: &mockConfig{},
|
||||
configFunc: func() (interface{}, error) {
|
||||
return "config1", nil
|
||||
},
|
||||
}
|
||||
|
||||
integration2 := mockIntegration{
|
||||
name: "test2",
|
||||
packagePath: ".",
|
||||
name: "test2",
|
||||
configPtr: &mockConfig2{},
|
||||
configFunc: func() (interface{}, error) {
|
||||
return "config2", nil
|
||||
},
|
||||
@@ -130,22 +175,28 @@ func TestRegisterIntegrations(t *testing.T) {
|
||||
if cfg1 != "config1" || cfg2 != "config2" {
|
||||
t.Error("config values mismatch")
|
||||
}
|
||||
|
||||
// Should have env vars from both structs
|
||||
envVars := loader.GetEnvVars()
|
||||
if len(envVars) != 3 {
|
||||
t.Errorf("expected 3 env vars (2 from mockConfig + 1 from mockConfig2), got %d", len(envVars))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterIntegrations_PartialFailure(t *testing.T) {
|
||||
loader := New()
|
||||
|
||||
integration1 := mockIntegration{
|
||||
name: "test1",
|
||||
packagePath: ".",
|
||||
name: "test1",
|
||||
configPtr: &mockConfig{},
|
||||
configFunc: func() (interface{}, error) {
|
||||
return "config1", nil
|
||||
},
|
||||
}
|
||||
|
||||
integration2 := mockIntegration{
|
||||
name: "test2",
|
||||
packagePath: "/nonexistent",
|
||||
name: "test2",
|
||||
configPtr: nil, // This should cause failure
|
||||
configFunc: func() (interface{}, error) {
|
||||
return "config2", nil
|
||||
},
|
||||
@@ -159,54 +210,5 @@ func TestRegisterIntegrations_PartialFailure(t *testing.T) {
|
||||
|
||||
func TestIntegration_Interface(t *testing.T) {
|
||||
// Verify that mockIntegration implements Integration interface
|
||||
var _ Integration = (*mockIntegration)(nil)
|
||||
}
|
||||
|
||||
func TestRegisterIntegration_RealPackage(t *testing.T) {
|
||||
// Integration test with real hlog package if available
|
||||
hlogPath := filepath.Join("..", "hlog")
|
||||
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
|
||||
t.Skip("hlog package not found, skipping integration test")
|
||||
}
|
||||
|
||||
loader := New()
|
||||
|
||||
// Create a simple integration for testing
|
||||
integration := mockIntegration{
|
||||
name: "hlog",
|
||||
packagePath: hlogPath,
|
||||
configFunc: func() (interface{}, error) {
|
||||
// Return a mock config instead of calling real ConfigFromEnv
|
||||
return struct{ LogLevel string }{LogLevel: "info"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
err := loader.RegisterIntegration(integration)
|
||||
if err != nil {
|
||||
t.Fatalf("RegisterIntegration with real package failed: %v", err)
|
||||
}
|
||||
|
||||
if err := loader.Load(); err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
|
||||
// Should have parsed env vars from hlog
|
||||
envVars := loader.GetEnvVars()
|
||||
if len(envVars) == 0 {
|
||||
t.Error("expected env vars from hlog package")
|
||||
}
|
||||
|
||||
// Check for known hlog variables
|
||||
foundLogLevel := false
|
||||
for _, ev := range envVars {
|
||||
if ev.Name == "LOG_LEVEL" {
|
||||
foundLogLevel = true
|
||||
t.Logf("Found LOG_LEVEL: %s", ev.Description)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundLogLevel {
|
||||
t.Error("expected to find LOG_LEVEL from hlog")
|
||||
}
|
||||
var _ IntegrationDepr = (*mockIntegration)(nil)
|
||||
}
|
||||
|
||||
178
ezconf/parser.go
178
ezconf/parser.go
@@ -1,146 +1,102 @@
|
||||
package ezconf
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ParseConfigFile parses a Go source file and extracts ENV comments from struct fields
|
||||
func ParseConfigFile(filename string) ([]EnvVar, error) {
|
||||
content, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read file")
|
||||
// ParseConfigStruct extracts environment variable metadata from a config
|
||||
// struct's ezconf struct tags using reflection.
|
||||
//
|
||||
// The configPtr parameter must be a pointer to a struct. Each field with an
|
||||
// ezconf tag will be parsed to extract environment variable information.
|
||||
//
|
||||
// Tag format: `ezconf:"VAR_NAME,description:Description text,default:value,required"`
|
||||
//
|
||||
// Components:
|
||||
// - First value: environment variable name (required)
|
||||
// - description:...: Description of the variable
|
||||
// - default:...: Default value
|
||||
// - required: Marks the variable as required (optionally required:condition)
|
||||
func ParseConfigStruct(configPtr any) ([]EnvVar, error) {
|
||||
if configPtr == nil {
|
||||
return nil, errors.New("config pointer cannot be nil")
|
||||
}
|
||||
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, filename, content, parser.ParseComments)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse file")
|
||||
v := reflect.ValueOf(configPtr)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return nil, errors.New("config must be a pointer to a struct")
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return nil, errors.New("config must be a pointer to a struct")
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
envVars := make([]EnvVar, 0)
|
||||
|
||||
// Walk through the AST
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
// Look for struct type declarations
|
||||
typeSpec, ok := n.(*ast.TypeSpec)
|
||||
if !ok {
|
||||
return true
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
tag := field.Tag.Get("ezconf")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
structType, ok := typeSpec.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
return true
|
||||
envVar, err := parseEzconfTag(tag)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to parse ezconf tag on field %s", field.Name)
|
||||
}
|
||||
|
||||
// Iterate through struct fields
|
||||
for _, field := range structType.Fields.List {
|
||||
var comment string
|
||||
|
||||
// Try to get from doc comment (comment before field)
|
||||
if field.Doc != nil && len(field.Doc.List) > 0 {
|
||||
comment = field.Doc.List[0].Text
|
||||
comment = strings.TrimPrefix(comment, "//")
|
||||
comment = strings.TrimSpace(comment)
|
||||
}
|
||||
|
||||
// Try to get from inline comment (comment after field)
|
||||
if comment == "" && field.Comment != nil && len(field.Comment.List) > 0 {
|
||||
comment = field.Comment.List[0].Text
|
||||
comment = strings.TrimPrefix(comment, "//")
|
||||
comment = strings.TrimSpace(comment)
|
||||
}
|
||||
|
||||
// Parse ENV comment
|
||||
if strings.HasPrefix(comment, "ENV ") {
|
||||
envVar, err := parseEnvComment(comment)
|
||||
if err == nil {
|
||||
envVars = append(envVars, *envVar)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
envVars = append(envVars, *envVar)
|
||||
}
|
||||
|
||||
return envVars, nil
|
||||
}
|
||||
|
||||
// ParseConfigPackage parses all Go files in a package directory and extracts ENV comments
|
||||
func ParseConfigPackage(packagePath string) ([]EnvVar, error) {
|
||||
// Find all .go files in the package
|
||||
files, err := filepath.Glob(filepath.Join(packagePath, "*.go"))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to glob package files")
|
||||
// parseEzconfTag parses an ezconf struct tag value to extract environment
|
||||
// variable information.
|
||||
//
|
||||
// Expected format: "VAR_NAME,description:Description text,default:value,required"
|
||||
func parseEzconfTag(tag string) (*EnvVar, error) {
|
||||
if tag == "" {
|
||||
return nil, errors.New("tag cannot be empty")
|
||||
}
|
||||
|
||||
allEnvVars := make([]EnvVar, 0)
|
||||
|
||||
for _, file := range files {
|
||||
// Skip test files
|
||||
if strings.HasSuffix(file, "_test.go") {
|
||||
continue
|
||||
}
|
||||
|
||||
envVars, err := ParseConfigFile(file)
|
||||
if err != nil {
|
||||
// Log error but continue with other files
|
||||
continue
|
||||
}
|
||||
|
||||
allEnvVars = append(allEnvVars, envVars...)
|
||||
}
|
||||
|
||||
return allEnvVars, nil
|
||||
}
|
||||
|
||||
// parseEnvComment parses a field comment to extract environment variable information.
|
||||
// Expected format: ENV ENV_NAME: Description (required <condition>) (default: <value>)
|
||||
func parseEnvComment(comment string) (*EnvVar, error) {
|
||||
// Check if comment starts with ENV
|
||||
if !strings.HasPrefix(comment, "ENV ") {
|
||||
return nil, errors.New("comment does not start with 'ENV '")
|
||||
}
|
||||
|
||||
// Remove "ENV " prefix
|
||||
comment = strings.TrimPrefix(comment, "ENV ")
|
||||
|
||||
// Extract env var name (everything before the first colon)
|
||||
colonIdx := strings.Index(comment, ":")
|
||||
if colonIdx == -1 {
|
||||
return nil, errors.New("missing colon separator")
|
||||
parts := strings.Split(tag, ",")
|
||||
if len(parts) == 0 {
|
||||
return nil, errors.New("tag cannot be empty")
|
||||
}
|
||||
|
||||
envVar := &EnvVar{
|
||||
Name: strings.TrimSpace(comment[:colonIdx]),
|
||||
Name: strings.TrimSpace(parts[0]),
|
||||
}
|
||||
|
||||
// Extract description and optional parts
|
||||
remainder := strings.TrimSpace(comment[colonIdx+1:])
|
||||
|
||||
// Check for (required ...) pattern
|
||||
requiredPattern := regexp.MustCompile(`\(required[^)]*\)`)
|
||||
if requiredPattern.MatchString(remainder) {
|
||||
envVar.Required = true
|
||||
remainder = requiredPattern.ReplaceAllString(remainder, "")
|
||||
if envVar.Name == "" {
|
||||
return nil, errors.New("environment variable name cannot be empty")
|
||||
}
|
||||
|
||||
// Check for (default: ...) pattern
|
||||
defaultPattern := regexp.MustCompile(`\(default:\s*([^)]*)\)`)
|
||||
if matches := defaultPattern.FindStringSubmatch(remainder); len(matches) > 1 {
|
||||
envVar.Default = strings.TrimSpace(matches[1])
|
||||
remainder = defaultPattern.ReplaceAllString(remainder, "")
|
||||
}
|
||||
for _, part := range parts[1:] {
|
||||
part = strings.TrimSpace(part)
|
||||
|
||||
// What remains is the description
|
||||
envVar.Description = strings.TrimSpace(remainder)
|
||||
switch {
|
||||
case strings.HasPrefix(part, "description:"):
|
||||
envVar.Description = strings.TrimSpace(strings.TrimPrefix(part, "description:"))
|
||||
case strings.HasPrefix(part, "default:"):
|
||||
envVar.Default = strings.TrimSpace(strings.TrimPrefix(part, "default:"))
|
||||
case part == "required":
|
||||
envVar.Required = true
|
||||
case strings.HasPrefix(part, "required:"):
|
||||
envVar.Required = true
|
||||
// Store the condition in the description if it adds context
|
||||
condition := strings.TrimSpace(strings.TrimPrefix(part, "required:"))
|
||||
if condition != "" && envVar.Description != "" {
|
||||
envVar.Description = envVar.Description + " (required " + condition + ")"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return envVar, nil
|
||||
}
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
package ezconf
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseEnvComment(t *testing.T) {
|
||||
func TestParseEzconfTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
comment string
|
||||
tag string
|
||||
wantEnvVar *EnvVar
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "simple env variable",
|
||||
comment: "ENV LOG_LEVEL: Log level for the application",
|
||||
name: "simple env variable",
|
||||
tag: "LOG_LEVEL,description:Log level for the application",
|
||||
wantEnvVar: &EnvVar{
|
||||
Name: "LOG_LEVEL",
|
||||
Description: "Log level for the application",
|
||||
@@ -25,8 +23,8 @@ func TestParseEnvComment(t *testing.T) {
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "env variable with default",
|
||||
comment: "ENV LOG_LEVEL: Log level for the application (default: info)",
|
||||
name: "env variable with default",
|
||||
tag: "LOG_LEVEL,description:Log level for the application,default:info",
|
||||
wantEnvVar: &EnvVar{
|
||||
Name: "LOG_LEVEL",
|
||||
Description: "Log level for the application",
|
||||
@@ -36,8 +34,8 @@ func TestParseEnvComment(t *testing.T) {
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "required env variable",
|
||||
comment: "ENV DATABASE_URL: Database connection string (required)",
|
||||
name: "required env variable",
|
||||
tag: "DATABASE_URL,description:Database connection string,required",
|
||||
wantEnvVar: &EnvVar{
|
||||
Name: "DATABASE_URL",
|
||||
Description: "Database connection string",
|
||||
@@ -47,25 +45,36 @@ func TestParseEnvComment(t *testing.T) {
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "required with condition and default",
|
||||
comment: "ENV LOG_DIR: Directory for log files (required when LOG_OUTPUT is file) (default: /var/log)",
|
||||
name: "required with condition and default",
|
||||
tag: "LOG_DIR,description:Directory for log files,required:when LOG_OUTPUT is file,default:/var/log",
|
||||
wantEnvVar: &EnvVar{
|
||||
Name: "LOG_DIR",
|
||||
Description: "Directory for log files",
|
||||
Description: "Directory for log files (required when LOG_OUTPUT is file)",
|
||||
Required: true,
|
||||
Default: "/var/log",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing colon",
|
||||
comment: "ENV LOG_LEVEL Log level",
|
||||
name: "name only",
|
||||
tag: "SIMPLE_VAR",
|
||||
wantEnvVar: &EnvVar{
|
||||
Name: "SIMPLE_VAR",
|
||||
Description: "",
|
||||
Required: false,
|
||||
Default: "",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty tag",
|
||||
tag: "",
|
||||
wantEnvVar: nil,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "not an ENV comment",
|
||||
comment: "This is a regular comment",
|
||||
name: "empty name",
|
||||
tag: ",description:some desc",
|
||||
wantEnvVar: nil,
|
||||
expectError: true,
|
||||
},
|
||||
@@ -73,7 +82,7 @@ func TestParseEnvComment(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
envVar, err := parseEnvComment(tt.comment)
|
||||
envVar, err := parseEzconfTag(tt.tag)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
@@ -103,32 +112,17 @@ func TestParseEnvComment(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigFile(t *testing.T) {
|
||||
// Create a temporary test file
|
||||
tempDir := t.TempDir()
|
||||
testFile := filepath.Join(tempDir, "config.go")
|
||||
|
||||
content := `package testpkg
|
||||
|
||||
type Config struct {
|
||||
// ENV LOG_LEVEL: Log level for the application (default: info)
|
||||
LogLevel string
|
||||
|
||||
// ENV LOG_OUTPUT: Output destination (default: console)
|
||||
LogOutput string
|
||||
|
||||
// ENV DATABASE_URL: Database connection string (required)
|
||||
DatabaseURL string
|
||||
}
|
||||
`
|
||||
|
||||
if err := os.WriteFile(testFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create test file: %v", err)
|
||||
func TestParseConfigStruct(t *testing.T) {
|
||||
type TestConfig struct {
|
||||
LogLevel string `ezconf:"LOG_LEVEL,description:Log level for the application,default:info"`
|
||||
LogOutput string `ezconf:"LOG_OUTPUT,description:Output destination,default:console"`
|
||||
DatabaseURL string `ezconf:"DATABASE_URL,description:Database connection string,required"`
|
||||
NoTag string
|
||||
}
|
||||
|
||||
envVars, err := ParseConfigFile(testFile)
|
||||
envVars, err := ParseConfigStruct(&TestConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("ParseConfigFile failed: %v", err)
|
||||
t.Fatalf("ParseConfigStruct failed: %v", err)
|
||||
}
|
||||
|
||||
if len(envVars) != 3 {
|
||||
@@ -152,51 +146,70 @@ type Config struct {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigPackage(t *testing.T) {
|
||||
// Test with actual hlog package
|
||||
hlogPath := filepath.Join("..", "hlog")
|
||||
if _, err := os.Stat(hlogPath); os.IsNotExist(err) {
|
||||
t.Skip("hlog package not found, skipping integration test")
|
||||
}
|
||||
|
||||
envVars, err := ParseConfigPackage(hlogPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseConfigPackage failed: %v", err)
|
||||
}
|
||||
|
||||
if len(envVars) == 0 {
|
||||
t.Error("expected at least one env var from hlog package")
|
||||
}
|
||||
|
||||
// Check for known hlog variables
|
||||
foundLogLevel := false
|
||||
for _, envVar := range envVars {
|
||||
if envVar.Name == "LOG_LEVEL" {
|
||||
foundLogLevel = true
|
||||
t.Logf("Found LOG_LEVEL: %s", envVar.Description)
|
||||
}
|
||||
}
|
||||
|
||||
if !foundLogLevel {
|
||||
t.Error("expected to find LOG_LEVEL in hlog package")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigFile_InvalidFile(t *testing.T) {
|
||||
_, err := ParseConfigFile("/nonexistent/file.go")
|
||||
func TestParseConfigStruct_NilPointer(t *testing.T) {
|
||||
_, err := ParseConfigStruct(nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent file")
|
||||
t.Error("expected error for nil pointer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigPackage_InvalidPath(t *testing.T) {
|
||||
envVars, err := ParseConfigPackage("/nonexistent/package")
|
||||
func TestParseConfigStruct_NotPointer(t *testing.T) {
|
||||
type TestConfig struct {
|
||||
Foo string `ezconf:"FOO,description:test"`
|
||||
}
|
||||
_, err := ParseConfigStruct(TestConfig{})
|
||||
if err == nil {
|
||||
t.Error("expected error for non-pointer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigStruct_NotStruct(t *testing.T) {
|
||||
str := "not a struct"
|
||||
_, err := ParseConfigStruct(&str)
|
||||
if err == nil {
|
||||
t.Error("expected error for non-struct pointer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigStruct_NoTags(t *testing.T) {
|
||||
type EmptyConfig struct {
|
||||
Foo string
|
||||
Bar int
|
||||
}
|
||||
|
||||
envVars, err := ParseConfigStruct(&EmptyConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("ParseConfigPackage should not error on invalid path: %v", err)
|
||||
t.Fatalf("ParseConfigStruct failed: %v", err)
|
||||
}
|
||||
|
||||
// Should return empty slice for invalid path
|
||||
if len(envVars) != 0 {
|
||||
t.Errorf("expected 0 env vars for invalid path, got %d", len(envVars))
|
||||
t.Errorf("expected 0 env vars for struct with no tags, got %d", len(envVars))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigStruct_UnexportedFields(t *testing.T) {
|
||||
type TestConfig struct {
|
||||
exported string `ezconf:"EXPORTED,description:An exported field"`
|
||||
unexported string `ezconf:"UNEXPORTED,description:An unexported field"`
|
||||
}
|
||||
|
||||
envVars, err := ParseConfigStruct(&TestConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("ParseConfigStruct failed: %v", err)
|
||||
}
|
||||
|
||||
if len(envVars) != 2 {
|
||||
t.Errorf("expected 2 env vars (both exported and unexported), got %d", len(envVars))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigStruct_InvalidTag(t *testing.T) {
|
||||
type TestConfig struct {
|
||||
Bad string `ezconf:",description:missing name"`
|
||||
}
|
||||
|
||||
_, err := ParseConfigStruct(&TestConfig{})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid tag")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
// It can be populated from environment variables using ConfigFromEnv
|
||||
// or created programmatically.
|
||||
type Config struct {
|
||||
LogLevel Level // ENV LOG_LEVEL: Log level for the logger - trace, debug, info, warn, error, fatal, panic (default: info)
|
||||
LogOutput string // ENV LOG_OUTPUT: Output destination for logs - console, file, or both (default: console)
|
||||
LogDir string // ENV LOG_DIR: Directory path for log files (required when LOG_OUTPUT is "file" or "both")
|
||||
LogFileName string // ENV LOG_FILE_NAME: Name of the log file (required when LOG_OUTPUT is "file" or "both")
|
||||
LogAppend bool // ENV LOG_APPEND: Append to existing log file or overwrite (default: true)
|
||||
LogLevel Level `ezconf:"LOG_LEVEL,description:Log level for the logger - trace debug info warn error fatal panic,default:info"`
|
||||
LogOutput string `ezconf:"LOG_OUTPUT,description:Output destination for logs - console file or both,default:console"`
|
||||
LogDir string `ezconf:"LOG_DIR,description:Directory path for log files,required:when LOG_OUTPUT is file or both"`
|
||||
LogFileName string `ezconf:"LOG_FILE_NAME,description:Name of the log file,required:when LOG_OUTPUT is file or both"`
|
||||
LogAppend bool `ezconf:"LOG_APPEND,description:Append to existing log file or overwrite,default:true"`
|
||||
}
|
||||
|
||||
// ConfigFromEnv loads logger configuration from environment variables.
|
||||
|
||||
@@ -1,35 +1,9 @@
|
||||
package hlog
|
||||
|
||||
import "runtime"
|
||||
import "git.haelnorr.com/h/golib/ezconf"
|
||||
|
||||
// EZConfIntegration provides integration with ezconf for automatic configuration
|
||||
type EZConfIntegration struct{}
|
||||
|
||||
// PackagePath returns the path to the hlog package for source parsing
|
||||
func (e EZConfIntegration) PackagePath() string {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
// Return directory of this file
|
||||
return filename[:len(filename)-len("/ezconf.go")]
|
||||
}
|
||||
|
||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
return ConfigFromEnv()
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name to use when registering with ezconf
|
||||
func (e EZConfIntegration) Name() string {
|
||||
return "hlog"
|
||||
}
|
||||
|
||||
// GroupName returns the display name for grouping environment variables
|
||||
func (e EZConfIntegration) GroupName() string {
|
||||
return "HLog"
|
||||
}
|
||||
|
||||
// NewEZConfIntegration creates a new EZConf integration helper
|
||||
func NewEZConfIntegration() EZConfIntegration {
|
||||
return EZConfIntegration{}
|
||||
// NewEZConfIntegration creates a new EZConf integration
|
||||
func NewEZConfIntegration() *ezconf.Integration {
|
||||
return ezconf.NewIntegration("hlog", "HLog",
|
||||
&Config{}, func() (any, error) { return ConfigFromEnv() })
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ require (
|
||||
github.com/rs/zerolog v1.34.0
|
||||
)
|
||||
|
||||
require git.haelnorr.com/h/golib/ezconf v0.2.1
|
||||
|
||||
require (
|
||||
git.haelnorr.com/h/golib/env v0.9.1
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
|
||||
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
||||
git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
|
||||
git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string // ENV HWS_HOST: Host to listen on (default: 127.0.0.1)
|
||||
Port uint64 // ENV HWS_PORT: Port to listen on (default: 3000)
|
||||
GZIP bool // ENV HWS_GZIP: Flag for GZIP compression on requests (default: false)
|
||||
ReadHeaderTimeout time.Duration // ENV HWS_READ_HEADER_TIMEOUT: Timeout for reading request headers in seconds (default: 2)
|
||||
WriteTimeout time.Duration // ENV HWS_WRITE_TIMEOUT: Timeout for writing requests in seconds (default: 10)
|
||||
IdleTimeout time.Duration // ENV HWS_IDLE_TIMEOUT: Timeout for idle connections in seconds (default: 120)
|
||||
ShutdownDelay time.Duration // ENV HWS_SHUTDOWN_DELAY: Delay in seconds before server shutsdown when Shutdown is called (default: 5)
|
||||
Host string `ezconf:"HWS_HOST,description:Host to listen on,default:127.0.0.1"`
|
||||
Port uint64 `ezconf:"HWS_PORT,description:Port to listen on,default:3000"`
|
||||
GZIP bool `ezconf:"HWS_GZIP,description:Flag for GZIP compression on requests,default:false"`
|
||||
ReadHeaderTimeout time.Duration `ezconf:"HWS_READ_HEADER_TIMEOUT,description:Timeout for reading request headers in seconds,default:2"`
|
||||
WriteTimeout time.Duration `ezconf:"HWS_WRITE_TIMEOUT,description:Timeout for writing requests in seconds,default:10"`
|
||||
IdleTimeout time.Duration `ezconf:"HWS_IDLE_TIMEOUT,description:Timeout for idle connections in seconds,default:120"`
|
||||
ShutdownDelay time.Duration `ezconf:"HWS_SHUTDOWN_DELAY,description:Delay in seconds before server shuts down when Shutdown is called,default:5"`
|
||||
}
|
||||
|
||||
// ConfigFromEnv returns a Config struct loaded from the environment variables
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
func Test_ConfigFromEnv(t *testing.T) {
|
||||
t.Run("Default values when no env vars set", func(t *testing.T) {
|
||||
// Clear any existing env vars
|
||||
os.Unsetenv("HWS_HOST")
|
||||
os.Unsetenv("HWS_PORT")
|
||||
os.Unsetenv("HWS_GZIP")
|
||||
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_HOST")
|
||||
_ = os.Unsetenv("HWS_PORT")
|
||||
_ = os.Unsetenv("HWS_GZIP")
|
||||
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -33,8 +33,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Custom host", func(t *testing.T) {
|
||||
os.Setenv("HWS_HOST", "192.168.1.1")
|
||||
defer os.Unsetenv("HWS_HOST")
|
||||
_ = os.Setenv("HWS_HOST", "192.168.1.1")
|
||||
defer func() {
|
||||
_ = os.Unsetenv("HWS_HOST")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -42,8 +44,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Custom port", func(t *testing.T) {
|
||||
os.Setenv("HWS_PORT", "8080")
|
||||
defer os.Unsetenv("HWS_PORT")
|
||||
_ = os.Setenv("HWS_PORT", "8080")
|
||||
defer func() {
|
||||
_ = os.Unsetenv("HWS_PORT")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -51,8 +55,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("GZIP enabled", func(t *testing.T) {
|
||||
os.Setenv("HWS_GZIP", "true")
|
||||
defer os.Unsetenv("HWS_GZIP")
|
||||
_ = os.Setenv("HWS_GZIP", "true")
|
||||
defer func() {
|
||||
_ = os.Unsetenv("HWS_GZIP")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -60,12 +66,14 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Custom timeouts", func(t *testing.T) {
|
||||
os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
|
||||
os.Setenv("HWS_WRITE_TIMEOUT", "30")
|
||||
os.Setenv("HWS_IDLE_TIMEOUT", "300")
|
||||
defer os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
defer os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
defer os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
|
||||
_ = os.Setenv("HWS_WRITE_TIMEOUT", "30")
|
||||
_ = os.Setenv("HWS_IDLE_TIMEOUT", "300")
|
||||
defer func() {
|
||||
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
require.NoError(t, err)
|
||||
@@ -75,19 +83,19 @@ func Test_ConfigFromEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("All custom values", func(t *testing.T) {
|
||||
os.Setenv("HWS_HOST", "0.0.0.0")
|
||||
os.Setenv("HWS_PORT", "9000")
|
||||
os.Setenv("HWS_GZIP", "true")
|
||||
os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
|
||||
os.Setenv("HWS_WRITE_TIMEOUT", "15")
|
||||
os.Setenv("HWS_IDLE_TIMEOUT", "180")
|
||||
_ = os.Setenv("HWS_HOST", "0.0.0.0")
|
||||
_ = os.Setenv("HWS_PORT", "9000")
|
||||
_ = os.Setenv("HWS_GZIP", "true")
|
||||
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
|
||||
_ = os.Setenv("HWS_WRITE_TIMEOUT", "15")
|
||||
_ = os.Setenv("HWS_IDLE_TIMEOUT", "180")
|
||||
defer func() {
|
||||
os.Unsetenv("HWS_HOST")
|
||||
os.Unsetenv("HWS_PORT")
|
||||
os.Unsetenv("HWS_GZIP")
|
||||
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_HOST")
|
||||
_ = os.Unsetenv("HWS_PORT")
|
||||
_ = os.Unsetenv("HWS_GZIP")
|
||||
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||
}()
|
||||
|
||||
config, err := hws.ConfigFromEnv()
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Error to use with Server.ThrowError
|
||||
// HWSError wraps an error with other information for use with HWS features
|
||||
type HWSError struct {
|
||||
StatusCode int // HTTP Status code
|
||||
Message string // Error message
|
||||
@@ -41,7 +41,7 @@ type ErrorPage interface {
|
||||
}
|
||||
|
||||
// AddErrorPage registers a handler that returns an ErrorPage
|
||||
func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||
func (s *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError})
|
||||
@@ -56,7 +56,7 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||
return errors.New("Render method of the error page did not write anything to the response writer")
|
||||
}
|
||||
|
||||
server.errorPage = pageFunc
|
||||
s.errorPage = pageFunc
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -64,7 +64,19 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||
// the error with the level specified by the HWSError.
|
||||
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
|
||||
// and the request chain should be terminated.
|
||||
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) error {
|
||||
func (s *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) {
|
||||
err := s.throwError(w, r, error)
|
||||
if err != nil {
|
||||
s.LogError(error)
|
||||
s.LogError(HWSError{
|
||||
Message: "Error occured during throwError",
|
||||
Error: errors.Wrap(err, "s.throwError"),
|
||||
Level: ErrorERROR,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) throwError(w http.ResponseWriter, r *http.Request, error HWSError) error {
|
||||
if error.StatusCode <= 0 {
|
||||
return errors.New("HWSError.StatusCode cannot be 0.")
|
||||
}
|
||||
@@ -77,32 +89,27 @@ func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error H
|
||||
if r == nil {
|
||||
return errors.New("Request cannot be nil")
|
||||
}
|
||||
if !server.IsReady() {
|
||||
if !s.IsReady() {
|
||||
return errors.New("ThrowError called before server started")
|
||||
}
|
||||
w.WriteHeader(error.StatusCode)
|
||||
server.LogError(error)
|
||||
if server.errorPage == nil {
|
||||
server.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
|
||||
s.LogError(error)
|
||||
if s.errorPage == nil {
|
||||
s.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
|
||||
return nil
|
||||
}
|
||||
if error.RenderErrorPage {
|
||||
server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
|
||||
errPage, err := server.errorPage(error)
|
||||
s.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
|
||||
errPage, err := s.errorPage(error)
|
||||
if err != nil {
|
||||
server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
|
||||
s.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
|
||||
}
|
||||
err = errPage.Render(r.Context(), w)
|
||||
if err != nil {
|
||||
server.LogError(HWSError{Message: "Failed to render error page", Error: err})
|
||||
s.LogError(HWSError{Message: "Failed to render error page", Error: err})
|
||||
}
|
||||
} else {
|
||||
server.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
|
||||
s.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) ThrowFatal(w http.ResponseWriter, err error) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
server.LogFatal(err)
|
||||
}
|
||||
|
||||
@@ -14,22 +14,26 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type goodPage struct{}
|
||||
type badPage struct{}
|
||||
type (
|
||||
goodPage struct{}
|
||||
badPage struct{}
|
||||
)
|
||||
|
||||
func goodRender(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return goodPage{}, nil
|
||||
}
|
||||
|
||||
func badRender1(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return badPage{}, nil
|
||||
}
|
||||
|
||||
func badRender2(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return nil, errors.New("I'm an error")
|
||||
}
|
||||
|
||||
func (g goodPage) Render(ctx context.Context, w io.Writer) error {
|
||||
w.Write([]byte("Test write to ResponseWriter"))
|
||||
return nil
|
||||
_, err := w.Write([]byte("Test write to ResponseWriter"))
|
||||
return err
|
||||
}
|
||||
|
||||
func (b badPage) Render(ctx context.Context, w io.Writer) error {
|
||||
@@ -85,40 +89,42 @@ func Test_ThrowError(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
t.Run("Server not started", func(t *testing.T) {
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
buf.Reset()
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "Error",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
// ThrowError logs errors internally when validation fails
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "ThrowError called before server started")
|
||||
})
|
||||
|
||||
startTestServer(t, server)
|
||||
defer server.Shutdown(t.Context())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request *http.Request
|
||||
error hws.HWSError
|
||||
valid bool
|
||||
name string
|
||||
request *http.Request
|
||||
error hws.HWSError
|
||||
expectLogItem string
|
||||
}{
|
||||
{
|
||||
name: "No HWSError.Status code",
|
||||
request: nil,
|
||||
error: hws.HWSError{},
|
||||
valid: false,
|
||||
name: "No HWSError.Status code",
|
||||
request: nil,
|
||||
error: hws.HWSError{},
|
||||
expectLogItem: "HWSError.StatusCode cannot be 0",
|
||||
},
|
||||
{
|
||||
name: "Negative HWSError.Status code",
|
||||
request: nil,
|
||||
error: hws.HWSError{StatusCode: -1},
|
||||
valid: false,
|
||||
name: "Negative HWSError.Status code",
|
||||
request: nil,
|
||||
error: hws.HWSError{StatusCode: -1},
|
||||
expectLogItem: "HWSError.StatusCode cannot be 0",
|
||||
},
|
||||
{
|
||||
name: "No HWSError.Message",
|
||||
request: nil,
|
||||
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
|
||||
valid: false,
|
||||
name: "No HWSError.Message",
|
||||
request: nil,
|
||||
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
|
||||
expectLogItem: "HWSError.Message cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "No HWSError.Error",
|
||||
@@ -127,7 +133,7 @@ func Test_ThrowError(t *testing.T) {
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
},
|
||||
valid: false,
|
||||
expectLogItem: "HWSError.Error cannot be nil",
|
||||
},
|
||||
{
|
||||
name: "No request provided",
|
||||
@@ -137,7 +143,7 @@ func Test_ThrowError(t *testing.T) {
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
},
|
||||
valid: false,
|
||||
expectLogItem: "Request cannot be nil",
|
||||
},
|
||||
{
|
||||
name: "Valid",
|
||||
@@ -147,106 +153,92 @@ func Test_ThrowError(t *testing.T) {
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
},
|
||||
valid: true,
|
||||
expectLogItem: "An error occured",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf.Reset()
|
||||
rr := httptest.NewRecorder()
|
||||
err := server.ThrowError(rr, tt.request, tt.error)
|
||||
if tt.valid {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
t.Log(err)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
server.ThrowError(rr, tt.request, tt.error)
|
||||
// ThrowError no longer returns errors; check logs instead
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, tt.expectLogItem)
|
||||
})
|
||||
}
|
||||
t.Run("Log level set correctly", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
Level: hws.ErrorWARN,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
_, err = buf.ReadString([]byte(" ")[0])
|
||||
_, err := buf.ReadString([]byte(" ")[0])
|
||||
require.NoError(t, err)
|
||||
loglvl, err := buf.ReadString([]byte(" ")[0])
|
||||
assert.NoError(t, err)
|
||||
if loglvl != "\x1b[33mWRN\x1b[0m " {
|
||||
err = errors.New("Log level not set correctly")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "\x1b[33mWRN\x1b[0m ", loglvl, "Log level should be WRN for ErrorWARN")
|
||||
|
||||
buf.Reset()
|
||||
err = server.ThrowError(rr, req, hws.HWSError{
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
_, err = buf.ReadString([]byte(" ")[0])
|
||||
require.NoError(t, err)
|
||||
loglvl, err = buf.ReadString([]byte(" ")[0])
|
||||
assert.NoError(t, err)
|
||||
if loglvl != "\x1b[31mERR\x1b[0m " {
|
||||
err = errors.New("Log level not set correctly")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "\x1b[31mERR\x1b[0m ", loglvl, "Log level should be ERR when no level specified")
|
||||
})
|
||||
|
||||
t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
|
||||
// Must be run before adding the error page to the test server
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
body := rr.Body.String()
|
||||
if body != "" {
|
||||
assert.Error(t, nil)
|
||||
}
|
||||
assert.Empty(t, body, "Error page should not render when no error page is set")
|
||||
})
|
||||
t.Run("Error page renders", func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
// Adding the error page will carry over to all future tests and cant be undone
|
||||
server.AddErrorPage(goodRender)
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
err := server.AddErrorPage(goodRender)
|
||||
require.NoError(t, err)
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
body := rr.Body.String()
|
||||
if body == "" {
|
||||
assert.Error(t, nil)
|
||||
}
|
||||
assert.NotEmpty(t, body, "Error page should render when RenderErrorPage is true")
|
||||
})
|
||||
t.Run("Error page doesnt render if no told to render", func(t *testing.T) {
|
||||
t.Run("Error page doesnt render if not told to render", func(t *testing.T) {
|
||||
// Error page already added to server
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
err := server.ThrowError(rr, req, hws.HWSError{
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
body := rr.Body.String()
|
||||
if body != "" {
|
||||
assert.Error(t, nil)
|
||||
}
|
||||
assert.Empty(t, body, "Error page should not render when RenderErrorPage is false")
|
||||
})
|
||||
server.Shutdown(t.Context())
|
||||
err := server.Shutdown(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Doesn't error if no logger added to server", func(t *testing.T) {
|
||||
t.Run("Doesn't panic if no logger added to server", func(t *testing.T) {
|
||||
server, err := hws.NewServer(&hws.Config{
|
||||
Host: "127.0.0.1",
|
||||
Port: randomPort(),
|
||||
@@ -261,13 +253,18 @@ func Test_ThrowError(t *testing.T) {
|
||||
err = server.Start(t.Context())
|
||||
require.NoError(t, err)
|
||||
<-server.Ready()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
err = server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
// Should not panic when no logger is present
|
||||
assert.NotPanics(t, func() {
|
||||
server.ThrowError(rr, req, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured",
|
||||
Error: errors.New("Error"),
|
||||
})
|
||||
}, "ThrowError should not panic when no logger is present")
|
||||
err = server.Shutdown(t.Context())
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,35 +1,9 @@
|
||||
package hws
|
||||
|
||||
import "runtime"
|
||||
import "git.haelnorr.com/h/golib/ezconf"
|
||||
|
||||
// EZConfIntegration provides integration with ezconf for automatic configuration
|
||||
type EZConfIntegration struct{}
|
||||
|
||||
// PackagePath returns the path to the hws package for source parsing
|
||||
func (e EZConfIntegration) PackagePath() string {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
// Return directory of this file
|
||||
return filename[:len(filename)-len("/ezconf.go")]
|
||||
}
|
||||
|
||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
return ConfigFromEnv()
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name to use when registering with ezconf
|
||||
func (e EZConfIntegration) Name() string {
|
||||
return "hws"
|
||||
}
|
||||
|
||||
// GroupName returns the display name for grouping environment variables
|
||||
func (e EZConfIntegration) GroupName() string {
|
||||
return "HWS"
|
||||
}
|
||||
|
||||
// NewEZConfIntegration creates a new EZConf integration helper
|
||||
func NewEZConfIntegration() EZConfIntegration {
|
||||
return EZConfIntegration{}
|
||||
// NewEZConfIntegration creates a new EZConf integration
|
||||
func NewEZConfIntegration() *ezconf.Integration {
|
||||
return ezconf.NewIntegration("hws", "HWS",
|
||||
&Config{}, func() (any, error) { return ConfigFromEnv() })
|
||||
}
|
||||
|
||||
11
hws/go.mod
11
hws/go.mod
@@ -4,21 +4,24 @@ go 1.25.5
|
||||
|
||||
require (
|
||||
git.haelnorr.com/h/golib/env v0.9.1
|
||||
git.haelnorr.com/h/golib/hlog v0.9.0
|
||||
git.haelnorr.com/h/golib/hlog v0.11.0
|
||||
git.haelnorr.com/h/golib/notify v0.1.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
k8s.io/apimachinery v0.35.0
|
||||
)
|
||||
|
||||
require git.haelnorr.com/h/golib/ezconf v0.2.1
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/gobwas/glob v0.2.3
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rs/zerolog v1.34.0 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
k8s.io/klog/v2 v2.130.1 // indirect
|
||||
k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 // indirect
|
||||
|
||||
17
hws/go.sum
17
hws/go.sum
@@ -1,7 +1,9 @@
|
||||
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
|
||||
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
||||
git.haelnorr.com/h/golib/hlog v0.9.0 h1:ib8n2MdmiRK2TF067p220kXmhDe9aAnlcsgpuv+QpvE=
|
||||
git.haelnorr.com/h/golib/hlog v0.9.0/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk=
|
||||
git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
|
||||
git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
|
||||
git.haelnorr.com/h/golib/hlog v0.11.0 h1:tCT8HWs51Nbin58sCTLcq5re6CZqo5/IHCzk3G+S3vQ=
|
||||
git.haelnorr.com/h/golib/hlog v0.11.0/go.mod h1:HjhXS5G3A0BwOZq7nu2qpNBtvOFiCa1GbAuBRxAkYqs=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
@@ -9,12 +11,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
|
||||
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -26,8 +32,9 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -17,6 +17,10 @@ import (
|
||||
func Test_GZIP_Compression(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
dbg, _ := hlog.LogLevel("debug")
|
||||
logcfg := &hlog.Config{
|
||||
LogLevel: dbg,
|
||||
}
|
||||
t.Run("GZIP enabled compresses response", func(t *testing.T) {
|
||||
server, err := hws.NewServer(&hws.Config{
|
||||
Host: "127.0.0.1",
|
||||
@@ -25,7 +29,7 @@ func Test_GZIP_Compression(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "")
|
||||
logger, err := hlog.NewLogger(logcfg, &buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddLogger(logger)
|
||||
@@ -80,7 +84,7 @@ func Test_GZIP_Compression(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "")
|
||||
logger, err := hlog.NewLogger(logcfg, &buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddLogger(logger)
|
||||
@@ -131,7 +135,7 @@ func Test_GZIP_Compression(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), &buf, nil, "")
|
||||
logger, err := hlog.NewLogger(logcfg, &buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddLogger(logger)
|
||||
@@ -179,20 +183,20 @@ func Test_GzipResponseWriter(t *testing.T) {
|
||||
t.Run("Can write through gzip writer", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&buf)
|
||||
|
||||
|
||||
testData := []byte("Test data to compress")
|
||||
n, err := gzWriter.Write(testData)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(testData), n)
|
||||
|
||||
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Decompress and verify
|
||||
gzReader, err := gzip.NewReader(&buf)
|
||||
require.NoError(t, err)
|
||||
defer gzReader.Close()
|
||||
|
||||
|
||||
decompressed, err := io.ReadAll(gzReader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testData, decompressed)
|
||||
@@ -215,9 +219,9 @@ func Test_GzipResponseWriter(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
|
||||
wrapped.ServeHTTP(rr, req)
|
||||
|
||||
|
||||
// Note: This is a simplified test
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,14 +6,15 @@ import (
|
||||
"net/url"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"github.com/gobwas/glob"
|
||||
)
|
||||
|
||||
type logger struct {
|
||||
logger *hlog.Logger
|
||||
ignoredPaths []string
|
||||
ignoredPaths []glob.Glob
|
||||
}
|
||||
|
||||
// TODO: add tests to make sure all the fields are correctly set
|
||||
// LogError uses the attached logger to log a HWSError
|
||||
func (s *Server) LogError(err HWSError) {
|
||||
if s.logger == nil {
|
||||
return
|
||||
@@ -29,45 +30,34 @@ func (s *Server) LogError(err HWSError) {
|
||||
s.logger.logger.Warn().Err(err.Error).Msg(err.Message)
|
||||
return
|
||||
case ErrorERROR:
|
||||
s.logger.logger.Error().Err(err.Error).Msg(err.Message)
|
||||
s.logger.logger.Error().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
|
||||
return
|
||||
case ErrorFATAL:
|
||||
s.logger.logger.Fatal().Err(err.Error).Msg(err.Message)
|
||||
s.logger.logger.Fatal().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
|
||||
return
|
||||
case ErrorPANIC:
|
||||
s.logger.logger.Panic().Err(err.Error).Msg(err.Message)
|
||||
s.logger.logger.Panic().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
|
||||
return
|
||||
default:
|
||||
s.logger.logger.Error().Err(err.Error).Msg(err.Message)
|
||||
s.logger.logger.Error().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) LogFatal(err error) {
|
||||
if err == nil {
|
||||
err = errors.New("LogFatal was called with a nil error")
|
||||
}
|
||||
if server.logger == nil {
|
||||
fmt.Printf("FATAL - %s: %s", "A fatal error has occured", err.Error())
|
||||
return
|
||||
}
|
||||
server.logger.logger.Fatal().Err(err).Msg("A fatal error has occured")
|
||||
}
|
||||
|
||||
// Server.AddLogger adds a logger to the server to use for request logging.
|
||||
func (server *Server) AddLogger(hlogger *hlog.Logger) error {
|
||||
// AddLogger adds a logger to the server to use for request logging.
|
||||
func (s *Server) AddLogger(hlogger *hlog.Logger) error {
|
||||
if hlogger == nil {
|
||||
return errors.New("Unable to add logger, no logger provided")
|
||||
return errors.New("unable to add logger, no logger provided")
|
||||
}
|
||||
server.logger = &logger{
|
||||
s.logger = &logger{
|
||||
logger: hlogger,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Server.LoggerIgnorePaths sets a list of URL paths to ignore logging for.
|
||||
// LoggerIgnorePaths sets a list of URL paths to ignore logging for.
|
||||
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
|
||||
// Useful for ignoring requests to CSS files or favicons
|
||||
func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
||||
func (s *Server) LoggerIgnorePaths(paths ...string) error {
|
||||
for _, path := range paths {
|
||||
u, err := url.Parse(path)
|
||||
valid := err == nil &&
|
||||
@@ -76,9 +66,22 @@ func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
||||
u.RawQuery == "" &&
|
||||
u.Fragment == ""
|
||||
if !valid {
|
||||
return fmt.Errorf("Invalid path: '%s'", path)
|
||||
return fmt.Errorf("invalid path: '%s'", path)
|
||||
}
|
||||
}
|
||||
server.logger.ignoredPaths = paths
|
||||
s.logger.ignoredPaths = prepareGlobs(paths)
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareGlobs(paths []string) []glob.Glob {
|
||||
compiledGlobs := make([]glob.Glob, 0, len(paths))
|
||||
for _, pattern := range paths {
|
||||
g, err := glob.Compile(pattern)
|
||||
if err != nil {
|
||||
// If pattern fails to compile, skip it
|
||||
continue
|
||||
}
|
||||
compiledGlobs = append(compiledGlobs, g)
|
||||
}
|
||||
return compiledGlobs
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ func Test_AddLogger(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LogError_AllLevels(t *testing.T) {
|
||||
dbg, _ := hlog.LogLevel("debug")
|
||||
logcfg := &hlog.Config{
|
||||
LogLevel: dbg,
|
||||
}
|
||||
t.Run("DEBUG level", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
// Create server with logger explicitly set to Debug level
|
||||
@@ -34,7 +38,7 @@ func Test_LogError_AllLevels(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
logger, err := hlog.NewLogger(hlog.LogLevel("debug"), &buf, nil, "")
|
||||
logger, err := hlog.NewLogger(logcfg, &buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddLogger(logger)
|
||||
@@ -169,7 +173,7 @@ func Test_LogFatal(t *testing.T) {
|
||||
// Note: We cannot actually test Fatal() as it calls os.Exit()
|
||||
// Testing this would require subprocess testing which is overly complex
|
||||
// These tests document the expected behavior and verify the function signatures exist
|
||||
|
||||
|
||||
t.Run("LogFatal with nil logger prints to stdout", func(t *testing.T) {
|
||||
_, err := hws.NewServer(&hws.Config{
|
||||
Host: "127.0.0.1",
|
||||
@@ -197,7 +201,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
||||
|
||||
err := server.LoggerIgnorePaths("http://example.com/path")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid path")
|
||||
assert.Contains(t, err.Error(), "invalid path")
|
||||
})
|
||||
|
||||
t.Run("Invalid path with host", func(t *testing.T) {
|
||||
@@ -207,7 +211,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
||||
err := server.LoggerIgnorePaths("//example.com/path")
|
||||
assert.Error(t, err)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "Invalid path")
|
||||
assert.Contains(t, err.Error(), "invalid path")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -217,7 +221,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
||||
|
||||
err := server.LoggerIgnorePaths("/path?query=value")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid path")
|
||||
assert.Contains(t, err.Error(), "invalid path")
|
||||
})
|
||||
|
||||
t.Run("Invalid path with fragment", func(t *testing.T) {
|
||||
@@ -226,7 +230,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
||||
|
||||
err := server.LoggerIgnorePaths("/path#fragment")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid path")
|
||||
assert.Contains(t, err.Error(), "invalid path")
|
||||
})
|
||||
|
||||
t.Run("Valid paths", func(t *testing.T) {
|
||||
|
||||
@@ -5,35 +5,37 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Middleware func(h http.Handler) http.Handler
|
||||
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
||||
type (
|
||||
Middleware func(h http.Handler) http.Handler
|
||||
MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
||||
)
|
||||
|
||||
// Server.AddMiddleware registers all the middleware.
|
||||
// AddMiddleware registers all the middleware.
|
||||
// Middleware will be run in the order that they are provided.
|
||||
// Can only be called once
|
||||
func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
||||
if !server.routes {
|
||||
func (s *Server) AddMiddleware(middleware ...Middleware) error {
|
||||
if !s.routes {
|
||||
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
|
||||
}
|
||||
if server.middleware {
|
||||
if s.middleware {
|
||||
return errors.New("Server.AddMiddleware already called")
|
||||
}
|
||||
// RUN LOGGING MIDDLEWARE FIRST
|
||||
server.server.Handler = logging(server.server.Handler, server.logger)
|
||||
s.server.Handler = logging(s.server.Handler, s.logger)
|
||||
|
||||
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
|
||||
for i := len(middleware); i > 0; i-- {
|
||||
server.server.Handler = middleware[i-1](server.server.Handler)
|
||||
s.server.Handler = middleware[i-1](s.server.Handler)
|
||||
}
|
||||
|
||||
// RUN GZIP
|
||||
if server.GZIP {
|
||||
server.server.Handler = addgzip(server.server.Handler)
|
||||
if s.GZIP {
|
||||
s.server.Handler = addgzip(s.server.Handler)
|
||||
}
|
||||
// RUN TIMER MIDDLEWARE LAST
|
||||
server.server.Handler = startTimer(server.server.Handler)
|
||||
s.server.Handler = startTimer(s.server.Handler)
|
||||
|
||||
server.middleware = true
|
||||
s.middleware = true
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -43,14 +45,14 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
||||
// and returns a new request and optional HWSError.
|
||||
// If a HWSError is returned, server.ThrowError will be called.
|
||||
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
|
||||
func (server *Server) NewMiddleware(
|
||||
func (s *Server) NewMiddleware(
|
||||
middlewareFunc MiddlewareFunc,
|
||||
) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
newReq, herr := middlewareFunc(w, r)
|
||||
if herr != nil {
|
||||
server.ThrowError(w, r, *herr)
|
||||
s.ThrowError(w, r, *herr)
|
||||
if herr.RenderErrorPage {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,8 +2,9 @@ package hws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
)
|
||||
|
||||
// Middleware to add logs to console with details of the request
|
||||
@@ -13,7 +14,7 @@ func logging(next http.Handler, logger *logger) http.Handler {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if slices.Contains(logger.ignoredPaths, r.URL.Path) {
|
||||
if globTest(r.URL.Path, logger.ignoredPaths) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
@@ -36,3 +37,12 @@ func logging(next http.Handler, logger *logger) http.Handler {
|
||||
Msg("Served")
|
||||
})
|
||||
}
|
||||
|
||||
func globTest(testPath string, globs []glob.Glob) bool {
|
||||
for _, g := range globs {
|
||||
if g.Match(testPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -18,16 +18,24 @@ func startTimer(next http.Handler) http.Handler {
|
||||
)
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return "hws context key " + string(c)
|
||||
}
|
||||
|
||||
var requestTimerCtxKey = contextKey("request-timer")
|
||||
|
||||
// Set the start time of the request
|
||||
func setStart(ctx context.Context, time time.Time) context.Context {
|
||||
return context.WithValue(ctx, "hws context key request-timer", time)
|
||||
return context.WithValue(ctx, requestTimerCtxKey, time)
|
||||
}
|
||||
|
||||
// Get the start time of the request
|
||||
func getStartTime(ctx context.Context) (time.Time, error) {
|
||||
start, ok := ctx.Value("hws context key request-timer").(time.Time)
|
||||
start, ok := ctx.Value(requestTimerCtxKey).(time.Time)
|
||||
if !ok {
|
||||
return time.Time{}, errors.New("Failed to get start time of request")
|
||||
return time.Time{}, errors.New("failed to get start time of request")
|
||||
}
|
||||
return start, nil
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func (s *Server) NotifySub(nt notify.Notification) {
|
||||
}
|
||||
_, exists := s.notifier.clients.getClient(nt.Target)
|
||||
if !exists {
|
||||
err := fmt.Errorf("Tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
|
||||
err := fmt.Errorf("tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
|
||||
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
||||
return
|
||||
}
|
||||
@@ -119,7 +119,7 @@ func (s *Server) NotifyID(nt notify.Notification, altID string) {
|
||||
clients, exists := s.notifier.clients.clientsIDMap[altID]
|
||||
s.notifier.clients.lock.RUnlock()
|
||||
if !exists {
|
||||
err := fmt.Errorf("Tried to notify client group that doesn't exist - altID: %s", altID)
|
||||
err := fmt.Errorf("tried to notify client group that doesn't exist - altID: %s", altID)
|
||||
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,8 +15,9 @@ func newTestServerWithNotifier(t *testing.T) *Server {
|
||||
t.Helper()
|
||||
|
||||
cfg := &Config{
|
||||
Host: "127.0.0.1",
|
||||
Port: 0,
|
||||
Host: "127.0.0.1",
|
||||
Port: 0,
|
||||
ShutdownDelay: 0, // No delay for tests
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg)
|
||||
@@ -359,7 +360,7 @@ func Test_ActiveClientStaysAlive(t *testing.T) {
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
for i := 0; i < 3; i++ {
|
||||
for range 3 {
|
||||
<-ticker.C
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
@@ -460,7 +461,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
|
||||
defer close(stop)
|
||||
|
||||
// Send 10 notifications quickly (buffer is 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
Message: "Burst message",
|
||||
@@ -468,7 +469,7 @@ func Test_SlowConsumerTolerance(t *testing.T) {
|
||||
}
|
||||
|
||||
// Client should receive all 10
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
select {
|
||||
case <-notifications:
|
||||
// Received
|
||||
@@ -487,7 +488,7 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
|
||||
defer close(stop)
|
||||
|
||||
// Fill buffer completely (buffer is 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
Message: "Fill buffer",
|
||||
@@ -500,15 +501,15 @@ func Test_SingleTimeoutRecovery(t *testing.T) {
|
||||
Message: "Timeout message",
|
||||
})
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(6 * time.Second)
|
||||
// Wait for timeout (5s timeout + small buffer)
|
||||
time.Sleep(5100 * time.Millisecond)
|
||||
|
||||
// Check failure count (should be 1)
|
||||
fails := atomic.LoadInt32(&client.consecutiveFails)
|
||||
require.Equal(t, int32(1), fails, "Should have 1 timeout")
|
||||
|
||||
// Now read all buffered messages
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
<-notifications
|
||||
}
|
||||
|
||||
@@ -538,7 +539,7 @@ func Test_ConsecutiveFailureDisconnect(t *testing.T) {
|
||||
defer close(stop)
|
||||
|
||||
// Fill buffer and never read to cause 5 consecutive timeouts
|
||||
for i := 0; i < 20; i++ {
|
||||
for range 20 {
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
Message: "Timeout message",
|
||||
@@ -684,7 +685,7 @@ func Test_ConcurrentSubscriptions(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
clients := make([]*Client, 100)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
for i := range 100 {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
@@ -716,7 +717,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
|
||||
messageCount := 50
|
||||
|
||||
// Send from multiple goroutines
|
||||
for i := 0; i < messageCount; i++ {
|
||||
for i := range messageCount {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
@@ -733,7 +734,7 @@ func Test_ConcurrentNotifications(t *testing.T) {
|
||||
// This is expected behavior - we're testing thread safety, not guaranteed delivery
|
||||
// Just verify we receive at least some messages without panicking or deadlocking
|
||||
received := 0
|
||||
timeout := time.After(2 * time.Second)
|
||||
timeout := time.After(500 * time.Millisecond)
|
||||
for received < messageCount {
|
||||
select {
|
||||
case <-notifications:
|
||||
@@ -751,7 +752,7 @@ func Test_ConcurrentCleanup(t *testing.T) {
|
||||
server := newTestServerWithNotifier(t)
|
||||
|
||||
// Create some clients
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
client, _ := server.GetClient("", "")
|
||||
// Set some to be old
|
||||
if i%2 == 0 {
|
||||
@@ -790,39 +791,34 @@ func Test_NoRaceConditions(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Create a few clients and read from them
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 5 {
|
||||
wg.Go(func() {
|
||||
client, _ := server.GetClient("", "")
|
||||
notifications, stop := client.Listen()
|
||||
defer close(stop)
|
||||
|
||||
// Actively read messages
|
||||
timeout := time.After(2 * time.Second)
|
||||
timeout := time.After(200 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-notifications:
|
||||
// Keep reading
|
||||
// Keep reading
|
||||
case <-timeout:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// Send a few notifications
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 20; j++ {
|
||||
wg.Go(func() {
|
||||
for range 10 {
|
||||
server.NotifyAll(notify.Notification{
|
||||
Message: "Stress test",
|
||||
})
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -948,7 +944,7 @@ func Test_ListenSignature(t *testing.T) {
|
||||
require.NotNil(t, stop)
|
||||
|
||||
// notifications should be receive-only
|
||||
_, ok := interface{}(notifications).(<-chan notify.Notification)
|
||||
_, ok := any(notifications).(<-chan notify.Notification)
|
||||
require.True(t, ok, "notifications should be receive-only channel")
|
||||
|
||||
// stop should be closeable
|
||||
@@ -964,7 +960,7 @@ func Test_BufferSize(t *testing.T) {
|
||||
defer close(stop)
|
||||
|
||||
// Send 10 messages without reading (buffer size is 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
server.NotifySub(notify.Notification{
|
||||
Target: client.sub.ID,
|
||||
Message: "Buffered",
|
||||
@@ -975,7 +971,7 @@ func Test_BufferSize(t *testing.T) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Read all 10
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := range 10 {
|
||||
select {
|
||||
case <-notifications:
|
||||
// Success
|
||||
|
||||
@@ -30,13 +30,13 @@ const (
|
||||
MethodPATCH Method = "PATCH"
|
||||
)
|
||||
|
||||
// Server.AddRoutes registers the page handlers for the server.
|
||||
// AddRoutes registers the page handlers for the server.
|
||||
// At least one route must be provided.
|
||||
// If any route patterns (path + method) are defined multiple times, the first
|
||||
// instance will be added and any additional conflicts will be discarded.
|
||||
func (server *Server) AddRoutes(routes ...Route) error {
|
||||
func (s *Server) AddRoutes(routes ...Route) error {
|
||||
if len(routes) == 0 {
|
||||
return errors.New("No routes provided")
|
||||
return errors.New("no routes provided")
|
||||
}
|
||||
patterns := []string{}
|
||||
mux := http.NewServeMux()
|
||||
@@ -47,10 +47,10 @@ func (server *Server) AddRoutes(routes ...Route) error {
|
||||
}
|
||||
for _, method := range route.Methods {
|
||||
if !validMethod(method) {
|
||||
return fmt.Errorf("Invalid method %s for path %s", method, route.Path)
|
||||
return fmt.Errorf("invalid method %s for path %s", method, route.Path)
|
||||
}
|
||||
if route.Handler == nil {
|
||||
return fmt.Errorf("No handler provided for %s %s", method, route.Path)
|
||||
return fmt.Errorf("no handler provided for %s %s", method, route.Path)
|
||||
}
|
||||
pattern := fmt.Sprintf("%s %s", method, route.Path)
|
||||
if slices.Contains(patterns, pattern) {
|
||||
@@ -61,8 +61,8 @@ func (server *Server) AddRoutes(routes ...Route) error {
|
||||
}
|
||||
}
|
||||
|
||||
server.server.Handler = mux
|
||||
server.routes = true
|
||||
s.server.Handler = mux
|
||||
s.routes = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func Test_AddRoutes(t *testing.T) {
|
||||
server := createTestServer(t, &buf)
|
||||
err := server.AddRoutes()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "No routes provided")
|
||||
assert.Contains(t, err.Error(), "no routes provided")
|
||||
})
|
||||
|
||||
t.Run("Single valid route", func(t *testing.T) {
|
||||
@@ -58,7 +58,7 @@ func Test_AddRoutes(t *testing.T) {
|
||||
Handler: handler,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid method")
|
||||
assert.Contains(t, err.Error(), "invalid method")
|
||||
})
|
||||
|
||||
t.Run("No handler provided", func(t *testing.T) {
|
||||
@@ -69,7 +69,7 @@ func Test_AddRoutes(t *testing.T) {
|
||||
Handler: nil,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "No handler provided")
|
||||
assert.Contains(t, err.Error(), "no handler provided")
|
||||
})
|
||||
|
||||
t.Run("All HTTP methods are valid", func(t *testing.T) {
|
||||
@@ -203,7 +203,7 @@ func Test_AddRoutes_MultipleMethods(t *testing.T) {
|
||||
Handler: handler,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid method")
|
||||
assert.Contains(t, err.Error(), "invalid method")
|
||||
})
|
||||
|
||||
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {
|
||||
|
||||
@@ -26,14 +26,14 @@ type Server struct {
|
||||
}
|
||||
|
||||
// Ready returns a channel that is closed when the server is started
|
||||
func (server *Server) Ready() <-chan struct{} {
|
||||
return server.ready
|
||||
func (s *Server) Ready() <-chan struct{} {
|
||||
return s.ready
|
||||
}
|
||||
|
||||
// IsReady checks if the server is running
|
||||
func (server *Server) IsReady() bool {
|
||||
func (s *Server) IsReady() bool {
|
||||
select {
|
||||
case <-server.ready:
|
||||
case <-s.ready:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -41,13 +41,13 @@ func (server *Server) IsReady() bool {
|
||||
}
|
||||
|
||||
// Addr returns the server's network address
|
||||
func (server *Server) Addr() string {
|
||||
return server.server.Addr
|
||||
func (s *Server) Addr() string {
|
||||
return s.server.Addr
|
||||
}
|
||||
|
||||
// Handler returns the server's HTTP handler for testing purposes
|
||||
func (server *Server) Handler() http.Handler {
|
||||
return server.server.Handler
|
||||
func (s *Server) Handler() http.Handler {
|
||||
return s.server.Handler
|
||||
}
|
||||
|
||||
// NewServer returns a new hws.Server with the specified configuration.
|
||||
@@ -75,7 +75,7 @@ func NewServer(config *Config) (*Server, error) {
|
||||
|
||||
valid := isValidHostname(config.Host)
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host)
|
||||
return nil, fmt.Errorf("hostname '%s' is not valid", config.Host)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
@@ -95,62 +95,64 @@ func NewServer(config *Config) (*Server, error) {
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (server *Server) Start(ctx context.Context) error {
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("Context cannot be nil")
|
||||
}
|
||||
if !server.routes {
|
||||
if !s.routes {
|
||||
return errors.New("Server.AddRoutes must be run before starting the server")
|
||||
}
|
||||
if !server.middleware {
|
||||
err := server.AddMiddleware()
|
||||
if !s.middleware {
|
||||
err := s.AddMiddleware()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "server.AddMiddleware")
|
||||
}
|
||||
}
|
||||
|
||||
server.startNotifier()
|
||||
s.startNotifier()
|
||||
|
||||
go func() {
|
||||
if server.logger == nil {
|
||||
fmt.Printf("Listening for requests on %s", server.server.Addr)
|
||||
if s.logger == nil {
|
||||
fmt.Printf("Listening for requests on %s", s.server.Addr)
|
||||
} else {
|
||||
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
|
||||
s.logger.logger.Info().Str("address", s.server.Addr).Msg("Listening for requests")
|
||||
}
|
||||
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
if server.logger == nil {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
if s.logger == nil {
|
||||
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
||||
} else {
|
||||
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
||||
s.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
server.waitUntilReady(ctx)
|
||||
s.waitUntilReady(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) Shutdown(ctx context.Context) error {
|
||||
server.logger.logger.Debug().Dur("shutdown_delay", server.shutdowndelay).Msg("HWS Server shutting down")
|
||||
server.NotifyAll(notify.Notification{
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if s.logger != nil {
|
||||
s.logger.logger.Debug().Dur("shutdown_delay", s.shutdowndelay).Msg("HWS Server shutting down")
|
||||
}
|
||||
s.NotifyAll(notify.Notification{
|
||||
Title: "Shutting down",
|
||||
Message: fmt.Sprintf("Server is shutting down in %v", server.shutdowndelay),
|
||||
Message: fmt.Sprintf("Server is shutting down in %v", s.shutdowndelay),
|
||||
Level: LevelShutdown,
|
||||
})
|
||||
<-time.NewTimer(server.shutdowndelay).C
|
||||
if !server.IsReady() {
|
||||
<-time.NewTimer(s.shutdowndelay).C
|
||||
if !s.IsReady() {
|
||||
return errors.New("Server isn't running")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("Context cannot be nil")
|
||||
}
|
||||
err := server.server.Shutdown(ctx)
|
||||
err := s.server.Shutdown(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to shutdown the server gracefully")
|
||||
}
|
||||
server.closeNotifier()
|
||||
server.ready = make(chan struct{})
|
||||
s.closeNotifier()
|
||||
s.ready = make(chan struct{})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -168,7 +170,7 @@ func isValidHostname(host string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (server *Server) waitUntilReady(ctx context.Context) error {
|
||||
func (s *Server) waitUntilReady(ctx context.Context) error {
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -180,14 +182,14 @@ func (server *Server) waitUntilReady(ctx context.Context) error {
|
||||
return ctx.Err()
|
||||
|
||||
case <-ticker.C:
|
||||
resp, err := http.Get("http://" + server.server.Addr + "/healthz")
|
||||
resp, err := http.Get("http://" + s.server.Addr + "/healthz")
|
||||
if err != nil {
|
||||
continue // not accepting yet
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
closeOnce.Do(func() { close(server.ready) })
|
||||
closeOnce.Do(func() { close(s.ready) })
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,12 +26,17 @@ func randomPort() uint64 {
|
||||
|
||||
func createTestServer(t *testing.T, w io.Writer) *hws.Server {
|
||||
server, err := hws.NewServer(&hws.Config{
|
||||
Host: "127.0.0.1",
|
||||
Port: randomPort(),
|
||||
Host: "127.0.0.1",
|
||||
Port: randomPort(),
|
||||
ShutdownDelay: 0, // No delay for tests
|
||||
})
|
||||
require.NoError(t, err)
|
||||
dbg, _ := hlog.LogLevel("debug")
|
||||
logcfg := &hlog.Config{
|
||||
LogLevel: dbg,
|
||||
}
|
||||
|
||||
logger, err := hlog.NewLogger(hlog.LogLevel("Debug"), w, nil, "")
|
||||
logger, err := hlog.NewLogger(logcfg, w)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddLogger(logger)
|
||||
@@ -227,5 +232,4 @@ func Test_NewServer(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -16,7 +17,7 @@ type Authenticator[T Model, TX DBTransaction] struct {
|
||||
tokenGenerator *jwt.TokenGenerator
|
||||
load LoadFunc[T, TX]
|
||||
beginTx BeginTX
|
||||
ignoredPaths []string
|
||||
ignoredPaths []glob.Glob
|
||||
logger *hlog.Logger
|
||||
server *hws.Server
|
||||
errorPage hws.ErrorPageFunc
|
||||
|
||||
@@ -9,16 +9,16 @@ import (
|
||||
// Config holds the configuration settings for the authenticator.
|
||||
// All time-based settings are in minutes.
|
||||
type Config struct {
|
||||
SSL bool // ENV HWSAUTH_SSL: Enable SSL secure cookies (default: false)
|
||||
TrustedHost string // ENV HWSAUTH_TRUSTED_HOST: Full server address for SSL (required if SSL is true)
|
||||
SecretKey string // ENV HWSAUTH_SECRET_KEY: Secret key for signing JWT tokens (required)
|
||||
AccessTokenExpiry int64 // ENV HWSAUTH_ACCESS_TOKEN_EXPIRY: Access token expiry in minutes (default: 5)
|
||||
RefreshTokenExpiry int64 // ENV HWSAUTH_REFRESH_TOKEN_EXPIRY: Refresh token expiry in minutes (default: 1440)
|
||||
TokenFreshTime int64 // ENV HWSAUTH_TOKEN_FRESH_TIME: Token fresh time in minutes (default: 5)
|
||||
LandingPage string // ENV HWSAUTH_LANDING_PAGE: Redirect destination for authenticated users (default: "/profile")
|
||||
DatabaseType string // ENV HWSAUTH_DATABASE_TYPE: Database type (postgres, mysql, sqlite, mariadb) (default: "postgres")
|
||||
DatabaseVersion string // ENV HWSAUTH_DATABASE_VERSION: Database version string (default: "15")
|
||||
JWTTableName string // ENV HWSAUTH_JWT_TABLE_NAME: Custom JWT blacklist table name (default: "jwtblacklist")
|
||||
SSL bool `ezconf:"HWSAUTH_SSL,description:Enable SSL secure cookies,default:false"`
|
||||
TrustedHost string `ezconf:"HWSAUTH_TRUSTED_HOST,description:Full server address for SSL,required:if SSL is true"`
|
||||
SecretKey string `ezconf:"HWSAUTH_SECRET_KEY,description:Secret key for signing JWT tokens,required"`
|
||||
AccessTokenExpiry int64 `ezconf:"HWSAUTH_ACCESS_TOKEN_EXPIRY,description:Access token expiry in minutes,default:5"`
|
||||
RefreshTokenExpiry int64 `ezconf:"HWSAUTH_REFRESH_TOKEN_EXPIRY,description:Refresh token expiry in minutes,default:1440"`
|
||||
TokenFreshTime int64 `ezconf:"HWSAUTH_TOKEN_FRESH_TIME,description:Token fresh time in minutes,default:5"`
|
||||
LandingPage string `ezconf:"HWSAUTH_LANDING_PAGE,description:Redirect destination for authenticated users,default:/profile"`
|
||||
DatabaseType string `ezconf:"HWSAUTH_DATABASE_TYPE,description:Database type (postgres mysql sqlite mariadb),default:postgres"`
|
||||
DatabaseVersion string `ezconf:"HWSAUTH_DATABASE_VERSION,description:Database version string,default:15"`
|
||||
JWTTableName string `ezconf:"HWSAUTH_JWT_TABLE_NAME,description:Custom JWT blacklist table name,default:jwtblacklist"`
|
||||
}
|
||||
|
||||
// ConfigFromEnv loads configuration from environment variables.
|
||||
|
||||
@@ -1,35 +1,9 @@
|
||||
package hwsauth
|
||||
|
||||
import "runtime"
|
||||
|
||||
// EZConfIntegration provides integration with ezconf for automatic configuration
|
||||
type EZConfIntegration struct{}
|
||||
|
||||
// PackagePath returns the path to the hwsauth package for source parsing
|
||||
func (e EZConfIntegration) PackagePath() string {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
// Return directory of this file
|
||||
return filename[:len(filename)-len("/ezconf.go")]
|
||||
}
|
||||
|
||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
return ConfigFromEnv()
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name to use when registering with ezconf
|
||||
func (e EZConfIntegration) Name() string {
|
||||
return "hwsauth"
|
||||
}
|
||||
|
||||
// GroupName returns the display name for grouping environment variables
|
||||
func (e EZConfIntegration) GroupName() string {
|
||||
return "HWSAuth"
|
||||
}
|
||||
import "git.haelnorr.com/h/golib/ezconf"
|
||||
|
||||
// NewEZConfIntegration creates a new EZConf integration helper
|
||||
func NewEZConfIntegration() EZConfIntegration {
|
||||
return EZConfIntegration{}
|
||||
func NewEZConfIntegration() *ezconf.Integration {
|
||||
return ezconf.NewIntegration("hwsauth", "HWSAuth", &Config{},
|
||||
func() (any, error) { return ConfigFromEnv() })
|
||||
}
|
||||
|
||||
@@ -5,24 +5,28 @@ go 1.25.5
|
||||
require (
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||
git.haelnorr.com/h/golib/env v0.9.1
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4
|
||||
git.haelnorr.com/h/golib/hws v0.3.0
|
||||
git.haelnorr.com/h/golib/ezconf v0.2.1
|
||||
git.haelnorr.com/h/golib/hlog v0.11.0
|
||||
git.haelnorr.com/h/golib/hws v0.5.0
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
)
|
||||
|
||||
require git.haelnorr.com/h/golib/notify v0.1.0 // indirect
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/gobwas/glob v0.2.3
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rs/zerolog v1.34.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
k8s.io/apimachinery v0.35.0 // indirect
|
||||
k8s.io/klog/v2 v2.130.1 // indirect
|
||||
|
||||
@@ -2,12 +2,16 @@ git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDO
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
|
||||
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
|
||||
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
|
||||
git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo=
|
||||
git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
|
||||
git.haelnorr.com/h/golib/ezconf v0.2.1 h1:axMyKtgO9Zk6E8CrYrLpMzifvpjz73yxCQq0lOtuhck=
|
||||
git.haelnorr.com/h/golib/ezconf v0.2.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
|
||||
git.haelnorr.com/h/golib/hlog v0.11.0 h1:tCT8HWs51Nbin58sCTLcq5re6CZqo5/IHCzk3G+S3vQ=
|
||||
git.haelnorr.com/h/golib/hlog v0.11.0/go.mod h1:HjhXS5G3A0BwOZq7nu2qpNBtvOFiCa1GbAuBRxAkYqs=
|
||||
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
|
||||
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
@@ -15,6 +19,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
|
||||
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
@@ -40,8 +46,8 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -23,8 +23,7 @@ func (tm TestModel) GetID() int {
|
||||
return tm.ID
|
||||
}
|
||||
|
||||
type TestTransaction struct {
|
||||
}
|
||||
type TestTransaction struct{}
|
||||
|
||||
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return nil, nil
|
||||
@@ -137,8 +136,10 @@ func TestCurrentModel(t *testing.T) {
|
||||
func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
|
||||
// Clear environment variables
|
||||
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
|
||||
os.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
||||
_ = os.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
defer func() {
|
||||
_ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
||||
}()
|
||||
|
||||
_, err := ConfigFromEnv()
|
||||
assert.Error(t, err)
|
||||
@@ -327,7 +328,9 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
|
||||
|
||||
db, _, err := createMockDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
@@ -409,7 +412,9 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
|
||||
|
||||
db, _, err := createMockDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
@@ -454,7 +459,9 @@ func TestLogin_BasicFunctionality(t *testing.T) {
|
||||
|
||||
db, _, err := createMockDB()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
@@ -476,6 +483,7 @@ func TestLogin_BasicFunctionality(t *testing.T) {
|
||||
// This test mainly checks that the function doesn't panic and has right call signature
|
||||
// The actual JWT functionality is tested in jwt package itself
|
||||
assert.NotPanics(t, func() {
|
||||
auth.Login(w, r, user, rememberMe)
|
||||
err := auth.Login(w, r, user, rememberMe)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package hwsauth
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
)
|
||||
|
||||
// IgnorePaths excludes specified paths from authentication middleware.
|
||||
@@ -22,9 +24,22 @@ func (auth *Authenticator[T, TX]) IgnorePaths(paths ...string) error {
|
||||
u.RawQuery == "" &&
|
||||
u.Fragment == ""
|
||||
if !valid {
|
||||
return fmt.Errorf("Invalid path: '%s'", path)
|
||||
return fmt.Errorf("invalid path: '%s'", path)
|
||||
}
|
||||
}
|
||||
auth.ignoredPaths = paths
|
||||
auth.ignoredPaths = prepareGlobs(paths)
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareGlobs(paths []string) []glob.Glob {
|
||||
compiledGlobs := make([]glob.Glob, 0, len(paths))
|
||||
for _, pattern := range paths {
|
||||
g, err := glob.Compile(pattern)
|
||||
if err != nil {
|
||||
// If pattern fails to compile, skip it
|
||||
continue
|
||||
}
|
||||
compiledGlobs = append(compiledGlobs, g)
|
||||
}
|
||||
return compiledGlobs
|
||||
}
|
||||
|
||||
@@ -33,13 +33,17 @@ func (auth *Authenticator[T, TX]) Logout(tx TX, w http.ResponseWriter, r *http.R
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "auth.getTokens")
|
||||
}
|
||||
err = aT.Revoke(jwt.DBTransaction(tx))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "aT.Revoke")
|
||||
if aT != nil {
|
||||
err = aT.Revoke(jwt.DBTransaction(tx))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "aT.Revoke")
|
||||
}
|
||||
}
|
||||
err = rT.Revoke(jwt.DBTransaction(tx))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rT.Revoke")
|
||||
if rT != nil {
|
||||
err = rT.Revoke(jwt.DBTransaction(tx))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rT.Revoke")
|
||||
}
|
||||
}
|
||||
cookies.DeleteCookie(w, "access", "/")
|
||||
cookies.DeleteCookie(w, "refresh", "/")
|
||||
|
||||
@@ -2,10 +2,12 @@ package hwsauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Authenticate returns the main authentication middleware.
|
||||
@@ -14,14 +16,22 @@ import (
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// server.AddMiddleware(auth.Authenticate())
|
||||
func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware {
|
||||
return auth.server.NewMiddleware(auth.authenticate())
|
||||
// server.AddMiddleware(auth.Authenticate(nil))
|
||||
//
|
||||
// If extraCheck is provided, it will run just before the user is added to the context,
|
||||
// and the return will determine if the user will be added, or the request passed on
|
||||
// without the user.
|
||||
func (auth *Authenticator[T, TX]) Authenticate(
|
||||
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
|
||||
) hws.Middleware {
|
||||
return auth.server.NewMiddleware(auth.authenticate(extraCheck))
|
||||
}
|
||||
|
||||
func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
|
||||
func (auth *Authenticator[T, TX]) authenticate(
|
||||
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
|
||||
) hws.MiddlewareFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
|
||||
if slices.Contains(auth.ignoredPaths, r.URL.Path) {
|
||||
if globTest(r.URL.Path, auth.ignoredPaths) {
|
||||
return r, nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
@@ -30,25 +40,70 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
|
||||
// Start the transaction
|
||||
tx, err := auth.beginTx(ctx)
|
||||
if err != nil {
|
||||
return nil, &hws.HWSError{Message: "Unable to start transaction", StatusCode: http.StatusServiceUnavailable, Error: err}
|
||||
return nil, &hws.HWSError{
|
||||
Message: "Unable to start transaction",
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Error: errors.Wrap(err, "auth.beginTx"),
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
// Type assert to TX - safe because user's beginTx should return their TX type
|
||||
txTyped, ok := tx.(TX)
|
||||
if !ok {
|
||||
return nil, &hws.HWSError{Message: "Transaction type mismatch", StatusCode: http.StatusInternalServerError, Error: err}
|
||||
return nil, &hws.HWSError{
|
||||
Message: "Transaction type mismatch",
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: errors.Wrap(err, "TX type not ok"),
|
||||
}
|
||||
}
|
||||
model, err := auth.getAuthenticatedUser(txTyped, w, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
rberr := tx.Rollback()
|
||||
if rberr != nil {
|
||||
return nil, &hws.HWSError{
|
||||
Message: "Failed rolling back after error",
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: errors.Wrap(err, "tx.Rollback"),
|
||||
}
|
||||
}
|
||||
auth.logger.Debug().
|
||||
Str("remote_addr", r.RemoteAddr).
|
||||
Err(err).
|
||||
Msg("Failed to authenticate user")
|
||||
return r, nil
|
||||
}
|
||||
tx.Commit()
|
||||
var check bool
|
||||
if extraCheck != nil {
|
||||
var err *hws.HWSError
|
||||
check, err = extraCheck(ctx, model.model, txTyped, w, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, &hws.HWSError{
|
||||
Message: "Failed to commit transaction",
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: errors.Wrap(err, "tx.Commit"),
|
||||
}
|
||||
}
|
||||
authContext := setAuthenticatedModel(r.Context(), model)
|
||||
newReq := r.WithContext(authContext)
|
||||
return newReq, nil
|
||||
if extraCheck == nil || check {
|
||||
return newReq, nil
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
func globTest(testPath string, globs []glob.Glob) bool {
|
||||
for _, g := range globs {
|
||||
if g.Match(testPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -39,9 +39,17 @@ type ContextLoader[T Model] func(ctx context.Context) T
|
||||
// }
|
||||
type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error)
|
||||
|
||||
type contextKey string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return "hwsauth context key" + string(c)
|
||||
}
|
||||
|
||||
var authenticatedModelContextKey = contextKey("authenticated-model")
|
||||
|
||||
// Return a new context with the user added in
|
||||
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
|
||||
return context.WithValue(ctx, "hwsauth context key authenticated-model", m)
|
||||
return context.WithValue(ctx, authenticatedModelContextKey, m)
|
||||
}
|
||||
|
||||
// Retrieve a user from the given context. Returns nil if not set
|
||||
@@ -53,7 +61,7 @@ func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[
|
||||
model = authenticatedModel[T]{}
|
||||
}
|
||||
}()
|
||||
model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T])
|
||||
model, cok := ctx.Value(authenticatedModelContextKey).(authenticatedModel[T])
|
||||
if !cok {
|
||||
return authenticatedModel[T]{}, false
|
||||
}
|
||||
|
||||
@@ -19,15 +19,12 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := getAuthorizedModel[T](r.Context())
|
||||
if !ok {
|
||||
err := auth.server.ThrowError(w, r, hws.HWSError{
|
||||
auth.server.ThrowError(w, r, hws.HWSError{
|
||||
Error: errors.New("Login required"),
|
||||
Message: "Please login to view this page",
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
auth.server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -66,15 +63,12 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
model, ok := getAuthorizedModel[T](r.Context())
|
||||
if !ok {
|
||||
err := auth.server.ThrowError(w, r, hws.HWSError{
|
||||
auth.server.ThrowError(w, r, hws.HWSError{
|
||||
Error: errors.New("Login required"),
|
||||
Message: "Please login to view this page",
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
auth.server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
isFresh := time.Now().Before(time.Unix(model.fresh, 0))
|
||||
|
||||
@@ -34,7 +34,7 @@ func (auth *Authenticator[T, TX]) RefreshAuthTokens(tx TX, w http.ResponseWriter
|
||||
rememberMe := map[string]bool{
|
||||
"session": false,
|
||||
"exp": true,
|
||||
}[aT.TTL]
|
||||
}[rT.TTL]
|
||||
// issue new tokens for the user
|
||||
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
|
||||
if err != nil {
|
||||
@@ -55,13 +55,20 @@ func (auth *Authenticator[T, TX]) getTokens(
|
||||
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
||||
// get the existing tokens from the cookies
|
||||
atStr, rtStr := jwt.GetTokenCookies(r)
|
||||
aT, err := auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
|
||||
var aT *jwt.AccessToken
|
||||
var rT *jwt.RefreshToken
|
||||
var err error
|
||||
if atStr != "" {
|
||||
aT, err = auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
|
||||
}
|
||||
}
|
||||
rT, err := auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
|
||||
if rtStr != "" {
|
||||
rT, err = auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
|
||||
}
|
||||
}
|
||||
return aT, rT, nil
|
||||
}
|
||||
@@ -72,13 +79,17 @@ func revokeTokenPair(
|
||||
aT *jwt.AccessToken,
|
||||
rT *jwt.RefreshToken,
|
||||
) error {
|
||||
err := aT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "aT.Revoke")
|
||||
if aT != nil {
|
||||
err := aT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "aT.Revoke")
|
||||
}
|
||||
}
|
||||
err = rT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rT.Revoke")
|
||||
if rT != nil {
|
||||
err := rT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rT.Revoke")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user