Compare commits
12 Commits
notify/v0.
...
hwsauth/v0
| Author | SHA1 | Date | |
|---|---|---|---|
| 05be28d7f3 | |||
| 8f7c87cef2 | |||
| 525b3b1396 | |||
| 563908bbb4 | |||
| 95a17597cf | |||
| cd29f11296 | |||
| 7ed40c7afe | |||
| 596a4c0529 | |||
| ed3bc4afb0 | |||
| 2c9de70018 | |||
| 965721bd89 | |||
| 5781aa523c |
173
AGENTS.md
Normal file
173
AGENTS.md
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
# AGENTS.md - Coding Agent Guidelines for golib
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
This is a Go library repository containing multiple independent packages:
|
||||||
|
- **cookies**: HTTP cookie utilities
|
||||||
|
- **env**: Environment variable helpers
|
||||||
|
- **ezconf**: Configuration loader with ENV parsing
|
||||||
|
- **hlog**: Logging with zerolog
|
||||||
|
- **hws**: HTTP web server
|
||||||
|
- **hwsauth**: Authentication middleware for hws
|
||||||
|
- **jwt**: JWT token generation and validation
|
||||||
|
- **tmdb**: The Movie Database API client
|
||||||
|
|
||||||
|
Each package has its own `go.mod` and can be used independently.
|
||||||
|
|
||||||
|
## Interactive Questions
|
||||||
|
All questions in plan mode should use the opencode interactive question prompter for user interaction.
|
||||||
|
|
||||||
|
## Build, Test, and Lint Commands
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
```bash
|
||||||
|
# Test all packages from repo root
|
||||||
|
go test ./...
|
||||||
|
|
||||||
|
# Test a specific package
|
||||||
|
cd <package> && go test
|
||||||
|
|
||||||
|
# Run a single test function
|
||||||
|
cd <package> && go test -run TestFunctionName
|
||||||
|
|
||||||
|
# Run tests with verbose output
|
||||||
|
cd <package> && go test -v
|
||||||
|
|
||||||
|
# Run tests matching a pattern
|
||||||
|
cd <package> && go test -run "TestName.*"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Building
|
||||||
|
```bash
|
||||||
|
# Each package is a library - no build needed
|
||||||
|
# Verify code compiles:
|
||||||
|
go build ./...
|
||||||
|
|
||||||
|
# Or for specific package:
|
||||||
|
cd <package> && go build
|
||||||
|
```
|
||||||
|
|
||||||
|
### Linting
|
||||||
|
```bash
|
||||||
|
# Use standard go tools
|
||||||
|
go vet ./...
|
||||||
|
go fmt ./...
|
||||||
|
|
||||||
|
# Check formatting without changing files
|
||||||
|
gofmt -l .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code Style Guidelines
|
||||||
|
|
||||||
|
### Package Structure
|
||||||
|
- Each package must have its own `go.mod` with module path: `git.haelnorr.com/h/golib/<package>`
|
||||||
|
- Go version should be current (1.23.4+)
|
||||||
|
- Each package should have a `doc.go` file with package documentation
|
||||||
|
|
||||||
|
### Imports
|
||||||
|
- Use standard library imports first
|
||||||
|
- Then third-party imports
|
||||||
|
- Then local imports from this repo (e.g., `git.haelnorr.com/h/golib/hlog`)
|
||||||
|
- Group imports with blank lines between groups
|
||||||
|
- Example:
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Formatting
|
||||||
|
- Use `gofmt` standard formatting
|
||||||
|
- No tabs for alignment, use spaces inside structs
|
||||||
|
- Line length: no hard limit, but prefer readability
|
||||||
|
|
||||||
|
### Types
|
||||||
|
- Use explicit types for struct fields
|
||||||
|
- Config structs must have ENV comments (see below)
|
||||||
|
- Prefer named return values for complex functions
|
||||||
|
- Use generics where appropriate (see `hwsauth.Authenticator[T Model, TX DBTransaction]`)
|
||||||
|
|
||||||
|
### Naming Conventions
|
||||||
|
- Packages: lowercase, single word (e.g., `cookies`, `ezconf`)
|
||||||
|
- Exported functions: PascalCase (e.g., `NewServer`, `ConfigFromEnv`)
|
||||||
|
- Unexported functions: camelCase (e.g., `isValidHostname`, `waitUntilReady`)
|
||||||
|
- Test functions: `Test<FunctionName>` or `Test<FunctionName>_<Case>` (underscore for sub-cases)
|
||||||
|
- Variables: camelCase, descriptive names
|
||||||
|
- Constants: PascalCase or UPPER_CASE depending on scope
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- Use `github.com/pkg/errors` for error wrapping
|
||||||
|
- Wrap errors with context: `errors.Wrap(err, "context message")`
|
||||||
|
- Return errors, don't panic (except in truly exceptional cases)
|
||||||
|
- Validate inputs and return descriptive errors
|
||||||
|
- Example:
|
||||||
|
```go
|
||||||
|
if config == nil {
|
||||||
|
return nil, errors.New("Config cannot be nil")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Pattern
|
||||||
|
- Each package with config should have a `Config` struct
|
||||||
|
- Provide `ConfigFromEnv() (*Config, error)` function
|
||||||
|
- ENV comment format for Config struct fields:
|
||||||
|
```go
|
||||||
|
type Config struct {
|
||||||
|
Host string // ENV HWS_HOST: Host to listen on (default: 127.0.0.1)
|
||||||
|
Port uint64 // ENV HWS_PORT: Port to listen on (default: 3000)
|
||||||
|
SSL bool // ENV HWS_SSL: Enable SSL (required when using production)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- Format: `// ENV ENV_NAME: Description (required <condition>) (default: <value>)`
|
||||||
|
- Include "required" only if no default
|
||||||
|
- Include "default" only if one exists
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
- Use `testing` package from standard library
|
||||||
|
- Use `github.com/stretchr/testify` for assertions (`require`, `assert`)
|
||||||
|
- Table-driven tests for multiple cases:
|
||||||
|
```go
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"valid case", "input", false},
|
||||||
|
{"error case", "", true},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- Test files use `<package>_test` for black-box tests or `<package>` for white-box
|
||||||
|
- Helper functions should use `t.Helper()`
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- All exported functions, types, and constants must have godoc comments
|
||||||
|
- Comments should start with the name being documented
|
||||||
|
- Example: `// NewServer returns a new hws.Server with the specified configuration.`
|
||||||
|
- Keep doc.go files up to date with package overview
|
||||||
|
- Follow RULES.md for README and wiki documentation
|
||||||
|
|
||||||
|
## Version Control (from RULES.md)
|
||||||
|
- Do NOT make changes to master branch
|
||||||
|
- Checkout a branch for new features
|
||||||
|
- Version numbers use git tags - do NOT change manually
|
||||||
|
- When updating docs, append branch name to version
|
||||||
|
- Changes to golib-wiki repo should use same branch name
|
||||||
|
|
||||||
|
## Testing Requirements (from RULES.md)
|
||||||
|
- All features MUST have tests
|
||||||
|
- Update existing tests when modifying features
|
||||||
|
- New features require new tests
|
||||||
|
|
||||||
|
## Documentation Requirements (from RULES.md)
|
||||||
|
- Document via: docstrings, README.md, doc.go, wiki
|
||||||
|
- README structure: Title+version, Features (NO EMOTICONS), Installation, Quick Start, Docs links, Additional info, License, Contributing, Related projects
|
||||||
|
- Wiki location: `~/projects/golib-wiki`
|
||||||
|
- Docstrings must conform to godoc standards
|
||||||
|
|
||||||
|
## License
|
||||||
|
- All modules use MIT License
|
||||||
3
RULES.md
3
RULES.md
@@ -45,3 +45,6 @@ Do not make any changes to master. Checkout a branch to work on new features
|
|||||||
Version numbers are specified using git tags.
|
Version numbers are specified using git tags.
|
||||||
Do not change version numbers. When updating documentation, append the branch name to the version number.
|
Do not change version numbers. When updating documentation, append the branch name to the version number.
|
||||||
Changes made to the golib-wiki repo should be made under the same branch name as the changes made in this repo
|
Changes made to the golib-wiki repo should be made under the same branch name as the changes made in this repo
|
||||||
|
|
||||||
|
4. Licencing
|
||||||
|
All modules should have an MIT License
|
||||||
|
|||||||
21
cookies/LICENSE
Normal file
21
cookies/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2026 haelnorr
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
61
cookies/README.md
Normal file
61
cookies/README.md
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# cookies v1.0.0
|
||||||
|
|
||||||
|
HTTP cookie utilities for Go web applications with security best practices.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Secure cookie setting with HttpOnly flag
|
||||||
|
- Cookie deletion with proper expiration
|
||||||
|
- Pagefrom tracking for post-login redirects
|
||||||
|
- Host validation for referer-based redirects
|
||||||
|
- Full test coverage
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get git.haelnorr.com/h/golib/cookies
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Set a secure cookie
|
||||||
|
cookies.SetCookie(w, "session", "/", "abc123", 3600)
|
||||||
|
|
||||||
|
// Delete a cookie
|
||||||
|
cookies.DeleteCookie(w, "old_session", "/")
|
||||||
|
|
||||||
|
// Handle pagefrom for redirects
|
||||||
|
if r.URL.Path == "/login" {
|
||||||
|
cookies.SetPageFrom(w, r, "example.com")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check pagefrom after login
|
||||||
|
redirectTo := cookies.CheckPageFrom(w, r)
|
||||||
|
http.Redirect(w, r, redirectTo, http.StatusFound)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
See the [wiki documentation](../golib/wiki/cookies.md) for detailed usage information and examples.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Please see the main golib repository for contributing guidelines.
|
||||||
|
|
||||||
|
## Related Projects
|
||||||
|
|
||||||
|
This package is part of the golib collection of utilities for Go applications and integrates well with other golib packages.
|
||||||
405
cookies/cookies_test.go
Normal file
405
cookies/cookies_test.go
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
package cookies
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetCookie(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cookie string
|
||||||
|
path string
|
||||||
|
value string
|
||||||
|
maxAge int
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic cookie",
|
||||||
|
cookie: "test",
|
||||||
|
path: "/",
|
||||||
|
value: "value",
|
||||||
|
maxAge: 3600,
|
||||||
|
expected: "test=value; Path=/; Max-Age=3600; HttpOnly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero max age",
|
||||||
|
cookie: "session",
|
||||||
|
path: "/api",
|
||||||
|
value: "abc123",
|
||||||
|
maxAge: 0,
|
||||||
|
expected: "session=abc123; Path=/api; HttpOnly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative max age",
|
||||||
|
cookie: "temp",
|
||||||
|
path: "/",
|
||||||
|
value: "temp",
|
||||||
|
maxAge: -1,
|
||||||
|
expected: "temp=temp; Path=/; Max-Age=0; HttpOnly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty value",
|
||||||
|
cookie: "empty",
|
||||||
|
path: "/",
|
||||||
|
value: "",
|
||||||
|
maxAge: 3600,
|
||||||
|
expected: "empty=; Path=/; Max-Age=3600; HttpOnly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters in value",
|
||||||
|
cookie: "data",
|
||||||
|
path: "/",
|
||||||
|
value: "test@123!#$%",
|
||||||
|
maxAge: 7200,
|
||||||
|
expected: "data=test@123!#$%; Path=/; Max-Age=7200; HttpOnly",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
SetCookie(w, tt.cookie, tt.path, tt.value, tt.maxAge)
|
||||||
|
|
||||||
|
headers := w.Header()["Set-Cookie"]
|
||||||
|
if len(headers) != 1 {
|
||||||
|
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the cookie header to check individual components
|
||||||
|
cookieHeader := headers[0]
|
||||||
|
|
||||||
|
// Check that all expected components are present
|
||||||
|
if !strings.Contains(cookieHeader, tt.cookie+"="+tt.value) {
|
||||||
|
t.Errorf("Expected cookie name/value not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
if !strings.Contains(cookieHeader, "Path="+tt.path) {
|
||||||
|
t.Errorf("Expected path not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
if !strings.Contains(cookieHeader, "HttpOnly") {
|
||||||
|
t.Errorf("Expected HttpOnly not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
if tt.maxAge != 0 {
|
||||||
|
expectedMaxAge := fmt.Sprintf("Max-Age=%d", tt.maxAge)
|
||||||
|
if tt.maxAge < 0 {
|
||||||
|
expectedMaxAge = "Max-Age=0" // Go normalizes negative Max-Age to 0
|
||||||
|
}
|
||||||
|
if !strings.Contains(cookieHeader, expectedMaxAge) {
|
||||||
|
t.Errorf("Expected Max-Age not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteCookie(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cookie string
|
||||||
|
path string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic deletion",
|
||||||
|
cookie: "test",
|
||||||
|
path: "/",
|
||||||
|
expected: "test=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; HttpOnly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete with specific path",
|
||||||
|
cookie: "session",
|
||||||
|
path: "/api",
|
||||||
|
expected: "session=; Path=/api; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; HttpOnly",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
DeleteCookie(w, tt.cookie, tt.path)
|
||||||
|
|
||||||
|
headers := w.Header()["Set-Cookie"]
|
||||||
|
if len(headers) != 1 {
|
||||||
|
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cookieHeader := headers[0]
|
||||||
|
|
||||||
|
// Check deletion-specific components
|
||||||
|
if !strings.Contains(cookieHeader, tt.cookie+"=") {
|
||||||
|
t.Errorf("Expected cookie name not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
if !strings.Contains(cookieHeader, "Path="+tt.path) {
|
||||||
|
t.Errorf("Expected path not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
if !strings.Contains(cookieHeader, "Max-Age=0") {
|
||||||
|
t.Errorf("Expected Max-Age=0 not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
if !strings.Contains(cookieHeader, "Expires=") {
|
||||||
|
t.Errorf("Expected Expires not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
if !strings.Contains(cookieHeader, "HttpOnly") {
|
||||||
|
t.Errorf("Expected HttpOnly not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPageFrom(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cookieValue string
|
||||||
|
cookiePath string
|
||||||
|
expectedResult string
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid pagefrom cookie",
|
||||||
|
cookieValue: "/dashboard",
|
||||||
|
cookiePath: "/",
|
||||||
|
expectedResult: "/dashboard",
|
||||||
|
shouldSet: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no pagefrom cookie",
|
||||||
|
cookieValue: "",
|
||||||
|
cookiePath: "",
|
||||||
|
expectedResult: "/",
|
||||||
|
shouldSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty pagefrom cookie",
|
||||||
|
cookieValue: "",
|
||||||
|
cookiePath: "/",
|
||||||
|
expectedResult: "",
|
||||||
|
shouldSet: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pagefrom with query params",
|
||||||
|
cookieValue: "/search?q=test",
|
||||||
|
cookiePath: "/",
|
||||||
|
expectedResult: "/search?q=test",
|
||||||
|
shouldSet: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pagefrom with special path",
|
||||||
|
cookieValue: "/api/v1/users",
|
||||||
|
cookiePath: "/api",
|
||||||
|
expectedResult: "/api/v1/users",
|
||||||
|
shouldSet: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := &http.Request{
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.shouldSet {
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: "pagefrom",
|
||||||
|
Value: tt.cookieValue,
|
||||||
|
Path: tt.cookiePath,
|
||||||
|
}
|
||||||
|
r.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := CheckPageFrom(w, r)
|
||||||
|
|
||||||
|
if result != tt.expectedResult {
|
||||||
|
t.Errorf("CheckPageFrom() = %v, want %v", result, tt.expectedResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the cookie was deleted
|
||||||
|
if tt.shouldSet {
|
||||||
|
headers := w.Header()["Set-Cookie"]
|
||||||
|
if len(headers) != 1 {
|
||||||
|
t.Errorf("Expected 1 Set-Cookie header for deletion, got %d", len(headers))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cookieHeader := headers[0]
|
||||||
|
if !strings.Contains(cookieHeader, "pagefrom=") {
|
||||||
|
t.Errorf("Expected pagefrom cookie deletion not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
if !strings.Contains(cookieHeader, "Max-Age=0") {
|
||||||
|
t.Errorf("Expected Max-Age=0 for deletion not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetPageFrom(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
referer string
|
||||||
|
trustedHost string
|
||||||
|
expectedSet bool
|
||||||
|
expectedValue string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid trusted host referer",
|
||||||
|
referer: "http://example.com/dashboard",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: true,
|
||||||
|
expectedValue: "/dashboard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid trusted host with https",
|
||||||
|
referer: "https://example.com/profile",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: true,
|
||||||
|
expectedValue: "/profile",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "untrusted host",
|
||||||
|
referer: "http://evil.com/dashboard",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: true,
|
||||||
|
expectedValue: "/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty path",
|
||||||
|
referer: "http://example.com",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: true,
|
||||||
|
expectedValue: "/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "login path - should not set",
|
||||||
|
referer: "http://example.com/login",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: false,
|
||||||
|
expectedValue: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "register path - should not set",
|
||||||
|
referer: "http://example.com/register",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: false,
|
||||||
|
expectedValue: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid referer URL",
|
||||||
|
referer: "not-a-url",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: true,
|
||||||
|
expectedValue: "/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty referer",
|
||||||
|
referer: "",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: true,
|
||||||
|
expectedValue: "/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root path",
|
||||||
|
referer: "http://example.com/",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: true,
|
||||||
|
expectedValue: "/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path with query string",
|
||||||
|
referer: "http://example.com/search?q=test",
|
||||||
|
trustedHost: "example.com",
|
||||||
|
expectedSet: true,
|
||||||
|
expectedValue: "/search",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := &http.Request{
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
if tt.referer != "" {
|
||||||
|
r.Header.Set("Referer", tt.referer)
|
||||||
|
}
|
||||||
|
|
||||||
|
SetPageFrom(w, r, tt.trustedHost)
|
||||||
|
|
||||||
|
headers := w.Header()["Set-Cookie"]
|
||||||
|
if tt.expectedSet {
|
||||||
|
if len(headers) != 1 {
|
||||||
|
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cookieHeader := headers[0]
|
||||||
|
if !strings.Contains(cookieHeader, "pagefrom="+tt.expectedValue) {
|
||||||
|
t.Errorf("Expected pagefrom=%s not found in: %s", tt.expectedValue, cookieHeader)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(headers) != 0 {
|
||||||
|
t.Errorf("Expected no Set-Cookie header, got %d", len(headers))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration(t *testing.T) {
|
||||||
|
// Test the complete flow: SetPageFrom -> CheckPageFrom
|
||||||
|
t.Run("complete flow", func(t *testing.T) {
|
||||||
|
// Step 1: Set pagefrom cookie
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
r1 := &http.Request{
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
r1.Header.Set("Referer", "http://example.com/dashboard")
|
||||||
|
|
||||||
|
SetPageFrom(w1, r1, "example.com")
|
||||||
|
|
||||||
|
// Extract the cookie from the response
|
||||||
|
headers1 := w1.Header()["Set-Cookie"]
|
||||||
|
if len(headers1) != 1 {
|
||||||
|
t.Errorf("Expected 1 Set-Cookie header, got %d", len(headers1))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the cookie was set correctly
|
||||||
|
cookieHeader := headers1[0]
|
||||||
|
if !strings.Contains(cookieHeader, "pagefrom=/dashboard") {
|
||||||
|
t.Errorf("Expected pagefrom=/dashboard not found in: %s", cookieHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Check pagefrom cookie (should delete it)
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
r2 := &http.Request{
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
r2.AddCookie(&http.Cookie{
|
||||||
|
Name: "pagefrom",
|
||||||
|
Value: "/dashboard",
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
|
||||||
|
result := CheckPageFrom(w2, r2)
|
||||||
|
|
||||||
|
if result != "/dashboard" {
|
||||||
|
t.Errorf("Expected result /dashboard, got %s", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the cookie was deleted
|
||||||
|
headers2 := w2.Header()["Set-Cookie"]
|
||||||
|
if len(headers2) != 1 {
|
||||||
|
t.Errorf("Expected 1 Set-Cookie header for deletion, got %d", len(headers2))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cookieHeader2 := headers2[0]
|
||||||
|
// Check for deletion indicators (Max-Age=0 with Expires in the past)
|
||||||
|
if !(strings.Contains(cookieHeader2, "Max-Age=0") && strings.Contains(cookieHeader2, "Expires=Thu, 01 Jan 1970")) {
|
||||||
|
t.Errorf("Expected cookie deletion, got: %s", cookieHeader2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
26
cookies/doc.go
Normal file
26
cookies/doc.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// Package cookies provides utilities for handling HTTP cookies in Go web applications.
|
||||||
|
// It includes functions for setting secure cookies, deleting cookies, and managing
|
||||||
|
// pagefrom tracking for post-login redirects.
|
||||||
|
//
|
||||||
|
// The package follows security best practices by setting the HttpOnly flag on all
|
||||||
|
// cookies to prevent XSS attacks. The SetCookie function allows you to specify the
|
||||||
|
// name, path, value, and max-age for cookies.
|
||||||
|
//
|
||||||
|
// The pagefrom functionality helps with user experience by remembering where a user
|
||||||
|
// was before being redirected to login/register pages, then redirecting them back
|
||||||
|
// after successful authentication.
|
||||||
|
//
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// // Set a session cookie
|
||||||
|
// cookies.SetCookie(w, "session", "/", "abc123", 3600)
|
||||||
|
//
|
||||||
|
// // Delete a cookie
|
||||||
|
// cookies.DeleteCookie(w, "old_session", "/")
|
||||||
|
//
|
||||||
|
// // Handle pagefrom tracking
|
||||||
|
// cookies.SetPageFrom(w, r, "example.com")
|
||||||
|
// redirectTo := cookies.CheckPageFrom(w, r)
|
||||||
|
//
|
||||||
|
// All functions are designed to be safe and handle edge cases gracefully.
|
||||||
|
package cookies
|
||||||
21
env/LICENSE
vendored
Normal file
21
env/LICENSE
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2026 haelnorr
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
67
env/README.md
vendored
Normal file
67
env/README.md
vendored
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# env v1.0.0
|
||||||
|
|
||||||
|
Environment variable utilities for Go applications with type safety and default values.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Type-safe environment variable parsing
|
||||||
|
- Support for all basic Go types (string, int variants, uint variants, bool, time.Duration)
|
||||||
|
- Graceful fallback to default values
|
||||||
|
- Comprehensive boolean parsing with multiple truthy/falsy values
|
||||||
|
- Full test coverage
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get git.haelnorr.com/h/golib/env
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
"git.haelnorr.com/h/golib/env"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// String values
|
||||||
|
host := env.String("HOST", "localhost")
|
||||||
|
|
||||||
|
// Integer values (all sizes supported)
|
||||||
|
port := env.Int("PORT", 8080)
|
||||||
|
timeout := env.Int64("TIMEOUT_SECONDS", 30)
|
||||||
|
|
||||||
|
// Unsigned integer values
|
||||||
|
maxConnections := env.UInt("MAX_CONNECTIONS", 100)
|
||||||
|
|
||||||
|
// Boolean values (supports many formats)
|
||||||
|
debug := env.Bool("DEBUG", false)
|
||||||
|
|
||||||
|
// Duration values
|
||||||
|
requestTimeout := env.Duration("REQUEST_TIMEOUT", 30*time.Second)
|
||||||
|
|
||||||
|
fmt.Printf("Server: %s:%d\n", host, port)
|
||||||
|
fmt.Printf("Debug: %v\n", debug)
|
||||||
|
fmt.Printf("Timeout: %v\n", requestTimeout)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
See the [wiki documentation](../golib/wiki/env.md) for detailed usage information and examples.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Please see the main golib repository for contributing guidelines.
|
||||||
|
|
||||||
|
## Related Projects
|
||||||
|
|
||||||
|
This package is part of the golib collection of utilities for Go applications.
|
||||||
18
env/doc.go
vendored
Normal file
18
env/doc.go
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
// Package env provides utilities for reading environment variables with type safety
|
||||||
|
// and default values. It supports common Go types including strings, integers (all sizes),
|
||||||
|
// unsigned integers (all sizes), booleans, and time.Duration values.
|
||||||
|
//
|
||||||
|
// The package follows a simple pattern where each function takes a key name and a
|
||||||
|
// default value, returning the parsed environment variable or the default if the
|
||||||
|
// variable is not set or cannot be parsed.
|
||||||
|
//
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// port := env.Int("PORT", 8080)
|
||||||
|
// debug := env.Bool("DEBUG", false)
|
||||||
|
// timeout := env.Duration("TIMEOUT", 30*time.Second)
|
||||||
|
//
|
||||||
|
// All functions gracefully handle missing environment variables by returning the
|
||||||
|
// provided default value. They also handle parsing errors by falling back to the
|
||||||
|
// default value.
|
||||||
|
package env
|
||||||
@@ -13,6 +13,7 @@ type Config struct {
|
|||||||
ReadHeaderTimeout time.Duration // ENV HWS_READ_HEADER_TIMEOUT: Timeout for reading request headers in seconds (default: 2)
|
ReadHeaderTimeout time.Duration // ENV HWS_READ_HEADER_TIMEOUT: Timeout for reading request headers in seconds (default: 2)
|
||||||
WriteTimeout time.Duration // ENV HWS_WRITE_TIMEOUT: Timeout for writing requests in seconds (default: 10)
|
WriteTimeout time.Duration // ENV HWS_WRITE_TIMEOUT: Timeout for writing requests in seconds (default: 10)
|
||||||
IdleTimeout time.Duration // ENV HWS_IDLE_TIMEOUT: Timeout for idle connections in seconds (default: 120)
|
IdleTimeout time.Duration // ENV HWS_IDLE_TIMEOUT: Timeout for idle connections in seconds (default: 120)
|
||||||
|
ShutdownDelay time.Duration // ENV HWS_SHUTDOWN_DELAY: Delay in seconds before server shutsdown when Shutdown is called (default: 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigFromEnv returns a Config struct loaded from the environment variables
|
// ConfigFromEnv returns a Config struct loaded from the environment variables
|
||||||
@@ -24,6 +25,7 @@ func ConfigFromEnv() (*Config, error) {
|
|||||||
ReadHeaderTimeout: time.Duration(env.Int("HWS_READ_HEADER_TIMEOUT", 2)) * time.Second,
|
ReadHeaderTimeout: time.Duration(env.Int("HWS_READ_HEADER_TIMEOUT", 2)) * time.Second,
|
||||||
WriteTimeout: time.Duration(env.Int("HWS_WRITE_TIMEOUT", 10)) * time.Second,
|
WriteTimeout: time.Duration(env.Int("HWS_WRITE_TIMEOUT", 10)) * time.Second,
|
||||||
IdleTimeout: time.Duration(env.Int("HWS_IDLE_TIMEOUT", 120)) * time.Second,
|
IdleTimeout: time.Duration(env.Int("HWS_IDLE_TIMEOUT", 120)) * time.Second,
|
||||||
|
ShutdownDelay: time.Duration(env.Int("HWS_SHUTDOWN_DELAY", 5)) * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
|
|||||||
@@ -13,12 +13,12 @@ import (
|
|||||||
func Test_ConfigFromEnv(t *testing.T) {
|
func Test_ConfigFromEnv(t *testing.T) {
|
||||||
t.Run("Default values when no env vars set", func(t *testing.T) {
|
t.Run("Default values when no env vars set", func(t *testing.T) {
|
||||||
// Clear any existing env vars
|
// Clear any existing env vars
|
||||||
os.Unsetenv("HWS_HOST")
|
_ = os.Unsetenv("HWS_HOST")
|
||||||
os.Unsetenv("HWS_PORT")
|
_ = os.Unsetenv("HWS_PORT")
|
||||||
os.Unsetenv("HWS_GZIP")
|
_ = os.Unsetenv("HWS_GZIP")
|
||||||
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||||
os.Unsetenv("HWS_WRITE_TIMEOUT")
|
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||||
os.Unsetenv("HWS_IDLE_TIMEOUT")
|
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -33,8 +33,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Custom host", func(t *testing.T) {
|
t.Run("Custom host", func(t *testing.T) {
|
||||||
os.Setenv("HWS_HOST", "192.168.1.1")
|
_ = os.Setenv("HWS_HOST", "192.168.1.1")
|
||||||
defer os.Unsetenv("HWS_HOST")
|
defer func() {
|
||||||
|
_ = os.Unsetenv("HWS_HOST")
|
||||||
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -42,8 +44,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Custom port", func(t *testing.T) {
|
t.Run("Custom port", func(t *testing.T) {
|
||||||
os.Setenv("HWS_PORT", "8080")
|
_ = os.Setenv("HWS_PORT", "8080")
|
||||||
defer os.Unsetenv("HWS_PORT")
|
defer func() {
|
||||||
|
_ = os.Unsetenv("HWS_PORT")
|
||||||
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -51,8 +55,10 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("GZIP enabled", func(t *testing.T) {
|
t.Run("GZIP enabled", func(t *testing.T) {
|
||||||
os.Setenv("HWS_GZIP", "true")
|
_ = os.Setenv("HWS_GZIP", "true")
|
||||||
defer os.Unsetenv("HWS_GZIP")
|
defer func() {
|
||||||
|
_ = os.Unsetenv("HWS_GZIP")
|
||||||
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -60,12 +66,14 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Custom timeouts", func(t *testing.T) {
|
t.Run("Custom timeouts", func(t *testing.T) {
|
||||||
os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
|
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "5")
|
||||||
os.Setenv("HWS_WRITE_TIMEOUT", "30")
|
_ = os.Setenv("HWS_WRITE_TIMEOUT", "30")
|
||||||
os.Setenv("HWS_IDLE_TIMEOUT", "300")
|
_ = os.Setenv("HWS_IDLE_TIMEOUT", "300")
|
||||||
defer os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
defer func() {
|
||||||
defer os.Unsetenv("HWS_WRITE_TIMEOUT")
|
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||||
defer os.Unsetenv("HWS_IDLE_TIMEOUT")
|
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||||
|
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||||
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -75,19 +83,19 @@ func Test_ConfigFromEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("All custom values", func(t *testing.T) {
|
t.Run("All custom values", func(t *testing.T) {
|
||||||
os.Setenv("HWS_HOST", "0.0.0.0")
|
_ = os.Setenv("HWS_HOST", "0.0.0.0")
|
||||||
os.Setenv("HWS_PORT", "9000")
|
_ = os.Setenv("HWS_PORT", "9000")
|
||||||
os.Setenv("HWS_GZIP", "true")
|
_ = os.Setenv("HWS_GZIP", "true")
|
||||||
os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
|
_ = os.Setenv("HWS_READ_HEADER_TIMEOUT", "3")
|
||||||
os.Setenv("HWS_WRITE_TIMEOUT", "15")
|
_ = os.Setenv("HWS_WRITE_TIMEOUT", "15")
|
||||||
os.Setenv("HWS_IDLE_TIMEOUT", "180")
|
_ = os.Setenv("HWS_IDLE_TIMEOUT", "180")
|
||||||
defer func() {
|
defer func() {
|
||||||
os.Unsetenv("HWS_HOST")
|
_ = os.Unsetenv("HWS_HOST")
|
||||||
os.Unsetenv("HWS_PORT")
|
_ = os.Unsetenv("HWS_PORT")
|
||||||
os.Unsetenv("HWS_GZIP")
|
_ = os.Unsetenv("HWS_GZIP")
|
||||||
os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
_ = os.Unsetenv("HWS_READ_HEADER_TIMEOUT")
|
||||||
os.Unsetenv("HWS_WRITE_TIMEOUT")
|
_ = os.Unsetenv("HWS_WRITE_TIMEOUT")
|
||||||
os.Unsetenv("HWS_IDLE_TIMEOUT")
|
_ = os.Unsetenv("HWS_IDLE_TIMEOUT")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
config, err := hws.ConfigFromEnv()
|
config, err := hws.ConfigFromEnv()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Error to use with Server.ThrowError
|
// HWSError wraps an error with other information for use with HWS features
|
||||||
type HWSError struct {
|
type HWSError struct {
|
||||||
StatusCode int // HTTP Status code
|
StatusCode int // HTTP Status code
|
||||||
Message string // Error message
|
Message string // Error message
|
||||||
@@ -41,7 +41,7 @@ type ErrorPage interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddErrorPage registers a handler that returns an ErrorPage
|
// AddErrorPage registers a handler that returns an ErrorPage
|
||||||
func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
func (s *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError})
|
page, err := pageFunc(HWSError{StatusCode: http.StatusInternalServerError})
|
||||||
@@ -56,7 +56,7 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
|||||||
return errors.New("Render method of the error page did not write anything to the response writer")
|
return errors.New("Render method of the error page did not write anything to the response writer")
|
||||||
}
|
}
|
||||||
|
|
||||||
server.errorPage = pageFunc
|
s.errorPage = pageFunc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +64,19 @@ func (server *Server) AddErrorPage(pageFunc ErrorPageFunc) error {
|
|||||||
// the error with the level specified by the HWSError.
|
// the error with the level specified by the HWSError.
|
||||||
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
|
// If HWSError.RenderErrorPage is true, the error page will be rendered to the ResponseWriter
|
||||||
// and the request chain should be terminated.
|
// and the request chain should be terminated.
|
||||||
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) error {
|
func (s *Server) ThrowError(w http.ResponseWriter, r *http.Request, error HWSError) {
|
||||||
|
err := s.throwError(w, r, error)
|
||||||
|
if err != nil {
|
||||||
|
s.LogError(error)
|
||||||
|
s.LogError(HWSError{
|
||||||
|
Message: "Error occured during throwError",
|
||||||
|
Error: errors.Wrap(err, "s.throwError"),
|
||||||
|
Level: ErrorERROR,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) throwError(w http.ResponseWriter, r *http.Request, error HWSError) error {
|
||||||
if error.StatusCode <= 0 {
|
if error.StatusCode <= 0 {
|
||||||
return errors.New("HWSError.StatusCode cannot be 0.")
|
return errors.New("HWSError.StatusCode cannot be 0.")
|
||||||
}
|
}
|
||||||
@@ -77,32 +89,27 @@ func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error H
|
|||||||
if r == nil {
|
if r == nil {
|
||||||
return errors.New("Request cannot be nil")
|
return errors.New("Request cannot be nil")
|
||||||
}
|
}
|
||||||
if !server.IsReady() {
|
if !s.IsReady() {
|
||||||
return errors.New("ThrowError called before server started")
|
return errors.New("ThrowError called before server started")
|
||||||
}
|
}
|
||||||
w.WriteHeader(error.StatusCode)
|
w.WriteHeader(error.StatusCode)
|
||||||
server.LogError(error)
|
s.LogError(error)
|
||||||
if server.errorPage == nil {
|
if s.errorPage == nil {
|
||||||
server.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
|
s.LogError(HWSError{Message: "No error page provided", Error: nil, Level: ErrorDEBUG})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if error.RenderErrorPage {
|
if error.RenderErrorPage {
|
||||||
server.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
|
s.LogError(HWSError{Message: "Error page rendering", Error: nil, Level: ErrorDEBUG})
|
||||||
errPage, err := server.errorPage(error)
|
errPage, err := s.errorPage(error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
server.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
|
s.LogError(HWSError{Message: "Failed to get a valid error page", Error: err})
|
||||||
}
|
}
|
||||||
err = errPage.Render(r.Context(), w)
|
err = errPage.Render(r.Context(), w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
server.LogError(HWSError{Message: "Failed to render error page", Error: err})
|
s.LogError(HWSError{Message: "Failed to render error page", Error: err})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
server.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
|
s.LogError(HWSError{Message: "Error page specified not to render", Error: nil, Level: ErrorDEBUG})
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) ThrowFatal(w http.ResponseWriter, err error) {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
server.LogFatal(err)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -14,22 +14,26 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
type goodPage struct{}
|
type (
|
||||||
type badPage struct{}
|
goodPage struct{}
|
||||||
|
badPage struct{}
|
||||||
|
)
|
||||||
|
|
||||||
func goodRender(error hws.HWSError) (hws.ErrorPage, error) {
|
func goodRender(error hws.HWSError) (hws.ErrorPage, error) {
|
||||||
return goodPage{}, nil
|
return goodPage{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func badRender1(error hws.HWSError) (hws.ErrorPage, error) {
|
func badRender1(error hws.HWSError) (hws.ErrorPage, error) {
|
||||||
return badPage{}, nil
|
return badPage{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func badRender2(error hws.HWSError) (hws.ErrorPage, error) {
|
func badRender2(error hws.HWSError) (hws.ErrorPage, error) {
|
||||||
return nil, errors.New("I'm an error")
|
return nil, errors.New("I'm an error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g goodPage) Render(ctx context.Context, w io.Writer) error {
|
func (g goodPage) Render(ctx context.Context, w io.Writer) error {
|
||||||
w.Write([]byte("Test write to ResponseWriter"))
|
_, err := w.Write([]byte("Test write to ResponseWriter"))
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b badPage) Render(ctx context.Context, w io.Writer) error {
|
func (b badPage) Render(ctx context.Context, w io.Writer) error {
|
||||||
@@ -85,40 +89,42 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
t.Run("Server not started", func(t *testing.T) {
|
t.Run("Server not started", func(t *testing.T) {
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
buf.Reset()
|
||||||
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "Error",
|
Message: "Error",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
// ThrowError logs errors internally when validation fails
|
||||||
|
output := buf.String()
|
||||||
|
assert.Contains(t, output, "ThrowError called before server started")
|
||||||
})
|
})
|
||||||
|
|
||||||
startTestServer(t, server)
|
startTestServer(t, server)
|
||||||
defer server.Shutdown(t.Context())
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
request *http.Request
|
request *http.Request
|
||||||
error hws.HWSError
|
error hws.HWSError
|
||||||
valid bool
|
expectLogItem string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "No HWSError.Status code",
|
name: "No HWSError.Status code",
|
||||||
request: nil,
|
request: nil,
|
||||||
error: hws.HWSError{},
|
error: hws.HWSError{},
|
||||||
valid: false,
|
expectLogItem: "HWSError.StatusCode cannot be 0",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Negative HWSError.Status code",
|
name: "Negative HWSError.Status code",
|
||||||
request: nil,
|
request: nil,
|
||||||
error: hws.HWSError{StatusCode: -1},
|
error: hws.HWSError{StatusCode: -1},
|
||||||
valid: false,
|
expectLogItem: "HWSError.StatusCode cannot be 0",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "No HWSError.Message",
|
name: "No HWSError.Message",
|
||||||
request: nil,
|
request: nil,
|
||||||
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
|
error: hws.HWSError{StatusCode: http.StatusInternalServerError},
|
||||||
valid: false,
|
expectLogItem: "HWSError.Message cannot be empty",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "No HWSError.Error",
|
name: "No HWSError.Error",
|
||||||
@@ -127,7 +133,7 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
},
|
},
|
||||||
valid: false,
|
expectLogItem: "HWSError.Error cannot be nil",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "No request provided",
|
name: "No request provided",
|
||||||
@@ -137,7 +143,7 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
},
|
},
|
||||||
valid: false,
|
expectLogItem: "Request cannot be nil",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Valid",
|
name: "Valid",
|
||||||
@@ -147,106 +153,92 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
},
|
},
|
||||||
valid: true,
|
expectLogItem: "An error occured",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
buf.Reset()
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
err := server.ThrowError(rr, tt.request, tt.error)
|
server.ThrowError(rr, tt.request, tt.error)
|
||||||
if tt.valid {
|
// ThrowError no longer returns errors; check logs instead
|
||||||
assert.NoError(t, err)
|
output := buf.String()
|
||||||
} else {
|
assert.Contains(t, output, tt.expectLogItem)
|
||||||
t.Log(err)
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
t.Run("Log level set correctly", func(t *testing.T) {
|
t.Run("Log level set correctly", func(t *testing.T) {
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
Level: hws.ErrorWARN,
|
Level: hws.ErrorWARN,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
_, err := buf.ReadString([]byte(" ")[0])
|
||||||
_, err = buf.ReadString([]byte(" ")[0])
|
require.NoError(t, err)
|
||||||
loglvl, err := buf.ReadString([]byte(" ")[0])
|
loglvl, err := buf.ReadString([]byte(" ")[0])
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if loglvl != "\x1b[33mWRN\x1b[0m " {
|
assert.Equal(t, "\x1b[33mWRN\x1b[0m ", loglvl, "Log level should be WRN for ErrorWARN")
|
||||||
err = errors.New("Log level not set correctly")
|
|
||||||
}
|
|
||||||
assert.NoError(t, err)
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
err = server.ThrowError(rr, req, hws.HWSError{
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
|
||||||
_, err = buf.ReadString([]byte(" ")[0])
|
_, err = buf.ReadString([]byte(" ")[0])
|
||||||
|
require.NoError(t, err)
|
||||||
loglvl, err = buf.ReadString([]byte(" ")[0])
|
loglvl, err = buf.ReadString([]byte(" ")[0])
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if loglvl != "\x1b[31mERR\x1b[0m " {
|
assert.Equal(t, "\x1b[31mERR\x1b[0m ", loglvl, "Log level should be ERR when no level specified")
|
||||||
err = errors.New("Log level not set correctly")
|
|
||||||
}
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
|
t.Run("Error page doesnt render if no error page set", func(t *testing.T) {
|
||||||
// Must be run before adding the error page to the test server
|
// Must be run before adding the error page to the test server
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
RenderErrorPage: true,
|
RenderErrorPage: true,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
if body != "" {
|
assert.Empty(t, body, "Error page should not render when no error page is set")
|
||||||
assert.Error(t, nil)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
t.Run("Error page renders", func(t *testing.T) {
|
t.Run("Error page renders", func(t *testing.T) {
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
// Adding the error page will carry over to all future tests and cant be undone
|
// Adding the error page will carry over to all future tests and cant be undone
|
||||||
server.AddErrorPage(goodRender)
|
err := server.AddErrorPage(goodRender)
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
require.NoError(t, err)
|
||||||
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
RenderErrorPage: true,
|
RenderErrorPage: true,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
if body == "" {
|
assert.NotEmpty(t, body, "Error page should render when RenderErrorPage is true")
|
||||||
assert.Error(t, nil)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
t.Run("Error page doesnt render if no told to render", func(t *testing.T) {
|
t.Run("Error page doesnt render if not told to render", func(t *testing.T) {
|
||||||
// Error page already added to server
|
// Error page already added to server
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
err := server.ThrowError(rr, req, hws.HWSError{
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
if body != "" {
|
assert.Empty(t, body, "Error page should not render when RenderErrorPage is false")
|
||||||
assert.Error(t, nil)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
server.Shutdown(t.Context())
|
err := server.Shutdown(t.Context())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("Doesn't error if no logger added to server", func(t *testing.T) {
|
t.Run("Doesn't panic if no logger added to server", func(t *testing.T) {
|
||||||
server, err := hws.NewServer(&hws.Config{
|
server, err := hws.NewServer(&hws.Config{
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: randomPort(),
|
Port: randomPort(),
|
||||||
@@ -261,13 +253,18 @@ func Test_ThrowError(t *testing.T) {
|
|||||||
err = server.Start(t.Context())
|
err = server.Start(t.Context())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
<-server.Ready()
|
<-server.Ready()
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
err = server.ThrowError(rr, req, hws.HWSError{
|
// Should not panic when no logger is present
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
server.ThrowError(rr, req, hws.HWSError{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Message: "An error occured",
|
Message: "An error occured",
|
||||||
Error: errors.New("Error"),
|
Error: errors.New("Error"),
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
}, "ThrowError should not panic when no logger is present")
|
||||||
|
err = server.Shutdown(t.Context())
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||||
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
|
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||||
return func() (interface{}, error) {
|
return func() (any, error) {
|
||||||
return ConfigFromEnv()
|
return ConfigFromEnv()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ go 1.25.5
|
|||||||
require (
|
require (
|
||||||
git.haelnorr.com/h/golib/env v0.9.1
|
git.haelnorr.com/h/golib/env v0.9.1
|
||||||
git.haelnorr.com/h/golib/hlog v0.9.0
|
git.haelnorr.com/h/golib/hlog v0.9.0
|
||||||
|
git.haelnorr.com/h/golib/notify v0.1.0
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
k8s.io/apimachinery v0.35.0
|
k8s.io/apimachinery v0.35.0
|
||||||
@@ -13,6 +14,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
|
github.com/gobwas/glob v0.2.3
|
||||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
|||||||
@@ -2,11 +2,15 @@ git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjo
|
|||||||
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
||||||
git.haelnorr.com/h/golib/hlog v0.9.0 h1:ib8n2MdmiRK2TF067p220kXmhDe9aAnlcsgpuv+QpvE=
|
git.haelnorr.com/h/golib/hlog v0.9.0 h1:ib8n2MdmiRK2TF067p220kXmhDe9aAnlcsgpuv+QpvE=
|
||||||
git.haelnorr.com/h/golib/hlog v0.9.0/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk=
|
git.haelnorr.com/h/golib/hlog v0.9.0/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk=
|
||||||
|
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||||
|
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
|
||||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
|
||||||
|
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
|
||||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
|||||||
@@ -6,14 +6,15 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hlog"
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"github.com/gobwas/glob"
|
||||||
)
|
)
|
||||||
|
|
||||||
type logger struct {
|
type logger struct {
|
||||||
logger *hlog.Logger
|
logger *hlog.Logger
|
||||||
ignoredPaths []string
|
ignoredPaths []glob.Glob
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add tests to make sure all the fields are correctly set
|
// LogError uses the attached logger to log a HWSError
|
||||||
func (s *Server) LogError(err HWSError) {
|
func (s *Server) LogError(err HWSError) {
|
||||||
if s.logger == nil {
|
if s.logger == nil {
|
||||||
return
|
return
|
||||||
@@ -29,45 +30,34 @@ func (s *Server) LogError(err HWSError) {
|
|||||||
s.logger.logger.Warn().Err(err.Error).Msg(err.Message)
|
s.logger.logger.Warn().Err(err.Error).Msg(err.Message)
|
||||||
return
|
return
|
||||||
case ErrorERROR:
|
case ErrorERROR:
|
||||||
s.logger.logger.Error().Err(err.Error).Msg(err.Message)
|
s.logger.logger.Error().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
|
||||||
return
|
return
|
||||||
case ErrorFATAL:
|
case ErrorFATAL:
|
||||||
s.logger.logger.Fatal().Err(err.Error).Msg(err.Message)
|
s.logger.logger.Fatal().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
|
||||||
return
|
return
|
||||||
case ErrorPANIC:
|
case ErrorPANIC:
|
||||||
s.logger.logger.Panic().Err(err.Error).Msg(err.Message)
|
s.logger.logger.Panic().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
s.logger.logger.Error().Err(err.Error).Msg(err.Message)
|
s.logger.logger.Error().Str("stacktrace", fmt.Sprintf("%+v", err.Error)).Err(err.Error).Msg(err.Message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) LogFatal(err error) {
|
// AddLogger adds a logger to the server to use for request logging.
|
||||||
if err == nil {
|
func (s *Server) AddLogger(hlogger *hlog.Logger) error {
|
||||||
err = errors.New("LogFatal was called with a nil error")
|
|
||||||
}
|
|
||||||
if server.logger == nil {
|
|
||||||
fmt.Printf("FATAL - %s: %s", "A fatal error has occured", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
server.logger.logger.Fatal().Err(err).Msg("A fatal error has occured")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server.AddLogger adds a logger to the server to use for request logging.
|
|
||||||
func (server *Server) AddLogger(hlogger *hlog.Logger) error {
|
|
||||||
if hlogger == nil {
|
if hlogger == nil {
|
||||||
return errors.New("Unable to add logger, no logger provided")
|
return errors.New("unable to add logger, no logger provided")
|
||||||
}
|
}
|
||||||
server.logger = &logger{
|
s.logger = &logger{
|
||||||
logger: hlogger,
|
logger: hlogger,
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Server.LoggerIgnorePaths sets a list of URL paths to ignore logging for.
|
// LoggerIgnorePaths sets a list of URL paths to ignore logging for.
|
||||||
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
|
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
|
||||||
// Useful for ignoring requests to CSS files or favicons
|
// Useful for ignoring requests to CSS files or favicons
|
||||||
func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
func (s *Server) LoggerIgnorePaths(paths ...string) error {
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
u, err := url.Parse(path)
|
u, err := url.Parse(path)
|
||||||
valid := err == nil &&
|
valid := err == nil &&
|
||||||
@@ -76,9 +66,22 @@ func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
|||||||
u.RawQuery == "" &&
|
u.RawQuery == "" &&
|
||||||
u.Fragment == ""
|
u.Fragment == ""
|
||||||
if !valid {
|
if !valid {
|
||||||
return fmt.Errorf("Invalid path: '%s'", path)
|
return fmt.Errorf("invalid path: '%s'", path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
server.logger.ignoredPaths = paths
|
s.logger.ignoredPaths = prepareGlobs(paths)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prepareGlobs(paths []string) []glob.Glob {
|
||||||
|
compiledGlobs := make([]glob.Glob, 0, len(paths))
|
||||||
|
for _, pattern := range paths {
|
||||||
|
g, err := glob.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
// If pattern fails to compile, skip it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
compiledGlobs = append(compiledGlobs, g)
|
||||||
|
}
|
||||||
|
return compiledGlobs
|
||||||
|
}
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
|||||||
|
|
||||||
err := server.LoggerIgnorePaths("http://example.com/path")
|
err := server.LoggerIgnorePaths("http://example.com/path")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid path")
|
assert.Contains(t, err.Error(), "invalid path")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Invalid path with host", func(t *testing.T) {
|
t.Run("Invalid path with host", func(t *testing.T) {
|
||||||
@@ -207,7 +207,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
|||||||
err := server.LoggerIgnorePaths("//example.com/path")
|
err := server.LoggerIgnorePaths("//example.com/path")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
assert.Contains(t, err.Error(), "Invalid path")
|
assert.Contains(t, err.Error(), "invalid path")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -217,7 +217,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
|||||||
|
|
||||||
err := server.LoggerIgnorePaths("/path?query=value")
|
err := server.LoggerIgnorePaths("/path?query=value")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid path")
|
assert.Contains(t, err.Error(), "invalid path")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Invalid path with fragment", func(t *testing.T) {
|
t.Run("Invalid path with fragment", func(t *testing.T) {
|
||||||
@@ -226,7 +226,7 @@ func Test_LoggerIgnorePaths(t *testing.T) {
|
|||||||
|
|
||||||
err := server.LoggerIgnorePaths("/path#fragment")
|
err := server.LoggerIgnorePaths("/path#fragment")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid path")
|
assert.Contains(t, err.Error(), "invalid path")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Valid paths", func(t *testing.T) {
|
t.Run("Valid paths", func(t *testing.T) {
|
||||||
|
|||||||
@@ -5,35 +5,37 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Middleware func(h http.Handler) http.Handler
|
type (
|
||||||
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
Middleware func(h http.Handler) http.Handler
|
||||||
|
MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
||||||
|
)
|
||||||
|
|
||||||
// Server.AddMiddleware registers all the middleware.
|
// AddMiddleware registers all the middleware.
|
||||||
// Middleware will be run in the order that they are provided.
|
// Middleware will be run in the order that they are provided.
|
||||||
// Can only be called once
|
// Can only be called once
|
||||||
func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
func (s *Server) AddMiddleware(middleware ...Middleware) error {
|
||||||
if !server.routes {
|
if !s.routes {
|
||||||
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
|
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
|
||||||
}
|
}
|
||||||
if server.middleware {
|
if s.middleware {
|
||||||
return errors.New("Server.AddMiddleware already called")
|
return errors.New("Server.AddMiddleware already called")
|
||||||
}
|
}
|
||||||
// RUN LOGGING MIDDLEWARE FIRST
|
// RUN LOGGING MIDDLEWARE FIRST
|
||||||
server.server.Handler = logging(server.server.Handler, server.logger)
|
s.server.Handler = logging(s.server.Handler, s.logger)
|
||||||
|
|
||||||
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
|
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
|
||||||
for i := len(middleware); i > 0; i-- {
|
for i := len(middleware); i > 0; i-- {
|
||||||
server.server.Handler = middleware[i-1](server.server.Handler)
|
s.server.Handler = middleware[i-1](s.server.Handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RUN GZIP
|
// RUN GZIP
|
||||||
if server.GZIP {
|
if s.GZIP {
|
||||||
server.server.Handler = addgzip(server.server.Handler)
|
s.server.Handler = addgzip(s.server.Handler)
|
||||||
}
|
}
|
||||||
// RUN TIMER MIDDLEWARE LAST
|
// RUN TIMER MIDDLEWARE LAST
|
||||||
server.server.Handler = startTimer(server.server.Handler)
|
s.server.Handler = startTimer(s.server.Handler)
|
||||||
|
|
||||||
server.middleware = true
|
s.middleware = true
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -43,17 +45,19 @@ func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
|||||||
// and returns a new request and optional HWSError.
|
// and returns a new request and optional HWSError.
|
||||||
// If a HWSError is returned, server.ThrowError will be called.
|
// If a HWSError is returned, server.ThrowError will be called.
|
||||||
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
|
// If HWSError.RenderErrorPage is true, the request chain will be terminated and the error page rendered
|
||||||
func (server *Server) NewMiddleware(
|
func (s *Server) NewMiddleware(
|
||||||
middlewareFunc MiddlewareFunc,
|
middlewareFunc MiddlewareFunc,
|
||||||
) Middleware {
|
) Middleware {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
newReq, herr := middlewareFunc(w, r)
|
newReq, herr := middlewareFunc(w, r)
|
||||||
if herr != nil {
|
if herr != nil {
|
||||||
server.ThrowError(w, r, *herr)
|
s.ThrowError(w, r, *herr)
|
||||||
if herr.RenderErrorPage {
|
if herr.RenderErrorPage {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, newReq)
|
next.ServeHTTP(w, newReq)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -2,8 +2,9 @@ package hws
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gobwas/glob"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Middleware to add logs to console with details of the request
|
// Middleware to add logs to console with details of the request
|
||||||
@@ -13,7 +14,7 @@ func logging(next http.Handler, logger *logger) http.Handler {
|
|||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if slices.Contains(logger.ignoredPaths, r.URL.Path) {
|
if globTest(r.URL.Path, logger.ignoredPaths) {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -36,3 +37,12 @@ func logging(next http.Handler, logger *logger) http.Handler {
|
|||||||
Msg("Served")
|
Msg("Served")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func globTest(testPath string, globs []glob.Glob) bool {
|
||||||
|
for _, g := range globs {
|
||||||
|
if g.Match(testPath) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,16 +18,24 @@ func startTimer(next http.Handler) http.Handler {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
func (c contextKey) String() string {
|
||||||
|
return "hws context key " + string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestTimerCtxKey = contextKey("request-timer")
|
||||||
|
|
||||||
// Set the start time of the request
|
// Set the start time of the request
|
||||||
func setStart(ctx context.Context, time time.Time) context.Context {
|
func setStart(ctx context.Context, time time.Time) context.Context {
|
||||||
return context.WithValue(ctx, "hws context key request-timer", time)
|
return context.WithValue(ctx, requestTimerCtxKey, time)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the start time of the request
|
// Get the start time of the request
|
||||||
func getStartTime(ctx context.Context) (time.Time, error) {
|
func getStartTime(ctx context.Context) (time.Time, error) {
|
||||||
start, ok := ctx.Value("hws context key request-timer").(time.Time)
|
start, ok := ctx.Value(requestTimerCtxKey).(time.Time)
|
||||||
if !ok {
|
if !ok {
|
||||||
return time.Time{}, errors.New("Failed to get start time of request")
|
return time.Time{}, errors.New("failed to get start time of request")
|
||||||
}
|
}
|
||||||
return start, nil
|
return start, nil
|
||||||
}
|
}
|
||||||
|
|||||||
316
hws/notify.go
Normal file
316
hws/notify.go
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/notify"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LevelShutdown is a special level used for the notification sent on shutdown.
|
||||||
|
// This can be used to check if the notification is a shutdown event and if it should
|
||||||
|
// be passed on to consumers or special considerations should be made.
|
||||||
|
const LevelShutdown notify.Level = "shutdown"
|
||||||
|
|
||||||
|
// Notifier manages client subscriptions and notification delivery for the HWS server.
|
||||||
|
// It wraps the notify.Notifier with additional client management features including
|
||||||
|
// dual identification (subscription ID + alternate ID) and automatic cleanup of
|
||||||
|
// inactive clients after 5 minutes.
|
||||||
|
type Notifier struct {
|
||||||
|
*notify.Notifier
|
||||||
|
clients *Clients
|
||||||
|
running bool
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clients maintains thread-safe mappings between subscriber IDs, alternate IDs,
|
||||||
|
// and Client instances. It supports querying clients by either their unique
|
||||||
|
// subscription ID or their alternate ID (where multiple clients can share an alternate ID).
|
||||||
|
type Clients struct {
|
||||||
|
clientsSubMap map[notify.Target]*Client
|
||||||
|
clientsIDMap map[string][]*Client
|
||||||
|
lock *sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client represents a unique subscriber to the notifications channel.
|
||||||
|
// It tracks activity via lastSeen timestamp (updated atomically) and monitors
|
||||||
|
// consecutive send failures for automatic disconnect detection.
|
||||||
|
type Client struct {
|
||||||
|
sub *notify.Subscriber
|
||||||
|
lastSeen int64 // accessed atomically
|
||||||
|
altID string
|
||||||
|
consecutiveFails int32 // accessed atomically
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) startNotifier() {
|
||||||
|
if s.notifier != nil && s.notifier.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
s.notifier = &Notifier{
|
||||||
|
Notifier: notify.NewNotifier(50),
|
||||||
|
clients: &Clients{
|
||||||
|
clientsSubMap: make(map[notify.Target]*Client),
|
||||||
|
clientsIDMap: make(map[string][]*Client),
|
||||||
|
lock: new(sync.RWMutex),
|
||||||
|
},
|
||||||
|
running: true,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(time.Minute)
|
||||||
|
go func() {
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
s.notifier.clients.cleanUp()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeNotifier() {
|
||||||
|
if s.notifier != nil {
|
||||||
|
if s.notifier.cancel != nil {
|
||||||
|
s.notifier.cancel()
|
||||||
|
}
|
||||||
|
s.notifier.running = false
|
||||||
|
s.notifier.Close()
|
||||||
|
}
|
||||||
|
s.notifier = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotifySub sends a notification to a specific subscriber identified by the notification's Target field.
|
||||||
|
// If the subscriber doesn't exist, a warning is logged but the operation does not fail.
|
||||||
|
// This is thread-safe and can be called from multiple goroutines.
|
||||||
|
func (s *Server) NotifySub(nt notify.Notification) {
|
||||||
|
if s.notifier == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, exists := s.notifier.clients.getClient(nt.Target)
|
||||||
|
if !exists {
|
||||||
|
err := fmt.Errorf("tried to notify subscriber that doesn't exist - subID: %s", nt.Target)
|
||||||
|
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.notifier.Notify(nt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotifyID sends a notification to all clients associated with the given alternate ID.
|
||||||
|
// Multiple clients can share the same alternate ID (e.g., multiple sessions for one user).
|
||||||
|
// If no clients exist with that ID, a warning is logged but the operation does not fail.
|
||||||
|
// This is thread-safe and can be called from multiple goroutines.
|
||||||
|
func (s *Server) NotifyID(nt notify.Notification, altID string) {
|
||||||
|
if s.notifier == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.notifier.clients.lock.RLock()
|
||||||
|
clients, exists := s.notifier.clients.clientsIDMap[altID]
|
||||||
|
s.notifier.clients.lock.RUnlock()
|
||||||
|
if !exists {
|
||||||
|
err := fmt.Errorf("tried to notify client group that doesn't exist - altID: %s", altID)
|
||||||
|
s.LogError(HWSError{Level: ErrorWARN, Message: "Failed to notify", Error: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, client := range clients {
|
||||||
|
ntt := nt
|
||||||
|
ntt.Target = client.sub.ID
|
||||||
|
s.NotifySub(ntt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotifyAll broadcasts a notification to all connected clients.
|
||||||
|
// This is thread-safe and can be called from multiple goroutines.
|
||||||
|
func (s *Server) NotifyAll(nt notify.Notification) {
|
||||||
|
if s.notifier == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nt.Target = ""
|
||||||
|
s.notifier.NotifyAll(nt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient returns a Client that can be used to receive notifications.
|
||||||
|
// If a client exists with the provided subID, that client will be returned.
|
||||||
|
// If altID is provided, it will update the existing Client.
|
||||||
|
// If subID is an empty string, a new client will be returned.
|
||||||
|
// If both altID and subID are empty, a new Client with no altID will be returned.
|
||||||
|
// Multiple clients with the same altID are permitted.
|
||||||
|
func (s *Server) GetClient(subID, altID string) (*Client, error) {
|
||||||
|
if s.notifier == nil || !s.notifier.running {
|
||||||
|
return nil, errors.New("notifier hasn't started")
|
||||||
|
}
|
||||||
|
target := notify.Target(subID)
|
||||||
|
client, exists := s.notifier.clients.getClient(target)
|
||||||
|
if exists {
|
||||||
|
s.notifier.clients.updateAltID(client, altID)
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
// An error should only be returned if there are 10 collisions of a randomly generated 16 bit byte string from rand.Rand()
|
||||||
|
// Basically never going to happen, and if it does its not my problem
|
||||||
|
sub, _ := s.notifier.Subscribe()
|
||||||
|
client = &Client{
|
||||||
|
sub: sub,
|
||||||
|
lastSeen: time.Now().Unix(),
|
||||||
|
altID: altID,
|
||||||
|
consecutiveFails: 0,
|
||||||
|
}
|
||||||
|
s.notifier.clients.addClient(client)
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *Clients) getClient(target notify.Target) (*Client, bool) {
|
||||||
|
cs.lock.RLock()
|
||||||
|
client, exists := cs.clientsSubMap[target]
|
||||||
|
cs.lock.RUnlock()
|
||||||
|
return client, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *Clients) updateAltID(client *Client, altID string) {
|
||||||
|
cs.lock.Lock()
|
||||||
|
if altID != "" && !slices.Contains(cs.clientsIDMap[altID], client) {
|
||||||
|
cs.clientsIDMap[altID] = append(cs.clientsIDMap[altID], client)
|
||||||
|
}
|
||||||
|
if client.altID != altID && client.altID != "" {
|
||||||
|
cs.deleteFromID(client, client.altID)
|
||||||
|
}
|
||||||
|
client.altID = altID
|
||||||
|
cs.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *Clients) deleteFromID(client *Client, altID string) {
|
||||||
|
cs.clientsIDMap[altID] = deleteFromSlice(cs.clientsIDMap[altID], client, func(a, b *Client) bool {
|
||||||
|
return a.sub.ID == b.sub.ID
|
||||||
|
})
|
||||||
|
if len(cs.clientsIDMap[altID]) == 0 {
|
||||||
|
delete(cs.clientsIDMap, altID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *Clients) addClient(client *Client) {
|
||||||
|
cs.lock.Lock()
|
||||||
|
cs.clientsSubMap[client.sub.ID] = client
|
||||||
|
if client.altID != "" {
|
||||||
|
cs.clientsIDMap[client.altID] = append(cs.clientsIDMap[client.altID], client)
|
||||||
|
}
|
||||||
|
cs.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *Clients) cleanUp() {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
|
||||||
|
// Collect clients to kill while holding read lock
|
||||||
|
cs.lock.RLock()
|
||||||
|
toKill := make([]*Client, 0)
|
||||||
|
for _, client := range cs.clientsSubMap {
|
||||||
|
if now-atomic.LoadInt64(&client.lastSeen) > 300 {
|
||||||
|
toKill = append(toKill, client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cs.lock.RUnlock()
|
||||||
|
|
||||||
|
// Kill clients without holding lock
|
||||||
|
for _, client := range toKill {
|
||||||
|
cs.killClient(client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *Clients) killClient(client *Client) {
|
||||||
|
client.sub.Unsubscribe()
|
||||||
|
|
||||||
|
cs.lock.Lock()
|
||||||
|
delete(cs.clientsSubMap, client.sub.ID)
|
||||||
|
if client.altID != "" {
|
||||||
|
cs.deleteFromID(client, client.altID)
|
||||||
|
}
|
||||||
|
cs.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen starts a goroutine that forwards notifications from the subscriber to a returned channel.
|
||||||
|
// It returns a receive-only channel for notifications and a channel to stop listening.
|
||||||
|
// The notification channel is buffered with size 10 to tolerate brief slowness.
|
||||||
|
//
|
||||||
|
// The goroutine automatically stops and closes the notification channel when:
|
||||||
|
// - The subscriber is unsubscribed
|
||||||
|
// - The stop channel is closed
|
||||||
|
// - The client fails to receive 5 consecutive notifications within 5 seconds each
|
||||||
|
//
|
||||||
|
// Client.lastSeen is updated every 30 seconds via heartbeat, or when a notification is successfully delivered.
|
||||||
|
// Consecutive send failures are tracked; after 5 failures, the client is considered disconnected and cleaned up.
|
||||||
|
func (c *Client) Listen() (<-chan notify.Notification, chan<- struct{}) {
|
||||||
|
ch := make(chan notify.Notification, 10)
|
||||||
|
stop := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
defer close(ch)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
|
||||||
|
case nt, ok := <-c.sub.Listen():
|
||||||
|
if !ok {
|
||||||
|
// Subscriber channel closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to send with timeout
|
||||||
|
timeout := time.NewTimer(5 * time.Second)
|
||||||
|
select {
|
||||||
|
case ch <- nt:
|
||||||
|
// Successfully sent - update lastSeen and reset failure count
|
||||||
|
atomic.StoreInt64(&c.lastSeen, time.Now().Unix())
|
||||||
|
atomic.StoreInt32(&c.consecutiveFails, 0)
|
||||||
|
timeout.Stop()
|
||||||
|
|
||||||
|
case <-timeout.C:
|
||||||
|
// Send timeout - increment failure count
|
||||||
|
fails := atomic.AddInt32(&c.consecutiveFails, 1)
|
||||||
|
if fails >= 5 {
|
||||||
|
// Too many consecutive failures - client is stuck/disconnected
|
||||||
|
c.sub.Unsubscribe()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-stop:
|
||||||
|
timeout.Stop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
// Heartbeat - update lastSeen to keep client alive
|
||||||
|
atomic.StoreInt64(&c.lastSeen, time.Now().Unix())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, stop
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) ID() string {
|
||||||
|
return string(c.sub.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteFromSlice[T any](a []T, c T, eq func(T, T) bool) []T {
|
||||||
|
n := 0
|
||||||
|
for _, x := range a {
|
||||||
|
if !eq(x, c) {
|
||||||
|
a[n] = x
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return a[:n]
|
||||||
|
}
|
||||||
1010
hws/notify_test.go
Normal file
1010
hws/notify_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -13,3 +13,7 @@ func (w *wrappedWriter) WriteHeader(statusCode int) {
|
|||||||
w.ResponseWriter.WriteHeader(statusCode)
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
w.statusCode = statusCode
|
w.statusCode = statusCode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *wrappedWriter) Unwrap() http.ResponseWriter {
|
||||||
|
return w.ResponseWriter
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,13 +30,13 @@ const (
|
|||||||
MethodPATCH Method = "PATCH"
|
MethodPATCH Method = "PATCH"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server.AddRoutes registers the page handlers for the server.
|
// AddRoutes registers the page handlers for the server.
|
||||||
// At least one route must be provided.
|
// At least one route must be provided.
|
||||||
// If any route patterns (path + method) are defined multiple times, the first
|
// If any route patterns (path + method) are defined multiple times, the first
|
||||||
// instance will be added and any additional conflicts will be discarded.
|
// instance will be added and any additional conflicts will be discarded.
|
||||||
func (server *Server) AddRoutes(routes ...Route) error {
|
func (s *Server) AddRoutes(routes ...Route) error {
|
||||||
if len(routes) == 0 {
|
if len(routes) == 0 {
|
||||||
return errors.New("No routes provided")
|
return errors.New("no routes provided")
|
||||||
}
|
}
|
||||||
patterns := []string{}
|
patterns := []string{}
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
@@ -47,10 +47,10 @@ func (server *Server) AddRoutes(routes ...Route) error {
|
|||||||
}
|
}
|
||||||
for _, method := range route.Methods {
|
for _, method := range route.Methods {
|
||||||
if !validMethod(method) {
|
if !validMethod(method) {
|
||||||
return fmt.Errorf("Invalid method %s for path %s", method, route.Path)
|
return fmt.Errorf("invalid method %s for path %s", method, route.Path)
|
||||||
}
|
}
|
||||||
if route.Handler == nil {
|
if route.Handler == nil {
|
||||||
return fmt.Errorf("No handler provided for %s %s", method, route.Path)
|
return fmt.Errorf("no handler provided for %s %s", method, route.Path)
|
||||||
}
|
}
|
||||||
pattern := fmt.Sprintf("%s %s", method, route.Path)
|
pattern := fmt.Sprintf("%s %s", method, route.Path)
|
||||||
if slices.Contains(patterns, pattern) {
|
if slices.Contains(patterns, pattern) {
|
||||||
@@ -61,8 +61,8 @@ func (server *Server) AddRoutes(routes ...Route) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server.server.Handler = mux
|
s.server.Handler = mux
|
||||||
server.routes = true
|
s.routes = true
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func Test_AddRoutes(t *testing.T) {
|
|||||||
server := createTestServer(t, &buf)
|
server := createTestServer(t, &buf)
|
||||||
err := server.AddRoutes()
|
err := server.AddRoutes()
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "No routes provided")
|
assert.Contains(t, err.Error(), "no routes provided")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Single valid route", func(t *testing.T) {
|
t.Run("Single valid route", func(t *testing.T) {
|
||||||
@@ -58,7 +58,7 @@ func Test_AddRoutes(t *testing.T) {
|
|||||||
Handler: handler,
|
Handler: handler,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid method")
|
assert.Contains(t, err.Error(), "invalid method")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("No handler provided", func(t *testing.T) {
|
t.Run("No handler provided", func(t *testing.T) {
|
||||||
@@ -69,7 +69,7 @@ func Test_AddRoutes(t *testing.T) {
|
|||||||
Handler: nil,
|
Handler: nil,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "No handler provided")
|
assert.Contains(t, err.Error(), "no handler provided")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("All HTTP methods are valid", func(t *testing.T) {
|
t.Run("All HTTP methods are valid", func(t *testing.T) {
|
||||||
@@ -203,7 +203,7 @@ func Test_AddRoutes_MultipleMethods(t *testing.T) {
|
|||||||
Handler: handler,
|
Handler: handler,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Invalid method")
|
assert.Contains(t, err.Error(), "invalid method")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {
|
t.Run("Empty Methods slice falls back to Method field", func(t *testing.T) {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/notify"
|
||||||
"k8s.io/apimachinery/pkg/util/validation"
|
"k8s.io/apimachinery/pkg/util/validation"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -20,17 +21,19 @@ type Server struct {
|
|||||||
middleware bool
|
middleware bool
|
||||||
errorPage ErrorPageFunc
|
errorPage ErrorPageFunc
|
||||||
ready chan struct{}
|
ready chan struct{}
|
||||||
|
notifier *Notifier
|
||||||
|
shutdowndelay time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ready returns a channel that is closed when the server is started
|
// Ready returns a channel that is closed when the server is started
|
||||||
func (server *Server) Ready() <-chan struct{} {
|
func (s *Server) Ready() <-chan struct{} {
|
||||||
return server.ready
|
return s.ready
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsReady checks if the server is running
|
// IsReady checks if the server is running
|
||||||
func (server *Server) IsReady() bool {
|
func (s *Server) IsReady() bool {
|
||||||
select {
|
select {
|
||||||
case <-server.ready:
|
case <-s.ready:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
@@ -38,13 +41,13 @@ func (server *Server) IsReady() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Addr returns the server's network address
|
// Addr returns the server's network address
|
||||||
func (server *Server) Addr() string {
|
func (s *Server) Addr() string {
|
||||||
return server.server.Addr
|
return s.server.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler returns the server's HTTP handler for testing purposes
|
// Handler returns the server's HTTP handler for testing purposes
|
||||||
func (server *Server) Handler() http.Handler {
|
func (s *Server) Handler() http.Handler {
|
||||||
return server.server.Handler
|
return s.server.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer returns a new hws.Server with the specified configuration.
|
// NewServer returns a new hws.Server with the specified configuration.
|
||||||
@@ -72,7 +75,7 @@ func NewServer(config *Config) (*Server, error) {
|
|||||||
|
|
||||||
valid := isValidHostname(config.Host)
|
valid := isValidHostname(config.Host)
|
||||||
if !valid {
|
if !valid {
|
||||||
return nil, fmt.Errorf("Hostname '%s' is not valid", config.Host)
|
return nil, fmt.Errorf("hostname '%s' is not valid", config.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
@@ -87,56 +90,69 @@ func NewServer(config *Config) (*Server, error) {
|
|||||||
routes: false,
|
routes: false,
|
||||||
GZIP: config.GZIP,
|
GZIP: config.GZIP,
|
||||||
ready: make(chan struct{}),
|
ready: make(chan struct{}),
|
||||||
|
shutdowndelay: config.ShutdownDelay,
|
||||||
}
|
}
|
||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) Start(ctx context.Context) error {
|
func (s *Server) Start(ctx context.Context) error {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return errors.New("Context cannot be nil")
|
return errors.New("Context cannot be nil")
|
||||||
}
|
}
|
||||||
if !server.routes {
|
if !s.routes {
|
||||||
return errors.New("Server.AddRoutes must be run before starting the server")
|
return errors.New("Server.AddRoutes must be run before starting the server")
|
||||||
}
|
}
|
||||||
if !server.middleware {
|
if !s.middleware {
|
||||||
err := server.AddMiddleware()
|
err := s.AddMiddleware()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "server.AddMiddleware")
|
return errors.Wrap(err, "server.AddMiddleware")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.startNotifier()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if server.logger == nil {
|
if s.logger == nil {
|
||||||
fmt.Printf("Listening for requests on %s", server.server.Addr)
|
fmt.Printf("Listening for requests on %s", s.server.Addr)
|
||||||
} else {
|
} else {
|
||||||
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
|
s.logger.logger.Info().Str("address", s.server.Addr).Msg("Listening for requests")
|
||||||
}
|
}
|
||||||
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
if server.logger == nil {
|
if s.logger == nil {
|
||||||
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
||||||
} else {
|
} else {
|
||||||
server.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
s.LogError(HWSError{Error: err, Message: "Server encountered a fatal error"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
server.waitUntilReady(ctx)
|
s.waitUntilReady(ctx)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) Shutdown(ctx context.Context) error {
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
if !server.IsReady() {
|
if s.logger != nil {
|
||||||
|
s.logger.logger.Debug().Dur("shutdown_delay", s.shutdowndelay).Msg("HWS Server shutting down")
|
||||||
|
}
|
||||||
|
s.NotifyAll(notify.Notification{
|
||||||
|
Title: "Shutting down",
|
||||||
|
Message: fmt.Sprintf("Server is shutting down in %v", s.shutdowndelay),
|
||||||
|
Level: LevelShutdown,
|
||||||
|
})
|
||||||
|
<-time.NewTimer(s.shutdowndelay).C
|
||||||
|
if !s.IsReady() {
|
||||||
return errors.New("Server isn't running")
|
return errors.New("Server isn't running")
|
||||||
}
|
}
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return errors.New("Context cannot be nil")
|
return errors.New("Context cannot be nil")
|
||||||
}
|
}
|
||||||
err := server.server.Shutdown(ctx)
|
err := s.server.Shutdown(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Failed to shutdown the server gracefully")
|
return errors.Wrap(err, "Failed to shutdown the server gracefully")
|
||||||
}
|
}
|
||||||
server.ready = make(chan struct{})
|
s.closeNotifier()
|
||||||
|
s.ready = make(chan struct{})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,7 +170,7 @@ func isValidHostname(host string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) waitUntilReady(ctx context.Context) error {
|
func (s *Server) waitUntilReady(ctx context.Context) error {
|
||||||
ticker := time.NewTicker(50 * time.Millisecond)
|
ticker := time.NewTicker(50 * time.Millisecond)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -166,14 +182,14 @@ func (server *Server) waitUntilReady(ctx context.Context) error {
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
|
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
resp, err := http.Get("http://" + server.server.Addr + "/healthz")
|
resp, err := http.Get("http://" + s.server.Addr + "/healthz")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue // not accepting yet
|
continue // not accepting yet
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
if resp.StatusCode == http.StatusOK {
|
||||||
closeOnce.Do(func() { close(server.ready) })
|
closeOnce.Do(func() { close(s.ready) })
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func createTestServer(t *testing.T, w io.Writer) *hws.Server {
|
|||||||
server, err := hws.NewServer(&hws.Config{
|
server, err := hws.NewServer(&hws.Config{
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: randomPort(),
|
Port: randomPort(),
|
||||||
|
ShutdownDelay: 0, // No delay for tests
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"git.haelnorr.com/h/golib/hlog"
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/golib/jwt"
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
"github.com/gobwas/glob"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ type Authenticator[T Model, TX DBTransaction] struct {
|
|||||||
tokenGenerator *jwt.TokenGenerator
|
tokenGenerator *jwt.TokenGenerator
|
||||||
load LoadFunc[T, TX]
|
load LoadFunc[T, TX]
|
||||||
beginTx BeginTX
|
beginTx BeginTX
|
||||||
ignoredPaths []string
|
ignoredPaths []glob.Glob
|
||||||
logger *hlog.Logger
|
logger *hlog.Logger
|
||||||
server *hws.Server
|
server *hws.Server
|
||||||
errorPage hws.ErrorPageFunc
|
errorPage hws.ErrorPageFunc
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ func (e EZConfIntegration) PackagePath() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
// ConfigFunc returns the ConfigFromEnv function for ezconf
|
||||||
func (e EZConfIntegration) ConfigFunc() func() (interface{}, error) {
|
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
|
||||||
return func() (interface{}, error) {
|
return func() (any, error) {
|
||||||
return ConfigFromEnv()
|
return ConfigFromEnv()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,16 +6,19 @@ require (
|
|||||||
git.haelnorr.com/h/golib/cookies v0.9.0
|
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||||
git.haelnorr.com/h/golib/env v0.9.1
|
git.haelnorr.com/h/golib/env v0.9.1
|
||||||
git.haelnorr.com/h/golib/hlog v0.10.4
|
git.haelnorr.com/h/golib/hlog v0.10.4
|
||||||
git.haelnorr.com/h/golib/hws v0.3.0
|
git.haelnorr.com/h/golib/hws v0.5.0
|
||||||
git.haelnorr.com/h/golib/jwt v0.10.1
|
git.haelnorr.com/h/golib/jwt v0.10.1
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require git.haelnorr.com/h/golib/notify v0.1.0 // indirect
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
|
github.com/gobwas/glob v0.2.3
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
|
|||||||
@@ -4,10 +4,12 @@ git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjo
|
|||||||
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
||||||
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
|
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
|
||||||
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
|
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
|
||||||
git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo=
|
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
|
||||||
git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
|
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
|
||||||
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
|
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
|
||||||
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||||
|
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||||
|
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
@@ -15,6 +17,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
|
||||||
|
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
|
||||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
|
|||||||
@@ -23,8 +23,7 @@ func (tm TestModel) GetID() int {
|
|||||||
return tm.ID
|
return tm.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type TestTransaction struct {
|
type TestTransaction struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
|
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -137,8 +136,10 @@ func TestCurrentModel(t *testing.T) {
|
|||||||
func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
|
func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
|
||||||
// Clear environment variables
|
// Clear environment variables
|
||||||
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
|
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
|
||||||
os.Setenv("HWSAUTH_SECRET_KEY", "")
|
_ = os.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||||
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
defer func() {
|
||||||
|
_ = os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
||||||
|
}()
|
||||||
|
|
||||||
_, err := ConfigFromEnv()
|
_, err := ConfigFromEnv()
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -327,7 +328,9 @@ func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
|
|||||||
|
|
||||||
db, _, err := createMockDB()
|
db, _, err := createMockDB()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer db.Close()
|
defer func() {
|
||||||
|
_ = db.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
auth, err := NewAuthenticator(
|
auth, err := NewAuthenticator(
|
||||||
cfg,
|
cfg,
|
||||||
@@ -409,7 +412,9 @@ func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
|
|||||||
|
|
||||||
db, _, err := createMockDB()
|
db, _, err := createMockDB()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer db.Close()
|
defer func() {
|
||||||
|
_ = db.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
auth, err := NewAuthenticator(
|
auth, err := NewAuthenticator(
|
||||||
cfg,
|
cfg,
|
||||||
@@ -454,7 +459,9 @@ func TestLogin_BasicFunctionality(t *testing.T) {
|
|||||||
|
|
||||||
db, _, err := createMockDB()
|
db, _, err := createMockDB()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer db.Close()
|
defer func() {
|
||||||
|
_ = db.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
auth, err := NewAuthenticator(
|
auth, err := NewAuthenticator(
|
||||||
cfg,
|
cfg,
|
||||||
@@ -476,6 +483,7 @@ func TestLogin_BasicFunctionality(t *testing.T) {
|
|||||||
// This test mainly checks that the function doesn't panic and has right call signature
|
// This test mainly checks that the function doesn't panic and has right call signature
|
||||||
// The actual JWT functionality is tested in jwt package itself
|
// The actual JWT functionality is tested in jwt package itself
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
auth.Login(w, r, user, rememberMe)
|
err := auth.Login(w, r, user, rememberMe)
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package hwsauth
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/gobwas/glob"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IgnorePaths excludes specified paths from authentication middleware.
|
// IgnorePaths excludes specified paths from authentication middleware.
|
||||||
@@ -22,9 +24,22 @@ func (auth *Authenticator[T, TX]) IgnorePaths(paths ...string) error {
|
|||||||
u.RawQuery == "" &&
|
u.RawQuery == "" &&
|
||||||
u.Fragment == ""
|
u.Fragment == ""
|
||||||
if !valid {
|
if !valid {
|
||||||
return fmt.Errorf("Invalid path: '%s'", path)
|
return fmt.Errorf("invalid path: '%s'", path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auth.ignoredPaths = paths
|
auth.ignoredPaths = prepareGlobs(paths)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prepareGlobs(paths []string) []glob.Glob {
|
||||||
|
compiledGlobs := make([]glob.Glob, 0, len(paths))
|
||||||
|
for _, pattern := range paths {
|
||||||
|
g, err := glob.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
// If pattern fails to compile, skip it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
compiledGlobs = append(compiledGlobs, g)
|
||||||
|
}
|
||||||
|
return compiledGlobs
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,14 +33,18 @@ func (auth *Authenticator[T, TX]) Logout(tx TX, w http.ResponseWriter, r *http.R
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "auth.getTokens")
|
return errors.Wrap(err, "auth.getTokens")
|
||||||
}
|
}
|
||||||
|
if aT != nil {
|
||||||
err = aT.Revoke(jwt.DBTransaction(tx))
|
err = aT.Revoke(jwt.DBTransaction(tx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "aT.Revoke")
|
return errors.Wrap(err, "aT.Revoke")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if rT != nil {
|
||||||
err = rT.Revoke(jwt.DBTransaction(tx))
|
err = rT.Revoke(jwt.DBTransaction(tx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "rT.Revoke")
|
return errors.Wrap(err, "rT.Revoke")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
cookies.DeleteCookie(w, "access", "/")
|
cookies.DeleteCookie(w, "access", "/")
|
||||||
cookies.DeleteCookie(w, "refresh", "/")
|
cookies.DeleteCookie(w, "refresh", "/")
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ package hwsauth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"git.haelnorr.com/h/golib/hws"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"github.com/gobwas/glob"
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authenticate returns the main authentication middleware.
|
// Authenticate returns the main authentication middleware.
|
||||||
@@ -14,14 +16,22 @@ import (
|
|||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
//
|
//
|
||||||
// server.AddMiddleware(auth.Authenticate())
|
// server.AddMiddleware(auth.Authenticate(nil))
|
||||||
func (auth *Authenticator[T, TX]) Authenticate() hws.Middleware {
|
//
|
||||||
return auth.server.NewMiddleware(auth.authenticate())
|
// If extraCheck is provided, it will run just before the user is added to the context,
|
||||||
|
// and the return will determine if the user will be added, or the request passed on
|
||||||
|
// without the user.
|
||||||
|
func (auth *Authenticator[T, TX]) Authenticate(
|
||||||
|
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
|
||||||
|
) hws.Middleware {
|
||||||
|
return auth.server.NewMiddleware(auth.authenticate(extraCheck))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
|
func (auth *Authenticator[T, TX]) authenticate(
|
||||||
|
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
|
||||||
|
) hws.MiddlewareFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
|
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
|
||||||
if slices.Contains(auth.ignoredPaths, r.URL.Path) {
|
if globTest(r.URL.Path, auth.ignoredPaths) {
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
@@ -30,25 +40,70 @@ func (auth *Authenticator[T, TX]) authenticate() hws.MiddlewareFunc {
|
|||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := auth.beginTx(ctx)
|
tx, err := auth.beginTx(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &hws.HWSError{Message: "Unable to start transaction", StatusCode: http.StatusServiceUnavailable, Error: err}
|
return nil, &hws.HWSError{
|
||||||
|
Message: "Unable to start transaction",
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Error: errors.Wrap(err, "auth.beginTx"),
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}()
|
||||||
// Type assert to TX - safe because user's beginTx should return their TX type
|
// Type assert to TX - safe because user's beginTx should return their TX type
|
||||||
txTyped, ok := tx.(TX)
|
txTyped, ok := tx.(TX)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &hws.HWSError{Message: "Transaction type mismatch", StatusCode: http.StatusInternalServerError, Error: err}
|
return nil, &hws.HWSError{
|
||||||
|
Message: "Transaction type mismatch",
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Error: errors.Wrap(err, "TX type not ok"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
model, err := auth.getAuthenticatedUser(txTyped, w, r)
|
model, err := auth.getAuthenticatedUser(txTyped, w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
rberr := tx.Rollback()
|
||||||
|
if rberr != nil {
|
||||||
|
return nil, &hws.HWSError{
|
||||||
|
Message: "Failed rolling back after error",
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Error: errors.Wrap(err, "tx.Rollback"),
|
||||||
|
}
|
||||||
|
}
|
||||||
auth.logger.Debug().
|
auth.logger.Debug().
|
||||||
Str("remote_addr", r.RemoteAddr).
|
Str("remote_addr", r.RemoteAddr).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed to authenticate user")
|
Msg("Failed to authenticate user")
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
tx.Commit()
|
var check bool
|
||||||
|
if extraCheck != nil {
|
||||||
|
var err *hws.HWSError
|
||||||
|
check, err = extraCheck(ctx, model.model, txTyped, w, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, &hws.HWSError{
|
||||||
|
Message: "Failed to commit transaction",
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Error: errors.Wrap(err, "tx.Commit"),
|
||||||
|
}
|
||||||
|
}
|
||||||
authContext := setAuthenticatedModel(r.Context(), model)
|
authContext := setAuthenticatedModel(r.Context(), model)
|
||||||
newReq := r.WithContext(authContext)
|
newReq := r.WithContext(authContext)
|
||||||
|
if extraCheck == nil || check {
|
||||||
return newReq, nil
|
return newReq, nil
|
||||||
}
|
}
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func globTest(testPath string, globs []glob.Glob) bool {
|
||||||
|
for _, g := range globs {
|
||||||
|
if g.Match(testPath) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,9 +39,17 @@ type ContextLoader[T Model] func(ctx context.Context) T
|
|||||||
// }
|
// }
|
||||||
type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error)
|
type LoadFunc[T Model, TX DBTransaction] func(ctx context.Context, tx TX, id int) (T, error)
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
func (c contextKey) String() string {
|
||||||
|
return "hwsauth context key" + string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
var authenticatedModelContextKey = contextKey("authenticated-model")
|
||||||
|
|
||||||
// Return a new context with the user added in
|
// Return a new context with the user added in
|
||||||
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
|
func setAuthenticatedModel[T Model](ctx context.Context, m authenticatedModel[T]) context.Context {
|
||||||
return context.WithValue(ctx, "hwsauth context key authenticated-model", m)
|
return context.WithValue(ctx, authenticatedModelContextKey, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve a user from the given context. Returns nil if not set
|
// Retrieve a user from the given context. Returns nil if not set
|
||||||
@@ -53,7 +61,7 @@ func getAuthorizedModel[T Model](ctx context.Context) (model authenticatedModel[
|
|||||||
model = authenticatedModel[T]{}
|
model = authenticatedModel[T]{}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
model, cok := ctx.Value("hwsauth context key authenticated-model").(authenticatedModel[T])
|
model, cok := ctx.Value(authenticatedModelContextKey).(authenticatedModel[T])
|
||||||
if !cok {
|
if !cok {
|
||||||
return authenticatedModel[T]{}, false
|
return authenticatedModel[T]{}, false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,15 +19,12 @@ func (auth *Authenticator[T, TX]) LoginReq(next http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_, ok := getAuthorizedModel[T](r.Context())
|
_, ok := getAuthorizedModel[T](r.Context())
|
||||||
if !ok {
|
if !ok {
|
||||||
err := auth.server.ThrowError(w, r, hws.HWSError{
|
auth.server.ThrowError(w, r, hws.HWSError{
|
||||||
Error: errors.New("Login required"),
|
Error: errors.New("Login required"),
|
||||||
Message: "Please login to view this page",
|
Message: "Please login to view this page",
|
||||||
StatusCode: http.StatusUnauthorized,
|
StatusCode: http.StatusUnauthorized,
|
||||||
RenderErrorPage: true,
|
RenderErrorPage: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
auth.server.ThrowFatal(w, err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
@@ -66,15 +63,12 @@ func (auth *Authenticator[T, TX]) FreshReq(next http.Handler) http.Handler {
|
|||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
model, ok := getAuthorizedModel[T](r.Context())
|
model, ok := getAuthorizedModel[T](r.Context())
|
||||||
if !ok {
|
if !ok {
|
||||||
err := auth.server.ThrowError(w, r, hws.HWSError{
|
auth.server.ThrowError(w, r, hws.HWSError{
|
||||||
Error: errors.New("Login required"),
|
Error: errors.New("Login required"),
|
||||||
Message: "Please login to view this page",
|
Message: "Please login to view this page",
|
||||||
StatusCode: http.StatusUnauthorized,
|
StatusCode: http.StatusUnauthorized,
|
||||||
RenderErrorPage: true,
|
RenderErrorPage: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
auth.server.ThrowFatal(w, err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
isFresh := time.Now().Before(time.Unix(model.fresh, 0))
|
isFresh := time.Now().Before(time.Unix(model.fresh, 0))
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func (auth *Authenticator[T, TX]) RefreshAuthTokens(tx TX, w http.ResponseWriter
|
|||||||
rememberMe := map[string]bool{
|
rememberMe := map[string]bool{
|
||||||
"session": false,
|
"session": false,
|
||||||
"exp": true,
|
"exp": true,
|
||||||
}[aT.TTL]
|
}[rT.TTL]
|
||||||
// issue new tokens for the user
|
// issue new tokens for the user
|
||||||
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
|
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -55,14 +55,21 @@ func (auth *Authenticator[T, TX]) getTokens(
|
|||||||
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
||||||
// get the existing tokens from the cookies
|
// get the existing tokens from the cookies
|
||||||
atStr, rtStr := jwt.GetTokenCookies(r)
|
atStr, rtStr := jwt.GetTokenCookies(r)
|
||||||
aT, err := auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
|
var aT *jwt.AccessToken
|
||||||
|
var rT *jwt.RefreshToken
|
||||||
|
var err error
|
||||||
|
if atStr != "" {
|
||||||
|
aT, err = auth.tokenGenerator.ValidateAccess(jwt.DBTransaction(tx), atStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
|
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
|
||||||
}
|
}
|
||||||
rT, err := auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
|
}
|
||||||
|
if rtStr != "" {
|
||||||
|
rT, err = auth.tokenGenerator.ValidateRefresh(jwt.DBTransaction(tx), rtStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
|
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return aT, rT, nil
|
return aT, rT, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,13 +79,17 @@ func revokeTokenPair(
|
|||||||
aT *jwt.AccessToken,
|
aT *jwt.AccessToken,
|
||||||
rT *jwt.RefreshToken,
|
rT *jwt.RefreshToken,
|
||||||
) error {
|
) error {
|
||||||
|
if aT != nil {
|
||||||
err := aT.Revoke(tx)
|
err := aT.Revoke(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "aT.Revoke")
|
return errors.Wrap(err, "aT.Revoke")
|
||||||
}
|
}
|
||||||
err = rT.Revoke(tx)
|
}
|
||||||
|
if rT != nil {
|
||||||
|
err := rT.Revoke(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "rT.Revoke")
|
return errors.Wrap(err, "rT.Revoke")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user