diff --git a/contexts/user.go b/contexts/user.go index 3a752a5..0ef041f 100644 --- a/contexts/user.go +++ b/contexts/user.go @@ -5,14 +5,19 @@ import ( "projectreshoot/db" ) +type AuthenticatedUser struct { + *db.User + Fresh int64 +} + // Return a new context with the user added in -func SetUser(ctx context.Context, u *db.User) context.Context { +func SetUser(ctx context.Context, u *AuthenticatedUser) context.Context { return context.WithValue(ctx, contextKeyAuthorizedUser, u) } // Retrieve a user from the given context. Returns nil if not set -func GetUser(ctx context.Context) *db.User { - user, ok := ctx.Value(contextKeyAuthorizedUser).(*db.User) +func GetUser(ctx context.Context) *AuthenticatedUser { + user, ok := ctx.Value(contextKeyAuthorizedUser).(*AuthenticatedUser) if !ok { return nil } diff --git a/handlers/account.go b/handlers/account.go index 08ed745..727f8ee 100644 --- a/handlers/account.go +++ b/handlers/account.go @@ -12,6 +12,7 @@ import ( "github.com/rs/zerolog" ) +// Renders the account page on the 'General' subpage func HandleAccountPage() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { @@ -20,6 +21,7 @@ func HandleAccountPage() http.Handler { ) } +// Handles a request to change the subpage for the Account page func HandleAccountSubpage() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { @@ -30,6 +32,7 @@ func HandleAccountSubpage() http.Handler { ) } +// Handles a request to change the users username func HandleChangeUsername( logger *zerolog.Logger, conn *sql.DB, diff --git a/handlers/reauthenticatate.go b/handlers/reauthenticatate.go new file mode 100644 index 0000000..9e188d8 --- /dev/null +++ b/handlers/reauthenticatate.go @@ -0,0 +1,119 @@ +package handlers + +import ( + "database/sql" + "net/http" + + "projectreshoot/config" + "projectreshoot/contexts" + "projectreshoot/cookies" + "projectreshoot/jwt" + "projectreshoot/view/component/form" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +// Get the tokens from the request +func getTokens( + config *config.Config, + conn *sql.DB, + r *http.Request, +) (*jwt.AccessToken, *jwt.RefreshToken, error) { + // get the existing tokens from the cookies + atStr, rtStr := cookies.GetTokenStrings(r) + aT, err := jwt.ParseAccessToken(config, conn, atStr) + if err != nil { + return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken") + } + rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + if err != nil { + return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken") + } + return aT, rT, nil +} + +// Revoke the given token pair +func revokeTokenPair( + conn *sql.DB, + aT *jwt.AccessToken, + rT *jwt.RefreshToken, +) error { + err := jwt.RevokeToken(conn, aT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + err = jwt.RevokeToken(conn, rT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + return nil +} + +// Issue new tokens for the user, invalidating the old ones +func refreshTokens( + config *config.Config, + conn *sql.DB, + w http.ResponseWriter, + r *http.Request, +) error { + aT, rT, err := getTokens(config, conn, r) + if err != nil { + return errors.Wrap(err, "getTokens") + } + rememberMe := map[string]bool{ + "session": false, + "exp": true, + }[aT.TTL] + // issue new tokens for the user + user := contexts.GetUser(r.Context()) + err = cookies.SetTokenCookies(w, r, config, user.User, true, rememberMe) + if err != nil { + return errors.Wrap(err, "cookies.SetTokenCookies") + } + err = revokeTokenPair(conn, aT, rT) + if err != nil { + return errors.Wrap(err, "revokeTokenPair") + } + + return nil +} + +// Validate the provided password +func validatePassword( + r *http.Request, +) error { + r.ParseForm() + password := r.FormValue("password") + user := contexts.GetUser(r.Context()) + err := user.CheckPassword(password) + if err != nil { + return errors.Wrap(err, "user.CheckPassword") + } + return nil +} + +// Handle request to reauthenticate (i.e. make token fresh again) +func HandleReauthenticate( + logger *zerolog.Logger, + config *config.Config, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + err := validatePassword(r) + if err != nil { + w.WriteHeader(445) + form.ConfirmPassword("Incorrect password").Render(r.Context(), w) + return + } + err = refreshTokens(config, conn, w, r) + if err != nil { + logger.Error().Err(err).Msg("Failed to refresh user tokens") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + }, + ) +} diff --git a/middleware/authentication.go b/middleware/authentication.go index 38cabbe..23f40f0 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -3,6 +3,7 @@ package middleware import ( "database/sql" "net/http" + "time" "projectreshoot/config" "projectreshoot/contexts" @@ -52,7 +53,7 @@ func getAuthenticatedUser( conn *sql.DB, w http.ResponseWriter, r *http.Request, -) (*db.User, error) { +) (*contexts.AuthenticatedUser, error) { // Get token strings from cookies atStr, rtStr := cookies.GetTokenStrings(r) // Attempt to parse the access token @@ -69,14 +70,22 @@ func getAuthenticatedUser( return nil, errors.Wrap(err, "refreshAuthTokens") } // New token pair sent, return the authorized user - return user, nil + authUser := contexts.AuthenticatedUser{ + User: user, + Fresh: time.Now().Unix(), + } + return &authUser, nil } // Access token valid user, err := aT.GetUser(conn) if err != nil { return nil, errors.Wrap(err, "rT.GetUser") } - return user, nil + authUser := contexts.AuthenticatedUser{ + User: user, + Fresh: aT.Fresh, + } + return &authUser, nil } // Attempt to authenticate the user and add their account details diff --git a/middleware/reauthentication.go b/middleware/reauthentication.go new file mode 100644 index 0000000..41fad65 --- /dev/null +++ b/middleware/reauthentication.go @@ -0,0 +1,21 @@ +package middleware + +import ( + "net/http" + "projectreshoot/contexts" + "time" +) + +func RequiresFresh( + next http.Handler, +) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := contexts.GetUser(r.Context()) + isFresh := time.Now().Before(time.Unix(user.Fresh, 0)) + if !isFresh { + w.WriteHeader(444) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/server/routes.go b/server/routes.go index 2bada29..2804a08 100644 --- a/server/routes.go +++ b/server/routes.go @@ -77,6 +77,12 @@ func addRoutes( )) mux.Handle("POST /change-username", middleware.RequiresLogin( - handlers.HandleChangeUsername(logger, conn), + middleware.RequiresFresh( + handlers.HandleChangeUsername(logger, conn), + ), + )) + mux.Handle("POST /reauthenticate", + middleware.RequiresLogin( + handlers.HandleReauthenticate(logger, config, conn), )) } diff --git a/view/component/form/confirmpass.templ b/view/component/form/confirmpass.templ new file mode 100644 index 0000000..6976ddf --- /dev/null +++ b/view/component/form/confirmpass.templ @@ -0,0 +1,84 @@ +package form + +import "fmt" + +templ ConfirmPassword(err string) { + {{ + xdata := fmt.Sprintf( + "{ errMsg: '%s'}", + err, + ) + }} +
+} diff --git a/view/component/popup/confirmPasswordModal.templ b/view/component/popup/confirmPasswordModal.templ new file mode 100644 index 0000000..4179a12 --- /dev/null +++ b/view/component/popup/confirmPasswordModal.templ @@ -0,0 +1,20 @@ + +package popup + +import "projectreshoot/view/component/form" + +templ ConfirmPasswordModal() { + +} diff --git a/view/component/errorPopup.templ b/view/component/popup/errorPopup.templ similarity index 99% rename from view/component/errorPopup.templ rename to view/component/popup/errorPopup.templ index 968beab..f809230 100644 --- a/view/component/errorPopup.templ +++ b/view/component/popup/errorPopup.templ @@ -1,4 +1,4 @@ -package component +package popup templ ErrorPopup() {