diff --git a/contexts/user.go b/contexts/user.go index 3a752a5..0ef041f 100644 --- a/contexts/user.go +++ b/contexts/user.go @@ -5,14 +5,19 @@ import ( "projectreshoot/db" ) +type AuthenticatedUser struct { + *db.User + Fresh int64 +} + // Return a new context with the user added in -func SetUser(ctx context.Context, u *db.User) context.Context { +func SetUser(ctx context.Context, u *AuthenticatedUser) context.Context { return context.WithValue(ctx, contextKeyAuthorizedUser, u) } // Retrieve a user from the given context. Returns nil if not set -func GetUser(ctx context.Context) *db.User { - user, ok := ctx.Value(contextKeyAuthorizedUser).(*db.User) +func GetUser(ctx context.Context) *AuthenticatedUser { + user, ok := ctx.Value(contextKeyAuthorizedUser).(*AuthenticatedUser) if !ok { return nil } diff --git a/cookies/delete.go b/cookies/delete.go deleted file mode 100644 index d847b65..0000000 --- a/cookies/delete.go +++ /dev/null @@ -1,18 +0,0 @@ -package cookies - -import ( - "net/http" - "time" -) - -// Tell the browser to delete the cookie matching the name provided -// Path must match the original set cookie for it to delete -func DeleteCookie(w http.ResponseWriter, name string, path string) { - http.SetCookie(w, &http.Cookie{ - Name: name, - Value: "", - Path: path, - Expires: time.Unix(0, 0), // Expire in the past - MaxAge: -1, // Immediately expire - }) -} diff --git a/cookies/functions.go b/cookies/functions.go new file mode 100644 index 0000000..8e33212 --- /dev/null +++ b/cookies/functions.go @@ -0,0 +1,37 @@ +package cookies + +import ( + "net/http" + "time" +) + +// Tell the browser to delete the cookie matching the name provided +// Path must match the original set cookie for it to delete +func DeleteCookie(w http.ResponseWriter, name string, path string) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: "", + Path: path, + Expires: time.Unix(0, 0), // Expire in the past + MaxAge: -1, // Immediately expire + HttpOnly: true, + }) +} + +// Set a cookie with the given name, path and value. maxAge directly relates +// to cookie MaxAge (0 for no max age, >0 for TTL in seconds) +func SetCookie( + w http.ResponseWriter, + name string, + path string, + value string, + maxAge int, +) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: value, + Path: path, + HttpOnly: true, + MaxAge: maxAge, + }) +} diff --git a/cookies/pagefrom.go b/cookies/pagefrom.go index 906db61..fa4cb0b 100644 --- a/cookies/pagefrom.go +++ b/cookies/pagefrom.go @@ -32,6 +32,5 @@ func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) { } else { pageFrom = parsedURL.Path } - pageFromCookie := &http.Cookie{Name: "pagefrom", Value: pageFrom, Path: "/"} - http.SetCookie(w, pageFromCookie) + SetCookie(w, "pagefrom", "/", pageFrom, 0) } diff --git a/db/user.go b/db/user.go new file mode 100644 index 0000000..fe62349 --- /dev/null +++ b/db/user.go @@ -0,0 +1,60 @@ +package db + +import ( + "database/sql" + + "github.com/pkg/errors" + "golang.org/x/crypto/bcrypt" +) + +type User struct { + ID int // Integer ID (index primary key) + Username string // Username (unique) + Password_hash string // Bcrypt password hash + Created_at int64 // Epoch timestamp when the user was added to the database + Bio string // Short byline set by the user +} + +// Uses bcrypt to set the users Password_hash from the given password +func (user *User) SetPassword(conn *sql.DB, password string) error { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return errors.Wrap(err, "bcrypt.GenerateFromPassword") + } + user.Password_hash = string(hashedPassword) + query := `UPDATE users SET password_hash = ? WHERE id = ?` + _, err = conn.Exec(query, user.Password_hash, user.ID) + if err != nil { + return errors.Wrap(err, "conn.Exec") + } + return nil +} + +// Uses bcrypt to check if the given password matches the users Password_hash +func (user *User) CheckPassword(password string) error { + err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password)) + if err != nil { + return errors.Wrap(err, "bcrypt.CompareHashAndPassword") + } + return nil +} + +// Change the user's username +func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error { + query := `UPDATE users SET username = ? WHERE id = ?` + _, err := conn.Exec(query, newUsername, user.ID) + if err != nil { + return errors.Wrap(err, "conn.Exec") + } + return nil +} + +// Change the user's bio +func (user *User) ChangeBio(conn *sql.DB, newBio string) error { + query := `UPDATE users SET bio = ? WHERE id = ?` + _, err := conn.Exec(query, newBio, user.ID) + if err != nil { + return errors.Wrap(err, "conn.Exec") + } + return nil +} diff --git a/db/user_functions.go b/db/user_functions.go new file mode 100644 index 0000000..3c3623e --- /dev/null +++ b/db/user_functions.go @@ -0,0 +1,108 @@ +package db + +import ( + "database/sql" + "fmt" + + "github.com/pkg/errors" +) + +// Creates a new user in the database and returns a pointer +func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { + query := `INSERT INTO users (username) VALUES (?)` + _, err := conn.Exec(query, username) + if err != nil { + return nil, errors.Wrap(err, "conn.Exec") + } + user, err := GetUserFromUsername(conn, username) + if err != nil { + return nil, errors.Wrap(err, "GetUserFromUsername") + } + err = user.SetPassword(conn, password) + if err != nil { + return nil, errors.Wrap(err, "user.SetPassword") + } + return user, nil +} + +// Fetches data from the users table using "WHERE column = 'value'" +func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { + query := fmt.Sprintf( + `SELECT + id, + username, + password_hash, + created_at, + bio + FROM users + WHERE %s = ? COLLATE NOCASE LIMIT 1`, + column, + ) + rows, err := conn.Query(query, value) + if err != nil { + return nil, errors.Wrap(err, "conn.Query") + } + return rows, nil +} + +// Scan the next row into the provided user pointer. Calls rows.Next() and +// assumes only row in the result. Providing a rows object with more than 1 +// row may result in undefined behaviour. +func scanUserRow(user *User, rows *sql.Rows) error { + for rows.Next() { + err := rows.Scan( + &user.ID, + &user.Username, + &user.Password_hash, + &user.Created_at, + &user.Bio, + ) + if err != nil { + return errors.Wrap(err, "rows.Scan") + } + } + return nil +} + +// Queries the database for a user matching the given username. +// Query is case insensitive +func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { + rows, err := fetchUserData(conn, "username", username) + if err != nil { + return nil, errors.Wrap(err, "fetchUserData") + } + defer rows.Close() + var user User + err = scanUserRow(&user, rows) + if err != nil { + return nil, errors.Wrap(err, "scanUserRow") + } + return &user, nil +} + +// Queries the database for a user matching the given ID. +func GetUserFromID(conn *sql.DB, id int) (*User, error) { + rows, err := fetchUserData(conn, "id", id) + if err != nil { + return nil, errors.Wrap(err, "fetchUserData") + } + defer rows.Close() + var user User + err = scanUserRow(&user, rows) + if err != nil { + return nil, errors.Wrap(err, "scanUserRow") + } + return &user, nil +} + +// Checks if the given username is unique. Returns true if not taken +func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { + query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` + rows, err := conn.Query(query, username) + if err != nil { + return false, errors.Wrap(err, "conn.Query") + } + defer rows.Close() + taken := rows.Next() + return !taken, nil +} diff --git a/db/users.go b/db/users.go deleted file mode 100644 index 23a207e..0000000 --- a/db/users.go +++ /dev/null @@ -1,125 +0,0 @@ -package db - -import ( - "database/sql" - - "github.com/pkg/errors" - "golang.org/x/crypto/bcrypt" -) - -type User struct { - ID int // Integer ID (index primary key) - Username string // Username (unique) - Password_hash string // Bcrypt password hash - Created_at int64 // Epoch timestamp when the user was added to the database -} - -// Uses bcrypt to set the users Password_hash from the given password -func (user *User) SetPassword(conn *sql.DB, password string) error { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "bcrypt.GenerateFromPassword") - } - user.Password_hash = string(hashedPassword) - query := `UPDATE users SET password_hash = ? WHERE id = ?` - result, err := conn.Exec(query, user.Password_hash, user.ID) - if err != nil { - return errors.Wrap(err, "conn.Exec") - } - ra, err := result.RowsAffected() - if err != nil { - return errors.Wrap(err, "result.RowsAffected") - } - if ra != 1 { - return errors.New("Password was not updated") - } - return nil -} - -// Uses bcrypt to check if the given password matches the users Password_hash -func (user *User) CheckPassword(password string) error { - err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password)) - if err != nil { - return errors.Wrap(err, "bcrypt.CompareHashAndPassword") - } - return nil -} - -// Creates a new user in the database and returns a pointer -func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { - query := `INSERT INTO users (username) VALUES (?)` - _, err := conn.Exec(query, username) - if err != nil { - return nil, errors.Wrap(err, "conn.Exec") - } - user, err := GetUserFromUsername(conn, username) - if err != nil { - return nil, errors.Wrap(err, "GetUserFromUsername") - } - err = user.SetPassword(conn, password) - if err != nil { - return nil, errors.Wrap(err, "user.SetPassword") - } - return user, nil -} - -// Queries the database for a user matching the given username. -// Query is case insensitive -func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { - query := `SELECT id, username, password_hash, created_at FROM users - WHERE username = ? COLLATE NOCASE` - rows, err := conn.Query(query, username) - if err != nil { - return nil, errors.Wrap(err, "conn.Query") - } - defer rows.Close() - var user User - for rows.Next() { - err := rows.Scan( - &user.ID, - &user.Username, - &user.Password_hash, - &user.Created_at, - ) - if err != nil { - return nil, errors.Wrap(err, "rows.Scan") - } - } - return &user, nil -} - -// Queries the database for a user matching the given ID. -func GetUserFromID(conn *sql.DB, id int) (*User, error) { - query := `SELECT id, username, password_hash, created_at FROM users - WHERE id = ?` - rows, err := conn.Query(query, id) - if err != nil { - return nil, errors.Wrap(err, "conn.Query") - } - defer rows.Close() - var user User - for rows.Next() { - err := rows.Scan( - &user.ID, - &user.Username, - &user.Password_hash, - &user.Created_at, - ) - if err != nil { - return nil, errors.Wrap(err, "rows.Scan") - } - } - return &user, nil -} - -// Checks if the given username is unique. Returns true if not taken -func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { - query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` - rows, err := conn.Query(query, username) - if err != nil { - return false, errors.Wrap(err, "conn.Query") - } - defer rows.Close() - taken := rows.Next() - return !taken, nil -} diff --git a/handlers/account.go b/handlers/account.go new file mode 100644 index 0000000..be79910 --- /dev/null +++ b/handlers/account.go @@ -0,0 +1,137 @@ +package handlers + +import ( + "database/sql" + "net/http" + + "projectreshoot/contexts" + "projectreshoot/cookies" + "projectreshoot/db" + "projectreshoot/view/component/account" + "projectreshoot/view/page" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +// Renders the account page on the 'General' subpage +func HandleAccountPage() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("subpage") + subpage := cookie.Value + if err != nil { + subpage = "General" + } + page.Account(subpage).Render(r.Context(), w) + }, + ) +} + +// Handles a request to change the subpage for the Accou/accountnt page +func HandleAccountSubpage() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + subpage := r.FormValue("subpage") + cookies.SetCookie(w, "subpage", "/account", subpage, 300) + account.AccountContainer(subpage).Render(r.Context(), w) + }, + ) +} + +// Handles a request to change the users username +func HandleChangeUsername( + logger *zerolog.Logger, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + newUsername := r.FormValue("username") + + unique, err := db.CheckUsernameUnique(conn, newUsername) + if err != nil { + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + if !unique { + account.ChangeUsername("Username is taken", newUsername). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.ChangeUsername(conn, newUsername) + if err != nil { + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("HX-Refresh", "true") + }, + ) +} + +// Handles a request to change the users bio +func HandleChangeBio( + logger *zerolog.Logger, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + newBio := r.FormValue("bio") + leng := len([]rune(newBio)) + if leng > 128 { + account.ChangeBio("Bio limited to 128 characters", newBio). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err := user.ChangeBio(conn, newBio) + if err != nil { + logger.Error().Err(err).Msg("Error updating bio") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("HX-Refresh", "true") + }, + ) +} +func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) { + r.ParseForm() + formPassword := r.FormValue("password") + formConfirmPassword := r.FormValue("confirm-password") + if formPassword != formConfirmPassword { + return "", errors.New("Passwords do not match") + } + if len(formPassword) > 72 { + return "", errors.New("Password exceeds maximum length of 72 bytes") + } + return formPassword, nil +} + +// Handles a request to change the users password +func HandleChangePassword( + logger *zerolog.Logger, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + newPass, err := validateChangePassword(conn, r) + if err != nil { + account.ChangePassword(err.Error()).Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.SetPassword(conn, newPass) + if err != nil { + logger.Error().Err(err).Msg("Error updating password") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("HX-Refresh", "true") + }, + ) +} diff --git a/handlers/profile.go b/handlers/profile.go index 91d21b2..91f381f 100644 --- a/handlers/profile.go +++ b/handlers/profile.go @@ -5,7 +5,7 @@ import ( "projectreshoot/view/page" ) -func HandleProfile() http.Handler { +func HandleProfilePage() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { page.Profile().Render(r.Context(), w) diff --git a/handlers/reauthenticatate.go b/handlers/reauthenticatate.go new file mode 100644 index 0000000..9e188d8 --- /dev/null +++ b/handlers/reauthenticatate.go @@ -0,0 +1,119 @@ +package handlers + +import ( + "database/sql" + "net/http" + + "projectreshoot/config" + "projectreshoot/contexts" + "projectreshoot/cookies" + "projectreshoot/jwt" + "projectreshoot/view/component/form" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +// Get the tokens from the request +func getTokens( + config *config.Config, + conn *sql.DB, + r *http.Request, +) (*jwt.AccessToken, *jwt.RefreshToken, error) { + // get the existing tokens from the cookies + atStr, rtStr := cookies.GetTokenStrings(r) + aT, err := jwt.ParseAccessToken(config, conn, atStr) + if err != nil { + return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken") + } + rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + if err != nil { + return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken") + } + return aT, rT, nil +} + +// Revoke the given token pair +func revokeTokenPair( + conn *sql.DB, + aT *jwt.AccessToken, + rT *jwt.RefreshToken, +) error { + err := jwt.RevokeToken(conn, aT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + err = jwt.RevokeToken(conn, rT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + return nil +} + +// Issue new tokens for the user, invalidating the old ones +func refreshTokens( + config *config.Config, + conn *sql.DB, + w http.ResponseWriter, + r *http.Request, +) error { + aT, rT, err := getTokens(config, conn, r) + if err != nil { + return errors.Wrap(err, "getTokens") + } + rememberMe := map[string]bool{ + "session": false, + "exp": true, + }[aT.TTL] + // issue new tokens for the user + user := contexts.GetUser(r.Context()) + err = cookies.SetTokenCookies(w, r, config, user.User, true, rememberMe) + if err != nil { + return errors.Wrap(err, "cookies.SetTokenCookies") + } + err = revokeTokenPair(conn, aT, rT) + if err != nil { + return errors.Wrap(err, "revokeTokenPair") + } + + return nil +} + +// Validate the provided password +func validatePassword( + r *http.Request, +) error { + r.ParseForm() + password := r.FormValue("password") + user := contexts.GetUser(r.Context()) + err := user.CheckPassword(password) + if err != nil { + return errors.Wrap(err, "user.CheckPassword") + } + return nil +} + +// Handle request to reauthenticate (i.e. make token fresh again) +func HandleReauthenticate( + logger *zerolog.Logger, + config *config.Config, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + err := validatePassword(r) + if err != nil { + w.WriteHeader(445) + form.ConfirmPassword("Incorrect password").Render(r.Context(), w) + return + } + err = refreshTokens(config, conn, w, r) + if err != nil { + logger.Error().Err(err).Msg("Failed to refresh user tokens") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + }, + ) +} diff --git a/middleware/authentication.go b/middleware/authentication.go index 38cabbe..23f40f0 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -3,6 +3,7 @@ package middleware import ( "database/sql" "net/http" + "time" "projectreshoot/config" "projectreshoot/contexts" @@ -52,7 +53,7 @@ func getAuthenticatedUser( conn *sql.DB, w http.ResponseWriter, r *http.Request, -) (*db.User, error) { +) (*contexts.AuthenticatedUser, error) { // Get token strings from cookies atStr, rtStr := cookies.GetTokenStrings(r) // Attempt to parse the access token @@ -69,14 +70,22 @@ func getAuthenticatedUser( return nil, errors.Wrap(err, "refreshAuthTokens") } // New token pair sent, return the authorized user - return user, nil + authUser := contexts.AuthenticatedUser{ + User: user, + Fresh: time.Now().Unix(), + } + return &authUser, nil } // Access token valid user, err := aT.GetUser(conn) if err != nil { return nil, errors.Wrap(err, "rT.GetUser") } - return user, nil + authUser := contexts.AuthenticatedUser{ + User: user, + Fresh: aT.Fresh, + } + return &authUser, nil } // Attempt to authenticate the user and add their account details diff --git a/middleware/reauthentication.go b/middleware/reauthentication.go new file mode 100644 index 0000000..41fad65 --- /dev/null +++ b/middleware/reauthentication.go @@ -0,0 +1,21 @@ +package middleware + +import ( + "net/http" + "projectreshoot/contexts" + "time" +) + +func RequiresFresh( + next http.Handler, +) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := contexts.GetUser(r.Context()) + isFresh := time.Now().Before(time.Unix(user.Fresh, 0)) + if !isFresh { + w.WriteHeader(444) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/schema.sql b/schema.sql index 80f9970..986d312 100644 --- a/schema.sql +++ b/schema.sql @@ -8,7 +8,8 @@ CREATE TABLE IF NOT EXISTS "users" ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT DEFAULT "", - created_at INTEGER DEFAULT (unixepoch()) + created_at INTEGER DEFAULT (unixepoch()), + bio TEXT DEFAULT "" ) STRICT; CREATE TRIGGER cleanup_expired_tokens AFTER INSERT ON jwtblacklist diff --git a/server/routes.go b/server/routes.go index 19c87f7..606036e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -60,9 +60,41 @@ func addRoutes( // Logout mux.Handle("POST /logout", handlers.HandleLogout(config, logger, conn)) + // Reauthentication request + mux.Handle("POST /reauthenticate", + middleware.RequiresLogin( + handlers.HandleReauthenticate(logger, config, conn), + )) + // Profile page mux.Handle("GET /profile", middleware.RequiresLogin( - handlers.HandleProfile(), + handlers.HandleProfilePage(), + )) + + // Account page + mux.Handle("GET /account", + middleware.RequiresLogin( + handlers.HandleAccountPage(), + )) + mux.Handle("POST /account-select-page", + middleware.RequiresLogin( + handlers.HandleAccountSubpage(), + )) + mux.Handle("POST /change-username", + middleware.RequiresLogin( + middleware.RequiresFresh( + handlers.HandleChangeUsername(logger, conn), + ), + )) + mux.Handle("POST /change-bio", + middleware.RequiresLogin( + handlers.HandleChangeBio(logger, conn), + )) + mux.Handle("POST /change-password", + middleware.RequiresLogin( + middleware.RequiresFresh( + handlers.HandleChangePassword(logger, conn), + ), )) } diff --git a/view/component/account/changebio.templ b/view/component/account/changebio.templ new file mode 100644 index 0000000..cb85013 --- /dev/null +++ b/view/component/account/changebio.templ @@ -0,0 +1,117 @@ +package account + +import "projectreshoot/contexts" + +templ ChangeBio(err string, bio string) { + {{ + user := contexts.GetUser(ctx) + if bio == "" { + bio = user.Bio + } + }} +
+} diff --git a/view/component/account/changepassword.templ b/view/component/account/changepassword.templ new file mode 100644 index 0000000..6c368d3 --- /dev/null +++ b/view/component/account/changepassword.templ @@ -0,0 +1,141 @@ +package account + +templ ChangePassword(err string) { + +} diff --git a/view/component/account/changeusername.templ b/view/component/account/changeusername.templ new file mode 100644 index 0000000..25cc151 --- /dev/null +++ b/view/component/account/changeusername.templ @@ -0,0 +1,108 @@ +package account + +import "projectreshoot/contexts" + +templ ChangeUsername(err string, username string) { + {{ + user := contexts.GetUser(ctx) + if username == "" { + username = user.Username + } + }} + +} diff --git a/view/component/account/container.templ b/view/component/account/container.templ new file mode 100644 index 0000000..f065949 --- /dev/null +++ b/view/component/account/container.templ @@ -0,0 +1,26 @@ +package account + +templ AccountContainer(subpage string) { +