From ea4dd2a407af89996b2a704a2fb3baaa656c34f2 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Fri, 14 Feb 2025 19:51:40 +1100 Subject: [PATCH] Added page protection for unauthorized access --- Makefile | 5 ++++ handlers/index.go | 6 +++- handlers/profile.go | 14 +++++++++ handlers/register.go | 6 +++- middleware/pageprotection.go | 36 ++++++++++++++++++++++++ server/routes.go | 39 ++++++++++++++++++-------- view/component/form/registerform.templ | 3 +- view/page/profile.templ | 13 +++++++++ 8 files changed, 107 insertions(+), 15 deletions(-) create mode 100644 handlers/profile.go create mode 100644 middleware/pageprotection.go create mode 100644 view/page/profile.templ diff --git a/Makefile b/Makefile index c4af9df..c056648 100644 --- a/Makefile +++ b/Makefile @@ -18,5 +18,10 @@ tester: go mod tidy && \ go run . --port 3232 --test --loglevel trace +test: + go mod tidy && \ + go test . -v + go test ./middleware -v + clean: go clean diff --git a/handlers/index.go b/handlers/index.go index 072408a..974b5be 100644 --- a/handlers/index.go +++ b/handlers/index.go @@ -12,7 +12,11 @@ func HandleRoot() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - page.Error("404", "Page not found", "The page or resource you have requested does not exist").Render(r.Context(), w) + page.Error( + "404", + "Page not found", + "The page or resource you have requested does not exist", + ).Render(r.Context(), w) return } page.Index().Render(r.Context(), w) diff --git a/handlers/profile.go b/handlers/profile.go new file mode 100644 index 0000000..91d21b2 --- /dev/null +++ b/handlers/profile.go @@ -0,0 +1,14 @@ +package handlers + +import ( + "net/http" + "projectreshoot/view/page" +) + +func HandleProfile() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + page.Profile().Render(r.Context(), w) + }, + ) +} diff --git a/handlers/register.go b/handlers/register.go index 6a274d8..895ab67 100644 --- a/handlers/register.go +++ b/handlers/register.go @@ -28,6 +28,9 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { if formPassword != formConfirmPassword { return nil, errors.New("Passwords do not match") } + if len(formPassword) > 72 { + return nil, errors.New("Password exceeds maximum length of 72 bytes") + } user, err := db.CreateNewUser(conn, formUsername, formPassword) if err != nil { return nil, errors.Wrap(err, "db.CreateNewUser") @@ -47,7 +50,8 @@ func HandleRegisterRequest( user, err := validateRegistration(conn, r) if err != nil { if err.Error() != "Username is taken" && - err.Error() != "Passwords do not match" { + err.Error() != "Passwords do not match" && + err.Error() != "Password exceeds maximum length of 72 bytes" { logger.Warn().Caller().Err(err).Msg("Registration request failed") w.WriteHeader(http.StatusInternalServerError) } else { diff --git a/middleware/pageprotection.go b/middleware/pageprotection.go new file mode 100644 index 0000000..64ef4da --- /dev/null +++ b/middleware/pageprotection.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "net/http" + "projectreshoot/contexts" + "projectreshoot/view/page" +) + +// Checks if the user is set in the context and shows 401 page if not logged in +func RequiresLogin(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := contexts.GetUser(r.Context()) + if user == nil { + page.Error( + "401", + "Unauthorized", + "Please login to view this page", + ).Render(r.Context(), w) + return + } + next.ServeHTTP(w, r) + }) +} + +// Checks if the user is set in the context and redirects them to profile if +// they are logged in +func RequiresLogout(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := contexts.GetUser(r.Context()) + if user != nil { + http.Redirect(w, r, "/profile", http.StatusFound) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/server/routes.go b/server/routes.go index 2bc545b..19c87f7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -6,6 +6,7 @@ import ( "projectreshoot/config" "projectreshoot/handlers" + "projectreshoot/middleware" "projectreshoot/view/page" "github.com/rs/zerolog" @@ -31,23 +32,37 @@ func addRoutes( mux.Handle("GET /about", handlers.HandlePage(page.About())) // Login page and handlers - mux.Handle("GET /login", handlers.HandleLoginPage(config.TrustedHost)) - mux.Handle("POST /login", handlers.HandleLoginRequest( - config, - logger, - conn, - )) + mux.Handle("GET /login", + middleware.RequiresLogout( + handlers.HandleLoginPage(config.TrustedHost), + )) + mux.Handle("POST /login", + middleware.RequiresLogout( + handlers.HandleLoginRequest( + config, + logger, + conn, + ))) // Register page and handlers - mux.Handle("GET /register", handlers.HandleRegisterPage(config.TrustedHost)) - mux.Handle("POST /register", handlers.HandleRegisterRequest( - config, - logger, - conn, - )) + mux.Handle("GET /register", + middleware.RequiresLogout( + handlers.HandleRegisterPage(config.TrustedHost), + )) + mux.Handle("POST /register", + middleware.RequiresLogout( + handlers.HandleRegisterRequest( + config, + logger, + conn, + ))) // Logout mux.Handle("POST /logout", handlers.HandleLogout(config, logger, conn)) // Profile page + mux.Handle("GET /profile", + middleware.RequiresLogin( + handlers.HandleProfile(), + )) } diff --git a/view/component/form/registerform.templ b/view/component/form/registerform.templ index 9c8a46d..f2a9da4 100644 --- a/view/component/form/registerform.templ +++ b/view/component/form/registerform.templ @@ -10,7 +10,8 @@ templ RegisterForm(registerError string) { errPasswords := "false" if registerError == "Username is taken" { errUsername = "true" - } else if registerError == "Passwords do not match" { + } else if registerError == "Passwords do not match" || + registerError == "Password exceeds maximum length of 72 bytes" { errPasswords = "true" } xdata := fmt.Sprintf( diff --git a/view/page/profile.templ b/view/page/profile.templ new file mode 100644 index 0000000..959f089 --- /dev/null +++ b/view/page/profile.templ @@ -0,0 +1,13 @@ +package page + +import "projectreshoot/view/layout" +import "projectreshoot/contexts" + +templ Profile() { + {{ user := contexts.GetUser(ctx) }} + @layout.Global() { +
+ Hello, { user.Username } +
+ } +}