Added authentication middleware
This commit is contained in:
11
contexts/keys.go
Normal file
11
contexts/keys.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package contexts
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
func (c contextKey) String() string {
|
||||||
|
return "projectreshoot context key " + string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
contextKeyAuthorizedUser = contextKey("auth-user")
|
||||||
|
)
|
||||||
20
contexts/user.go
Normal file
20
contexts/user.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package contexts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"projectreshoot/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Return a new context with the user added in
|
||||||
|
func SetUser(ctx context.Context, u *db.User) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyAuthorizedUser, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve a user from the given context. Returns nil if not set
|
||||||
|
func GetUser(ctx context.Context) *db.User {
|
||||||
|
user, ok := ctx.Value(contextKeyAuthorizedUser).(*db.User)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return user
|
||||||
|
}
|
||||||
@@ -59,9 +59,10 @@ func SetTokenCookies(
|
|||||||
r *http.Request,
|
r *http.Request,
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
user *db.User,
|
user *db.User,
|
||||||
|
fresh bool,
|
||||||
rememberMe bool,
|
rememberMe bool,
|
||||||
) error {
|
) error {
|
||||||
at, atexp, err := jwt.GenerateAccessToken(config, user, true, rememberMe)
|
at, atexp, err := jwt.GenerateAccessToken(config, user, fresh, rememberMe)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "jwt.GenerateAccessToken")
|
return errors.Wrap(err, "jwt.GenerateAccessToken")
|
||||||
}
|
}
|
||||||
|
|||||||
24
db/users.go
24
db/users.go
@@ -69,3 +69,27 @@ func GetUserFromUsername(conn *sql.DB, username string) (User, error) {
|
|||||||
}
|
}
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Queries the database for a user matching the given ID.
|
||||||
|
func GetUserFromID(conn *sql.DB, id int) (User, error) {
|
||||||
|
query := `SELECT id, username, password_hash, created_at FROM users
|
||||||
|
WHERE id = ?`
|
||||||
|
rows, err := conn.Query(query, id)
|
||||||
|
if err != nil {
|
||||||
|
return User{}, errors.Wrap(err, "conn.Query")
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var user User
|
||||||
|
for rows.Next() {
|
||||||
|
err := rows.Scan(
|
||||||
|
&user.ID,
|
||||||
|
&user.Username,
|
||||||
|
&user.Password_hash,
|
||||||
|
&user.Created_at,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return User{}, errors.Wrap(err, "rows.Scan")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ func HandleLoginRequest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
rememberMe := checkRememberMe(r)
|
rememberMe := checkRememberMe(r)
|
||||||
err = cookies.SetTokenCookies(w, r, config, &user, rememberMe)
|
err = cookies.SetTokenCookies(w, r, config, &user, true, rememberMe)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
form.LoginForm(err.Error()).Render(r.Context(), w)
|
form.LoginForm(err.Error()).Render(r.Context(), w)
|
||||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
||||||
|
|||||||
24
jwt/parse.go
24
jwt/parse.go
@@ -1,6 +1,7 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -14,7 +15,11 @@ import (
|
|||||||
// Parse an access token and return a struct with all the claims. Does validation on
|
// Parse an access token and return a struct with all the claims. Does validation on
|
||||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
// has the correct scope.
|
// has the correct scope.
|
||||||
func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, error) {
|
func ParseAccessToken(
|
||||||
|
config *config.Config,
|
||||||
|
conn *sql.DB,
|
||||||
|
tokenString string,
|
||||||
|
) (AccessToken, error) {
|
||||||
claims, err := parseToken(config.SecretKey, tokenString)
|
claims, err := parseToken(config.SecretKey, tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AccessToken{}, errors.Wrap(err, "parseToken")
|
return AccessToken{}, errors.Wrap(err, "parseToken")
|
||||||
@@ -66,13 +71,21 @@ func ParseAccessToken(config *config.Config, tokenString string) (AccessToken, e
|
|||||||
Scope: scope,
|
Scope: scope,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
valid, err := CheckTokenNotRevoked(conn, token)
|
||||||
|
if err != nil || !valid {
|
||||||
|
return AccessToken{}, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||||
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse a refresh token and return a struct with all the claims. Does validation on
|
// Parse a refresh token and return a struct with all the claims. Does validation on
|
||||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
// has the correct scope.
|
// has the correct scope.
|
||||||
func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken, error) {
|
func ParseRefreshToken(
|
||||||
|
config *config.Config,
|
||||||
|
conn *sql.DB,
|
||||||
|
tokenString string,
|
||||||
|
) (RefreshToken, error) {
|
||||||
claims, err := parseToken(config.SecretKey, tokenString)
|
claims, err := parseToken(config.SecretKey, tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return RefreshToken{}, errors.Wrap(err, "parseToken")
|
return RefreshToken{}, errors.Wrap(err, "parseToken")
|
||||||
@@ -119,6 +132,13 @@ func ParseRefreshToken(config *config.Config, tokenString string) (RefreshToken,
|
|||||||
Scope: scope,
|
Scope: scope,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
valid, err := CheckTokenNotRevoked(conn, token)
|
||||||
|
if err != nil {
|
||||||
|
return RefreshToken{}, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return RefreshToken{}, errors.New("Token has been revoked")
|
||||||
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ func RevokeToken(conn *sql.DB, t Token) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if a token has been revoked
|
// Check if a token has been revoked. Returns true if not revoked.
|
||||||
func CheckRevoked(conn *sql.DB, t Token) (bool, error) {
|
func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
|
||||||
jti := t.GetJTI()
|
jti := t.GetJTI()
|
||||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||||
rows, err := conn.Query(query, jti)
|
rows, err := conn.Query(query, jti)
|
||||||
@@ -27,5 +27,5 @@ func CheckRevoked(conn *sql.DB, t Token) (bool, error) {
|
|||||||
return false, errors.Wrap(err, "conn.Exec")
|
return false, errors.Wrap(err, "conn.Exec")
|
||||||
}
|
}
|
||||||
revoked := rows.Next()
|
revoked := rows.Next()
|
||||||
return revoked, nil
|
return !revoked, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,18 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import "github.com/google/uuid"
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"projectreshoot/db"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
type Token interface {
|
type Token interface {
|
||||||
GetJTI() uuid.UUID
|
GetJTI() uuid.UUID
|
||||||
GetEXP() int64
|
GetEXP() int64
|
||||||
GetScope() string
|
GetScope() string
|
||||||
|
GetUser(conn *sql.DB) (*db.User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Access token
|
// Access token
|
||||||
@@ -31,6 +38,21 @@ type RefreshToken struct {
|
|||||||
Scope string // Should be "refresh"
|
Scope string // Should be "refresh"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) {
|
||||||
|
user, err := db.GetUserFromID(conn, a.SUB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) {
|
||||||
|
user, err := db.GetUserFromID(conn, r.SUB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a AccessToken) GetJTI() uuid.UUID {
|
func (a AccessToken) GetJTI() uuid.UUID {
|
||||||
return a.JTI
|
return a.JTI
|
||||||
}
|
}
|
||||||
|
|||||||
7
main.go
7
main.go
@@ -64,8 +64,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|||||||
|
|
||||||
srv := server.NewServer(config, logger, conn)
|
srv := server.NewServer(config, logger, conn)
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
Addr: net.JoinHostPort(config.Host, config.Port),
|
||||||
Handler: srv,
|
Handler: srv,
|
||||||
|
ReadHeaderTimeout: 2 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runs function for testing in dev if --test flag true
|
// Runs function for testing in dev if --test flag true
|
||||||
|
|||||||
@@ -1,21 +1,105 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"projectreshoot/config"
|
||||||
|
"projectreshoot/contexts"
|
||||||
|
"projectreshoot/cookies"
|
||||||
|
"projectreshoot/db"
|
||||||
|
"projectreshoot/jwt"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Take current request
|
// Attempt to use a valid refresh token to generate a new token pair
|
||||||
// Get cookies from browser
|
func refreshAuthTokens(
|
||||||
// Parse the tokens
|
config *config.Config,
|
||||||
// Check if tokens blacklisted
|
conn *sql.DB,
|
||||||
// Trigger refresh if required
|
w http.ResponseWriter,
|
||||||
// Create context with state of user authorization
|
req *http.Request,
|
||||||
// Pass request on with context
|
ref *jwt.RefreshToken,
|
||||||
|
) (*db.User, error) {
|
||||||
|
user, err := ref.GetUser(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "rT.GetUser")
|
||||||
|
}
|
||||||
|
|
||||||
func Authentication(logger *zerolog.Logger, next http.Handler) http.Handler {
|
rememberMe := map[string]bool{
|
||||||
|
"session": false,
|
||||||
|
"exp": true,
|
||||||
|
}[ref.TTL]
|
||||||
|
|
||||||
|
// Set fresh to true because new tokens coming from refresh request
|
||||||
|
err = cookies.SetTokenCookies(w, req, config, user, false, rememberMe)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
|
||||||
|
}
|
||||||
|
// New tokens sent, revoke the used refresh token
|
||||||
|
err = jwt.RevokeToken(conn, ref)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "jwt.RevokeToken")
|
||||||
|
}
|
||||||
|
// Return the authorized user
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the cookies for token strings and attempt to authenticate them
|
||||||
|
func getAuthenticatedUser(
|
||||||
|
config *config.Config,
|
||||||
|
conn *sql.DB,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) (*db.User, error) {
|
||||||
|
// Get token strings from cookies
|
||||||
|
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||||
|
// Attempt to parse the access token
|
||||||
|
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
||||||
|
if err != nil {
|
||||||
|
// Access token invalid, attempt to parse refresh token
|
||||||
|
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||||
|
}
|
||||||
|
// Refresh token valid, attempt to get a new token pair
|
||||||
|
user, err := refreshAuthTokens(config, conn, w, r, &rT)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "refreshAuthTokens")
|
||||||
|
}
|
||||||
|
// New token pair sent, return the authorized user
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
// Access token valid
|
||||||
|
user, err := aT.GetUser(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "rT.GetUser")
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to authenticate the user and add their account details
|
||||||
|
// to the request context
|
||||||
|
func Authentication(
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
config *config.Config,
|
||||||
|
conn *sql.DB,
|
||||||
|
next http.Handler,
|
||||||
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
next.ServeHTTP(w, r)
|
user, err := getAuthenticatedUser(config, conn, w, r)
|
||||||
|
if err != nil {
|
||||||
|
// User auth failed, delete the cookies to avoid repeat requests
|
||||||
|
cookies.DeleteCookie(w, "access", "/")
|
||||||
|
cookies.DeleteCookie(w, "refresh", "/")
|
||||||
|
logger.Debug().
|
||||||
|
Str("remote_addr", r.RemoteAddr).
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to authenticate user")
|
||||||
|
}
|
||||||
|
ctx := contexts.SetUser(r.Context(), user)
|
||||||
|
newReq := r.WithContext(ctx)
|
||||||
|
next.ServeHTTP(w, newReq)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
|
|||||||
Int("status", wrapped.statusCode).
|
Int("status", wrapped.statusCode).
|
||||||
Str("method", r.Method).
|
Str("method", r.Method).
|
||||||
Str("resource", r.URL.Path).
|
Str("resource", r.URL.Path).
|
||||||
Dur("time_elapsed", time.Since(start)).Msg("Served")
|
Dur("time_elapsed", time.Since(start)).
|
||||||
|
Str("remote_addr", r.RemoteAddr).
|
||||||
|
Msg("Served")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func NewServer(
|
|||||||
// Add middleware here, must be added in reverse order of execution
|
// Add middleware here, must be added in reverse order of execution
|
||||||
// i.e. First in list will get executed last during the request handling
|
// i.e. First in list will get executed last during the request handling
|
||||||
handler = middleware.Logging(logger, handler)
|
handler = middleware.Logging(logger, handler)
|
||||||
handler = middleware.Authentication(logger, handler)
|
handler = middleware.Authentication(logger, config, conn, handler)
|
||||||
|
|
||||||
// Serve the favicon and exluded files before any middleware is added
|
// Serve the favicon and exluded files before any middleware is added
|
||||||
handler = middleware.ExcludedFiles(handler)
|
handler = middleware.ExcludedFiles(handler)
|
||||||
|
|||||||
Reference in New Issue
Block a user