From 732f8510aeeb64e46361364a1160ea9dfeeee769 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 11 Feb 2025 23:46:25 +1100 Subject: [PATCH] Added authentication middleware --- contexts/keys.go | 11 ++++ contexts/user.go | 20 +++++++ cookies/tokens.go | 3 +- db/users.go | 24 +++++++++ handlers/login.go | 2 +- jwt/parse.go | 24 ++++++++- jwt/revoke.go | 6 +-- jwt/tokens.go | 24 ++++++++- main.go | 7 ++- middleware/authentication.go | 102 +++++++++++++++++++++++++++++++---- middleware/logging.go | 4 +- server/server.go | 2 +- 12 files changed, 208 insertions(+), 21 deletions(-) create mode 100644 contexts/keys.go create mode 100644 contexts/user.go diff --git a/contexts/keys.go b/contexts/keys.go new file mode 100644 index 0000000..e5a08df --- /dev/null +++ b/contexts/keys.go @@ -0,0 +1,11 @@ +package contexts + +type contextKey string + +func (c contextKey) String() string { + return "projectreshoot context key " + string(c) +} + +var ( + contextKeyAuthorizedUser = contextKey("auth-user") +) diff --git a/contexts/user.go b/contexts/user.go new file mode 100644 index 0000000..3a752a5 --- /dev/null +++ b/contexts/user.go @@ -0,0 +1,20 @@ +package contexts + +import ( + "context" + "projectreshoot/db" +) + +// Return a new context with the user added in +func SetUser(ctx context.Context, u *db.User) context.Context { + return context.WithValue(ctx, contextKeyAuthorizedUser, u) +} + +// Retrieve a user from the given context. Returns nil if not set +func GetUser(ctx context.Context) *db.User { + user, ok := ctx.Value(contextKeyAuthorizedUser).(*db.User) + if !ok { + return nil + } + return user +} diff --git a/cookies/tokens.go b/cookies/tokens.go index 3d2e7f5..a5c20a0 100644 --- a/cookies/tokens.go +++ b/cookies/tokens.go @@ -59,9 +59,10 @@ func SetTokenCookies( r *http.Request, config *config.Config, user *db.User, + fresh bool, rememberMe bool, ) error { - at, atexp, err := jwt.GenerateAccessToken(config, user, true, rememberMe) + at, atexp, err := jwt.GenerateAccessToken(config, user, fresh, rememberMe) if err != nil { return errors.Wrap(err, "jwt.GenerateAccessToken") } diff --git a/db/users.go b/db/users.go index 3cd038e..4436e96 100644 --- a/db/users.go +++ b/db/users.go @@ -69,3 +69,27 @@ func GetUserFromUsername(conn *sql.DB, username string) (User, error) { } return user, nil } + +// Queries the database for a user matching the given ID. +func GetUserFromID(conn *sql.DB, id int) (User, error) { + query := `SELECT id, username, password_hash, created_at FROM users + WHERE id = ?` + rows, err := conn.Query(query, id) + if err != nil { + return User{}, errors.Wrap(err, "conn.Query") + } + defer rows.Close() + var user User + for rows.Next() { + err := rows.Scan( + &user.ID, + &user.Username, + &user.Password_hash, + &user.Created_at, + ) + if err != nil { + return User{}, errors.Wrap(err, "rows.Scan") + } + } + return user, nil +} diff --git a/handlers/login.go b/handlers/login.go index a7afd00..c4da38b 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -63,7 +63,7 @@ func HandleLoginRequest( } rememberMe := checkRememberMe(r) - err = cookies.SetTokenCookies(w, r, config, &user, rememberMe) + err = cookies.SetTokenCookies(w, r, config, &user, true, rememberMe) if err != nil { form.LoginForm(err.Error()).Render(r.Context(), w) logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") diff --git a/jwt/parse.go b/jwt/parse.go index 40b07d8..4dd9b45 100644 --- a/jwt/parse.go +++ b/jwt/parse.go @@ -1,6 +1,7 @@ package jwt import ( + "database/sql" "fmt" "time" @@ -14,7 +15,11 @@ import ( // Parse an access token and return a struct with all the claims. Does validation on // all the claims, including checking if it is expired, has a valid issuer, and // has the correct scope. -func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, error) { +func ParseAccessToken( + config *config.Config, + conn *sql.DB, + tokenString string, +) (AccessToken, error) { claims, err := parseToken(config.SecretKey, tokenString) if err != nil { return AccessToken{}, errors.Wrap(err, "parseToken") @@ -66,13 +71,21 @@ func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, e Scope: scope, } + valid, err := CheckTokenNotRevoked(conn, token) + if err != nil || !valid { + return AccessToken{}, errors.Wrap(err, "CheckTokenNotRevoked") + } return token, nil } // Parse a refresh token and return a struct with all the claims. Does validation on // all the claims, including checking if it is expired, has a valid issuer, and // has the correct scope. -func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken, error) { +func ParseRefreshToken( + config *config.Config, + conn *sql.DB, + tokenString string, +) (RefreshToken, error) { claims, err := parseToken(config.SecretKey, tokenString) if err != nil { return RefreshToken{}, errors.Wrap(err, "parseToken") @@ -119,6 +132,13 @@ func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken, Scope: scope, } + valid, err := CheckTokenNotRevoked(conn, token) + if err != nil { + return RefreshToken{}, errors.Wrap(err, "CheckTokenNotRevoked") + } + if !valid { + return RefreshToken{}, errors.New("Token has been revoked") + } return token, nil } diff --git a/jwt/revoke.go b/jwt/revoke.go index 9b22e08..66465a7 100644 --- a/jwt/revoke.go +++ b/jwt/revoke.go @@ -18,8 +18,8 @@ func RevokeToken(conn *sql.DB, t Token) error { return nil } -// Check if a token has been revoked -func CheckRevoked(conn *sql.DB, t Token) (bool, error) { +// Check if a token has been revoked. Returns true if not revoked. +func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) { jti := t.GetJTI() query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` rows, err := conn.Query(query, jti) @@ -27,5 +27,5 @@ func CheckRevoked(conn *sql.DB, t Token) (bool, error) { return false, errors.Wrap(err, "conn.Exec") } revoked := rows.Next() - return revoked, nil + return !revoked, nil } diff --git a/jwt/tokens.go b/jwt/tokens.go index b9fa0d3..5697f9c 100644 --- a/jwt/tokens.go +++ b/jwt/tokens.go @@ -1,11 +1,18 @@ package jwt -import "github.com/google/uuid" +import ( + "database/sql" + "projectreshoot/db" + + "github.com/google/uuid" + "github.com/pkg/errors" +) type Token interface { GetJTI() uuid.UUID GetEXP() int64 GetScope() string + GetUser(conn *sql.DB) (*db.User, error) } // Access token @@ -31,6 +38,21 @@ type RefreshToken struct { Scope string // Should be "refresh" } +func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) { + user, err := db.GetUserFromID(conn, a.SUB) + if err != nil { + return nil, errors.Wrap(err, "db.GetUserFromID") + } + return &user, nil +} +func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) { + user, err := db.GetUserFromID(conn, r.SUB) + if err != nil { + return nil, errors.Wrap(err, "db.GetUserFromID") + } + return &user, nil +} + func (a AccessToken) GetJTI() uuid.UUID { return a.JTI } diff --git a/main.go b/main.go index a174c4a..2d00ad7 100644 --- a/main.go +++ b/main.go @@ -64,8 +64,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { srv := server.NewServer(config, logger, conn) httpServer := &http.Server{ - Addr: net.JoinHostPort(config.Host, config.Port), - Handler: srv, + Addr: net.JoinHostPort(config.Host, config.Port), + Handler: srv, + ReadHeaderTimeout: 2 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, } // Runs function for testing in dev if --test flag true diff --git a/middleware/authentication.go b/middleware/authentication.go index de72c46..c29852c 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -1,21 +1,105 @@ package middleware import ( + "database/sql" "net/http" + "projectreshoot/config" + "projectreshoot/contexts" + "projectreshoot/cookies" + "projectreshoot/db" + "projectreshoot/jwt" + + "github.com/pkg/errors" "github.com/rs/zerolog" ) -// Take current request -// Get cookies from browser -// Parse the tokens -// Check if tokens blacklisted -// Trigger refresh if required -// Create context with state of user authorization -// Pass request on with context +// Attempt to use a valid refresh token to generate a new token pair +func refreshAuthTokens( + config *config.Config, + conn *sql.DB, + w http.ResponseWriter, + req *http.Request, + ref *jwt.RefreshToken, +) (*db.User, error) { + user, err := ref.GetUser(conn) + if err != nil { + return nil, errors.Wrap(err, "rT.GetUser") + } -func Authentication(logger *zerolog.Logger, next http.Handler) http.Handler { + rememberMe := map[string]bool{ + "session": false, + "exp": true, + }[ref.TTL] + + // Set fresh to true because new tokens coming from refresh request + err = cookies.SetTokenCookies(w, req, config, user, false, rememberMe) + if err != nil { + return nil, errors.Wrap(err, "cookies.SetTokenCookies") + } + // New tokens sent, revoke the used refresh token + err = jwt.RevokeToken(conn, ref) + if err != nil { + return nil, errors.Wrap(err, "jwt.RevokeToken") + } + // Return the authorized user + return user, nil +} + +// Check the cookies for token strings and attempt to authenticate them +func getAuthenticatedUser( + config *config.Config, + conn *sql.DB, + w http.ResponseWriter, + r *http.Request, +) (*db.User, error) { + // Get token strings from cookies + atStr, rtStr := cookies.GetTokenStrings(r) + // Attempt to parse the access token + aT, err := jwt.ParseAccessToken(config, conn, atStr) + if err != nil { + // Access token invalid, attempt to parse refresh token + rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + if err != nil { + return nil, errors.Wrap(err, "jwt.ParseRefreshToken") + } + // Refresh token valid, attempt to get a new token pair + user, err := refreshAuthTokens(config, conn, w, r, &rT) + if err != nil { + return nil, errors.Wrap(err, "refreshAuthTokens") + } + // New token pair sent, return the authorized user + return user, nil + } + // Access token valid + user, err := aT.GetUser(conn) + if err != nil { + return nil, errors.Wrap(err, "rT.GetUser") + } + return user, nil +} + +// Attempt to authenticate the user and add their account details +// to the request context +func Authentication( + logger *zerolog.Logger, + config *config.Config, + conn *sql.DB, + next http.Handler, +) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(w, r) + user, err := getAuthenticatedUser(config, conn, w, r) + if err != nil { + // User auth failed, delete the cookies to avoid repeat requests + cookies.DeleteCookie(w, "access", "/") + cookies.DeleteCookie(w, "refresh", "/") + logger.Debug(). + Str("remote_addr", r.RemoteAddr). + Err(err). + Msg("Failed to authenticate user") + } + ctx := contexts.SetUser(r.Context(), user) + newReq := r.WithContext(ctx) + next.ServeHTTP(w, newReq) }) } diff --git a/middleware/logging.go b/middleware/logging.go index ef37c5c..878acf8 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -32,6 +32,8 @@ func Logging(logger *zerolog.Logger, next http.Handler) http.Handler { Int("status", wrapped.statusCode). Str("method", r.Method). Str("resource", r.URL.Path). - Dur("time_elapsed", time.Since(start)).Msg("Served") + Dur("time_elapsed", time.Since(start)). + Str("remote_addr", r.RemoteAddr). + Msg("Served") }) } diff --git a/server/server.go b/server/server.go index 49896d8..9979389 100644 --- a/server/server.go +++ b/server/server.go @@ -27,7 +27,7 @@ func NewServer( // Add middleware here, must be added in reverse order of execution // i.e. First in list will get executed last during the request handling handler = middleware.Logging(logger, handler) - handler = middleware.Authentication(logger, handler) + handler = middleware.Authentication(logger, config, conn, handler) // Serve the favicon and exluded files before any middleware is added handler = middleware.ExcludedFiles(handler)