diff --git a/cookies/pagefrom.go b/cookies/pagefrom.go index 45e4cd4..906db61 100644 --- a/cookies/pagefrom.go +++ b/cookies/pagefrom.go @@ -24,12 +24,11 @@ func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) { if err != nil { return } - // NOTE: its possible this could cause an infinite redirect - // if that happens, will need to add a way to 'blacklist' certain paths - // from being set here var pageFrom string if parsedURL.Path == "" || parsedURL.Host != trustedHost { pageFrom = "/" + } else if parsedURL.Path == "/login" || parsedURL.Path == "/register" { + return } else { pageFrom = parsedURL.Path } diff --git a/db/users.go b/db/users.go index 4436e96..23a207e 100644 --- a/db/users.go +++ b/db/users.go @@ -45,38 +45,32 @@ func (user *User) CheckPassword(password string) error { return nil } -// Queries the database for a user matching the given username. -// Query is case insensitive -func GetUserFromUsername(conn *sql.DB, username string) (User, error) { - query := `SELECT id, username, password_hash, created_at FROM users - WHERE username = ? COLLATE NOCASE` - rows, err := conn.Query(query, username) +// Creates a new user in the database and returns a pointer +func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { + query := `INSERT INTO users (username) VALUES (?)` + _, err := conn.Exec(query, username) if err != nil { - return User{}, errors.Wrap(err, "conn.Query") + return nil, errors.Wrap(err, "conn.Exec") } - defer rows.Close() - var user User - for rows.Next() { - err := rows.Scan( - &user.ID, - &user.Username, - &user.Password_hash, - &user.Created_at, - ) - if err != nil { - return User{}, errors.Wrap(err, "rows.Scan") - } + user, err := GetUserFromUsername(conn, username) + if err != nil { + return nil, errors.Wrap(err, "GetUserFromUsername") + } + err = user.SetPassword(conn, password) + if err != nil { + return nil, errors.Wrap(err, "user.SetPassword") } return user, nil } -// Queries the database for a user matching the given ID. -func GetUserFromID(conn *sql.DB, id int) (User, error) { +// Queries the database for a user matching the given username. +// Query is case insensitive +func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { query := `SELECT id, username, password_hash, created_at FROM users - WHERE id = ?` - rows, err := conn.Query(query, id) + WHERE username = ? COLLATE NOCASE` + rows, err := conn.Query(query, username) if err != nil { - return User{}, errors.Wrap(err, "conn.Query") + return nil, errors.Wrap(err, "conn.Query") } defer rows.Close() var user User @@ -88,8 +82,44 @@ func GetUserFromID(conn *sql.DB, id int) (User, error) { &user.Created_at, ) if err != nil { - return User{}, errors.Wrap(err, "rows.Scan") + return nil, errors.Wrap(err, "rows.Scan") } } - return user, nil + return &user, nil +} + +// Queries the database for a user matching the given ID. +func GetUserFromID(conn *sql.DB, id int) (*User, error) { + query := `SELECT id, username, password_hash, created_at FROM users + WHERE id = ?` + rows, err := conn.Query(query, id) + if err != nil { + return nil, errors.Wrap(err, "conn.Query") + } + defer rows.Close() + var user User + for rows.Next() { + err := rows.Scan( + &user.ID, + &user.Username, + &user.Password_hash, + &user.Created_at, + ) + if err != nil { + return nil, errors.Wrap(err, "rows.Scan") + } + } + return &user, nil +} + +// Checks if the given username is unique. Returns true if not taken +func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { + query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` + rows, err := conn.Query(query, username) + if err != nil { + return false, errors.Wrap(err, "conn.Query") + } + defer rows.Close() + taken := rows.Next() + return !taken, nil } diff --git a/handlers/login.go b/handlers/login.go index 2b5f4a3..7af3901 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -16,17 +16,17 @@ import ( // Validates the username matches a user in the database and the password // is correct. Returns the corresponding user -func validateLogin(conn *sql.DB, r *http.Request) (db.User, error) { +func validateLogin(conn *sql.DB, r *http.Request) (*db.User, error) { formUsername := r.FormValue("username") formPassword := r.FormValue("password") user, err := db.GetUserFromUsername(conn, formUsername) if err != nil { - return db.User{}, errors.Wrap(err, "db.GetUserFromUsername") + return nil, errors.Wrap(err, "db.GetUserFromUsername") } err = user.CheckPassword(formPassword) if err != nil { - return db.User{}, errors.New("Username or password incorrect") + return nil, errors.New("Username or password incorrect") } return user, nil } @@ -54,17 +54,19 @@ func HandleLoginRequest( r.ParseForm() user, err := validateLogin(conn, r) if err != nil { - form.LoginForm(err.Error()).Render(r.Context(), w) if err.Error() != "Username or password incorrect" { logger.Warn().Caller().Err(err).Msg("Login request failed") + w.WriteHeader(http.StatusInternalServerError) + } else { + form.LoginForm(err.Error()).Render(r.Context(), w) } return } rememberMe := checkRememberMe(r) - err = cookies.SetTokenCookies(w, r, config, &user, true, rememberMe) + err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) if err != nil { - form.LoginForm(err.Error()).Render(r.Context(), w) + w.WriteHeader(http.StatusInternalServerError) logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") } diff --git a/handlers/register.go b/handlers/register.go new file mode 100644 index 0000000..6a274d8 --- /dev/null +++ b/handlers/register.go @@ -0,0 +1,81 @@ +package handlers + +import ( + "database/sql" + "net/http" + + "projectreshoot/config" + "projectreshoot/cookies" + "projectreshoot/db" + "projectreshoot/view/component/form" + "projectreshoot/view/page" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { + formUsername := r.FormValue("username") + formPassword := r.FormValue("password") + formConfirmPassword := r.FormValue("confirm-password") + unique, err := db.CheckUsernameUnique(conn, formUsername) + if err != nil { + return nil, errors.Wrap(err, "db.CheckUsernameUnique") + } + if !unique { + return nil, errors.New("Username is taken") + } + if formPassword != formConfirmPassword { + return nil, errors.New("Passwords do not match") + } + user, err := db.CreateNewUser(conn, formUsername, formPassword) + if err != nil { + return nil, errors.Wrap(err, "db.CreateNewUser") + } + + return user, nil +} + +func HandleRegisterRequest( + config *config.Config, + logger *zerolog.Logger, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + user, err := validateRegistration(conn, r) + if err != nil { + if err.Error() != "Username is taken" && + err.Error() != "Passwords do not match" { + logger.Warn().Caller().Err(err).Msg("Registration request failed") + w.WriteHeader(http.StatusInternalServerError) + } else { + form.RegisterForm(err.Error()).Render(r.Context(), w) + } + return + } + + rememberMe := checkRememberMe(r) + err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") + } + + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) + }, + ) +} + +// Handles a request to view the login page. Will attempt to set "pagefrom" +// cookie so a successful login can redirect the user to the page they came +func HandleRegisterPage(trustedHost string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cookies.SetPageFrom(w, r, trustedHost) + page.Register().Render(r.Context(), w) + }, + ) +} diff --git a/jwt/tokens.go b/jwt/tokens.go index 5697f9c..d76e952 100644 --- a/jwt/tokens.go +++ b/jwt/tokens.go @@ -43,14 +43,14 @@ func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) { if err != nil { return nil, errors.Wrap(err, "db.GetUserFromID") } - return &user, nil + return user, nil } func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) { user, err := db.GetUserFromID(conn, r.SUB) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromID") } - return &user, nil + return user, nil } func (a AccessToken) GetJTI() uuid.UUID { diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index c608127..8465ca0 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -50,13 +50,13 @@ func TestAuthenticationMiddleware(t *testing.T) { require.NoError(t, err) // Good tokens - atStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false) + atStr, _, err := jwt.GenerateAccessToken(cfg, user, false, false) require.NoError(t, err) - rtStr, _, err := jwt.GenerateRefreshToken(cfg, &user, false) + rtStr, _, err := jwt.GenerateRefreshToken(cfg, user, false) require.NoError(t, err) // Create a token and revoke it for testing - expStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false) + expStr, _, err := jwt.GenerateAccessToken(cfg, user, false, false) require.NoError(t, err) expT, err := jwt.ParseAccessToken(cfg, conn, expStr) require.NoError(t, err) diff --git a/schema.sql b/schema.sql index 3e55fc5..80f9970 100644 --- a/schema.sql +++ b/schema.sql @@ -7,7 +7,7 @@ exp INTEGER NOT NULL CREATE TABLE IF NOT EXISTS "users" ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, - password_hash TEXT, + password_hash TEXT DEFAULT "", created_at INTEGER DEFAULT (unixepoch()) ) STRICT; CREATE TRIGGER cleanup_expired_tokens diff --git a/server/routes.go b/server/routes.go index 8576910..2bc545b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -38,7 +38,16 @@ func addRoutes( conn, )) + // Register page and handlers + mux.Handle("GET /register", handlers.HandleRegisterPage(config.TrustedHost)) + mux.Handle("POST /register", handlers.HandleRegisterRequest( + config, + logger, + conn, + )) + // Logout mux.Handle("POST /logout", handlers.HandleLogout(config, logger, conn)) + // Profile page } diff --git a/view/component/form/loginform.templ b/view/component/form/loginform.templ index 9442272..3380320 100644 --- a/view/component/form/loginform.templ +++ b/view/component/form/loginform.templ @@ -8,11 +8,9 @@ import "fmt" // error icons on the username and password field templ LoginForm(loginError string) { {{ - var errCreds string + errCreds := "false" if loginError == "Username or password incorrect" { errCreds = "true" - } else { - errCreds = "false" } xdata := fmt.Sprintf( "{credentialError: %s, errorMessage: '%s'}", diff --git a/view/component/form/registerform.templ b/view/component/form/registerform.templ new file mode 100644 index 0000000..9c8a46d --- /dev/null +++ b/view/component/form/registerform.templ @@ -0,0 +1,200 @@ +package form + +import "fmt" + +// Login Form. If loginError is not an empty string, it will display the +// contents of loginError to the user. +templ RegisterForm(registerError string) { + {{ + errUsername := "false" + errPasswords := "false" + if registerError == "Username is taken" { + errUsername = "true" + } else if registerError == "Passwords do not match" { + errPasswords = "true" + } + xdata := fmt.Sprintf( + "{errUsername: %s, errPasswords: %s, errorMessage: '%s'}", + errUsername, + errPasswords, + registerError, + ) + }} +
+} diff --git a/view/page/register.templ b/view/page/register.templ new file mode 100644 index 0000000..75ded68 --- /dev/null +++ b/view/page/register.templ @@ -0,0 +1,42 @@ +package page + +import "projectreshoot/view/layout" +import "projectreshoot/view/component/form" + +// Returns the login page +templ Register() { + @layout.Global() { ++ Already have an account? + + Login here + +
+