Compare commits
3 Commits
hws/v0.1.0
...
env/v0.9.1
| Author | SHA1 | Date | |
|---|---|---|---|
| f3312f7aef | |||
| 61d519399f | |||
| b13b783d7e |
91
env/boolean_test.go
vendored
Normal file
91
env/boolean_test.go
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue bool
|
||||
expected bool
|
||||
shouldSet bool
|
||||
}{
|
||||
// Truthy values
|
||||
{"true lowercase", "TEST_BOOL", "true", false, true, true},
|
||||
{"true uppercase", "TEST_BOOL", "TRUE", false, true, true},
|
||||
{"true mixed case", "TEST_BOOL", "TrUe", false, true, true},
|
||||
{"t", "TEST_BOOL", "t", false, true, true},
|
||||
{"T", "TEST_BOOL", "T", false, true, true},
|
||||
{"yes", "TEST_BOOL", "yes", false, true, true},
|
||||
{"YES", "TEST_BOOL", "YES", false, true, true},
|
||||
{"y", "TEST_BOOL", "y", false, true, true},
|
||||
{"Y", "TEST_BOOL", "Y", false, true, true},
|
||||
{"on", "TEST_BOOL", "on", false, true, true},
|
||||
{"ON", "TEST_BOOL", "ON", false, true, true},
|
||||
{"1", "TEST_BOOL", "1", false, true, true},
|
||||
{"enable", "TEST_BOOL", "enable", false, true, true},
|
||||
{"ENABLE", "TEST_BOOL", "ENABLE", false, true, true},
|
||||
{"enabled", "TEST_BOOL", "enabled", false, true, true},
|
||||
{"ENABLED", "TEST_BOOL", "ENABLED", false, true, true},
|
||||
{"active", "TEST_BOOL", "active", false, true, true},
|
||||
{"ACTIVE", "TEST_BOOL", "ACTIVE", false, true, true},
|
||||
{"affirmative", "TEST_BOOL", "affirmative", false, true, true},
|
||||
{"AFFIRMATIVE", "TEST_BOOL", "AFFIRMATIVE", false, true, true},
|
||||
|
||||
// Falsy values
|
||||
{"false lowercase", "TEST_BOOL", "false", true, false, true},
|
||||
{"false uppercase", "TEST_BOOL", "FALSE", true, false, true},
|
||||
{"false mixed case", "TEST_BOOL", "FaLsE", true, false, true},
|
||||
{"f", "TEST_BOOL", "f", true, false, true},
|
||||
{"F", "TEST_BOOL", "F", true, false, true},
|
||||
{"no", "TEST_BOOL", "no", true, false, true},
|
||||
{"NO", "TEST_BOOL", "NO", true, false, true},
|
||||
{"n", "TEST_BOOL", "n", true, false, true},
|
||||
{"N", "TEST_BOOL", "N", true, false, true},
|
||||
{"off", "TEST_BOOL", "off", true, false, true},
|
||||
{"OFF", "TEST_BOOL", "OFF", true, false, true},
|
||||
{"0", "TEST_BOOL", "0", true, false, true},
|
||||
{"disable", "TEST_BOOL", "disable", true, false, true},
|
||||
{"DISABLE", "TEST_BOOL", "DISABLE", true, false, true},
|
||||
{"disabled", "TEST_BOOL", "disabled", true, false, true},
|
||||
{"DISABLED", "TEST_BOOL", "DISABLED", true, false, true},
|
||||
{"inactive", "TEST_BOOL", "inactive", true, false, true},
|
||||
{"INACTIVE", "TEST_BOOL", "INACTIVE", true, false, true},
|
||||
{"negative", "TEST_BOOL", "negative", true, false, true},
|
||||
{"NEGATIVE", "TEST_BOOL", "NEGATIVE", true, false, true},
|
||||
|
||||
// Whitespace handling
|
||||
{"true with spaces", "TEST_BOOL", " true ", false, true, true},
|
||||
{"false with spaces", "TEST_BOOL", " false ", true, false, true},
|
||||
|
||||
// Default values
|
||||
{"not set default true", "TEST_BOOL_NOTSET", "", true, true, false},
|
||||
{"not set default false", "TEST_BOOL_NOTSET", "", false, false, false},
|
||||
|
||||
// Invalid values should return default
|
||||
{"invalid value default true", "TEST_BOOL", "invalid", true, true, true},
|
||||
{"invalid value default false", "TEST_BOOL", "invalid", false, false, true},
|
||||
{"empty string default true", "TEST_BOOL", "", true, true, true},
|
||||
{"empty string default false", "TEST_BOOL", "", false, false, true},
|
||||
{"random text default true", "TEST_BOOL", "maybe", true, true, true},
|
||||
{"random text default false", "TEST_BOOL", "maybe", false, false, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := Bool(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Bool() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
42
env/duration_test.go
vendored
Normal file
42
env/duration_test.go
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue time.Duration
|
||||
expected time.Duration
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid positive duration", "TEST_DURATION", "100", 0, 100 * time.Nanosecond, true},
|
||||
{"valid zero", "TEST_DURATION", "0", 10 * time.Second, 0, true},
|
||||
{"large value", "TEST_DURATION", "1000000000", 0, 1 * time.Second, true},
|
||||
{"valid negative duration", "TEST_DURATION", "-100", 0, -100 * time.Nanosecond, true},
|
||||
{"not set", "TEST_DURATION_NOTSET", "", 5 * time.Minute, 5 * time.Minute, false},
|
||||
{"invalid value", "TEST_DURATION", "not_a_number", 30 * time.Second, 30 * time.Second, true},
|
||||
{"empty string", "TEST_DURATION", "", 1 * time.Hour, 1 * time.Hour, true},
|
||||
{"float value", "TEST_DURATION", "10.5", 2 * time.Second, 2 * time.Second, true},
|
||||
{"very large value", "TEST_DURATION", "9223372036854775807", 0, 9223372036854775807 * time.Nanosecond, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := Duration(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Duration() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
45
env/int.go
vendored
45
env/int.go
vendored
@@ -20,6 +20,51 @@ func Int(key string, defaultValue int) int {
|
||||
return intVal
|
||||
}
|
||||
|
||||
// Get an environment variable as an int8, specifying a default value if its
|
||||
// not set or can't be parsed properly into an int8
|
||||
func Int8(key string, defaultValue int8) int8 {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseInt(val, 10, 8)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return int8(intVal)
|
||||
}
|
||||
|
||||
// Get an environment variable as an int16, specifying a default value if its
|
||||
// not set or can't be parsed properly into an int16
|
||||
func Int16(key string, defaultValue int16) int16 {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseInt(val, 10, 16)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return int16(intVal)
|
||||
}
|
||||
|
||||
// Get an environment variable as an int32, specifying a default value if its
|
||||
// not set or can't be parsed properly into an int32
|
||||
func Int32(key string, defaultValue int32) int32 {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseInt(val, 10, 32)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return int32(intVal)
|
||||
}
|
||||
|
||||
// Get an environment variable as an int64, specifying a default value if its
|
||||
// not set or can't be parsed properly into an int64
|
||||
func Int64(key string, defaultValue int64) int64 {
|
||||
|
||||
170
env/int_test.go
vendored
Normal file
170
env/int_test.go
vendored
Normal file
@@ -0,0 +1,170 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue int
|
||||
expected int
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid positive int", "TEST_INT", "42", 0, 42, true},
|
||||
{"valid negative int", "TEST_INT", "-42", 0, -42, true},
|
||||
{"valid zero", "TEST_INT", "0", 10, 0, true},
|
||||
{"not set", "TEST_INT_NOTSET", "", 100, 100, false},
|
||||
{"invalid value", "TEST_INT", "not_a_number", 50, 50, true},
|
||||
{"empty string", "TEST_INT", "", 75, 75, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := Int(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Int() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt8(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue int8
|
||||
expected int8
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid positive int8", "TEST_INT8", "42", 0, 42, true},
|
||||
{"valid negative int8", "TEST_INT8", "-42", 0, -42, true},
|
||||
{"max int8", "TEST_INT8", "127", 0, 127, true},
|
||||
{"min int8", "TEST_INT8", "-128", 0, -128, true},
|
||||
{"overflow", "TEST_INT8", "128", 10, 10, true},
|
||||
{"not set", "TEST_INT8_NOTSET", "", 50, 50, false},
|
||||
{"invalid value", "TEST_INT8", "not_a_number", 25, 25, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := Int8(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Int8() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue int16
|
||||
expected int16
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid positive int16", "TEST_INT16", "1000", 0, 1000, true},
|
||||
{"valid negative int16", "TEST_INT16", "-1000", 0, -1000, true},
|
||||
{"max int16", "TEST_INT16", "32767", 0, 32767, true},
|
||||
{"min int16", "TEST_INT16", "-32768", 0, -32768, true},
|
||||
{"overflow", "TEST_INT16", "32768", 100, 100, true},
|
||||
{"not set", "TEST_INT16_NOTSET", "", 500, 500, false},
|
||||
{"invalid value", "TEST_INT16", "invalid", 250, 250, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := Int16(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Int16() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt32(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue int32
|
||||
expected int32
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid positive int32", "TEST_INT32", "100000", 0, 100000, true},
|
||||
{"valid negative int32", "TEST_INT32", "-100000", 0, -100000, true},
|
||||
{"max int32", "TEST_INT32", "2147483647", 0, 2147483647, true},
|
||||
{"min int32", "TEST_INT32", "-2147483648", 0, -2147483648, true},
|
||||
{"overflow", "TEST_INT32", "2147483648", 1000, 1000, true},
|
||||
{"not set", "TEST_INT32_NOTSET", "", 5000, 5000, false},
|
||||
{"invalid value", "TEST_INT32", "abc123", 2500, 2500, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := Int32(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Int32() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue int64
|
||||
expected int64
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid positive int64", "TEST_INT64", "1000000000", 0, 1000000000, true},
|
||||
{"valid negative int64", "TEST_INT64", "-1000000000", 0, -1000000000, true},
|
||||
{"max int64", "TEST_INT64", "9223372036854775807", 0, 9223372036854775807, true},
|
||||
{"min int64", "TEST_INT64", "-9223372036854775808", 0, -9223372036854775808, true},
|
||||
{"overflow", "TEST_INT64", "9223372036854775808", 10000, 10000, true},
|
||||
{"not set", "TEST_INT64_NOTSET", "", 50000, 50000, false},
|
||||
{"invalid value", "TEST_INT64", "not_valid", 25000, 25000, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := Int64(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Int64() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
43
env/string_test.go
vendored
Normal file
43
env/string_test.go
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue string
|
||||
expected string
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid string", "TEST_STRING", "hello", "default", "hello", true},
|
||||
{"empty string", "TEST_STRING", "", "default", "", true},
|
||||
{"string with spaces", "TEST_STRING", "hello world", "default", "hello world", true},
|
||||
{"string with special chars", "TEST_STRING", "test@123!$%", "default", "test@123!$%", true},
|
||||
{"multiline string", "TEST_STRING", "line1\nline2\nline3", "default", "line1\nline2\nline3", true},
|
||||
{"unicode string", "TEST_STRING", "Hello 世界 🌍", "default", "Hello 世界 🌍", true},
|
||||
{"not set", "TEST_STRING_NOTSET", "", "default_value", "default_value", false},
|
||||
{"numeric string", "TEST_STRING", "12345", "default", "12345", true},
|
||||
{"boolean string", "TEST_STRING", "true", "default", "true", true},
|
||||
{"path string", "TEST_STRING", "/usr/local/bin", "default", "/usr/local/bin", true},
|
||||
{"url string", "TEST_STRING", "https://example.com", "default", "https://example.com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := String(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("String() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
81
env/uint.go
vendored
Normal file
81
env/uint.go
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Get an environment variable as a uint, specifying a default value if its
|
||||
// not set or can't be parsed properly into a uint
|
||||
func UInt(key string, defaultValue uint) uint {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseUint(val, 10, 0)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return uint(intVal)
|
||||
}
|
||||
|
||||
// Get an environment variable as a uint8, specifying a default value if its
|
||||
// not set or can't be parsed properly into a uint8
|
||||
func UInt8(key string, defaultValue uint8) uint8 {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseUint(val, 10, 8)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return uint8(intVal)
|
||||
}
|
||||
|
||||
// Get an environment variable as a uint16, specifying a default value if its
|
||||
// not set or can't be parsed properly into a uint16
|
||||
func UInt16(key string, defaultValue uint16) uint16 {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseUint(val, 10, 16)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return uint16(intVal)
|
||||
}
|
||||
|
||||
// Get an environment variable as a uint32, specifying a default value if its
|
||||
// not set or can't be parsed properly into a uint32
|
||||
func UInt32(key string, defaultValue uint32) uint32 {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseUint(val, 10, 32)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return uint32(intVal)
|
||||
}
|
||||
|
||||
// Get an environment variable as a uint64, specifying a default value if its
|
||||
// not set or can't be parsed properly into a uint64
|
||||
func UInt64(key string, defaultValue uint64) uint64 {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseUint(val, 10, 64)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return intVal
|
||||
}
|
||||
171
env/uint_test.go
vendored
Normal file
171
env/uint_test.go
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue uint
|
||||
expected uint
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid uint", "TEST_UINT", "42", 0, 42, true},
|
||||
{"valid zero", "TEST_UINT", "0", 10, 0, true},
|
||||
{"large value", "TEST_UINT", "4294967295", 0, 4294967295, true},
|
||||
{"not set", "TEST_UINT_NOTSET", "", 100, 100, false},
|
||||
{"invalid value", "TEST_UINT", "not_a_number", 50, 50, true},
|
||||
{"negative value", "TEST_UINT", "-42", 75, 75, true},
|
||||
{"empty string", "TEST_UINT", "", 25, 25, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := UInt(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("UInt() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUInt8(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue uint8
|
||||
expected uint8
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid uint8", "TEST_UINT8", "42", 0, 42, true},
|
||||
{"valid zero", "TEST_UINT8", "0", 10, 0, true},
|
||||
{"max uint8", "TEST_UINT8", "255", 0, 255, true},
|
||||
{"overflow", "TEST_UINT8", "256", 10, 10, true},
|
||||
{"not set", "TEST_UINT8_NOTSET", "", 50, 50, false},
|
||||
{"invalid value", "TEST_UINT8", "abc", 25, 25, true},
|
||||
{"negative value", "TEST_UINT8", "-1", 30, 30, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := UInt8(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("UInt8() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUInt16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue uint16
|
||||
expected uint16
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid uint16", "TEST_UINT16", "1000", 0, 1000, true},
|
||||
{"valid zero", "TEST_UINT16", "0", 100, 0, true},
|
||||
{"max uint16", "TEST_UINT16", "65535", 0, 65535, true},
|
||||
{"overflow", "TEST_UINT16", "65536", 100, 100, true},
|
||||
{"not set", "TEST_UINT16_NOTSET", "", 500, 500, false},
|
||||
{"invalid value", "TEST_UINT16", "invalid", 250, 250, true},
|
||||
{"negative value", "TEST_UINT16", "-100", 300, 300, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := UInt16(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("UInt16() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUInt32(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue uint32
|
||||
expected uint32
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid uint32", "TEST_UINT32", "100000", 0, 100000, true},
|
||||
{"valid zero", "TEST_UINT32", "0", 1000, 0, true},
|
||||
{"max uint32", "TEST_UINT32", "4294967295", 0, 4294967295, true},
|
||||
{"overflow", "TEST_UINT32", "4294967296", 1000, 1000, true},
|
||||
{"not set", "TEST_UINT32_NOTSET", "", 5000, 5000, false},
|
||||
{"invalid value", "TEST_UINT32", "xyz", 2500, 2500, true},
|
||||
{"negative value", "TEST_UINT32", "-1000", 3000, 3000, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := UInt32(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("UInt32() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
defaultValue uint64
|
||||
expected uint64
|
||||
shouldSet bool
|
||||
}{
|
||||
{"valid uint64", "TEST_UINT64", "1000000000", 0, 1000000000, true},
|
||||
{"valid zero", "TEST_UINT64", "0", 10000, 0, true},
|
||||
{"max uint64", "TEST_UINT64", "18446744073709551615", 0, 18446744073709551615, true},
|
||||
{"overflow", "TEST_UINT64", "18446744073709551616", 10000, 10000, true},
|
||||
{"not set", "TEST_UINT64_NOTSET", "", 50000, 50000, false},
|
||||
{"invalid value", "TEST_UINT64", "not_valid", 25000, 25000, true},
|
||||
{"negative value", "TEST_UINT64", "-5000", 30000, 30000, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.shouldSet {
|
||||
os.Setenv(tt.key, tt.value)
|
||||
defer os.Unsetenv(tt.key)
|
||||
}
|
||||
|
||||
result := UInt64(tt.key, tt.defaultValue)
|
||||
if result != tt.expected {
|
||||
t.Errorf("UInt64() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
54
hwsauth/authenticate.go
Normal file
54
hwsauth/authenticate.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Check the cookies for token strings and attempt to authenticate them
|
||||
func (auth *Authenticator[T]) getAuthenticatedUser(
|
||||
tx *sql.Tx,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) (*authenticatedModel[T], error) {
|
||||
// Get token strings from cookies
|
||||
atStr, rtStr := jwt.GetTokenCookies(r)
|
||||
if atStr == "" && rtStr == "" {
|
||||
return nil, errors.New("No token strings provided")
|
||||
}
|
||||
// Attempt to parse the access token
|
||||
aT, err := auth.tokenGenerator.ValidateAccess(tx, atStr)
|
||||
if err != nil {
|
||||
// Access token invalid, attempt to parse refresh token
|
||||
rT, err := auth.tokenGenerator.ValidateRefresh(tx, rtStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "auth.tokenGenerator.ValidateRefresh")
|
||||
}
|
||||
// Refresh token valid, attempt to get a new token pair
|
||||
model, err := auth.refreshAuthTokens(tx, w, r, rT)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "auth.refreshAuthTokens")
|
||||
}
|
||||
// New token pair sent, return the authorized user
|
||||
authUser := authenticatedModel[T]{
|
||||
model: model,
|
||||
fresh: time.Now().Unix(),
|
||||
}
|
||||
return &authUser, nil
|
||||
}
|
||||
|
||||
// Access token valid
|
||||
model, err := auth.load(tx, aT.SUB)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "auth.load")
|
||||
}
|
||||
authUser := authenticatedModel[T]{
|
||||
model: model,
|
||||
fresh: aT.Fresh,
|
||||
}
|
||||
return &authUser, nil
|
||||
}
|
||||
93
hwsauth/authenticator.go
Normal file
93
hwsauth/authenticator.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type Authenticator[T Model] struct {
|
||||
tokenGenerator *jwt.TokenGenerator
|
||||
load LoadFunc[T]
|
||||
conn *sql.DB
|
||||
ignoredPaths []string
|
||||
logger *zerolog.Logger
|
||||
server *hws.Server
|
||||
errorPage hws.ErrorPage
|
||||
SSL bool // Use SSL for JWT tokens. Default true
|
||||
TrustedHost string // TrustedHost to use for SSL verification
|
||||
SecretKey string // Secret key to use for JWT tokens
|
||||
AccessTokenExpiry int64 // Expiry time for Access tokens in minutes. Default 5
|
||||
RefreshTokenExpiry int64 // Expiry time for Refresh tokens in minutes. Default 1440 (1 day)
|
||||
TokenFreshTime int64 // Expiry time of token freshness. Default 5 minutes
|
||||
LandingPage string // Path of the desired landing page for logged in users
|
||||
}
|
||||
|
||||
// NewAuthenticator creates and returns a new Authenticator using the provided configuration.
|
||||
// All expiry times should be provided in minutes.
|
||||
// trustedHost and secretKey strings must be provided.
|
||||
func NewAuthenticator[T Model](
|
||||
load LoadFunc[T],
|
||||
server *hws.Server,
|
||||
conn *sql.DB,
|
||||
logger *zerolog.Logger,
|
||||
errorPage hws.ErrorPage,
|
||||
) (*Authenticator[T], error) {
|
||||
if load == nil {
|
||||
return nil, errors.New("No function to load model supplied")
|
||||
}
|
||||
if server == nil {
|
||||
return nil, errors.New("No hws.Server provided")
|
||||
}
|
||||
if conn == nil {
|
||||
return nil, errors.New("No database connection supplied")
|
||||
}
|
||||
if logger == nil {
|
||||
return nil, errors.New("No logger provided")
|
||||
}
|
||||
if errorPage == nil {
|
||||
return nil, errors.New("No ErrorPage provided")
|
||||
}
|
||||
auth := Authenticator[T]{
|
||||
load: load,
|
||||
server: server,
|
||||
conn: conn,
|
||||
logger: logger,
|
||||
errorPage: errorPage,
|
||||
AccessTokenExpiry: 5,
|
||||
RefreshTokenExpiry: 1440,
|
||||
TokenFreshTime: 5,
|
||||
SSL: true,
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
|
||||
// Initialise finishes the setup and prepares the Authenticator for use.
|
||||
// Any custom configuration must be set before Initialise is called
|
||||
func (auth *Authenticator[T]) Initialise() error {
|
||||
if auth.TrustedHost == "" {
|
||||
return errors.New("Trusted host must be provided")
|
||||
}
|
||||
if auth.SecretKey == "" {
|
||||
return errors.New("Secret key cannot be blank")
|
||||
}
|
||||
if auth.LandingPage == "" {
|
||||
return errors.New("No landing page specified")
|
||||
}
|
||||
tokenGen, err := jwt.CreateGenerator(
|
||||
auth.AccessTokenExpiry,
|
||||
auth.RefreshTokenExpiry,
|
||||
auth.TokenFreshTime,
|
||||
auth.TrustedHost,
|
||||
auth.SecretKey,
|
||||
auth.conn,
|
||||
)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.CreateGenerator")
|
||||
}
|
||||
auth.tokenGenerator = tokenGen
|
||||
return nil
|
||||
}
|
||||
19
hwsauth/go.mod
Normal file
19
hwsauth/go.mod
Normal file
@@ -0,0 +1,19 @@
|
||||
module git.haelnorr.com/h/golib/hwsauth
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||
git.haelnorr.com/h/golib/jwt v0.9.2
|
||||
git.haelnorr.com/h/golib/hws v0.1.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/rs/zerolog v1.34.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
)
|
||||
36
hwsauth/go.sum
Normal file
36
hwsauth/go.sum
Normal file
@@ -0,0 +1,36 @@
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDOV/AuWs=
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
|
||||
git.haelnorr.com/h/golib/hws v0.1.0 h1:+0eNq1uGWrGfbS5AgHeGoGDjVfCWuaVu+1wBxgPqyOY=
|
||||
git.haelnorr.com/h/golib/hws v0.1.0/go.mod h1:b2pbkMaebzmck9TxqGBGzTJPEcB5TWcEHwFknLE7dqM=
|
||||
git.haelnorr.com/h/golib/jwt v0.9.2 h1:l1Ow7DPGACAU54CnMP/NlZjdc4nRD1wr3xZ8a7taRvU=
|
||||
git.haelnorr.com/h/golib/jwt v0.9.2/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
22
hwsauth/ignorepaths.go
Normal file
22
hwsauth/ignorepaths.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (auth *Authenticator[T]) IgnorePaths(paths ...string) error {
|
||||
for _, path := range paths {
|
||||
u, err := url.Parse(path)
|
||||
valid := err == nil &&
|
||||
u.Scheme == "" &&
|
||||
u.Host == "" &&
|
||||
u.RawQuery == "" &&
|
||||
u.Fragment == ""
|
||||
if !valid {
|
||||
return fmt.Errorf("Invalid path: '%s'", path)
|
||||
}
|
||||
}
|
||||
auth.ignoredPaths = paths
|
||||
return nil
|
||||
}
|
||||
22
hwsauth/login.go
Normal file
22
hwsauth/login.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func (auth *Authenticator[T]) Login(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
model T,
|
||||
rememberMe bool,
|
||||
) error {
|
||||
|
||||
err := jwt.SetTokenCookies(w, r, auth.tokenGenerator, model.ID(), true, rememberMe, auth.SSL)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.SetTokenCookies")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
27
hwsauth/logout.go
Normal file
27
hwsauth/logout.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func (auth *Authenticator[T]) Logout(tx *sql.Tx, w http.ResponseWriter, r *http.Request) error {
|
||||
aT, rT, err := auth.getTokens(tx, r)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "auth.getTokens")
|
||||
}
|
||||
err = aT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "aT.Revoke")
|
||||
}
|
||||
err = rT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rT.Revoke")
|
||||
}
|
||||
cookies.DeleteCookie(w, "access", "/")
|
||||
cookies.DeleteCookie(w, "refresh", "/")
|
||||
return nil
|
||||
}
|
||||
42
hwsauth/middleware.go
Normal file
42
hwsauth/middleware.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (auth *Authenticator[T]) Authenticate() hws.Middleware {
|
||||
return auth.server.NewMiddleware(auth.authenticate())
|
||||
}
|
||||
|
||||
func (auth *Authenticator[T]) authenticate() hws.MiddlewareFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
|
||||
if slices.Contains(auth.ignoredPaths, r.URL.Path) {
|
||||
return r, nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := auth.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, hws.NewError(http.StatusServiceUnavailable, "Unable to start transaction", err)
|
||||
}
|
||||
model, err := auth.getAuthenticatedUser(tx, w, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
auth.logger.Debug().
|
||||
Str("remote_addr", r.RemoteAddr).
|
||||
Err(err).
|
||||
Msg("Failed to authenticate user")
|
||||
return r, nil
|
||||
}
|
||||
tx.Commit()
|
||||
authContext := setAuthenticatedModel(r.Context(), model)
|
||||
newReq := r.WithContext(authContext)
|
||||
return newReq, nil
|
||||
}
|
||||
}
|
||||
46
hwsauth/model.go
Normal file
46
hwsauth/model.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type authenticatedModel[T Model] struct {
|
||||
model T
|
||||
fresh int64
|
||||
}
|
||||
|
||||
func getNil[T Model]() T {
|
||||
var result T
|
||||
return result
|
||||
}
|
||||
|
||||
type Model interface {
|
||||
ID() int
|
||||
}
|
||||
|
||||
type ContextLoader[T Model] func(ctx context.Context) T
|
||||
|
||||
type LoadFunc[T Model] func(tx *sql.Tx, id int) (T, error)
|
||||
|
||||
// Return a new context with the user added in
|
||||
func setAuthenticatedModel[T Model](ctx context.Context, m *authenticatedModel[T]) context.Context {
|
||||
return context.WithValue(ctx, "hwsauth context key authenticated-model", m)
|
||||
}
|
||||
|
||||
// Retrieve a user from the given context. Returns nil if not set
|
||||
func getAuthorizedModel[T Model](ctx context.Context) *authenticatedModel[T] {
|
||||
model, ok := ctx.Value("hwsauth context key authenticated-model").(*authenticatedModel[T])
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
func (auth *Authenticator[T]) CurrentModel(ctx context.Context) T {
|
||||
model := getAuthorizedModel[T](ctx)
|
||||
if model == nil {
|
||||
return getNil[T]()
|
||||
}
|
||||
return model.model
|
||||
}
|
||||
43
hwsauth/protectpage.go
Normal file
43
hwsauth/protectpage.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Checks if the model is set in the context and shows 401 page if not logged in
|
||||
func (auth *Authenticator[T]) LoginReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
model := getAuthorizedModel[T](r.Context())
|
||||
if model == nil {
|
||||
auth.errorPage(http.StatusUnauthorized, w, r)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Checks if the model is set in the context and redirects them to the landing page if
|
||||
// they are logged in
|
||||
func (auth *Authenticator[T]) LogoutReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
model := getAuthorizedModel[T](r.Context())
|
||||
if model != nil {
|
||||
http.Redirect(w, r, auth.LandingPage, http.StatusFound)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (auth *Authenticator[T]) FreshReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
model := getAuthorizedModel[T](r.Context())
|
||||
isFresh := time.Now().Before(time.Unix(model.fresh, 0))
|
||||
if !isFresh {
|
||||
w.WriteHeader(444)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
66
hwsauth/reauthenticate.go
Normal file
66
hwsauth/reauthenticate.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func (auth *Authenticator[T]) RefreshAuthTokens(tx *sql.Tx, w http.ResponseWriter, r *http.Request) error {
|
||||
aT, rT, err := auth.getTokens(tx, r)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "getTokens")
|
||||
}
|
||||
rememberMe := map[string]bool{
|
||||
"session": false,
|
||||
"exp": true,
|
||||
}[aT.TTL]
|
||||
// issue new tokens for the user
|
||||
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.SetTokenCookies")
|
||||
}
|
||||
err = revokeTokenPair(tx, aT, rT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "revokeTokenPair")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the tokens from the request
|
||||
func (auth *Authenticator[T]) getTokens(
|
||||
tx *sql.Tx,
|
||||
r *http.Request,
|
||||
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
||||
// get the existing tokens from the cookies
|
||||
atStr, rtStr := jwt.GetTokenCookies(r)
|
||||
aT, err := auth.tokenGenerator.ValidateAccess(tx, atStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
|
||||
}
|
||||
rT, err := auth.tokenGenerator.ValidateRefresh(tx, rtStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
|
||||
}
|
||||
return aT, rT, nil
|
||||
}
|
||||
|
||||
// Revoke the given token pair
|
||||
func revokeTokenPair(
|
||||
tx *sql.Tx,
|
||||
aT *jwt.AccessToken,
|
||||
rT *jwt.RefreshToken,
|
||||
) error {
|
||||
err := aT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "aT.Revoke")
|
||||
}
|
||||
err = rT.Revoke(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rT.Revoke")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
40
hwsauth/refreshtokens.go
Normal file
40
hwsauth/refreshtokens.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Attempt to use a valid refresh token to generate a new token pair
|
||||
func (auth *Authenticator[T]) refreshAuthTokens(
|
||||
tx *sql.Tx,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
rT *jwt.RefreshToken,
|
||||
) (T, error) {
|
||||
model, err := auth.load(tx, rT.SUB)
|
||||
if err != nil {
|
||||
return getNil[T](), errors.Wrap(err, "auth.load")
|
||||
}
|
||||
|
||||
rememberMe := map[string]bool{
|
||||
"session": false,
|
||||
"exp": true,
|
||||
}[rT.TTL]
|
||||
|
||||
// Set fresh to true because new tokens coming from refresh request
|
||||
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, model.ID(), false, rememberMe, auth.SSL)
|
||||
if err != nil {
|
||||
return getNil[T](), errors.Wrap(err, "jwt.SetTokenCookies")
|
||||
}
|
||||
// New tokens sent, revoke the old tokens
|
||||
err = rT.Revoke(tx)
|
||||
if err != nil {
|
||||
return getNil[T](), errors.Wrap(err, "rT.Revoke")
|
||||
}
|
||||
// Return the authorized user
|
||||
return model, nil
|
||||
}
|
||||
Reference in New Issue
Block a user