Compare commits
11 Commits
hlog/v0.9.
...
hlog/v0.9.
| Author | SHA1 | Date | |
|---|---|---|---|
| 557e9812e6 | |||
| f3312f7aef | |||
| 61d519399f | |||
| b13b783d7e | |||
| 14eec74683 | |||
| ade3fa0454 | |||
| 516be905a9 | |||
| 6e632267ea | |||
| 05aad5f11b | |||
| c4574e32c7 | |||
| c466cd3163 |
19
cookies/delete.go
Normal file
19
cookies/delete.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package cookies
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tell the browser to delete the cookie matching the name provided
|
||||||
|
// Path must match the original set cookie for it to delete
|
||||||
|
func DeleteCookie(w http.ResponseWriter, name string, path string) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: "",
|
||||||
|
Path: path,
|
||||||
|
Expires: time.Unix(0, 0), // Expire in the past
|
||||||
|
MaxAge: -1, // Immediately expire
|
||||||
|
HttpOnly: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
3
cookies/go.mod
Normal file
3
cookies/go.mod
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
module git.haelnorr.com/h/golib/cookies
|
||||||
|
|
||||||
|
go 1.25.5
|
||||||
36
cookies/pagefrom.go
Normal file
36
cookies/pagefrom.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package cookies
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
|
||||||
|
func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
|
||||||
|
pageFromCookie, err := r.Cookie("pagefrom")
|
||||||
|
if err != nil {
|
||||||
|
return "/"
|
||||||
|
}
|
||||||
|
pageFrom := pageFromCookie.Value
|
||||||
|
DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
|
||||||
|
return pageFrom
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the referer of the request, and if it matches the trustedHost, set
|
||||||
|
// the "pagefrom" cookie as the Path of the referer
|
||||||
|
func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
|
||||||
|
referer := r.Referer()
|
||||||
|
parsedURL, err := url.Parse(referer)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var pageFrom string
|
||||||
|
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
|
||||||
|
pageFrom = "/"
|
||||||
|
} else if parsedURL.Path == "/login" || parsedURL.Path == "/register" {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
pageFrom = parsedURL.Path
|
||||||
|
}
|
||||||
|
SetCookie(w, "pagefrom", "/", pageFrom, 0)
|
||||||
|
}
|
||||||
23
cookies/set.go
Normal file
23
cookies/set.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package cookies
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Set a cookie with the given name, path and value. maxAge directly relates
|
||||||
|
// to cookie MaxAge (0 for no max age, >0 for TTL in seconds)
|
||||||
|
func SetCookie(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
name string,
|
||||||
|
path string,
|
||||||
|
value string,
|
||||||
|
maxAge int,
|
||||||
|
) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: path,
|
||||||
|
HttpOnly: true,
|
||||||
|
MaxAge: maxAge,
|
||||||
|
})
|
||||||
|
}
|
||||||
35
env/boolean.go
vendored
Normal file
35
env/boolean.go
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get an environment variable as a boolean, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into a bool
|
||||||
|
func Bool(key string, defaultValue bool) bool {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
truthy := map[string]bool{
|
||||||
|
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
|
||||||
|
"enable": true, "enabled": true, "active": true, "affirmative": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
falsy := map[string]bool{
|
||||||
|
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
|
||||||
|
"disable": false, "disabled": false, "inactive": false, "negative": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := strings.TrimSpace(strings.ToLower(val))
|
||||||
|
|
||||||
|
if val, ok := truthy[normalized]; ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
if val, ok := falsy[normalized]; ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
91
env/boolean_test.go
vendored
Normal file
91
env/boolean_test.go
vendored
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBool(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue bool
|
||||||
|
expected bool
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
// Truthy values
|
||||||
|
{"true lowercase", "TEST_BOOL", "true", false, true, true},
|
||||||
|
{"true uppercase", "TEST_BOOL", "TRUE", false, true, true},
|
||||||
|
{"true mixed case", "TEST_BOOL", "TrUe", false, true, true},
|
||||||
|
{"t", "TEST_BOOL", "t", false, true, true},
|
||||||
|
{"T", "TEST_BOOL", "T", false, true, true},
|
||||||
|
{"yes", "TEST_BOOL", "yes", false, true, true},
|
||||||
|
{"YES", "TEST_BOOL", "YES", false, true, true},
|
||||||
|
{"y", "TEST_BOOL", "y", false, true, true},
|
||||||
|
{"Y", "TEST_BOOL", "Y", false, true, true},
|
||||||
|
{"on", "TEST_BOOL", "on", false, true, true},
|
||||||
|
{"ON", "TEST_BOOL", "ON", false, true, true},
|
||||||
|
{"1", "TEST_BOOL", "1", false, true, true},
|
||||||
|
{"enable", "TEST_BOOL", "enable", false, true, true},
|
||||||
|
{"ENABLE", "TEST_BOOL", "ENABLE", false, true, true},
|
||||||
|
{"enabled", "TEST_BOOL", "enabled", false, true, true},
|
||||||
|
{"ENABLED", "TEST_BOOL", "ENABLED", false, true, true},
|
||||||
|
{"active", "TEST_BOOL", "active", false, true, true},
|
||||||
|
{"ACTIVE", "TEST_BOOL", "ACTIVE", false, true, true},
|
||||||
|
{"affirmative", "TEST_BOOL", "affirmative", false, true, true},
|
||||||
|
{"AFFIRMATIVE", "TEST_BOOL", "AFFIRMATIVE", false, true, true},
|
||||||
|
|
||||||
|
// Falsy values
|
||||||
|
{"false lowercase", "TEST_BOOL", "false", true, false, true},
|
||||||
|
{"false uppercase", "TEST_BOOL", "FALSE", true, false, true},
|
||||||
|
{"false mixed case", "TEST_BOOL", "FaLsE", true, false, true},
|
||||||
|
{"f", "TEST_BOOL", "f", true, false, true},
|
||||||
|
{"F", "TEST_BOOL", "F", true, false, true},
|
||||||
|
{"no", "TEST_BOOL", "no", true, false, true},
|
||||||
|
{"NO", "TEST_BOOL", "NO", true, false, true},
|
||||||
|
{"n", "TEST_BOOL", "n", true, false, true},
|
||||||
|
{"N", "TEST_BOOL", "N", true, false, true},
|
||||||
|
{"off", "TEST_BOOL", "off", true, false, true},
|
||||||
|
{"OFF", "TEST_BOOL", "OFF", true, false, true},
|
||||||
|
{"0", "TEST_BOOL", "0", true, false, true},
|
||||||
|
{"disable", "TEST_BOOL", "disable", true, false, true},
|
||||||
|
{"DISABLE", "TEST_BOOL", "DISABLE", true, false, true},
|
||||||
|
{"disabled", "TEST_BOOL", "disabled", true, false, true},
|
||||||
|
{"DISABLED", "TEST_BOOL", "DISABLED", true, false, true},
|
||||||
|
{"inactive", "TEST_BOOL", "inactive", true, false, true},
|
||||||
|
{"INACTIVE", "TEST_BOOL", "INACTIVE", true, false, true},
|
||||||
|
{"negative", "TEST_BOOL", "negative", true, false, true},
|
||||||
|
{"NEGATIVE", "TEST_BOOL", "NEGATIVE", true, false, true},
|
||||||
|
|
||||||
|
// Whitespace handling
|
||||||
|
{"true with spaces", "TEST_BOOL", " true ", false, true, true},
|
||||||
|
{"false with spaces", "TEST_BOOL", " false ", true, false, true},
|
||||||
|
|
||||||
|
// Default values
|
||||||
|
{"not set default true", "TEST_BOOL_NOTSET", "", true, true, false},
|
||||||
|
{"not set default false", "TEST_BOOL_NOTSET", "", false, false, false},
|
||||||
|
|
||||||
|
// Invalid values should return default
|
||||||
|
{"invalid value default true", "TEST_BOOL", "invalid", true, true, true},
|
||||||
|
{"invalid value default false", "TEST_BOOL", "invalid", false, false, true},
|
||||||
|
{"empty string default true", "TEST_BOOL", "", true, true, true},
|
||||||
|
{"empty string default false", "TEST_BOOL", "", false, false, true},
|
||||||
|
{"random text default true", "TEST_BOOL", "maybe", true, true, true},
|
||||||
|
{"random text default false", "TEST_BOOL", "maybe", false, false, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Bool(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Bool() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
23
env/duration.go
vendored
Normal file
23
env/duration.go
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get an environment variable as a time.Duration, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly
|
||||||
|
func Duration(key string, defaultValue time.Duration) time.Duration {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return time.Duration(defaultValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.Atoi(val)
|
||||||
|
if err != nil {
|
||||||
|
return time.Duration(defaultValue)
|
||||||
|
}
|
||||||
|
return time.Duration(intVal)
|
||||||
|
|
||||||
|
}
|
||||||
42
env/duration_test.go
vendored
Normal file
42
env/duration_test.go
vendored
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDuration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue time.Duration
|
||||||
|
expected time.Duration
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid positive duration", "TEST_DURATION", "100", 0, 100 * time.Nanosecond, true},
|
||||||
|
{"valid zero", "TEST_DURATION", "0", 10 * time.Second, 0, true},
|
||||||
|
{"large value", "TEST_DURATION", "1000000000", 0, 1 * time.Second, true},
|
||||||
|
{"valid negative duration", "TEST_DURATION", "-100", 0, -100 * time.Nanosecond, true},
|
||||||
|
{"not set", "TEST_DURATION_NOTSET", "", 5 * time.Minute, 5 * time.Minute, false},
|
||||||
|
{"invalid value", "TEST_DURATION", "not_a_number", 30 * time.Second, 30 * time.Second, true},
|
||||||
|
{"empty string", "TEST_DURATION", "", 1 * time.Hour, 1 * time.Hour, true},
|
||||||
|
{"float value", "TEST_DURATION", "10.5", 2 * time.Second, 2 * time.Second, true},
|
||||||
|
{"very large value", "TEST_DURATION", "9223372036854775807", 0, 9223372036854775807 * time.Nanosecond, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Duration(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Duration() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
3
env/go.mod
vendored
Normal file
3
env/go.mod
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
module git.haelnorr.com/h/golib/env
|
||||||
|
|
||||||
|
go 1.25.5
|
||||||
82
env/int.go
vendored
Normal file
82
env/int.go
vendored
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get an environment variable as an int, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into an int
|
||||||
|
func Int(key string, defaultValue int) int {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.Atoi(val)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return intVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get an environment variable as an int8, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into an int8
|
||||||
|
func Int8(key string, defaultValue int8) int8 {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseInt(val, 10, 8)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return int8(intVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get an environment variable as an int16, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into an int16
|
||||||
|
func Int16(key string, defaultValue int16) int16 {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseInt(val, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return int16(intVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get an environment variable as an int32, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into an int32
|
||||||
|
func Int32(key string, defaultValue int32) int32 {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseInt(val, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return int32(intVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get an environment variable as an int64, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into an int64
|
||||||
|
func Int64(key string, defaultValue int64) int64 {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseInt(val, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return intVal
|
||||||
|
|
||||||
|
}
|
||||||
170
env/int_test.go
vendored
Normal file
170
env/int_test.go
vendored
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue int
|
||||||
|
expected int
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid positive int", "TEST_INT", "42", 0, 42, true},
|
||||||
|
{"valid negative int", "TEST_INT", "-42", 0, -42, true},
|
||||||
|
{"valid zero", "TEST_INT", "0", 10, 0, true},
|
||||||
|
{"not set", "TEST_INT_NOTSET", "", 100, 100, false},
|
||||||
|
{"invalid value", "TEST_INT", "not_a_number", 50, 50, true},
|
||||||
|
{"empty string", "TEST_INT", "", 75, 75, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Int(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Int() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInt8(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue int8
|
||||||
|
expected int8
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid positive int8", "TEST_INT8", "42", 0, 42, true},
|
||||||
|
{"valid negative int8", "TEST_INT8", "-42", 0, -42, true},
|
||||||
|
{"max int8", "TEST_INT8", "127", 0, 127, true},
|
||||||
|
{"min int8", "TEST_INT8", "-128", 0, -128, true},
|
||||||
|
{"overflow", "TEST_INT8", "128", 10, 10, true},
|
||||||
|
{"not set", "TEST_INT8_NOTSET", "", 50, 50, false},
|
||||||
|
{"invalid value", "TEST_INT8", "not_a_number", 25, 25, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Int8(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Int8() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInt16(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue int16
|
||||||
|
expected int16
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid positive int16", "TEST_INT16", "1000", 0, 1000, true},
|
||||||
|
{"valid negative int16", "TEST_INT16", "-1000", 0, -1000, true},
|
||||||
|
{"max int16", "TEST_INT16", "32767", 0, 32767, true},
|
||||||
|
{"min int16", "TEST_INT16", "-32768", 0, -32768, true},
|
||||||
|
{"overflow", "TEST_INT16", "32768", 100, 100, true},
|
||||||
|
{"not set", "TEST_INT16_NOTSET", "", 500, 500, false},
|
||||||
|
{"invalid value", "TEST_INT16", "invalid", 250, 250, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Int16(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Int16() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInt32(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue int32
|
||||||
|
expected int32
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid positive int32", "TEST_INT32", "100000", 0, 100000, true},
|
||||||
|
{"valid negative int32", "TEST_INT32", "-100000", 0, -100000, true},
|
||||||
|
{"max int32", "TEST_INT32", "2147483647", 0, 2147483647, true},
|
||||||
|
{"min int32", "TEST_INT32", "-2147483648", 0, -2147483648, true},
|
||||||
|
{"overflow", "TEST_INT32", "2147483648", 1000, 1000, true},
|
||||||
|
{"not set", "TEST_INT32_NOTSET", "", 5000, 5000, false},
|
||||||
|
{"invalid value", "TEST_INT32", "abc123", 2500, 2500, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Int32(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Int32() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInt64(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue int64
|
||||||
|
expected int64
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid positive int64", "TEST_INT64", "1000000000", 0, 1000000000, true},
|
||||||
|
{"valid negative int64", "TEST_INT64", "-1000000000", 0, -1000000000, true},
|
||||||
|
{"max int64", "TEST_INT64", "9223372036854775807", 0, 9223372036854775807, true},
|
||||||
|
{"min int64", "TEST_INT64", "-9223372036854775808", 0, -9223372036854775808, true},
|
||||||
|
{"overflow", "TEST_INT64", "9223372036854775808", 10000, 10000, true},
|
||||||
|
{"not set", "TEST_INT64_NOTSET", "", 50000, 50000, false},
|
||||||
|
{"invalid value", "TEST_INT64", "not_valid", 25000, 25000, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Int64(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Int64() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
14
env/string.go
vendored
Normal file
14
env/string.go
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get an environment variable, specifying a default value if its not set
|
||||||
|
func String(key string, defaultValue string) string {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
43
env/string_test.go
vendored
Normal file
43
env/string_test.go
vendored
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue string
|
||||||
|
expected string
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid string", "TEST_STRING", "hello", "default", "hello", true},
|
||||||
|
{"empty string", "TEST_STRING", "", "default", "", true},
|
||||||
|
{"string with spaces", "TEST_STRING", "hello world", "default", "hello world", true},
|
||||||
|
{"string with special chars", "TEST_STRING", "test@123!$%", "default", "test@123!$%", true},
|
||||||
|
{"multiline string", "TEST_STRING", "line1\nline2\nline3", "default", "line1\nline2\nline3", true},
|
||||||
|
{"unicode string", "TEST_STRING", "Hello 世界 🌍", "default", "Hello 世界 🌍", true},
|
||||||
|
{"not set", "TEST_STRING_NOTSET", "", "default_value", "default_value", false},
|
||||||
|
{"numeric string", "TEST_STRING", "12345", "default", "12345", true},
|
||||||
|
{"boolean string", "TEST_STRING", "true", "default", "true", true},
|
||||||
|
{"path string", "TEST_STRING", "/usr/local/bin", "default", "/usr/local/bin", true},
|
||||||
|
{"url string", "TEST_STRING", "https://example.com", "default", "https://example.com", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := String(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("String() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
81
env/uint.go
vendored
Normal file
81
env/uint.go
vendored
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get an environment variable as a uint, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into a uint
|
||||||
|
func UInt(key string, defaultValue uint) uint {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseUint(val, 10, 0)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return uint(intVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get an environment variable as a uint8, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into a uint8
|
||||||
|
func UInt8(key string, defaultValue uint8) uint8 {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseUint(val, 10, 8)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return uint8(intVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get an environment variable as a uint16, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into a uint16
|
||||||
|
func UInt16(key string, defaultValue uint16) uint16 {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseUint(val, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return uint16(intVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get an environment variable as a uint32, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into a uint32
|
||||||
|
func UInt32(key string, defaultValue uint32) uint32 {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseUint(val, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return uint32(intVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get an environment variable as a uint64, specifying a default value if its
|
||||||
|
// not set or can't be parsed properly into a uint64
|
||||||
|
func UInt64(key string, defaultValue uint64) uint64 {
|
||||||
|
val, exists := os.LookupEnv(key)
|
||||||
|
if !exists {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseUint(val, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return intVal
|
||||||
|
}
|
||||||
171
env/uint_test.go
vendored
Normal file
171
env/uint_test.go
vendored
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUInt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue uint
|
||||||
|
expected uint
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid uint", "TEST_UINT", "42", 0, 42, true},
|
||||||
|
{"valid zero", "TEST_UINT", "0", 10, 0, true},
|
||||||
|
{"large value", "TEST_UINT", "4294967295", 0, 4294967295, true},
|
||||||
|
{"not set", "TEST_UINT_NOTSET", "", 100, 100, false},
|
||||||
|
{"invalid value", "TEST_UINT", "not_a_number", 50, 50, true},
|
||||||
|
{"negative value", "TEST_UINT", "-42", 75, 75, true},
|
||||||
|
{"empty string", "TEST_UINT", "", 25, 25, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := UInt(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UInt() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUInt8(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue uint8
|
||||||
|
expected uint8
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid uint8", "TEST_UINT8", "42", 0, 42, true},
|
||||||
|
{"valid zero", "TEST_UINT8", "0", 10, 0, true},
|
||||||
|
{"max uint8", "TEST_UINT8", "255", 0, 255, true},
|
||||||
|
{"overflow", "TEST_UINT8", "256", 10, 10, true},
|
||||||
|
{"not set", "TEST_UINT8_NOTSET", "", 50, 50, false},
|
||||||
|
{"invalid value", "TEST_UINT8", "abc", 25, 25, true},
|
||||||
|
{"negative value", "TEST_UINT8", "-1", 30, 30, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := UInt8(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UInt8() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUInt16(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue uint16
|
||||||
|
expected uint16
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid uint16", "TEST_UINT16", "1000", 0, 1000, true},
|
||||||
|
{"valid zero", "TEST_UINT16", "0", 100, 0, true},
|
||||||
|
{"max uint16", "TEST_UINT16", "65535", 0, 65535, true},
|
||||||
|
{"overflow", "TEST_UINT16", "65536", 100, 100, true},
|
||||||
|
{"not set", "TEST_UINT16_NOTSET", "", 500, 500, false},
|
||||||
|
{"invalid value", "TEST_UINT16", "invalid", 250, 250, true},
|
||||||
|
{"negative value", "TEST_UINT16", "-100", 300, 300, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := UInt16(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UInt16() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUInt32(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue uint32
|
||||||
|
expected uint32
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid uint32", "TEST_UINT32", "100000", 0, 100000, true},
|
||||||
|
{"valid zero", "TEST_UINT32", "0", 1000, 0, true},
|
||||||
|
{"max uint32", "TEST_UINT32", "4294967295", 0, 4294967295, true},
|
||||||
|
{"overflow", "TEST_UINT32", "4294967296", 1000, 1000, true},
|
||||||
|
{"not set", "TEST_UINT32_NOTSET", "", 5000, 5000, false},
|
||||||
|
{"invalid value", "TEST_UINT32", "xyz", 2500, 2500, true},
|
||||||
|
{"negative value", "TEST_UINT32", "-1000", 3000, 3000, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := UInt32(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UInt32() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUInt64(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
defaultValue uint64
|
||||||
|
expected uint64
|
||||||
|
shouldSet bool
|
||||||
|
}{
|
||||||
|
{"valid uint64", "TEST_UINT64", "1000000000", 0, 1000000000, true},
|
||||||
|
{"valid zero", "TEST_UINT64", "0", 10000, 0, true},
|
||||||
|
{"max uint64", "TEST_UINT64", "18446744073709551615", 0, 18446744073709551615, true},
|
||||||
|
{"overflow", "TEST_UINT64", "18446744073709551616", 10000, 10000, true},
|
||||||
|
{"not set", "TEST_UINT64_NOTSET", "", 50000, 50000, false},
|
||||||
|
{"invalid value", "TEST_UINT64", "not_valid", 25000, 25000, true},
|
||||||
|
{"negative value", "TEST_UINT64", "-5000", 30000, 30000, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldSet {
|
||||||
|
os.Setenv(tt.key, tt.value)
|
||||||
|
defer os.Unsetenv(tt.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := UInt64(tt.key, tt.defaultValue)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UInt64() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,12 +1,16 @@
|
|||||||
package hlog
|
package hlog
|
||||||
|
|
||||||
import "github.com/rs/zerolog"
|
import (
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
type Level = zerolog.Level
|
type Level = zerolog.Level
|
||||||
|
|
||||||
// Takes a log level as string and converts it to a Level interface.
|
// Takes a log level as string and converts it to a Level interface.
|
||||||
// If the string is not a valid input it will return InfoLevel
|
// If the string is not a valid input it will return InfoLevel
|
||||||
func LogLevel(level string) Level {
|
// Valid levels: trace, debug, info, warn, error, fatal, panic
|
||||||
|
func LogLevel(level string) (Level, error) {
|
||||||
levels := map[string]zerolog.Level{
|
levels := map[string]zerolog.Level{
|
||||||
"trace": zerolog.TraceLevel,
|
"trace": zerolog.TraceLevel,
|
||||||
"debug": zerolog.DebugLevel,
|
"debug": zerolog.DebugLevel,
|
||||||
@@ -18,7 +22,7 @@ func LogLevel(level string) Level {
|
|||||||
}
|
}
|
||||||
logLevel, valid := levels[level]
|
logLevel, valid := levels[level]
|
||||||
if !valid {
|
if !valid {
|
||||||
return zerolog.InfoLevel
|
return 0, errors.New("Invalid log level specified.")
|
||||||
}
|
}
|
||||||
return logLevel
|
return logLevel, nil
|
||||||
}
|
}
|
||||||
|
|||||||
39
hws/errors.go
Normal file
39
hws/errors.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type HWSError struct {
|
||||||
|
statusCode int // HTTP Status code
|
||||||
|
message string // Error message
|
||||||
|
error error // Error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorPage func(statusCode int, w http.ResponseWriter, r *http.Request) error
|
||||||
|
|
||||||
|
func NewError(statusCode int, msg string, err error) *HWSError {
|
||||||
|
return &HWSError{
|
||||||
|
statusCode: statusCode,
|
||||||
|
message: msg,
|
||||||
|
error: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) AddErrorPage(page ErrorPage) {
|
||||||
|
server.errorPage = page
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error *HWSError) {
|
||||||
|
w.WriteHeader(error.statusCode)
|
||||||
|
server.logger.logger.Error().Err(error.error).Msg(error.message)
|
||||||
|
if server.errorPage != nil {
|
||||||
|
err := server.errorPage(error.statusCode, w, r)
|
||||||
|
if err != nil {
|
||||||
|
server.logger.logger.Error().Err(err).Msg("Failed to render error page")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) ThrowWarn(w http.ResponseWriter, error *HWSError) {
|
||||||
|
w.WriteHeader(error.statusCode)
|
||||||
|
server.logger.logger.Warn().Err(error.error).Msg(error.message)
|
||||||
|
}
|
||||||
14
hws/go.mod
Normal file
14
hws/go.mod
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
module git.haelnorr.com/h/golib/hws
|
||||||
|
|
||||||
|
go 1.25.5
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/pkg/errors v0.9.1
|
||||||
|
github.com/rs/zerolog v1.34.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||||
|
golang.org/x/sys v0.12.0 // indirect
|
||||||
|
)
|
||||||
16
hws/go.sum
Normal file
16
hws/go.sum
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
|
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||||
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
|
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||||
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||||
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
31
hws/gzip.go
Normal file
31
hws/gzip.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func addgzip(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Encoding", "gzip")
|
||||||
|
gz := gzip.NewWriter(w)
|
||||||
|
defer gz.Close()
|
||||||
|
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
|
||||||
|
next.ServeHTTP(gzw, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type gzipResponseWriter struct {
|
||||||
|
io.Writer
|
||||||
|
http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w gzipResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
return w.Writer.Write(b)
|
||||||
|
}
|
||||||
44
hws/logger.go
Normal file
44
hws/logger.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type logger struct {
|
||||||
|
logger *zerolog.Logger
|
||||||
|
ignoredPaths []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server.AddLogger adds a logger to the server to use for request logging.
|
||||||
|
func (server *Server) AddLogger(zlogger *zerolog.Logger) error {
|
||||||
|
if zlogger == nil {
|
||||||
|
return errors.New("Unable to add logger, no logger provided")
|
||||||
|
}
|
||||||
|
server.logger = &logger{
|
||||||
|
logger: zlogger,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server.LoggerIgnorePaths sets a list of URL paths to ignore logging for.
|
||||||
|
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
|
||||||
|
// Useful for ignoring requests to CSS files or favicons
|
||||||
|
func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
||||||
|
for _, path := range paths {
|
||||||
|
u, err := url.Parse(path)
|
||||||
|
valid := err == nil &&
|
||||||
|
u.Scheme == "" &&
|
||||||
|
u.Host == "" &&
|
||||||
|
u.RawQuery == "" &&
|
||||||
|
u.Fragment == ""
|
||||||
|
if !valid {
|
||||||
|
return fmt.Errorf("Invalid path: '%s'", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
server.logger.ignoredPaths = paths
|
||||||
|
return nil
|
||||||
|
}
|
||||||
51
hws/middleware.go
Normal file
51
hws/middleware.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Middleware func(h http.Handler) http.Handler
|
||||||
|
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
||||||
|
|
||||||
|
// Server.AddMiddleware registers all the middleware.
|
||||||
|
// Middleware will be run in the order that they are provided.
|
||||||
|
func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
||||||
|
if !server.routes {
|
||||||
|
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RUN LOGGING MIDDLEWARE FIRST
|
||||||
|
server.server.Handler = logging(server.server.Handler, server.logger)
|
||||||
|
|
||||||
|
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
|
||||||
|
for i := len(middleware); i > 0; i-- {
|
||||||
|
server.server.Handler = middleware[i-1](server.server.Handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RUN GZIP
|
||||||
|
if server.gzip {
|
||||||
|
server.server.Handler = addgzip(server.server.Handler)
|
||||||
|
}
|
||||||
|
// RUN TIMER MIDDLEWARE LAST
|
||||||
|
server.server.Handler = startTimer(server.server.Handler)
|
||||||
|
|
||||||
|
server.middleware = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) NewMiddleware(
|
||||||
|
middlewareFunc MiddlewareFunc,
|
||||||
|
) Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
newReq, herr := middlewareFunc(w, r)
|
||||||
|
if herr != nil {
|
||||||
|
server.ThrowError(w, r, herr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, newReq)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
38
hws/middleware_logging.go
Normal file
38
hws/middleware_logging.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Middleware to add logs to console with details of the request
|
||||||
|
func logging(next http.Handler, logger *logger) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if logger == nil {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if slices.Contains(logger.ignoredPaths, r.URL.Path) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
start, err := getStartTime(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
logger.logger.Error().Err(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wrapped := &wrappedWriter{
|
||||||
|
ResponseWriter: w,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
}
|
||||||
|
next.ServeHTTP(wrapped, r)
|
||||||
|
logger.logger.Info().
|
||||||
|
Int("status", wrapped.statusCode).
|
||||||
|
Str("method", r.Method).
|
||||||
|
Str("resource", r.URL.Path).
|
||||||
|
Dur("time_elapsed", time.Since(start)).
|
||||||
|
Str("remote_addr", r.Header.Get("X-Forwarded-For")).
|
||||||
|
Msg("Served")
|
||||||
|
})
|
||||||
|
}
|
||||||
33
hws/middleware_timer.go
Normal file
33
hws/middleware_timer.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func startTimer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
start := time.Now()
|
||||||
|
ctx := setStart(r.Context(), start)
|
||||||
|
newReq := r.WithContext(ctx)
|
||||||
|
next.ServeHTTP(w, newReq)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the start time of the request
|
||||||
|
func setStart(ctx context.Context, time time.Time) context.Context {
|
||||||
|
return context.WithValue(ctx, "hws context key request-timer", time)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the start time of the request
|
||||||
|
func getStartTime(ctx context.Context) (time.Time, error) {
|
||||||
|
start, ok := ctx.Value("hws context key request-timer").(time.Time)
|
||||||
|
if !ok {
|
||||||
|
return time.Time{}, errors.New("Failed to get start time of request")
|
||||||
|
}
|
||||||
|
return start, nil
|
||||||
|
}
|
||||||
15
hws/responsewriter.go
Normal file
15
hws/responsewriter.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Wraps the http.ResponseWriter, adding a statusCode field
|
||||||
|
type wrappedWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extends WriteHeader to the ResponseWriter to add the status code
|
||||||
|
func (w *wrappedWriter) WriteHeader(statusCode int) {
|
||||||
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
w.statusCode = statusCode
|
||||||
|
}
|
||||||
62
hws/routes.go
Normal file
62
hws/routes.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Route struct {
|
||||||
|
Path string // Absolute path to the requested resource
|
||||||
|
Method Method // HTTP Method
|
||||||
|
Handler http.Handler // Handler to use for the request
|
||||||
|
}
|
||||||
|
|
||||||
|
type Method string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MethodGET Method = "GET"
|
||||||
|
MethodPOST Method = "POST"
|
||||||
|
MethodPUT Method = "PUT"
|
||||||
|
MethodHEAD Method = "HEAD"
|
||||||
|
MethodDELETE Method = "DELETE"
|
||||||
|
MethodCONNECT Method = "CONNECT"
|
||||||
|
MethodOPTIONS Method = "OPTIONS"
|
||||||
|
MethodTRACE Method = "TRACE"
|
||||||
|
MethodPATCH Method = "PATCH"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server.AddRoutes registers the page handlers for the server.
|
||||||
|
// At least one route must be provided.
|
||||||
|
func (server *Server) AddRoutes(routes ...Route) error {
|
||||||
|
if len(routes) == 0 {
|
||||||
|
return errors.New("No routes provided")
|
||||||
|
}
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {})
|
||||||
|
for _, route := range routes {
|
||||||
|
if !validMethod(route.Method) {
|
||||||
|
return fmt.Errorf("Invalid method %s for path %s", route.Method, route.Path)
|
||||||
|
}
|
||||||
|
if route.Handler == nil {
|
||||||
|
return fmt.Errorf("No handler provided for %s %s", route.Method, route.Path)
|
||||||
|
}
|
||||||
|
pattern := fmt.Sprintf("%s %s", route.Method, route.Path)
|
||||||
|
mux.Handle(pattern, route.Handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.server.Handler = mux
|
||||||
|
server.routes = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validMethod(m Method) bool {
|
||||||
|
switch m {
|
||||||
|
case MethodGET, MethodPOST, MethodPUT, MethodHEAD,
|
||||||
|
MethodDELETE, MethodCONNECT, MethodOPTIONS, MethodTRACE, MethodPATCH:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
52
hws/safefileserver.go
Normal file
52
hws/safefileserver.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Wrapper for default FileSystem
|
||||||
|
type justFilesFilesystem struct {
|
||||||
|
fs http.FileSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrapper for default File
|
||||||
|
type neuteredReaddirFile struct {
|
||||||
|
http.File
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modifies the behavior of FileSystem.Open to return the neutered version of File
|
||||||
|
func (fs justFilesFilesystem) Open(name string) (http.File, error) {
|
||||||
|
f, err := fs.fs.Open(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the requested path is a directory
|
||||||
|
// and explicitly return an error to trigger a 404
|
||||||
|
fileInfo, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if fileInfo.IsDir() {
|
||||||
|
return nil, os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
return neuteredReaddirFile{f}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overrides the Readdir method of File to always return nil
|
||||||
|
func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SafeFileServer(fileSystem *http.FileSystem) (http.Handler, error) {
|
||||||
|
if fileSystem == nil {
|
||||||
|
return nil, errors.New("No file system provided")
|
||||||
|
}
|
||||||
|
nfs := justFilesFilesystem{*fileSystem}
|
||||||
|
fs := http.FileServer(nfs)
|
||||||
|
return fs, nil
|
||||||
|
}
|
||||||
84
hws/server.go
Normal file
84
hws/server.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package hws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
server *http.Server
|
||||||
|
logger *logger
|
||||||
|
routes bool
|
||||||
|
middleware bool
|
||||||
|
gzip bool
|
||||||
|
errorPage ErrorPage
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer returns a new hws.Server with the specified parameters.
|
||||||
|
// The timeout options are specified in seconds
|
||||||
|
func NewServer(
|
||||||
|
host string,
|
||||||
|
port string,
|
||||||
|
readHeaderTimeout time.Duration,
|
||||||
|
writeTimeout time.Duration,
|
||||||
|
idleTimeout time.Duration,
|
||||||
|
gzip bool,
|
||||||
|
) (*Server, error) {
|
||||||
|
// TODO: test that host and port are valid values
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: net.JoinHostPort(host, port),
|
||||||
|
ReadHeaderTimeout: readHeaderTimeout * time.Second,
|
||||||
|
WriteTimeout: writeTimeout * time.Second,
|
||||||
|
IdleTimeout: idleTimeout * time.Second,
|
||||||
|
}
|
||||||
|
server := &Server{
|
||||||
|
server: httpServer,
|
||||||
|
routes: false,
|
||||||
|
gzip: gzip,
|
||||||
|
}
|
||||||
|
return server, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) Start() error {
|
||||||
|
if !server.routes {
|
||||||
|
return errors.New("Server.AddRoutes must be run before starting the server")
|
||||||
|
}
|
||||||
|
if !server.middleware {
|
||||||
|
err := server.AddMiddleware()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "server.AddMiddleware")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if server.logger == nil {
|
||||||
|
fmt.Printf("Listening for requests on %s", server.server.Addr)
|
||||||
|
} else {
|
||||||
|
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
|
||||||
|
}
|
||||||
|
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
if server.logger == nil {
|
||||||
|
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
||||||
|
} else {
|
||||||
|
server.logger.logger.Error().Err(err).Msg("Server encountered a fatal error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) Shutdown(ctx context.Context) {
|
||||||
|
if err := server.server.Shutdown(ctx); err != nil {
|
||||||
|
if server.logger == nil {
|
||||||
|
fmt.Printf("Failed to gracefully shutdown the server: %s", err.Error())
|
||||||
|
} else {
|
||||||
|
server.logger.logger.Error().Err(err).Msg("Failed to gracefully shutdown the server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
54
hwsauth/authenticate.go
Normal file
54
hwsauth/authenticate.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check the cookies for token strings and attempt to authenticate them
|
||||||
|
func (auth *Authenticator[T]) getAuthenticatedUser(
|
||||||
|
tx *sql.Tx,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) (*authenticatedModel[T], error) {
|
||||||
|
// Get token strings from cookies
|
||||||
|
atStr, rtStr := jwt.GetTokenCookies(r)
|
||||||
|
if atStr == "" && rtStr == "" {
|
||||||
|
return nil, errors.New("No token strings provided")
|
||||||
|
}
|
||||||
|
// Attempt to parse the access token
|
||||||
|
aT, err := auth.tokenGenerator.ValidateAccess(tx, atStr)
|
||||||
|
if err != nil {
|
||||||
|
// Access token invalid, attempt to parse refresh token
|
||||||
|
rT, err := auth.tokenGenerator.ValidateRefresh(tx, rtStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "auth.tokenGenerator.ValidateRefresh")
|
||||||
|
}
|
||||||
|
// Refresh token valid, attempt to get a new token pair
|
||||||
|
model, err := auth.refreshAuthTokens(tx, w, r, rT)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "auth.refreshAuthTokens")
|
||||||
|
}
|
||||||
|
// New token pair sent, return the authorized user
|
||||||
|
authUser := authenticatedModel[T]{
|
||||||
|
model: model,
|
||||||
|
fresh: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
return &authUser, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Access token valid
|
||||||
|
model, err := auth.load(tx, aT.SUB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "auth.load")
|
||||||
|
}
|
||||||
|
authUser := authenticatedModel[T]{
|
||||||
|
model: model,
|
||||||
|
fresh: aT.Fresh,
|
||||||
|
}
|
||||||
|
return &authUser, nil
|
||||||
|
}
|
||||||
93
hwsauth/authenticator.go
Normal file
93
hwsauth/authenticator.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Authenticator[T Model] struct {
|
||||||
|
tokenGenerator *jwt.TokenGenerator
|
||||||
|
load LoadFunc[T]
|
||||||
|
conn *sql.DB
|
||||||
|
ignoredPaths []string
|
||||||
|
logger *zerolog.Logger
|
||||||
|
server *hws.Server
|
||||||
|
errorPage hws.ErrorPage
|
||||||
|
SSL bool // Use SSL for JWT tokens. Default true
|
||||||
|
TrustedHost string // TrustedHost to use for SSL verification
|
||||||
|
SecretKey string // Secret key to use for JWT tokens
|
||||||
|
AccessTokenExpiry int64 // Expiry time for Access tokens in minutes. Default 5
|
||||||
|
RefreshTokenExpiry int64 // Expiry time for Refresh tokens in minutes. Default 1440 (1 day)
|
||||||
|
TokenFreshTime int64 // Expiry time of token freshness. Default 5 minutes
|
||||||
|
LandingPage string // Path of the desired landing page for logged in users
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthenticator creates and returns a new Authenticator using the provided configuration.
|
||||||
|
// All expiry times should be provided in minutes.
|
||||||
|
// trustedHost and secretKey strings must be provided.
|
||||||
|
func NewAuthenticator[T Model](
|
||||||
|
load LoadFunc[T],
|
||||||
|
server *hws.Server,
|
||||||
|
conn *sql.DB,
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
errorPage hws.ErrorPage,
|
||||||
|
) (*Authenticator[T], error) {
|
||||||
|
if load == nil {
|
||||||
|
return nil, errors.New("No function to load model supplied")
|
||||||
|
}
|
||||||
|
if server == nil {
|
||||||
|
return nil, errors.New("No hws.Server provided")
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return nil, errors.New("No database connection supplied")
|
||||||
|
}
|
||||||
|
if logger == nil {
|
||||||
|
return nil, errors.New("No logger provided")
|
||||||
|
}
|
||||||
|
if errorPage == nil {
|
||||||
|
return nil, errors.New("No ErrorPage provided")
|
||||||
|
}
|
||||||
|
auth := Authenticator[T]{
|
||||||
|
load: load,
|
||||||
|
server: server,
|
||||||
|
conn: conn,
|
||||||
|
logger: logger,
|
||||||
|
errorPage: errorPage,
|
||||||
|
AccessTokenExpiry: 5,
|
||||||
|
RefreshTokenExpiry: 1440,
|
||||||
|
TokenFreshTime: 5,
|
||||||
|
SSL: true,
|
||||||
|
}
|
||||||
|
return &auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialise finishes the setup and prepares the Authenticator for use.
|
||||||
|
// Any custom configuration must be set before Initialise is called
|
||||||
|
func (auth *Authenticator[T]) Initialise() error {
|
||||||
|
if auth.TrustedHost == "" {
|
||||||
|
return errors.New("Trusted host must be provided")
|
||||||
|
}
|
||||||
|
if auth.SecretKey == "" {
|
||||||
|
return errors.New("Secret key cannot be blank")
|
||||||
|
}
|
||||||
|
if auth.LandingPage == "" {
|
||||||
|
return errors.New("No landing page specified")
|
||||||
|
}
|
||||||
|
tokenGen, err := jwt.CreateGenerator(
|
||||||
|
auth.AccessTokenExpiry,
|
||||||
|
auth.RefreshTokenExpiry,
|
||||||
|
auth.TokenFreshTime,
|
||||||
|
auth.TrustedHost,
|
||||||
|
auth.SecretKey,
|
||||||
|
auth.conn,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "jwt.CreateGenerator")
|
||||||
|
}
|
||||||
|
auth.tokenGenerator = tokenGen
|
||||||
|
return nil
|
||||||
|
}
|
||||||
19
hwsauth/go.mod
Normal file
19
hwsauth/go.mod
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
module git.haelnorr.com/h/golib/hwsauth
|
||||||
|
|
||||||
|
go 1.25.5
|
||||||
|
|
||||||
|
require (
|
||||||
|
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||||
|
git.haelnorr.com/h/golib/jwt v0.9.2
|
||||||
|
git.haelnorr.com/h/golib/hws v0.1.0
|
||||||
|
github.com/pkg/errors v0.9.1
|
||||||
|
github.com/rs/zerolog v1.34.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||||
|
golang.org/x/sys v0.12.0 // indirect
|
||||||
|
)
|
||||||
36
hwsauth/go.sum
Normal file
36
hwsauth/go.sum
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDOV/AuWs=
|
||||||
|
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
|
||||||
|
git.haelnorr.com/h/golib/hws v0.1.0 h1:+0eNq1uGWrGfbS5AgHeGoGDjVfCWuaVu+1wBxgPqyOY=
|
||||||
|
git.haelnorr.com/h/golib/hws v0.1.0/go.mod h1:b2pbkMaebzmck9TxqGBGzTJPEcB5TWcEHwFknLE7dqM=
|
||||||
|
git.haelnorr.com/h/golib/jwt v0.9.2 h1:l1Ow7DPGACAU54CnMP/NlZjdc4nRD1wr3xZ8a7taRvU=
|
||||||
|
git.haelnorr.com/h/golib/jwt v0.9.2/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||||
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
|
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||||
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||||
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
22
hwsauth/ignorepaths.go
Normal file
22
hwsauth/ignorepaths.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (auth *Authenticator[T]) IgnorePaths(paths ...string) error {
|
||||||
|
for _, path := range paths {
|
||||||
|
u, err := url.Parse(path)
|
||||||
|
valid := err == nil &&
|
||||||
|
u.Scheme == "" &&
|
||||||
|
u.Host == "" &&
|
||||||
|
u.RawQuery == "" &&
|
||||||
|
u.Fragment == ""
|
||||||
|
if !valid {
|
||||||
|
return fmt.Errorf("Invalid path: '%s'", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auth.ignoredPaths = paths
|
||||||
|
return nil
|
||||||
|
}
|
||||||
22
hwsauth/login.go
Normal file
22
hwsauth/login.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (auth *Authenticator[T]) Login(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
model T,
|
||||||
|
rememberMe bool,
|
||||||
|
) error {
|
||||||
|
|
||||||
|
err := jwt.SetTokenCookies(w, r, auth.tokenGenerator, model.ID(), true, rememberMe, auth.SSL)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "jwt.SetTokenCookies")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
27
hwsauth/logout.go
Normal file
27
hwsauth/logout.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (auth *Authenticator[T]) Logout(tx *sql.Tx, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
aT, rT, err := auth.getTokens(tx, r)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "auth.getTokens")
|
||||||
|
}
|
||||||
|
err = aT.Revoke(tx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "aT.Revoke")
|
||||||
|
}
|
||||||
|
err = rT.Revoke(tx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "rT.Revoke")
|
||||||
|
}
|
||||||
|
cookies.DeleteCookie(w, "access", "/")
|
||||||
|
cookies.DeleteCookie(w, "refresh", "/")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
42
hwsauth/middleware.go
Normal file
42
hwsauth/middleware.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (auth *Authenticator[T]) Authenticate() hws.Middleware {
|
||||||
|
return auth.server.NewMiddleware(auth.authenticate())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *Authenticator[T]) authenticate() hws.MiddlewareFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
|
||||||
|
if slices.Contains(auth.ignoredPaths, r.URL.Path) {
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Start the transaction
|
||||||
|
tx, err := auth.conn.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, hws.NewError(http.StatusServiceUnavailable, "Unable to start transaction", err)
|
||||||
|
}
|
||||||
|
model, err := auth.getAuthenticatedUser(tx, w, r)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
auth.logger.Debug().
|
||||||
|
Str("remote_addr", r.RemoteAddr).
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to authenticate user")
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
authContext := setAuthenticatedModel(r.Context(), model)
|
||||||
|
newReq := r.WithContext(authContext)
|
||||||
|
return newReq, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
46
hwsauth/model.go
Normal file
46
hwsauth/model.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
type authenticatedModel[T Model] struct {
|
||||||
|
model T
|
||||||
|
fresh int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNil[T Model]() T {
|
||||||
|
var result T
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model interface {
|
||||||
|
ID() int
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContextLoader[T Model] func(ctx context.Context) T
|
||||||
|
|
||||||
|
type LoadFunc[T Model] func(tx *sql.Tx, id int) (T, error)
|
||||||
|
|
||||||
|
// Return a new context with the user added in
|
||||||
|
func setAuthenticatedModel[T Model](ctx context.Context, m *authenticatedModel[T]) context.Context {
|
||||||
|
return context.WithValue(ctx, "hwsauth context key authenticated-model", m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve a user from the given context. Returns nil if not set
|
||||||
|
func getAuthorizedModel[T Model](ctx context.Context) *authenticatedModel[T] {
|
||||||
|
model, ok := ctx.Value("hwsauth context key authenticated-model").(*authenticatedModel[T])
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *Authenticator[T]) CurrentModel(ctx context.Context) T {
|
||||||
|
model := getAuthorizedModel[T](ctx)
|
||||||
|
if model == nil {
|
||||||
|
return getNil[T]()
|
||||||
|
}
|
||||||
|
return model.model
|
||||||
|
}
|
||||||
43
hwsauth/protectpage.go
Normal file
43
hwsauth/protectpage.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Checks if the model is set in the context and shows 401 page if not logged in
|
||||||
|
func (auth *Authenticator[T]) LoginReq(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
model := getAuthorizedModel[T](r.Context())
|
||||||
|
if model == nil {
|
||||||
|
auth.errorPage(http.StatusUnauthorized, w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if the model is set in the context and redirects them to the landing page if
|
||||||
|
// they are logged in
|
||||||
|
func (auth *Authenticator[T]) LogoutReq(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
model := getAuthorizedModel[T](r.Context())
|
||||||
|
if model != nil {
|
||||||
|
http.Redirect(w, r, auth.LandingPage, http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *Authenticator[T]) FreshReq(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
model := getAuthorizedModel[T](r.Context())
|
||||||
|
isFresh := time.Now().Before(time.Unix(model.fresh, 0))
|
||||||
|
if !isFresh {
|
||||||
|
w.WriteHeader(444)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
66
hwsauth/reauthenticate.go
Normal file
66
hwsauth/reauthenticate.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (auth *Authenticator[T]) RefreshAuthTokens(tx *sql.Tx, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
aT, rT, err := auth.getTokens(tx, r)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "getTokens")
|
||||||
|
}
|
||||||
|
rememberMe := map[string]bool{
|
||||||
|
"session": false,
|
||||||
|
"exp": true,
|
||||||
|
}[aT.TTL]
|
||||||
|
// issue new tokens for the user
|
||||||
|
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "jwt.SetTokenCookies")
|
||||||
|
}
|
||||||
|
err = revokeTokenPair(tx, aT, rT)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "revokeTokenPair")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the tokens from the request
|
||||||
|
func (auth *Authenticator[T]) getTokens(
|
||||||
|
tx *sql.Tx,
|
||||||
|
r *http.Request,
|
||||||
|
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
||||||
|
// get the existing tokens from the cookies
|
||||||
|
atStr, rtStr := jwt.GetTokenCookies(r)
|
||||||
|
aT, err := auth.tokenGenerator.ValidateAccess(tx, atStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
|
||||||
|
}
|
||||||
|
rT, err := auth.tokenGenerator.ValidateRefresh(tx, rtStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
|
||||||
|
}
|
||||||
|
return aT, rT, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke the given token pair
|
||||||
|
func revokeTokenPair(
|
||||||
|
tx *sql.Tx,
|
||||||
|
aT *jwt.AccessToken,
|
||||||
|
rT *jwt.RefreshToken,
|
||||||
|
) error {
|
||||||
|
err := aT.Revoke(tx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "aT.Revoke")
|
||||||
|
}
|
||||||
|
err = rT.Revoke(tx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "rT.Revoke")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
40
hwsauth/refreshtokens.go
Normal file
40
hwsauth/refreshtokens.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package hwsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Attempt to use a valid refresh token to generate a new token pair
|
||||||
|
func (auth *Authenticator[T]) refreshAuthTokens(
|
||||||
|
tx *sql.Tx,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
rT *jwt.RefreshToken,
|
||||||
|
) (T, error) {
|
||||||
|
model, err := auth.load(tx, rT.SUB)
|
||||||
|
if err != nil {
|
||||||
|
return getNil[T](), errors.Wrap(err, "auth.load")
|
||||||
|
}
|
||||||
|
|
||||||
|
rememberMe := map[string]bool{
|
||||||
|
"session": false,
|
||||||
|
"exp": true,
|
||||||
|
}[rT.TTL]
|
||||||
|
|
||||||
|
// Set fresh to true because new tokens coming from refresh request
|
||||||
|
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, model.ID(), false, rememberMe, auth.SSL)
|
||||||
|
if err != nil {
|
||||||
|
return getNil[T](), errors.Wrap(err, "jwt.SetTokenCookies")
|
||||||
|
}
|
||||||
|
// New tokens sent, revoke the old tokens
|
||||||
|
err = rT.Revoke(tx)
|
||||||
|
if err != nil {
|
||||||
|
return getNil[T](), errors.Wrap(err, "rT.Revoke")
|
||||||
|
}
|
||||||
|
// Return the authorized user
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
73
jwt/cookies.go
Normal file
73
jwt/cookies.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get the value of the access and refresh tokens
|
||||||
|
func GetTokenCookies(
|
||||||
|
r *http.Request,
|
||||||
|
) (acc string, ref string) {
|
||||||
|
accCookie, accErr := r.Cookie("access")
|
||||||
|
refCookie, refErr := r.Cookie("refresh")
|
||||||
|
var (
|
||||||
|
accStr string = ""
|
||||||
|
refStr string = ""
|
||||||
|
)
|
||||||
|
if accErr == nil {
|
||||||
|
accStr = accCookie.Value
|
||||||
|
}
|
||||||
|
if refErr == nil {
|
||||||
|
refStr = refCookie.Value
|
||||||
|
}
|
||||||
|
return accStr, refStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a token with the provided details
|
||||||
|
func setToken(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
token string,
|
||||||
|
scope string,
|
||||||
|
exp int64,
|
||||||
|
rememberme bool,
|
||||||
|
useSSL bool,
|
||||||
|
) {
|
||||||
|
tokenCookie := &http.Cookie{
|
||||||
|
Name: scope,
|
||||||
|
Value: token,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
Secure: useSSL,
|
||||||
|
}
|
||||||
|
if rememberme {
|
||||||
|
tokenCookie.Expires = time.Unix(exp, 0)
|
||||||
|
}
|
||||||
|
http.SetCookie(w, tokenCookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new tokens for the subject and set them as cookies
|
||||||
|
func SetTokenCookies(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
tokenGen *TokenGenerator,
|
||||||
|
subject int,
|
||||||
|
fresh bool,
|
||||||
|
rememberMe bool,
|
||||||
|
useSSL bool,
|
||||||
|
) error {
|
||||||
|
at, atexp, err := tokenGen.NewAccess(subject, fresh, rememberMe)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "jwt.GenerateAccessToken")
|
||||||
|
}
|
||||||
|
rt, rtexp, err := tokenGen.NewRefresh(subject, rememberMe)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "jwt.GenerateRefreshToken")
|
||||||
|
}
|
||||||
|
// Don't set the cookies until we know no errors occured
|
||||||
|
setToken(w, at, "access", atexp, rememberMe, useSSL)
|
||||||
|
setToken(w, rt, "refresh", rtexp, rememberMe, useSSL)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
62
jwt/generator.go
Normal file
62
jwt/generator.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenGenerator struct {
|
||||||
|
accessExpireAfter int64 // Access Token expiry time in minutes
|
||||||
|
refreshExpireAfter int64 // Refresh Token expiry time in minutes
|
||||||
|
freshExpireAfter int64 // Token freshness expiry time in minutes
|
||||||
|
trustedHost string // Trusted hostname to use for the tokens
|
||||||
|
secretKey string // Secret key to use for token hashing
|
||||||
|
dbConn *sql.DB // Database handle for token blacklisting
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
|
||||||
|
// All expiry times should be provided in minutes.
|
||||||
|
// trustedHost and secretKey strings must be provided.
|
||||||
|
// dbConn can be nil, but doing this will disable token revocation
|
||||||
|
func CreateGenerator(
|
||||||
|
accessExpireAfter int64,
|
||||||
|
refreshExpireAfter int64,
|
||||||
|
freshExpireAfter int64,
|
||||||
|
trustedHost string,
|
||||||
|
secretKey string,
|
||||||
|
dbConn *sql.DB,
|
||||||
|
) (gen *TokenGenerator, err error) {
|
||||||
|
if accessExpireAfter <= 0 {
|
||||||
|
return nil, errors.New("accessExpireAfter must be greater than 0")
|
||||||
|
}
|
||||||
|
if refreshExpireAfter <= 0 {
|
||||||
|
return nil, errors.New("refreshExpireAfter must be greater than 0")
|
||||||
|
}
|
||||||
|
if freshExpireAfter <= 0 {
|
||||||
|
return nil, errors.New("freshExpireAfter must be greater than 0")
|
||||||
|
}
|
||||||
|
if trustedHost == "" {
|
||||||
|
return nil, errors.New("trustedHost cannot be an empty string")
|
||||||
|
}
|
||||||
|
if secretKey == "" {
|
||||||
|
return nil, errors.New("secretKey cannot be an empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbConn != nil {
|
||||||
|
err := dbConn.Ping()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Failed to ping database")
|
||||||
|
}
|
||||||
|
// TODO: check if jwtblacklist table exists
|
||||||
|
// TODO: create jwtblacklist table if not existing
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TokenGenerator{
|
||||||
|
accessExpireAfter: accessExpireAfter,
|
||||||
|
refreshExpireAfter: refreshExpireAfter,
|
||||||
|
freshExpireAfter: freshExpireAfter,
|
||||||
|
trustedHost: trustedHost,
|
||||||
|
secretKey: secretKey,
|
||||||
|
dbConn: dbConn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
90
jwt/generator_test.go
Normal file
90
jwt/generator_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateGenerator_Success_NoDB(t *testing.T) {
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"secret",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, gen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateGenerator_Success_WithDB(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"secret",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, gen)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateGenerator_InvalidInputs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func() error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"access expiry <= 0",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(0, 1, 1, "h", "s", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"refresh expiry <= 0",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(1, 0, 1, "h", "s", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fresh expiry <= 0",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(1, 1, 0, "h", "s", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"empty trustedHost",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(1, 1, 1, "", "s", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"empty secretKey",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(1, 1, 1, "h", "", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Error(t, tt.fn())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
17
jwt/go.mod
Normal file
17
jwt/go.mod
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
module git.haelnorr.com/h/golib/jwt
|
||||||
|
|
||||||
|
go 1.25.5
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/pkg/errors v0.9.1
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
19
jwt/go.sum
Normal file
19
jwt/go.sum
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
38
jwt/revoke.go
Normal file
38
jwt/revoke.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Revoke a token by adding it to the database
|
||||||
|
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||||
|
if gen.dbConn == nil {
|
||||||
|
return errors.New("No DB provided, unable to use this function")
|
||||||
|
}
|
||||||
|
jti := t.GetJTI()
|
||||||
|
exp := t.GetEXP()
|
||||||
|
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||||
|
_, err := tx.Exec(query, jti, exp)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Exec")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if a token has been revoked. Returns true if not revoked.
|
||||||
|
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||||
|
if gen.dbConn == nil {
|
||||||
|
return false, errors.New("No DB provided, unable to use this function")
|
||||||
|
}
|
||||||
|
jti := t.GetJTI()
|
||||||
|
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||||
|
rows, err := tx.Query(query, jti)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "tx.Query")
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
revoked := rows.Next()
|
||||||
|
return !revoked, nil
|
||||||
|
}
|
||||||
83
jwt/revoke_test.go
Normal file
83
jwt/revoke_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"supersecret",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return gen
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoDBFail(t *testing.T) {
|
||||||
|
jti := uuid.New()
|
||||||
|
exp := time.Now().Add(time.Hour).Unix()
|
||||||
|
|
||||||
|
token := AccessToken{
|
||||||
|
JTI: jti,
|
||||||
|
EXP: exp,
|
||||||
|
gen: &TokenGenerator{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke should fail due to no DB
|
||||||
|
err := token.Revoke(&sql.Tx{})
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// CheckNotRevoked should fail
|
||||||
|
_, err = token.CheckNotRevoked(&sql.Tx{})
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
||||||
|
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
jti := uuid.New()
|
||||||
|
exp := time.Now().Add(time.Hour).Unix()
|
||||||
|
|
||||||
|
token := AccessToken{
|
||||||
|
JTI: jti,
|
||||||
|
EXP: exp,
|
||||||
|
gen: gen,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke expectations
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec(`INSERT INTO jwtblacklist`).
|
||||||
|
WithArgs(jti, exp).
|
||||||
|
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||||
|
WithArgs(jti).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||||
|
defer tx.Rollback()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = token.Revoke(tx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, valid)
|
||||||
|
|
||||||
|
require.NoError(t, tx.Commit())
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
79
jwt/tokengen.go
Normal file
79
jwt/tokengen.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generates an access token for the provided subject
|
||||||
|
func (gen *TokenGenerator) NewAccess(
|
||||||
|
subjectID int,
|
||||||
|
fresh bool,
|
||||||
|
rememberMe bool,
|
||||||
|
) (tokenString string, expiresIn int64, err error) {
|
||||||
|
issuedAt := time.Now().Unix()
|
||||||
|
expiresAt := issuedAt + (gen.accessExpireAfter * 60)
|
||||||
|
var freshExpiresAt int64
|
||||||
|
if fresh {
|
||||||
|
freshExpiresAt = issuedAt + (gen.freshExpireAfter * 60)
|
||||||
|
} else {
|
||||||
|
freshExpiresAt = issuedAt
|
||||||
|
}
|
||||||
|
var ttl string
|
||||||
|
if rememberMe {
|
||||||
|
ttl = "exp"
|
||||||
|
} else {
|
||||||
|
ttl = "session"
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||||
|
jwt.MapClaims{
|
||||||
|
"iss": gen.trustedHost,
|
||||||
|
"scope": "access",
|
||||||
|
"ttl": ttl,
|
||||||
|
"jti": uuid.New(),
|
||||||
|
"iat": issuedAt,
|
||||||
|
"exp": expiresAt,
|
||||||
|
"fresh": freshExpiresAt,
|
||||||
|
"sub": subjectID,
|
||||||
|
})
|
||||||
|
|
||||||
|
signedToken, err := token.SignedString([]byte(gen.secretKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||||
|
}
|
||||||
|
return signedToken, expiresAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates a refresh token for the provided user
|
||||||
|
func (gen *TokenGenerator) NewRefresh(
|
||||||
|
subjectID int,
|
||||||
|
rememberMe bool,
|
||||||
|
) (tokenStr string, exp int64, err error) {
|
||||||
|
issuedAt := time.Now().Unix()
|
||||||
|
expiresAt := issuedAt + (gen.refreshExpireAfter * 60)
|
||||||
|
var ttl string
|
||||||
|
if rememberMe {
|
||||||
|
ttl = "exp"
|
||||||
|
} else {
|
||||||
|
ttl = "session"
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||||
|
jwt.MapClaims{
|
||||||
|
"iss": gen.trustedHost,
|
||||||
|
"scope": "refresh",
|
||||||
|
"ttl": ttl,
|
||||||
|
"jti": uuid.New(),
|
||||||
|
"iat": issuedAt,
|
||||||
|
"exp": expiresAt,
|
||||||
|
"sub": subjectID,
|
||||||
|
})
|
||||||
|
|
||||||
|
signedToken, err := token.SignedString([]byte(gen.secretKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||||
|
}
|
||||||
|
return signedToken, expiresAt, nil
|
||||||
|
}
|
||||||
38
jwt/tokengen_test.go
Normal file
38
jwt/tokengen_test.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestGenerator(t *testing.T) *TokenGenerator {
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"supersecret",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return gen
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAccessToken(t *testing.T) {
|
||||||
|
gen := newTestGenerator(t)
|
||||||
|
|
||||||
|
tokenStr, exp, err := gen.NewAccess(123, true, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, tokenStr)
|
||||||
|
require.Greater(t, exp, int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRefreshToken(t *testing.T) {
|
||||||
|
gen := newTestGenerator(t)
|
||||||
|
|
||||||
|
tokenStr, exp, err := gen.NewRefresh(123, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, tokenStr)
|
||||||
|
require.Greater(t, exp, int64(0))
|
||||||
|
}
|
||||||
71
jwt/tokens.go
Normal file
71
jwt/tokens.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Token interface {
|
||||||
|
GetJTI() uuid.UUID
|
||||||
|
GetEXP() int64
|
||||||
|
GetScope() string
|
||||||
|
Revoke(*sql.Tx) error
|
||||||
|
CheckNotRevoked(*sql.Tx) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Access token
|
||||||
|
type AccessToken struct {
|
||||||
|
ISS string // Issuer, generally TrustedHost
|
||||||
|
IAT int64 // Time issued at
|
||||||
|
EXP int64 // Time expiring at
|
||||||
|
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||||
|
SUB int // Subject (user) ID
|
||||||
|
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||||
|
Fresh int64 // Time freshness expiring at
|
||||||
|
Scope string // Should be "access"
|
||||||
|
gen *TokenGenerator
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh token
|
||||||
|
type RefreshToken struct {
|
||||||
|
ISS string // Issuer, generally TrustedHost
|
||||||
|
IAT int64 // Time issued at
|
||||||
|
EXP int64 // Time expiring at
|
||||||
|
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||||
|
SUB int // Subject (user) ID
|
||||||
|
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||||
|
Scope string // Should be "refresh"
|
||||||
|
gen *TokenGenerator
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a AccessToken) GetJTI() uuid.UUID {
|
||||||
|
return a.JTI
|
||||||
|
}
|
||||||
|
func (r RefreshToken) GetJTI() uuid.UUID {
|
||||||
|
return r.JTI
|
||||||
|
}
|
||||||
|
func (a AccessToken) GetEXP() int64 {
|
||||||
|
return a.EXP
|
||||||
|
}
|
||||||
|
func (r RefreshToken) GetEXP() int64 {
|
||||||
|
return r.EXP
|
||||||
|
}
|
||||||
|
func (a AccessToken) GetScope() string {
|
||||||
|
return a.Scope
|
||||||
|
}
|
||||||
|
func (r RefreshToken) GetScope() string {
|
||||||
|
return r.Scope
|
||||||
|
}
|
||||||
|
func (a AccessToken) Revoke(tx *sql.Tx) error {
|
||||||
|
return a.gen.revoke(tx, a)
|
||||||
|
}
|
||||||
|
func (r RefreshToken) Revoke(tx *sql.Tx) error {
|
||||||
|
return r.gen.revoke(tx, r)
|
||||||
|
}
|
||||||
|
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||||
|
return a.gen.checkNotRevoked(tx, a)
|
||||||
|
}
|
||||||
|
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||||
|
return r.gen.checkNotRevoked(tx, r)
|
||||||
|
}
|
||||||
123
jwt/util.go
Normal file
123
jwt/util.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse a token, validating its signing sigature and returning the claims
|
||||||
|
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
|
||||||
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(secretKey), nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "jwt.Parse")
|
||||||
|
}
|
||||||
|
// Token decoded, parse the claims
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("Failed to parse claims")
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if a token is expired. Returns the expiry if not expired
|
||||||
|
func checkTokenExpired(expiry interface{}) (int64, error) {
|
||||||
|
// Coerce the expiry to a float64 to avoid scientific notation
|
||||||
|
expFloat, ok := expiry.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'exp' claim")
|
||||||
|
}
|
||||||
|
// Convert to the int64 time we expect :)
|
||||||
|
expiryTime := int64(expFloat)
|
||||||
|
|
||||||
|
// Check if its expired
|
||||||
|
isExpired := time.Now().After(time.Unix(expiryTime, 0))
|
||||||
|
if isExpired {
|
||||||
|
return 0, errors.New("Token has expired")
|
||||||
|
}
|
||||||
|
return expiryTime, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if a token has a valid issuer. Returns the issuer if valid
|
||||||
|
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
|
||||||
|
issuerVal, ok := issuer.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("Missing or invalid 'iss' claim")
|
||||||
|
}
|
||||||
|
if issuer != trustedHost {
|
||||||
|
return "", errors.New("Issuer does not matched trusted host")
|
||||||
|
}
|
||||||
|
return issuerVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the scope matches the expected scope. Returns scope if true
|
||||||
|
func getTokenScope(scope interface{}) (string, error) {
|
||||||
|
scopeStr, ok := scope.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("Missing or invalid 'scope' claim")
|
||||||
|
}
|
||||||
|
return scopeStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the TTL of the token, either "session" or "exp"
|
||||||
|
func getTokenTTL(ttl interface{}) (string, error) {
|
||||||
|
ttlStr, ok := ttl.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("Missing or invalid 'ttl' claim")
|
||||||
|
}
|
||||||
|
if ttlStr != "exp" && ttlStr != "session" {
|
||||||
|
return "", errors.New("TTL value is not recognised")
|
||||||
|
}
|
||||||
|
return ttlStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the time the token was issued at
|
||||||
|
func getIssuedTime(issued interface{}) (int64, error) {
|
||||||
|
// Same float64 -> int64 trick as expiry
|
||||||
|
issuedFloat, ok := issued.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'iat' claim")
|
||||||
|
}
|
||||||
|
issuedAt := int64(issuedFloat)
|
||||||
|
return issuedAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the freshness expiry timestamp
|
||||||
|
func getFreshTime(fresh interface{}) (int64, error) {
|
||||||
|
freshUntil, ok := fresh.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'fresh' claim")
|
||||||
|
}
|
||||||
|
return int64(freshUntil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the subject of the token
|
||||||
|
func getTokenSubject(sub interface{}) (int, error) {
|
||||||
|
subject, ok := sub.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'sub' claim")
|
||||||
|
}
|
||||||
|
return int(subject), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the JTI of the token
|
||||||
|
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
|
||||||
|
jtiStr, ok := jti.(string)
|
||||||
|
if !ok {
|
||||||
|
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
|
||||||
|
}
|
||||||
|
jtiUUID, err := uuid.Parse(jtiStr)
|
||||||
|
if err != nil {
|
||||||
|
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
|
||||||
|
}
|
||||||
|
return jtiUUID, nil
|
||||||
|
}
|
||||||
146
jwt/validate.go
Normal file
146
jwt/validate.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse an access token and return a struct with all the claims. Does validation on
|
||||||
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
|
// has the correct scope.
|
||||||
|
func (gen *TokenGenerator) ValidateAccess(
|
||||||
|
tx *sql.Tx,
|
||||||
|
tokenString string,
|
||||||
|
) (*AccessToken, error) {
|
||||||
|
if tokenString == "" {
|
||||||
|
return nil, errors.New("Access token string not provided")
|
||||||
|
}
|
||||||
|
claims, err := parseToken(gen.secretKey, tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "parseToken")
|
||||||
|
}
|
||||||
|
expiry, err := checkTokenExpired(claims["exp"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||||
|
}
|
||||||
|
issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||||
|
}
|
||||||
|
ttl, err := getTokenTTL(claims["ttl"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenTTL")
|
||||||
|
}
|
||||||
|
scope, err := getTokenScope(claims["scope"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenScope")
|
||||||
|
}
|
||||||
|
if scope != "access" {
|
||||||
|
return nil, errors.New("Token is not an Access token")
|
||||||
|
}
|
||||||
|
issuedAt, err := getIssuedTime(claims["iat"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getIssuedTime")
|
||||||
|
}
|
||||||
|
subject, err := getTokenSubject(claims["sub"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenSubject")
|
||||||
|
}
|
||||||
|
fresh, err := getFreshTime(claims["fresh"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getFreshTime")
|
||||||
|
}
|
||||||
|
jti, err := getTokenJTI(claims["jti"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenJTI")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := &AccessToken{
|
||||||
|
ISS: issuer,
|
||||||
|
TTL: ttl,
|
||||||
|
EXP: expiry,
|
||||||
|
IAT: issuedAt,
|
||||||
|
SUB: subject,
|
||||||
|
Fresh: fresh,
|
||||||
|
JTI: jti,
|
||||||
|
Scope: scope,
|
||||||
|
gen: gen,
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
|
if err != nil && gen.dbConn != nil {
|
||||||
|
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||||
|
}
|
||||||
|
if !valid && gen.dbConn != nil {
|
||||||
|
return nil, errors.New("Token has been revoked")
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse a refresh token and return a struct with all the claims. Does validation on
|
||||||
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
|
// has the correct scope.
|
||||||
|
func (gen *TokenGenerator) ValidateRefresh(
|
||||||
|
tx *sql.Tx,
|
||||||
|
tokenString string,
|
||||||
|
) (*RefreshToken, error) {
|
||||||
|
if tokenString == "" {
|
||||||
|
return nil, errors.New("Refresh token string not provided")
|
||||||
|
}
|
||||||
|
claims, err := parseToken(gen.secretKey, tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "parseToken")
|
||||||
|
}
|
||||||
|
expiry, err := checkTokenExpired(claims["exp"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||||
|
}
|
||||||
|
issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||||
|
}
|
||||||
|
ttl, err := getTokenTTL(claims["ttl"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenTTL")
|
||||||
|
}
|
||||||
|
scope, err := getTokenScope(claims["scope"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenScope")
|
||||||
|
}
|
||||||
|
if scope != "refresh" {
|
||||||
|
return nil, errors.New("Token is not an Refresh token")
|
||||||
|
}
|
||||||
|
issuedAt, err := getIssuedTime(claims["iat"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getIssuedTime")
|
||||||
|
}
|
||||||
|
subject, err := getTokenSubject(claims["sub"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenSubject")
|
||||||
|
}
|
||||||
|
jti, err := getTokenJTI(claims["jti"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenJTI")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := &RefreshToken{
|
||||||
|
ISS: issuer,
|
||||||
|
TTL: ttl,
|
||||||
|
EXP: expiry,
|
||||||
|
IAT: issuedAt,
|
||||||
|
SUB: subject,
|
||||||
|
JTI: jti,
|
||||||
|
Scope: scope,
|
||||||
|
gen: gen,
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
|
if err != nil && gen.dbConn != nil {
|
||||||
|
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||||
|
}
|
||||||
|
if !valid && gen.dbConn != nil {
|
||||||
|
return nil, errors.New("Token has been revoked")
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
118
jwt/validate_test.go
Normal file
118
jwt/validate_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"supersecret",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return gen, mock, func() { db.Close() }
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||||
|
WithArgs(jti).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{}))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAccess_Success(t *testing.T) {
|
||||||
|
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// We don't know the JTI beforehand; match any arg
|
||||||
|
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||||
|
|
||||||
|
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
token, err := gen.ValidateAccess(tx, tokenStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, token.SUB)
|
||||||
|
require.Equal(t, "access", token.Scope)
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAccess_NoDB(t *testing.T) {
|
||||||
|
gen := newGeneratorWithNoDB(t)
|
||||||
|
|
||||||
|
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, token.SUB)
|
||||||
|
require.Equal(t, "access", token.Scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRefresh_Success(t *testing.T) {
|
||||||
|
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||||
|
|
||||||
|
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
token, err := gen.ValidateRefresh(tx, tokenStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, token.SUB)
|
||||||
|
require.Equal(t, "refresh", token.Scope)
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRefresh_NoDB(t *testing.T) {
|
||||||
|
gen := newGeneratorWithNoDB(t)
|
||||||
|
|
||||||
|
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := gen.ValidateRefresh(nil, tokenStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, token.SUB)
|
||||||
|
require.Equal(t, "refresh", token.Scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAccess_EmptyToken(t *testing.T) {
|
||||||
|
gen := newTestGenerator(t)
|
||||||
|
|
||||||
|
_, err := gen.ValidateAccess(nil, "")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRefresh_WrongScope(t *testing.T) {
|
||||||
|
gen := newTestGenerator(t)
|
||||||
|
|
||||||
|
// Create access token but validate as refresh
|
||||||
|
tokenStr, _, err := gen.NewAccess(1, false, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = gen.ValidateRefresh(nil, tokenStr)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
32
tmdb/config.go
Normal file
32
tmdb/config.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Image Image `json:"images"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Image struct {
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
SecureBaseURL string `json:"secure_base_url"`
|
||||||
|
BackdropSizes []string `json:"backdrop_sizes"`
|
||||||
|
LogoSizes []string `json:"logo_sizes"`
|
||||||
|
PosterSizes []string `json:"poster_sizes"`
|
||||||
|
ProfileSizes []string `json:"profile_sizes"`
|
||||||
|
StillSizes []string `json:"still_sizes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConfig(token string) (*Config, error) {
|
||||||
|
url := "https://api.themoviedb.org/3/configuration"
|
||||||
|
data, err := tmdbGet(url, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdbGet")
|
||||||
|
}
|
||||||
|
config := Config{}
|
||||||
|
json.Unmarshal(data, &config)
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
54
tmdb/credits.go
Normal file
54
tmdb/credits.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Credits struct {
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
Cast []Cast `json:"cast"`
|
||||||
|
Crew []Crew `json:"crew"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cast struct {
|
||||||
|
Adult bool `json:"adult"`
|
||||||
|
Gender int `json:"gender"`
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
KnownFor string `json:"known_for_department"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
OriginalName string `json:"original_name"`
|
||||||
|
Popularity int `json:"popularity"`
|
||||||
|
Profile string `json:"profile_path"`
|
||||||
|
CastID int32 `json:"cast_id"`
|
||||||
|
Character string `json:"character"`
|
||||||
|
CreditID string `json:"credit_id"`
|
||||||
|
Order int `json:"order"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Crew struct {
|
||||||
|
Adult bool `json:"adult"`
|
||||||
|
Gender int `json:"gender"`
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
KnownFor string `json:"known_for_department"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
OriginalName string `json:"original_name"`
|
||||||
|
Popularity int `json:"popularity"`
|
||||||
|
Profile string `json:"profile_path"`
|
||||||
|
CreditID string `json:"credit_id"`
|
||||||
|
Department string `json:"department"`
|
||||||
|
Job string `json:"job"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCredits(movieid int32, token string) (*Credits, error) {
|
||||||
|
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v/credits?language=en-US", movieid)
|
||||||
|
data, err := tmdbGet(url, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdbGet")
|
||||||
|
}
|
||||||
|
credits := Credits{}
|
||||||
|
json.Unmarshal(data, &credits)
|
||||||
|
return &credits, nil
|
||||||
|
}
|
||||||
41
tmdb/crew_functions.go
Normal file
41
tmdb/crew_functions.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import "sort"
|
||||||
|
|
||||||
|
type BilledCrew struct {
|
||||||
|
Name string
|
||||||
|
Roles []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (credits *Credits) BilledCrew() []BilledCrew {
|
||||||
|
crewmap := make(map[string][]string)
|
||||||
|
billedcrew := []BilledCrew{}
|
||||||
|
for _, crew := range credits.Crew {
|
||||||
|
if crew.Job == "Director" ||
|
||||||
|
crew.Job == "Screenplay" ||
|
||||||
|
crew.Job == "Writer" ||
|
||||||
|
crew.Job == "Novel" ||
|
||||||
|
crew.Job == "Story" {
|
||||||
|
crewmap[crew.Name] = append(crewmap[crew.Name], crew.Job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, jobs := range crewmap {
|
||||||
|
billedcrew = append(billedcrew, BilledCrew{Name: name, Roles: jobs})
|
||||||
|
}
|
||||||
|
for i := range billedcrew {
|
||||||
|
sort.Strings(billedcrew[i].Roles)
|
||||||
|
}
|
||||||
|
sort.Slice(billedcrew, func(i, j int) bool {
|
||||||
|
return billedcrew[i].Roles[0] < billedcrew[j].Roles[0]
|
||||||
|
})
|
||||||
|
return billedcrew
|
||||||
|
}
|
||||||
|
|
||||||
|
func (billedcrew *BilledCrew) FRoles() string {
|
||||||
|
jobs := ""
|
||||||
|
for _, job := range billedcrew.Roles {
|
||||||
|
jobs += job + ", "
|
||||||
|
}
|
||||||
|
return jobs[:len(jobs)-2]
|
||||||
|
}
|
||||||
5
tmdb/go.mod
Normal file
5
tmdb/go.mod
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
module git.haelnorr.com/h/golib/tmdb
|
||||||
|
|
||||||
|
go 1.25.5
|
||||||
|
|
||||||
|
require github.com/pkg/errors v0.9.1
|
||||||
2
tmdb/go.sum
Normal file
2
tmdb/go.sum
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
45
tmdb/movie.go
Normal file
45
tmdb/movie.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Movie struct {
|
||||||
|
Adult bool `json:"adult"`
|
||||||
|
Backdrop string `json:"backdrop_path"`
|
||||||
|
Collection string `json:"belongs_to_collection"`
|
||||||
|
Budget int `json:"budget"`
|
||||||
|
Genres []Genre `json:"genres"`
|
||||||
|
Homepage string `json:"homepage"`
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
IMDbID string `json:"imdb_id"`
|
||||||
|
OriginalLanguage string `json:"original_language"`
|
||||||
|
OriginalTitle string `json:"original_title"`
|
||||||
|
Overview string `json:"overview"`
|
||||||
|
Popularity float32 `json:"popularity"`
|
||||||
|
Poster string `json:"poster_path"`
|
||||||
|
ProductionCompanies []ProductionCompany `json:"production_companies"`
|
||||||
|
ProductionCountries []ProductionCountry `json:"production_countries"`
|
||||||
|
ReleaseDate string `json:"release_date"`
|
||||||
|
Revenue int `json:"revenue"`
|
||||||
|
Runtime int `json:"runtime"`
|
||||||
|
SpokenLanguages []SpokenLanguage `json:"spoken_languages"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Tagline string `json:"tagline"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Video bool `json:"video"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMovie(id int32, token string) (*Movie, error) {
|
||||||
|
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v?language=en-US", id)
|
||||||
|
data, err := tmdbGet(url, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdbGet")
|
||||||
|
}
|
||||||
|
movie := Movie{}
|
||||||
|
json.Unmarshal(data, &movie)
|
||||||
|
return &movie, nil
|
||||||
|
}
|
||||||
42
tmdb/movie_functions.go
Normal file
42
tmdb/movie_functions.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (movie *Movie) FRuntime() string {
|
||||||
|
hours := movie.Runtime / 60
|
||||||
|
mins := movie.Runtime % 60
|
||||||
|
return fmt.Sprintf("%dh %02dm", hours, mins)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *Movie) GetPoster(image *Image, size string) string {
|
||||||
|
base, err := url.Parse(image.SecureBaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
fullPath := path.Join(base.Path, size, movie.Poster)
|
||||||
|
base.Path = fullPath
|
||||||
|
return base.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *Movie) ReleaseYear() string {
|
||||||
|
if movie.ReleaseDate == "" {
|
||||||
|
return ""
|
||||||
|
} else {
|
||||||
|
return "(" + movie.ReleaseDate[:4] + ")"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *Movie) FGenres() string {
|
||||||
|
genres := ""
|
||||||
|
for _, genre := range movie.Genres {
|
||||||
|
genres += genre.Name + ", "
|
||||||
|
}
|
||||||
|
if len(genres) > 2 {
|
||||||
|
return genres[:len(genres)-2]
|
||||||
|
}
|
||||||
|
return genres
|
||||||
|
}
|
||||||
28
tmdb/request.go
Normal file
28
tmdb/request.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func tmdbGet(url string, token string) ([]byte, error) {
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "http.NewRequest")
|
||||||
|
}
|
||||||
|
req.Header.Add("accept", "application/json")
|
||||||
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "http.DefaultClient.Do")
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "io.ReadAll")
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
79
tmdb/search.go
Normal file
79
tmdb/search.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Page int `json:"page"`
|
||||||
|
TotalPages int `json:"total_pages"`
|
||||||
|
TotalResults int `json:"total_results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResultMovies struct {
|
||||||
|
Result
|
||||||
|
Results []ResultMovie `json:"results"`
|
||||||
|
}
|
||||||
|
type ResultMovie struct {
|
||||||
|
Adult bool `json:"adult"`
|
||||||
|
BackdropPath string `json:"backdrop_path"`
|
||||||
|
GenreIDs []int `json:"genre_ids"`
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
OriginalLanguage string `json:"original_language"`
|
||||||
|
OriginalTitle string `json:"original_title"`
|
||||||
|
Overview string `json:"overview"`
|
||||||
|
Popularity int `json:"popularity"`
|
||||||
|
PosterPath string `json:"poster_path"`
|
||||||
|
ReleaseDate string `json:"release_date"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Video bool `json:"video"`
|
||||||
|
VoteAverage int `json:"vote_average"`
|
||||||
|
VoteCount int `json:"vote_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *ResultMovie) GetPoster(image *Image, size string) string {
|
||||||
|
base, err := url.Parse(image.SecureBaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
fullPath := path.Join(base.Path, size, movie.PosterPath)
|
||||||
|
base.Path = fullPath
|
||||||
|
return base.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *ResultMovie) ReleaseYear() string {
|
||||||
|
if movie.ReleaseDate == "" {
|
||||||
|
return ""
|
||||||
|
} else {
|
||||||
|
return "(" + movie.ReleaseDate[:4] + ")"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: genres list https://developer.themoviedb.org/reference/genre-movie-list
|
||||||
|
// func (movie *ResultMovie) FGenres() string {
|
||||||
|
// genres := ""
|
||||||
|
// for _, genre := range movie.Genres {
|
||||||
|
// genres += genre.Name + ", "
|
||||||
|
// }
|
||||||
|
// return genres[:len(genres)-2]
|
||||||
|
// }
|
||||||
|
|
||||||
|
func SearchMovies(token string, query string, adult bool, page int) (*ResultMovies, error) {
|
||||||
|
url := "https://api.themoviedb.org/3/search/movie" +
|
||||||
|
fmt.Sprintf("?query=%s", url.QueryEscape(query)) +
|
||||||
|
fmt.Sprintf("&include_adult=%t", adult) +
|
||||||
|
fmt.Sprintf("&page=%v", page) +
|
||||||
|
"&language=en-US"
|
||||||
|
response, err := tmdbGet(url, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdbGet")
|
||||||
|
}
|
||||||
|
var results ResultMovies
|
||||||
|
json.Unmarshal(response, &results)
|
||||||
|
return &results, nil
|
||||||
|
}
|
||||||
24
tmdb/structs.go
Normal file
24
tmdb/structs.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
type Genre struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProductionCompany struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Logo string `json:"logo_path"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
OriginCountry string `json:"origin_country"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProductionCountry struct {
|
||||||
|
ISO_3166_1 string `json:"iso_3166_1"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SpokenLanguage struct {
|
||||||
|
EnglishName string `json:"english_name"`
|
||||||
|
ISO_639_1 string `json:"iso_639_1"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user