Added authentication middleware

This commit is contained in:
2025-02-11 23:46:25 +11:00
parent 97aabcf06f
commit 732f8510ae
12 changed files with 208 additions and 21 deletions

11
contexts/keys.go Normal file
View File

@@ -0,0 +1,11 @@
package contexts
type contextKey string
func (c contextKey) String() string {
return "projectreshoot context key " + string(c)
}
var (
contextKeyAuthorizedUser = contextKey("auth-user")
)

20
contexts/user.go Normal file
View File

@@ -0,0 +1,20 @@
package contexts
import (
"context"
"projectreshoot/db"
)
// Return a new context with the user added in
func SetUser(ctx context.Context, u *db.User) context.Context {
return context.WithValue(ctx, contextKeyAuthorizedUser, u)
}
// Retrieve a user from the given context. Returns nil if not set
func GetUser(ctx context.Context) *db.User {
user, ok := ctx.Value(contextKeyAuthorizedUser).(*db.User)
if !ok {
return nil
}
return user
}

View File

@@ -59,9 +59,10 @@ func SetTokenCookies(
r *http.Request, r *http.Request,
config *config.Config, config *config.Config,
user *db.User, user *db.User,
fresh bool,
rememberMe bool, rememberMe bool,
) error { ) error {
at, atexp, err := jwt.GenerateAccessToken(config, user, true, rememberMe) at, atexp, err := jwt.GenerateAccessToken(config, user, fresh, rememberMe)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.GenerateAccessToken") return errors.Wrap(err, "jwt.GenerateAccessToken")
} }

View File

@@ -69,3 +69,27 @@ func GetUserFromUsername(conn *sql.DB, username string) (User, error) {
} }
return user, nil return user, nil
} }
// Queries the database for a user matching the given ID.
func GetUserFromID(conn *sql.DB, id int) (User, error) {
query := `SELECT id, username, password_hash, created_at FROM users
WHERE id = ?`
rows, err := conn.Query(query, id)
if err != nil {
return User{}, errors.Wrap(err, "conn.Query")
}
defer rows.Close()
var user User
for rows.Next() {
err := rows.Scan(
&user.ID,
&user.Username,
&user.Password_hash,
&user.Created_at,
)
if err != nil {
return User{}, errors.Wrap(err, "rows.Scan")
}
}
return user, nil
}

View File

@@ -63,7 +63,7 @@ func HandleLoginRequest(
} }
rememberMe := checkRememberMe(r) rememberMe := checkRememberMe(r)
err = cookies.SetTokenCookies(w, r, config, &user, rememberMe) err = cookies.SetTokenCookies(w, r, config, &user, true, rememberMe)
if err != nil { if err != nil {
form.LoginForm(err.Error()).Render(r.Context(), w) form.LoginForm(err.Error()).Render(r.Context(), w)
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")

View File

@@ -1,6 +1,7 @@
package jwt package jwt
import ( import (
"database/sql"
"fmt" "fmt"
"time" "time"
@@ -14,7 +15,11 @@ import (
// Parse an access token and return a struct with all the claims. Does validation on // Parse an access token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and // all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope. // has the correct scope.
func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, error) { func ParseAccessToken(
config *config.Config,
conn *sql.DB,
tokenString string,
) (AccessToken, error) {
claims, err := parseToken(config.SecretKey, tokenString) claims, err := parseToken(config.SecretKey, tokenString)
if err != nil { if err != nil {
return AccessToken{}, errors.Wrap(err, "parseToken") return AccessToken{}, errors.Wrap(err, "parseToken")
@@ -66,13 +71,21 @@ func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, e
Scope: scope, Scope: scope,
} }
valid, err := CheckTokenNotRevoked(conn, token)
if err != nil || !valid {
return AccessToken{}, errors.Wrap(err, "CheckTokenNotRevoked")
}
return token, nil return token, nil
} }
// Parse a refresh token and return a struct with all the claims. Does validation on // Parse a refresh token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and // all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope. // has the correct scope.
func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken, error) { func ParseRefreshToken(
config *config.Config,
conn *sql.DB,
tokenString string,
) (RefreshToken, error) {
claims, err := parseToken(config.SecretKey, tokenString) claims, err := parseToken(config.SecretKey, tokenString)
if err != nil { if err != nil {
return RefreshToken{}, errors.Wrap(err, "parseToken") return RefreshToken{}, errors.Wrap(err, "parseToken")
@@ -119,6 +132,13 @@ func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken,
Scope: scope, Scope: scope,
} }
valid, err := CheckTokenNotRevoked(conn, token)
if err != nil {
return RefreshToken{}, errors.Wrap(err, "CheckTokenNotRevoked")
}
if !valid {
return RefreshToken{}, errors.New("Token has been revoked")
}
return token, nil return token, nil
} }

View File

@@ -18,8 +18,8 @@ func RevokeToken(conn *sql.DB, t Token) error {
return nil return nil
} }
// Check if a token has been revoked // Check if a token has been revoked. Returns true if not revoked.
func CheckRevoked(conn *sql.DB, t Token) (bool, error) { func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
jti := t.GetJTI() jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := conn.Query(query, jti) rows, err := conn.Query(query, jti)
@@ -27,5 +27,5 @@ func CheckRevoked(conn *sql.DB, t Token) (bool, error) {
return false, errors.Wrap(err, "conn.Exec") return false, errors.Wrap(err, "conn.Exec")
} }
revoked := rows.Next() revoked := rows.Next()
return revoked, nil return !revoked, nil
} }

View File

@@ -1,11 +1,18 @@
package jwt package jwt
import "github.com/google/uuid" import (
"database/sql"
"projectreshoot/db"
"github.com/google/uuid"
"github.com/pkg/errors"
)
type Token interface { type Token interface {
GetJTI() uuid.UUID GetJTI() uuid.UUID
GetEXP() int64 GetEXP() int64
GetScope() string GetScope() string
GetUser(conn *sql.DB) (*db.User, error)
} }
// Access token // Access token
@@ -31,6 +38,21 @@ type RefreshToken struct {
Scope string // Should be "refresh" Scope string // Should be "refresh"
} }
func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) {
user, err := db.GetUserFromID(conn, a.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return &user, nil
}
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) {
user, err := db.GetUserFromID(conn, r.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return &user, nil
}
func (a AccessToken) GetJTI() uuid.UUID { func (a AccessToken) GetJTI() uuid.UUID {
return a.JTI return a.JTI
} }

View File

@@ -64,8 +64,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
srv := server.NewServer(config, logger, conn) srv := server.NewServer(config, logger, conn)
httpServer := &http.Server{ httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, config.Port), Addr: net.JoinHostPort(config.Host, config.Port),
Handler: srv, Handler: srv,
ReadHeaderTimeout: 2 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
} }
// Runs function for testing in dev if --test flag true // Runs function for testing in dev if --test flag true

View File

@@ -1,21 +1,105 @@
package middleware package middleware
import ( import (
"database/sql"
"net/http" "net/http"
"projectreshoot/config"
"projectreshoot/contexts"
"projectreshoot/cookies"
"projectreshoot/db"
"projectreshoot/jwt"
"github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
// Take current request // Attempt to use a valid refresh token to generate a new token pair
// Get cookies from browser func refreshAuthTokens(
// Parse the tokens config *config.Config,
// Check if tokens blacklisted conn *sql.DB,
// Trigger refresh if required w http.ResponseWriter,
// Create context with state of user authorization req *http.Request,
// Pass request on with context ref *jwt.RefreshToken,
) (*db.User, error) {
user, err := ref.GetUser(conn)
if err != nil {
return nil, errors.Wrap(err, "rT.GetUser")
}
func Authentication(logger *zerolog.Logger, next http.Handler) http.Handler { rememberMe := map[string]bool{
"session": false,
"exp": true,
}[ref.TTL]
// Set fresh to true because new tokens coming from refresh request
err = cookies.SetTokenCookies(w, req, config, user, false, rememberMe)
if err != nil {
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
}
// New tokens sent, revoke the used refresh token
err = jwt.RevokeToken(conn, ref)
if err != nil {
return nil, errors.Wrap(err, "jwt.RevokeToken")
}
// Return the authorized user
return user, nil
}
// Check the cookies for token strings and attempt to authenticate them
func getAuthenticatedUser(
config *config.Config,
conn *sql.DB,
w http.ResponseWriter,
r *http.Request,
) (*db.User, error) {
// Get token strings from cookies
atStr, rtStr := cookies.GetTokenStrings(r)
// Attempt to parse the access token
aT, err := jwt.ParseAccessToken(config, conn, atStr)
if err != nil {
// Access token invalid, attempt to parse refresh token
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
if err != nil {
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
}
// Refresh token valid, attempt to get a new token pair
user, err := refreshAuthTokens(config, conn, w, r, &rT)
if err != nil {
return nil, errors.Wrap(err, "refreshAuthTokens")
}
// New token pair sent, return the authorized user
return user, nil
}
// Access token valid
user, err := aT.GetUser(conn)
if err != nil {
return nil, errors.Wrap(err, "rT.GetUser")
}
return user, nil
}
// Attempt to authenticate the user and add their account details
// to the request context
func Authentication(
logger *zerolog.Logger,
config *config.Config,
conn *sql.DB,
next http.Handler,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r) user, err := getAuthenticatedUser(config, conn, w, r)
if err != nil {
// User auth failed, delete the cookies to avoid repeat requests
cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/")
logger.Debug().
Str("remote_addr", r.RemoteAddr).
Err(err).
Msg("Failed to authenticate user")
}
ctx := contexts.SetUser(r.Context(), user)
newReq := r.WithContext(ctx)
next.ServeHTTP(w, newReq)
}) })
} }

View File

@@ -32,6 +32,8 @@ func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
Int("status", wrapped.statusCode). Int("status", wrapped.statusCode).
Str("method", r.Method). Str("method", r.Method).
Str("resource", r.URL.Path). Str("resource", r.URL.Path).
Dur("time_elapsed", time.Since(start)).Msg("Served") Dur("time_elapsed", time.Since(start)).
Str("remote_addr", r.RemoteAddr).
Msg("Served")
}) })
} }

View File

@@ -27,7 +27,7 @@ func NewServer(
// Add middleware here, must be added in reverse order of execution // Add middleware here, must be added in reverse order of execution
// i.e. First in list will get executed last during the request handling // i.e. First in list will get executed last during the request handling
handler = middleware.Logging(logger, handler) handler = middleware.Logging(logger, handler)
handler = middleware.Authentication(logger, handler) handler = middleware.Authentication(logger, config, conn, handler)
// Serve the favicon and exluded files before any middleware is added // Serve the favicon and exluded files before any middleware is added
handler = middleware.ExcludedFiles(handler) handler = middleware.ExcludedFiles(handler)