Compare commits

...

6 Commits

Author SHA1 Message Date
b13b783d7e created hwsauth module 2026-01-04 01:01:17 +11:00
14eec74683 created hws module 2026-01-04 00:59:24 +11:00
ade3fa0454 imported env module 2026-01-02 19:03:07 +11:00
516be905a9 imported cookies module 2026-01-02 18:25:38 +11:00
6e632267ea added cookie control to jwt 2026-01-02 18:15:49 +11:00
05aad5f11b fixed transaction issues 2026-01-01 22:44:39 +11:00
39 changed files with 1309 additions and 69 deletions

19
cookies/delete.go Normal file
View File

@@ -0,0 +1,19 @@
package cookies
import (
"net/http"
"time"
)
// Tell the browser to delete the cookie matching the name provided
// Path must match the original set cookie for it to delete
func DeleteCookie(w http.ResponseWriter, name string, path string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: path,
Expires: time.Unix(0, 0), // Expire in the past
MaxAge: -1, // Immediately expire
HttpOnly: true,
})
}

3
cookies/go.mod Normal file
View File

@@ -0,0 +1,3 @@
module git.haelnorr.com/h/golib/cookies
go 1.25.5

36
cookies/pagefrom.go Normal file
View File

@@ -0,0 +1,36 @@
package cookies
import (
"net/http"
"net/url"
)
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
pageFromCookie, err := r.Cookie("pagefrom")
if err != nil {
return "/"
}
pageFrom := pageFromCookie.Value
DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
return pageFrom
}
// Check the referer of the request, and if it matches the trustedHost, set
// the "pagefrom" cookie as the Path of the referer
func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
referer := r.Referer()
parsedURL, err := url.Parse(referer)
if err != nil {
return
}
var pageFrom string
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
pageFrom = "/"
} else if parsedURL.Path == "/login" || parsedURL.Path == "/register" {
return
} else {
pageFrom = parsedURL.Path
}
SetCookie(w, "pagefrom", "/", pageFrom, 0)
}

23
cookies/set.go Normal file
View File

@@ -0,0 +1,23 @@
package cookies
import (
"net/http"
)
// Set a cookie with the given name, path and value. maxAge directly relates
// to cookie MaxAge (0 for no max age, >0 for TTL in seconds)
func SetCookie(
w http.ResponseWriter,
name string,
path string,
value string,
maxAge int,
) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
Path: path,
HttpOnly: true,
MaxAge: maxAge,
})
}

35
env/boolean.go vendored Normal file
View File

@@ -0,0 +1,35 @@
package env
import (
"os"
"strings"
)
// Get an environment variable as a boolean, specifying a default value if its
// not set or can't be parsed properly into a bool
func Bool(key string, defaultValue bool) bool {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
truthy := map[string]bool{
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
"enable": true, "enabled": true, "active": true, "affirmative": true,
}
falsy := map[string]bool{
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
"disable": false, "disabled": false, "inactive": false, "negative": false,
}
normalized := strings.TrimSpace(strings.ToLower(val))
if val, ok := truthy[normalized]; ok {
return val
}
if val, ok := falsy[normalized]; ok {
return val
}
return defaultValue
}

23
env/duration.go vendored Normal file
View File

@@ -0,0 +1,23 @@
package env
import (
"os"
"strconv"
"time"
)
// Get an environment variable as a time.Duration, specifying a default value if its
// not set or can't be parsed properly
func Duration(key string, defaultValue time.Duration) time.Duration {
val, exists := os.LookupEnv(key)
if !exists {
return time.Duration(defaultValue)
}
intVal, err := strconv.Atoi(val)
if err != nil {
return time.Duration(defaultValue)
}
return time.Duration(intVal)
}

3
env/go.mod vendored Normal file
View File

@@ -0,0 +1,3 @@
module git.haelnorr.com/h/golib/env
go 1.25.5

37
env/int.go vendored Normal file
View File

@@ -0,0 +1,37 @@
package env
import (
"os"
"strconv"
)
// Get an environment variable as an int, specifying a default value if its
// not set or can't be parsed properly into an int
func Int(key string, defaultValue int) int {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.Atoi(val)
if err != nil {
return defaultValue
}
return intVal
}
// Get an environment variable as an int64, specifying a default value if its
// not set or can't be parsed properly into an int64
func Int64(key string, defaultValue int64) int64 {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return defaultValue
}
return intVal
}

14
env/string.go vendored Normal file
View File

@@ -0,0 +1,14 @@
package env
import (
"os"
)
// Get an environment variable, specifying a default value if its not set
func String(key string, defaultValue string) string {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
return val
}

39
hws/errors.go Normal file
View File

@@ -0,0 +1,39 @@
package hws
import "net/http"
type HWSError struct {
statusCode int // HTTP Status code
message string // Error message
error error // Error
}
type ErrorPage func(statusCode int, w http.ResponseWriter, r *http.Request) error
func NewError(statusCode int, msg string, err error) *HWSError {
return &HWSError{
statusCode: statusCode,
message: msg,
error: err,
}
}
func (server *Server) AddErrorPage(page ErrorPage) {
server.errorPage = page
}
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error *HWSError) {
w.WriteHeader(error.statusCode)
server.logger.logger.Error().Err(error.error).Msg(error.message)
if server.errorPage != nil {
err := server.errorPage(error.statusCode, w, r)
if err != nil {
server.logger.logger.Error().Err(err).Msg("Failed to render error page")
}
}
}
func (server *Server) ThrowWarn(w http.ResponseWriter, error *HWSError) {
w.WriteHeader(error.statusCode)
server.logger.logger.Warn().Err(error.error).Msg(error.message)
}

14
hws/go.mod Normal file
View File

@@ -0,0 +1,14 @@
module git.haelnorr.com/h/golib/hws
go 1.25.5
require (
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.34.0
)
require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
golang.org/x/sys v0.12.0 // indirect
)

16
hws/go.sum Normal file
View File

@@ -0,0 +1,16 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

31
hws/gzip.go Normal file
View File

@@ -0,0 +1,31 @@
package hws
import (
"compress/gzip"
"io"
"net/http"
"strings"
)
func addgzip(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
next.ServeHTTP(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
next.ServeHTTP(gzw, r)
})
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

44
hws/logger.go Normal file
View File

@@ -0,0 +1,44 @@
package hws
import (
"errors"
"fmt"
"net/url"
"github.com/rs/zerolog"
)
type logger struct {
logger *zerolog.Logger
ignoredPaths []string
}
// Server.AddLogger adds a logger to the server to use for request logging.
func (server *Server) AddLogger(zlogger *zerolog.Logger) error {
if zlogger == nil {
return errors.New("Unable to add logger, no logger provided")
}
server.logger = &logger{
logger: zlogger,
}
return nil
}
// Server.LoggerIgnorePaths sets a list of URL paths to ignore logging for.
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
// Useful for ignoring requests to CSS files or favicons
func (server *Server) LoggerIgnorePaths(paths ...string) error {
for _, path := range paths {
u, err := url.Parse(path)
valid := err == nil &&
u.Scheme == "" &&
u.Host == "" &&
u.RawQuery == "" &&
u.Fragment == ""
if !valid {
return fmt.Errorf("Invalid path: '%s'", path)
}
}
server.logger.ignoredPaths = paths
return nil
}

51
hws/middleware.go Normal file
View File

@@ -0,0 +1,51 @@
package hws
import (
"errors"
"net/http"
)
type Middleware func(h http.Handler) http.Handler
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
// Server.AddMiddleware registers all the middleware.
// Middleware will be run in the order that they are provided.
func (server *Server) AddMiddleware(middleware ...Middleware) error {
if !server.routes {
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
}
// RUN LOGGING MIDDLEWARE FIRST
server.server.Handler = logging(server.server.Handler, server.logger)
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
for i := len(middleware); i > 0; i-- {
server.server.Handler = middleware[i-1](server.server.Handler)
}
// RUN GZIP
if server.gzip {
server.server.Handler = addgzip(server.server.Handler)
}
// RUN TIMER MIDDLEWARE LAST
server.server.Handler = startTimer(server.server.Handler)
server.middleware = true
return nil
}
func (server *Server) NewMiddleware(
middlewareFunc MiddlewareFunc,
) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
newReq, herr := middlewareFunc(w, r)
if herr != nil {
server.ThrowError(w, r, herr)
return
}
next.ServeHTTP(w, newReq)
})
}
}

38
hws/middleware_logging.go Normal file
View File

@@ -0,0 +1,38 @@
package hws
import (
"net/http"
"slices"
"time"
)
// Middleware to add logs to console with details of the request
func logging(next http.Handler, logger *logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if logger == nil {
next.ServeHTTP(w, r)
return
}
if slices.Contains(logger.ignoredPaths, r.URL.Path) {
next.ServeHTTP(w, r)
return
}
start, err := getStartTime(r.Context())
if err != nil {
logger.logger.Error().Err(err)
return
}
wrapped := &wrappedWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
next.ServeHTTP(wrapped, r)
logger.logger.Info().
Int("status", wrapped.statusCode).
Str("method", r.Method).
Str("resource", r.URL.Path).
Dur("time_elapsed", time.Since(start)).
Str("remote_addr", r.Header.Get("X-Forwarded-For")).
Msg("Served")
})
}

33
hws/middleware_timer.go Normal file
View File

@@ -0,0 +1,33 @@
package hws
import (
"context"
"errors"
"net/http"
"time"
)
func startTimer(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
ctx := setStart(r.Context(), start)
newReq := r.WithContext(ctx)
next.ServeHTTP(w, newReq)
},
)
}
// Set the start time of the request
func setStart(ctx context.Context, time time.Time) context.Context {
return context.WithValue(ctx, "hws context key request-timer", time)
}
// Get the start time of the request
func getStartTime(ctx context.Context) (time.Time, error) {
start, ok := ctx.Value("hws context key request-timer").(time.Time)
if !ok {
return time.Time{}, errors.New("Failed to get start time of request")
}
return start, nil
}

15
hws/responsewriter.go Normal file
View File

@@ -0,0 +1,15 @@
package hws
import "net/http"
// Wraps the http.ResponseWriter, adding a statusCode field
type wrappedWriter struct {
http.ResponseWriter
statusCode int
}
// Extends WriteHeader to the ResponseWriter to add the status code
func (w *wrappedWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
w.statusCode = statusCode
}

62
hws/routes.go Normal file
View File

@@ -0,0 +1,62 @@
package hws
import (
"errors"
"fmt"
"net/http"
)
type Route struct {
Path string // Absolute path to the requested resource
Method Method // HTTP Method
Handler http.Handler // Handler to use for the request
}
type Method string
const (
MethodGET Method = "GET"
MethodPOST Method = "POST"
MethodPUT Method = "PUT"
MethodHEAD Method = "HEAD"
MethodDELETE Method = "DELETE"
MethodCONNECT Method = "CONNECT"
MethodOPTIONS Method = "OPTIONS"
MethodTRACE Method = "TRACE"
MethodPATCH Method = "PATCH"
)
// Server.AddRoutes registers the page handlers for the server.
// At least one route must be provided.
func (server *Server) AddRoutes(routes ...Route) error {
if len(routes) == 0 {
return errors.New("No routes provided")
}
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {})
for _, route := range routes {
if !validMethod(route.Method) {
return fmt.Errorf("Invalid method %s for path %s", route.Method, route.Path)
}
if route.Handler == nil {
return fmt.Errorf("No handler provided for %s %s", route.Method, route.Path)
}
pattern := fmt.Sprintf("%s %s", route.Method, route.Path)
mux.Handle(pattern, route.Handler)
}
server.server.Handler = mux
server.routes = true
return nil
}
func validMethod(m Method) bool {
switch m {
case MethodGET, MethodPOST, MethodPUT, MethodHEAD,
MethodDELETE, MethodCONNECT, MethodOPTIONS, MethodTRACE, MethodPATCH:
return true
default:
return false
}
}

52
hws/safefileserver.go Normal file
View File

@@ -0,0 +1,52 @@
package hws
import (
"net/http"
"os"
"github.com/pkg/errors"
)
// Wrapper for default FileSystem
type justFilesFilesystem struct {
fs http.FileSystem
}
// Wrapper for default File
type neuteredReaddirFile struct {
http.File
}
// Modifies the behavior of FileSystem.Open to return the neutered version of File
func (fs justFilesFilesystem) Open(name string) (http.File, error) {
f, err := fs.fs.Open(name)
if err != nil {
return nil, err
}
// Check if the requested path is a directory
// and explicitly return an error to trigger a 404
fileInfo, err := f.Stat()
if err != nil {
return nil, err
}
if fileInfo.IsDir() {
return nil, os.ErrNotExist
}
return neuteredReaddirFile{f}, nil
}
// Overrides the Readdir method of File to always return nil
func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) {
return nil, nil
}
func SafeFileServer(fileSystem *http.FileSystem) (http.Handler, error) {
if fileSystem == nil {
return nil, errors.New("No file system provided")
}
nfs := justFilesFilesystem{*fileSystem}
fs := http.FileServer(nfs)
return fs, nil
}

84
hws/server.go Normal file
View File

@@ -0,0 +1,84 @@
package hws
import (
"context"
"fmt"
"net"
"net/http"
"time"
"github.com/pkg/errors"
)
type Server struct {
server *http.Server
logger *logger
routes bool
middleware bool
gzip bool
errorPage ErrorPage
}
// NewServer returns a new hws.Server with the specified parameters.
// The timeout options are specified in seconds
func NewServer(
host string,
port string,
readHeaderTimeout time.Duration,
writeTimeout time.Duration,
idleTimeout time.Duration,
gzip bool,
) (*Server, error) {
// TODO: test that host and port are valid values
httpServer := &http.Server{
Addr: net.JoinHostPort(host, port),
ReadHeaderTimeout: readHeaderTimeout * time.Second,
WriteTimeout: writeTimeout * time.Second,
IdleTimeout: idleTimeout * time.Second,
}
server := &Server{
server: httpServer,
routes: false,
gzip: gzip,
}
return server, nil
}
func (server *Server) Start() error {
if !server.routes {
return errors.New("Server.AddRoutes must be run before starting the server")
}
if !server.middleware {
err := server.AddMiddleware()
if err != nil {
return errors.Wrap(err, "server.AddMiddleware")
}
}
go func() {
if server.logger == nil {
fmt.Printf("Listening for requests on %s", server.server.Addr)
} else {
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
}
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if server.logger == nil {
fmt.Printf("Server encountered a fatal error: %s", err.Error())
} else {
server.logger.logger.Error().Err(err).Msg("Server encountered a fatal error")
}
}
}()
return nil
}
func (server *Server) Shutdown(ctx context.Context) {
if err := server.server.Shutdown(ctx); err != nil {
if server.logger == nil {
fmt.Printf("Failed to gracefully shutdown the server: %s", err.Error())
} else {
server.logger.logger.Error().Err(err).Msg("Failed to gracefully shutdown the server")
}
}
}

54
hwsauth/authenticate.go Normal file
View File

@@ -0,0 +1,54 @@
package hwsauth
import (
"database/sql"
"net/http"
"time"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
// Check the cookies for token strings and attempt to authenticate them
func (auth *Authenticator[T]) getAuthenticatedUser(
tx *sql.Tx,
w http.ResponseWriter,
r *http.Request,
) (*authenticatedModel[T], error) {
// Get token strings from cookies
atStr, rtStr := jwt.GetTokenCookies(r)
if atStr == "" && rtStr == "" {
return nil, errors.New("No token strings provided")
}
// Attempt to parse the access token
aT, err := auth.tokenGenerator.ValidateAccess(tx, atStr)
if err != nil {
// Access token invalid, attempt to parse refresh token
rT, err := auth.tokenGenerator.ValidateRefresh(tx, rtStr)
if err != nil {
return nil, errors.Wrap(err, "auth.tokenGenerator.ValidateRefresh")
}
// Refresh token valid, attempt to get a new token pair
model, err := auth.refreshAuthTokens(tx, w, r, rT)
if err != nil {
return nil, errors.Wrap(err, "auth.refreshAuthTokens")
}
// New token pair sent, return the authorized user
authUser := authenticatedModel[T]{
model: model,
fresh: time.Now().Unix(),
}
return &authUser, nil
}
// Access token valid
model, err := auth.load(tx, aT.SUB)
if err != nil {
return nil, errors.Wrap(err, "auth.load")
}
authUser := authenticatedModel[T]{
model: model,
fresh: aT.Fresh,
}
return &authUser, nil
}

93
hwsauth/authenticator.go Normal file
View File

@@ -0,0 +1,93 @@
package hwsauth
import (
"database/sql"
"projectreshoot/pkg/hws"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
type Authenticator[T Model] struct {
tokenGenerator *jwt.TokenGenerator
load LoadFunc[T]
conn *sql.DB
ignoredPaths []string
logger *zerolog.Logger
server *hws.Server
errorPage hws.ErrorPage
SSL bool // Use SSL for JWT tokens. Default true
TrustedHost string // TrustedHost to use for SSL verification
SecretKey string // Secret key to use for JWT tokens
AccessTokenExpiry int64 // Expiry time for Access tokens in minutes. Default 5
RefreshTokenExpiry int64 // Expiry time for Refresh tokens in minutes. Default 1440 (1 day)
TokenFreshTime int64 // Expiry time of token freshness. Default 5 minutes
LandingPage string // Path of the desired landing page for logged in users
}
// NewAuthenticator creates and returns a new Authenticator using the provided configuration.
// All expiry times should be provided in minutes.
// trustedHost and secretKey strings must be provided.
func NewAuthenticator[T Model](
load LoadFunc[T],
server *hws.Server,
conn *sql.DB,
logger *zerolog.Logger,
errorPage hws.ErrorPage,
) (*Authenticator[T], error) {
if load == nil {
return nil, errors.New("No function to load model supplied")
}
if server == nil {
return nil, errors.New("No hws.Server provided")
}
if conn == nil {
return nil, errors.New("No database connection supplied")
}
if logger == nil {
return nil, errors.New("No logger provided")
}
if errorPage == nil {
return nil, errors.New("No ErrorPage provided")
}
auth := Authenticator[T]{
load: load,
server: server,
conn: conn,
logger: logger,
errorPage: errorPage,
AccessTokenExpiry: 5,
RefreshTokenExpiry: 1440,
TokenFreshTime: 5,
SSL: true,
}
return &auth, nil
}
// Initialise finishes the setup and prepares the Authenticator for use.
// Any custom configuration must be set before Initialise is called
func (auth *Authenticator[T]) Initialise() error {
if auth.TrustedHost == "" {
return errors.New("Trusted host must be provided")
}
if auth.SecretKey == "" {
return errors.New("Secret key cannot be blank")
}
if auth.LandingPage == "" {
return errors.New("No landing page specified")
}
tokenGen, err := jwt.CreateGenerator(
auth.AccessTokenExpiry,
auth.RefreshTokenExpiry,
auth.TokenFreshTime,
auth.TrustedHost,
auth.SecretKey,
auth.conn,
)
if err != nil {
return errors.Wrap(err, "jwt.CreateGenerator")
}
auth.tokenGenerator = tokenGen
return nil
}

18
hwsauth/go.mod Normal file
View File

@@ -0,0 +1,18 @@
module git.haelnorr.com/h/golib/hwsauth
go 1.25.5
require (
git.haelnorr.com/h/golib/cookies v0.9.0
git.haelnorr.com/h/golib/jwt v0.9.2
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.34.0
)
require (
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
golang.org/x/sys v0.12.0 // indirect
)

34
hwsauth/go.sum Normal file
View File

@@ -0,0 +1,34 @@
git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDOV/AuWs=
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
git.haelnorr.com/h/golib/jwt v0.9.2 h1:l1Ow7DPGACAU54CnMP/NlZjdc4nRD1wr3xZ8a7taRvU=
git.haelnorr.com/h/golib/jwt v0.9.2/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

22
hwsauth/ignorepaths.go Normal file
View File

@@ -0,0 +1,22 @@
package hwsauth
import (
"fmt"
"net/url"
)
func (auth *Authenticator[T]) IgnorePaths(paths ...string) error {
for _, path := range paths {
u, err := url.Parse(path)
valid := err == nil &&
u.Scheme == "" &&
u.Host == "" &&
u.RawQuery == "" &&
u.Fragment == ""
if !valid {
return fmt.Errorf("Invalid path: '%s'", path)
}
}
auth.ignoredPaths = paths
return nil
}

22
hwsauth/login.go Normal file
View File

@@ -0,0 +1,22 @@
package hwsauth
import (
"net/http"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
func (auth *Authenticator[T]) Login(
w http.ResponseWriter,
r *http.Request,
model T,
rememberMe bool,
) error {
err := jwt.SetTokenCookies(w, r, auth.tokenGenerator, model.ID(), true, rememberMe, auth.SSL)
if err != nil {
return errors.Wrap(err, "jwt.SetTokenCookies")
}
return nil
}

27
hwsauth/logout.go Normal file
View File

@@ -0,0 +1,27 @@
package hwsauth
import (
"database/sql"
"net/http"
"git.haelnorr.com/h/golib/cookies"
"github.com/pkg/errors"
)
func (auth *Authenticator[T]) Logout(tx *sql.Tx, w http.ResponseWriter, r *http.Request) error {
aT, rT, err := auth.getTokens(tx, r)
if err != nil {
return errors.Wrap(err, "auth.getTokens")
}
err = aT.Revoke(tx)
if err != nil {
return errors.Wrap(err, "aT.Revoke")
}
err = rT.Revoke(tx)
if err != nil {
return errors.Wrap(err, "rT.Revoke")
}
cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/")
return nil
}

42
hwsauth/middleware.go Normal file
View File

@@ -0,0 +1,42 @@
package hwsauth
import (
"context"
"net/http"
"projectreshoot/pkg/hws"
"slices"
"time"
)
func (auth *Authenticator[T]) Authenticate() hws.Middleware {
return auth.server.NewMiddleware(auth.authenticate())
}
func (auth *Authenticator[T]) authenticate() hws.MiddlewareFunc {
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
if slices.Contains(auth.ignoredPaths, r.URL.Path) {
return r, nil
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Start the transaction
tx, err := auth.conn.BeginTx(ctx, nil)
if err != nil {
return nil, hws.NewError(http.StatusServiceUnavailable, "Unable to start transaction", err)
}
model, err := auth.getAuthenticatedUser(tx, w, r)
if err != nil {
tx.Rollback()
auth.logger.Debug().
Str("remote_addr", r.RemoteAddr).
Err(err).
Msg("Failed to authenticate user")
return r, nil
}
tx.Commit()
authContext := setAuthenticatedModel(r.Context(), model)
newReq := r.WithContext(authContext)
return newReq, nil
}
}

46
hwsauth/model.go Normal file
View File

@@ -0,0 +1,46 @@
package hwsauth
import (
"context"
"database/sql"
)
type authenticatedModel[T Model] struct {
model T
fresh int64
}
func getNil[T Model]() T {
var result T
return result
}
type Model interface {
ID() int
}
type ContextLoader[T Model] func(ctx context.Context) T
type LoadFunc[T Model] func(tx *sql.Tx, id int) (T, error)
// Return a new context with the user added in
func setAuthenticatedModel[T Model](ctx context.Context, m *authenticatedModel[T]) context.Context {
return context.WithValue(ctx, "hwsauth context key authenticated-model", m)
}
// Retrieve a user from the given context. Returns nil if not set
func getAuthorizedModel[T Model](ctx context.Context) *authenticatedModel[T] {
model, ok := ctx.Value("hwsauth context key authenticated-model").(*authenticatedModel[T])
if !ok {
return nil
}
return model
}
func (auth *Authenticator[T]) CurrentModel(ctx context.Context) T {
model := getAuthorizedModel[T](ctx)
if model == nil {
return getNil[T]()
}
return model.model
}

43
hwsauth/protectpage.go Normal file
View File

@@ -0,0 +1,43 @@
package hwsauth
import (
"net/http"
"time"
)
// Checks if the model is set in the context and shows 401 page if not logged in
func (auth *Authenticator[T]) LoginReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model := getAuthorizedModel[T](r.Context())
if model == nil {
auth.errorPage(http.StatusUnauthorized, w, r)
return
}
next.ServeHTTP(w, r)
})
}
// Checks if the model is set in the context and redirects them to the landing page if
// they are logged in
func (auth *Authenticator[T]) LogoutReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model := getAuthorizedModel[T](r.Context())
if model != nil {
http.Redirect(w, r, auth.LandingPage, http.StatusFound)
return
}
next.ServeHTTP(w, r)
})
}
func (auth *Authenticator[T]) FreshReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
model := getAuthorizedModel[T](r.Context())
isFresh := time.Now().Before(time.Unix(model.fresh, 0))
if !isFresh {
w.WriteHeader(444)
return
}
next.ServeHTTP(w, r)
})
}

66
hwsauth/reauthenticate.go Normal file
View File

@@ -0,0 +1,66 @@
package hwsauth
import (
"database/sql"
"net/http"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
func (auth *Authenticator[T]) RefreshAuthTokens(tx *sql.Tx, w http.ResponseWriter, r *http.Request) error {
aT, rT, err := auth.getTokens(tx, r)
if err != nil {
return errors.Wrap(err, "getTokens")
}
rememberMe := map[string]bool{
"session": false,
"exp": true,
}[aT.TTL]
// issue new tokens for the user
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, rT.SUB, true, rememberMe, auth.SSL)
if err != nil {
return errors.Wrap(err, "jwt.SetTokenCookies")
}
err = revokeTokenPair(tx, aT, rT)
if err != nil {
return errors.Wrap(err, "revokeTokenPair")
}
return nil
}
// Get the tokens from the request
func (auth *Authenticator[T]) getTokens(
tx *sql.Tx,
r *http.Request,
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
// get the existing tokens from the cookies
atStr, rtStr := jwt.GetTokenCookies(r)
aT, err := auth.tokenGenerator.ValidateAccess(tx, atStr)
if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateAccess")
}
rT, err := auth.tokenGenerator.ValidateRefresh(tx, rtStr)
if err != nil {
return nil, nil, errors.Wrap(err, "tokenGenerator.ValidateRefresh")
}
return aT, rT, nil
}
// Revoke the given token pair
func revokeTokenPair(
tx *sql.Tx,
aT *jwt.AccessToken,
rT *jwt.RefreshToken,
) error {
err := aT.Revoke(tx)
if err != nil {
return errors.Wrap(err, "aT.Revoke")
}
err = rT.Revoke(tx)
if err != nil {
return errors.Wrap(err, "rT.Revoke")
}
return nil
}

40
hwsauth/refreshtokens.go Normal file
View File

@@ -0,0 +1,40 @@
package hwsauth
import (
"database/sql"
"net/http"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
// Attempt to use a valid refresh token to generate a new token pair
func (auth *Authenticator[T]) refreshAuthTokens(
tx *sql.Tx,
w http.ResponseWriter,
r *http.Request,
rT *jwt.RefreshToken,
) (T, error) {
model, err := auth.load(tx, rT.SUB)
if err != nil {
return getNil[T](), errors.Wrap(err, "auth.load")
}
rememberMe := map[string]bool{
"session": false,
"exp": true,
}[rT.TTL]
// Set fresh to true because new tokens coming from refresh request
err = jwt.SetTokenCookies(w, r, auth.tokenGenerator, model.ID(), false, rememberMe, auth.SSL)
if err != nil {
return getNil[T](), errors.Wrap(err, "jwt.SetTokenCookies")
}
// New tokens sent, revoke the old tokens
err = rT.Revoke(tx)
if err != nil {
return getNil[T](), errors.Wrap(err, "rT.Revoke")
}
// Return the authorized user
return model, nil
}

73
jwt/cookies.go Normal file
View File

@@ -0,0 +1,73 @@
package jwt
import (
"github.com/pkg/errors"
"net/http"
"time"
)
// Get the value of the access and refresh tokens
func GetTokenCookies(
r *http.Request,
) (acc string, ref string) {
accCookie, accErr := r.Cookie("access")
refCookie, refErr := r.Cookie("refresh")
var (
accStr string = ""
refStr string = ""
)
if accErr == nil {
accStr = accCookie.Value
}
if refErr == nil {
refStr = refCookie.Value
}
return accStr, refStr
}
// Set a token with the provided details
func setToken(
w http.ResponseWriter,
token string,
scope string,
exp int64,
rememberme bool,
useSSL bool,
) {
tokenCookie := &http.Cookie{
Name: scope,
Value: token,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: useSSL,
}
if rememberme {
tokenCookie.Expires = time.Unix(exp, 0)
}
http.SetCookie(w, tokenCookie)
}
// Generate new tokens for the subject and set them as cookies
func SetTokenCookies(
w http.ResponseWriter,
r *http.Request,
tokenGen *TokenGenerator,
subject int,
fresh bool,
rememberMe bool,
useSSL bool,
) error {
at, atexp, err := tokenGen.NewAccess(subject, fresh, rememberMe)
if err != nil {
return errors.Wrap(err, "jwt.GenerateAccessToken")
}
rt, rtexp, err := tokenGen.NewRefresh(subject, rememberMe)
if err != nil {
return errors.Wrap(err, "jwt.GenerateRefreshToken")
}
// Don't set the cookies until we know no errors occured
setToken(w, at, "access", atexp, rememberMe, useSSL)
setToken(w, rt, "refresh", rtexp, rememberMe, useSSL)
return nil
}

View File

@@ -1,47 +1,31 @@
package jwt package jwt
import ( import (
"context" "database/sql"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// Revoke a token by adding it to the database // Revoke a token by adding it to the database
func revoke(ctx context.Context, t Token) error { func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
db := t.getDB() if gen.dbConn == nil {
if db == nil {
return errors.New("No DB provided, unable to use this function") return errors.New("No DB provided, unable to use this function")
} }
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "db.BeginTx")
}
defer tx.Rollback()
jti := t.GetJTI() jti := t.GetJTI()
exp := t.GetEXP() exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)` query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err = tx.Exec(query, jti, exp) _, err := tx.Exec(query, jti, exp)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.Exec") return errors.Wrap(err, "tx.Exec")
} }
err = tx.Commit()
if err != nil {
return errors.Wrap(err, "tx.Commit")
}
return nil return nil
} }
// Check if a token has been revoked. Returns true if not revoked. // Check if a token has been revoked. Returns true if not revoked.
func checkNotRevoked(ctx context.Context, t Token) (bool, error) { func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
db := t.getDB() if gen.dbConn == nil {
if db == nil {
return false, errors.New("No DB provided, unable to use this function") return false, errors.New("No DB provided, unable to use this function")
} }
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return false, errors.Wrap(err, "db.BeginTx")
}
defer tx.Rollback()
jti := t.GetJTI() jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := tx.Query(query, jti) rows, err := tx.Query(query, jti)
@@ -50,9 +34,5 @@ func checkNotRevoked(ctx context.Context, t Token) (bool, error) {
} }
defer rows.Close() defer rows.Close()
revoked := rows.Next() revoked := rows.Next()
err = tx.Commit()
if err != nil {
return false, errors.Wrap(err, "tx.Commit")
}
return !revoked, nil return !revoked, nil
} }

View File

@@ -2,6 +2,7 @@ package jwt
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"
@@ -31,14 +32,15 @@ func TestNoDBFail(t *testing.T) {
token := AccessToken{ token := AccessToken{
JTI: jti, JTI: jti,
EXP: exp, EXP: exp,
gen: &TokenGenerator{},
} }
// Revoke should fail due to no DB // Revoke should fail due to no DB
err := token.Revoke(context.Background()) err := token.Revoke(&sql.Tx{})
require.Error(t, err) require.Error(t, err)
// CheckNotRevoked should fail // CheckNotRevoked should fail
_, err = token.CheckNotRevoked(context.Background()) _, err = token.CheckNotRevoked(&sql.Tx{})
require.Error(t, err) require.Error(t, err)
} }
@@ -52,7 +54,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
token := AccessToken{ token := AccessToken{
JTI: jti, JTI: jti,
EXP: exp, EXP: exp,
db: gen.dbConn, gen: gen,
} }
// Revoke expectations // Revoke expectations
@@ -60,21 +62,22 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
mock.ExpectExec(`INSERT INTO jwtblacklist`). mock.ExpectExec(`INSERT INTO jwtblacklist`).
WithArgs(jti, exp). WithArgs(jti, exp).
WillReturnResult(sqlmock.NewResult(1, 1)) WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := token.Revoke(context.Background())
require.NoError(t, err)
// CheckNotRevoked expectations (now revoked)
mock.ExpectBegin()
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`). mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
WithArgs(jti). WithArgs(jti).
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1)) WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectCommit() mock.ExpectCommit()
valid, err := token.CheckNotRevoked(context.Background()) tx, err := gen.dbConn.BeginTx(context.Background(), nil)
defer tx.Rollback()
require.NoError(t, err)
err = token.Revoke(tx)
require.NoError(t, err)
valid, err := token.CheckNotRevoked(tx)
require.NoError(t, err) require.NoError(t, err)
require.False(t, valid) require.False(t, valid)
require.NoError(t, tx.Commit())
require.NoError(t, mock.ExpectationsWereMet()) require.NoError(t, mock.ExpectationsWereMet())
} }

View File

@@ -1,7 +1,6 @@
package jwt package jwt
import ( import (
"context"
"database/sql" "database/sql"
"github.com/google/uuid" "github.com/google/uuid"
@@ -11,8 +10,8 @@ type Token interface {
GetJTI() uuid.UUID GetJTI() uuid.UUID
GetEXP() int64 GetEXP() int64
GetScope() string GetScope() string
getDB() *sql.DB Revoke(*sql.Tx) error
Revoke(context.Context) error CheckNotRevoked(*sql.Tx) (bool, error)
} }
// Access token // Access token
@@ -25,7 +24,7 @@ type AccessToken struct {
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Fresh int64 // Time freshness expiring at Fresh int64 // Time freshness expiring at
Scope string // Should be "access" Scope string // Should be "access"
db *sql.DB gen *TokenGenerator
} }
// Refresh token // Refresh token
@@ -37,7 +36,7 @@ type RefreshToken struct {
SUB int // Subject (user) ID SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Scope string // Should be "refresh" Scope string // Should be "refresh"
db *sql.DB gen *TokenGenerator
} }
func (a AccessToken) GetJTI() uuid.UUID { func (a AccessToken) GetJTI() uuid.UUID {
@@ -58,21 +57,15 @@ func (a AccessToken) GetScope() string {
func (r RefreshToken) GetScope() string { func (r RefreshToken) GetScope() string {
return r.Scope return r.Scope
} }
func (a AccessToken) getDB() *sql.DB { func (a AccessToken) Revoke(tx *sql.Tx) error {
return a.db return a.gen.revoke(tx, a)
} }
func (r RefreshToken) getDB() *sql.DB { func (r RefreshToken) Revoke(tx *sql.Tx) error {
return r.db return r.gen.revoke(tx, r)
} }
func (a AccessToken) Revoke(ctx context.Context) error { func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
return revoke(ctx, a) return a.gen.checkNotRevoked(tx, a)
} }
func (r RefreshToken) Revoke(ctx context.Context) error { func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
return revoke(ctx, r) return r.gen.checkNotRevoked(tx, r)
}
func (a AccessToken) CheckNotRevoked(ctx context.Context) (bool, error) {
return checkNotRevoked(ctx, a)
}
func (r RefreshToken) CheckNotRevoked(ctx context.Context) (bool, error) {
return checkNotRevoked(ctx, r)
} }

View File

@@ -1,7 +1,8 @@
package jwt package jwt
import ( import (
"context" "database/sql"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -9,7 +10,7 @@ import (
// all the claims, including checking if it is expired, has a valid issuer, and // all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope. // has the correct scope.
func (gen *TokenGenerator) ValidateAccess( func (gen *TokenGenerator) ValidateAccess(
ctx context.Context, tx *sql.Tx,
tokenString string, tokenString string,
) (*AccessToken, error) { ) (*AccessToken, error) {
if tokenString == "" { if tokenString == "" {
@@ -64,10 +65,10 @@ func (gen *TokenGenerator) ValidateAccess(
Fresh: fresh, Fresh: fresh,
JTI: jti, JTI: jti,
Scope: scope, Scope: scope,
db: gen.dbConn, gen: gen,
} }
valid, err := token.CheckNotRevoked(ctx) valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.dbConn != nil { if err != nil && gen.dbConn != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked") return nil, errors.Wrap(err, "token.CheckNotRevoked")
} }
@@ -81,7 +82,7 @@ func (gen *TokenGenerator) ValidateAccess(
// all the claims, including checking if it is expired, has a valid issuer, and // all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope. // has the correct scope.
func (gen *TokenGenerator) ValidateRefresh( func (gen *TokenGenerator) ValidateRefresh(
ctx context.Context, tx *sql.Tx,
tokenString string, tokenString string,
) (*RefreshToken, error) { ) (*RefreshToken, error) {
if tokenString == "" { if tokenString == "" {
@@ -131,10 +132,10 @@ func (gen *TokenGenerator) ValidateRefresh(
SUB: subject, SUB: subject,
JTI: jti, JTI: jti,
Scope: scope, Scope: scope,
db: gen.dbConn, gen: gen,
} }
valid, err := token.CheckNotRevoked(ctx) valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.dbConn != nil { if err != nil && gen.dbConn != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked") return nil, errors.Wrap(err, "token.CheckNotRevoked")
} }

View File

@@ -2,6 +2,7 @@ package jwt
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
@@ -43,10 +44,15 @@ func TestValidateAccess_Success(t *testing.T) {
// We don't know the JTI beforehand; match any arg // We don't know the JTI beforehand; match any arg
expectNotRevoked(mock, sqlmock.AnyArg()) expectNotRevoked(mock, sqlmock.AnyArg())
token, err := gen.ValidateAccess(context.Background(), tokenStr) tx, err := gen.dbConn.BeginTx(context.Background(), nil)
require.NoError(t, err)
defer tx.Rollback()
token, err := gen.ValidateAccess(tx, tokenStr)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 42, token.SUB) require.Equal(t, 42, token.SUB)
require.Equal(t, "access", token.Scope) require.Equal(t, "access", token.Scope)
tx.Commit()
} }
func TestValidateAccess_NoDB(t *testing.T) { func TestValidateAccess_NoDB(t *testing.T) {
@@ -55,7 +61,7 @@ func TestValidateAccess_NoDB(t *testing.T) {
tokenStr, _, err := gen.NewAccess(42, true, false) tokenStr, _, err := gen.NewAccess(42, true, false)
require.NoError(t, err) require.NoError(t, err)
token, err := gen.ValidateAccess(context.Background(), tokenStr) token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 42, token.SUB) require.Equal(t, 42, token.SUB)
require.Equal(t, "access", token.Scope) require.Equal(t, "access", token.Scope)
@@ -70,10 +76,15 @@ func TestValidateRefresh_Success(t *testing.T) {
expectNotRevoked(mock, sqlmock.AnyArg()) expectNotRevoked(mock, sqlmock.AnyArg())
token, err := gen.ValidateRefresh(context.Background(), tokenStr) tx, err := gen.dbConn.BeginTx(context.Background(), nil)
require.NoError(t, err)
defer tx.Rollback()
token, err := gen.ValidateRefresh(tx, tokenStr)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 42, token.SUB) require.Equal(t, 42, token.SUB)
require.Equal(t, "refresh", token.Scope) require.Equal(t, "refresh", token.Scope)
tx.Commit()
} }
func TestValidateRefresh_NoDB(t *testing.T) { func TestValidateRefresh_NoDB(t *testing.T) {
@@ -82,7 +93,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
tokenStr, _, err := gen.NewRefresh(42, false) tokenStr, _, err := gen.NewRefresh(42, false)
require.NoError(t, err) require.NoError(t, err)
token, err := gen.ValidateRefresh(context.Background(), tokenStr) token, err := gen.ValidateRefresh(nil, tokenStr)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 42, token.SUB) require.Equal(t, 42, token.SUB)
require.Equal(t, "refresh", token.Scope) require.Equal(t, "refresh", token.Scope)
@@ -91,7 +102,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
func TestValidateAccess_EmptyToken(t *testing.T) { func TestValidateAccess_EmptyToken(t *testing.T) {
gen := newTestGenerator(t) gen := newTestGenerator(t)
_, err := gen.ValidateAccess(context.Background(), "") _, err := gen.ValidateAccess(nil, "")
require.Error(t, err) require.Error(t, err)
} }
@@ -102,6 +113,6 @@ func TestValidateRefresh_WrongScope(t *testing.T) {
tokenStr, _, err := gen.NewAccess(1, false, false) tokenStr, _, err := gen.NewAccess(1, false, false)
require.NoError(t, err) require.NoError(t, err)
_, err = gen.ValidateRefresh(context.Background(), tokenStr) _, err = gen.ValidateRefresh(nil, tokenStr)
require.Error(t, err) require.Error(t, err)
} }