Added user registration

This commit is contained in:
2025-02-14 18:49:46 +11:00
parent 5c8bec0ad2
commit 5616b8a248
11 changed files with 405 additions and 44 deletions

View File

@@ -24,12 +24,11 @@ func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
if err != nil {
return
}
// NOTE: its possible this could cause an infinite redirect
// if that happens, will need to add a way to 'blacklist' certain paths
// from being set here
var pageFrom string
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
pageFrom = "/"
} else if parsedURL.Path == "/login" || parsedURL.Path == "/register" {
return
} else {
pageFrom = parsedURL.Path
}

View File

@@ -45,38 +45,32 @@ func (user *User) CheckPassword(password string) error {
return nil
}
// Queries the database for a user matching the given username.
// Query is case insensitive
func GetUserFromUsername(conn *sql.DB, username string) (User, error) {
query := `SELECT id, username, password_hash, created_at FROM users
WHERE username = ? COLLATE NOCASE`
rows, err := conn.Query(query, username)
// Creates a new user in the database and returns a pointer
func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) {
query := `INSERT INTO users (username) VALUES (?)`
_, err := conn.Exec(query, username)
if err != nil {
return User{}, errors.Wrap(err, "conn.Query")
return nil, errors.Wrap(err, "conn.Exec")
}
defer rows.Close()
var user User
for rows.Next() {
err := rows.Scan(
&user.ID,
&user.Username,
&user.Password_hash,
&user.Created_at,
)
if err != nil {
return User{}, errors.Wrap(err, "rows.Scan")
}
user, err := GetUserFromUsername(conn, username)
if err != nil {
return nil, errors.Wrap(err, "GetUserFromUsername")
}
err = user.SetPassword(conn, password)
if err != nil {
return nil, errors.Wrap(err, "user.SetPassword")
}
return user, nil
}
// Queries the database for a user matching the given ID.
func GetUserFromID(conn *sql.DB, id int) (User, error) {
// Queries the database for a user matching the given username.
// Query is case insensitive
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
query := `SELECT id, username, password_hash, created_at FROM users
WHERE id = ?`
rows, err := conn.Query(query, id)
WHERE username = ? COLLATE NOCASE`
rows, err := conn.Query(query, username)
if err != nil {
return User{}, errors.Wrap(err, "conn.Query")
return nil, errors.Wrap(err, "conn.Query")
}
defer rows.Close()
var user User
@@ -88,8 +82,44 @@ func GetUserFromID(conn *sql.DB, id int) (User, error) {
&user.Created_at,
)
if err != nil {
return User{}, errors.Wrap(err, "rows.Scan")
return nil, errors.Wrap(err, "rows.Scan")
}
}
return user, nil
return &user, nil
}
// Queries the database for a user matching the given ID.
func GetUserFromID(conn *sql.DB, id int) (*User, error) {
query := `SELECT id, username, password_hash, created_at FROM users
WHERE id = ?`
rows, err := conn.Query(query, id)
if err != nil {
return nil, errors.Wrap(err, "conn.Query")
}
defer rows.Close()
var user User
for rows.Next() {
err := rows.Scan(
&user.ID,
&user.Username,
&user.Password_hash,
&user.Created_at,
)
if err != nil {
return nil, errors.Wrap(err, "rows.Scan")
}
}
return &user, nil
}
// Checks if the given username is unique. Returns true if not taken
func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) {
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
rows, err := conn.Query(query, username)
if err != nil {
return false, errors.Wrap(err, "conn.Query")
}
defer rows.Close()
taken := rows.Next()
return !taken, nil
}

View File

@@ -16,17 +16,17 @@ import (
// Validates the username matches a user in the database and the password
// is correct. Returns the corresponding user
func validateLogin(conn *sql.DB, r *http.Request) (db.User, error) {
func validateLogin(conn *sql.DB, r *http.Request) (*db.User, error) {
formUsername := r.FormValue("username")
formPassword := r.FormValue("password")
user, err := db.GetUserFromUsername(conn, formUsername)
if err != nil {
return db.User{}, errors.Wrap(err, "db.GetUserFromUsername")
return nil, errors.Wrap(err, "db.GetUserFromUsername")
}
err = user.CheckPassword(formPassword)
if err != nil {
return db.User{}, errors.New("Username or password incorrect")
return nil, errors.New("Username or password incorrect")
}
return user, nil
}
@@ -54,17 +54,19 @@ func HandleLoginRequest(
r.ParseForm()
user, err := validateLogin(conn, r)
if err != nil {
form.LoginForm(err.Error()).Render(r.Context(), w)
if err.Error() != "Username or password incorrect" {
logger.Warn().Caller().Err(err).Msg("Login request failed")
w.WriteHeader(http.StatusInternalServerError)
} else {
form.LoginForm(err.Error()).Render(r.Context(), w)
}
return
}
rememberMe := checkRememberMe(r)
err = cookies.SetTokenCookies(w, r, config, &user, true, rememberMe)
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
if err != nil {
form.LoginForm(err.Error()).Render(r.Context(), w)
w.WriteHeader(http.StatusInternalServerError)
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
}

81
handlers/register.go Normal file
View File

@@ -0,0 +1,81 @@
package handlers
import (
"database/sql"
"net/http"
"projectreshoot/config"
"projectreshoot/cookies"
"projectreshoot/db"
"projectreshoot/view/component/form"
"projectreshoot/view/page"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
formUsername := r.FormValue("username")
formPassword := r.FormValue("password")
formConfirmPassword := r.FormValue("confirm-password")
unique, err := db.CheckUsernameUnique(conn, formUsername)
if err != nil {
return nil, errors.Wrap(err, "db.CheckUsernameUnique")
}
if !unique {
return nil, errors.New("Username is taken")
}
if formPassword != formConfirmPassword {
return nil, errors.New("Passwords do not match")
}
user, err := db.CreateNewUser(conn, formUsername, formPassword)
if err != nil {
return nil, errors.Wrap(err, "db.CreateNewUser")
}
return user, nil
}
func HandleRegisterRequest(
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
user, err := validateRegistration(conn, r)
if err != nil {
if err.Error() != "Username is taken" &&
err.Error() != "Passwords do not match" {
logger.Warn().Caller().Err(err).Msg("Registration request failed")
w.WriteHeader(http.StatusInternalServerError)
} else {
form.RegisterForm(err.Error()).Render(r.Context(), w)
}
return
}
rememberMe := checkRememberMe(r)
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
}
pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom)
},
)
}
// Handles a request to view the login page. Will attempt to set "pagefrom"
// cookie so a successful login can redirect the user to the page they came
func HandleRegisterPage(trustedHost string) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
cookies.SetPageFrom(w, r, trustedHost)
page.Register().Render(r.Context(), w)
},
)
}

View File

@@ -43,14 +43,14 @@ func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) {
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return &user, nil
return user, nil
}
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) {
user, err := db.GetUserFromID(conn, r.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return &user, nil
return user, nil
}
func (a AccessToken) GetJTI() uuid.UUID {

View File

@@ -50,13 +50,13 @@ func TestAuthenticationMiddleware(t *testing.T) {
require.NoError(t, err)
// Good tokens
atStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false)
atStr, _, err := jwt.GenerateAccessToken(cfg, user, false, false)
require.NoError(t, err)
rtStr, _, err := jwt.GenerateRefreshToken(cfg, &user, false)
rtStr, _, err := jwt.GenerateRefreshToken(cfg, user, false)
require.NoError(t, err)
// Create a token and revoke it for testing
expStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false)
expStr, _, err := jwt.GenerateAccessToken(cfg, user, false, false)
require.NoError(t, err)
expT, err := jwt.ParseAccessToken(cfg, conn, expStr)
require.NoError(t, err)

View File

@@ -7,7 +7,7 @@ exp INTEGER NOT NULL
CREATE TABLE IF NOT EXISTS "users" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT,
password_hash TEXT DEFAULT "",
created_at INTEGER DEFAULT (unixepoch())
) STRICT;
CREATE TRIGGER cleanup_expired_tokens

View File

@@ -38,7 +38,16 @@ func addRoutes(
conn,
))
// Register page and handlers
mux.Handle("GET /register", handlers.HandleRegisterPage(config.TrustedHost))
mux.Handle("POST /register", handlers.HandleRegisterRequest(
config,
logger,
conn,
))
// Logout
mux.Handle("POST /logout", handlers.HandleLogout(config, logger, conn))
// Profile page
}

View File

@@ -8,11 +8,9 @@ import "fmt"
// error icons on the username and password field
templ LoginForm(loginError string) {
{{
var errCreds string
errCreds := "false"
if loginError == "Username or password incorrect" {
errCreds = "true"
} else {
errCreds = "false"
}
xdata := fmt.Sprintf(
"{credentialError: %s, errorMessage: '%s'}",

View File

@@ -0,0 +1,200 @@
package form
import "fmt"
// Login Form. If loginError is not an empty string, it will display the
// contents of loginError to the user.
templ RegisterForm(registerError string) {
{{
errUsername := "false"
errPasswords := "false"
if registerError == "Username is taken" {
errUsername = "true"
} else if registerError == "Passwords do not match" {
errPasswords = "true"
}
xdata := fmt.Sprintf(
"{errUsername: %s, errPasswords: %s, errorMessage: '%s'}",
errUsername,
errPasswords,
registerError,
)
}}
<form
hx-post="/register"
x-data="{ submitted: false, buttontext: 'Login' }"
x-on:htmx:xhr:loadstart="submitted=true;buttontext='Loading...'"
>
<div
class="grid gap-y-4"
x-data={ xdata }
>
<!-- Form Group -->
<div>
<label
for="email"
class="block text-sm mb-2"
>Username</label>
<div class="relative">
<input
type="text"
idnutanix="username"
name="username"
class="py-3 px-4 block w-full rounded-lg text-sm
focus:border-blue focus:ring-blue bg-base
disabled:opacity-50
disabled:pointer-events-none"
required
aria-describedby="username-error"
/>
<div
class="absolute inset-y-0 end-0
pointer-events-none pe-3 pt-3"
x-show="errUsername"
x-cloak
>
<svg
class="size-5 text-red"
width="16"
height="16"
fill="currentColor"
viewBox="0 0 16 16"
aria-hidden="true"
>
<path
d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0zM8
4a.905.905 0 0 0-.9.995l.35 3.507a.552.552 0 0
0 1.1 0l.35-3.507A.905.905 0 0 0 8 4zm.002 6a1
1 0 1 0 0 2 1 1 0 0 0 0-2z"
></path>
</svg>
</div>
<p
class="text-center text-xs text-red mt-2"
id="username-error"
x-show="errUsername"
x-cloak
x-text="if (errUsername) return errorMessage;"
></p>
</div>
</div>
<div>
<div class="flex justify-between items-center">
<label
for="password"
class="block text-sm mb-2"
>Password</label>
</div>
<div class="relative">
<input
type="password"
id="password"
name="password"
class="py-3 px-4 block w-full rounded-lg text-sm
focus:border-blue focus:ring-blue bg-base
disabled:opacity-50 disabled:pointer-events-none"
required
aria-describedby="password-error"
/>
<div
class="absolute inset-y-0 end-0
pointer-events-none pe-3 pt-3"
x-show="errPasswords"
x-cloak
>
<svg
class="size-5 text-red"
width="16"
height="16"
fill="currentColor"
viewBox="0 0 16 16"
aria-hidden="true"
>
<path
d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0zM8
4a.905.905 0 0 0-.9.995l.35 3.507a.552.552 0 0
0 1.1 0l.35-3.507A.905.905 0 0 0 8 4zm.002 6a1
1 0 1 0 0 2 1 1 0 0 0 0-2z"
></path>
</svg>
</div>
</div>
</div>
<div>
<div class="flex justify-between items-center">
<label
for="confirm-password"
class="block text-sm mb-2"
>Confirm Password</label>
</div>
<div class="relative">
<input
type="password"
id="confirm-password"
name="confirm-password"
class="py-3 px-4 block w-full rounded-lg text-sm
focus:border-blue focus:ring-blue bg-base
disabled:opacity-50 disabled:pointer-events-none"
required
aria-describedby="confirm-password-error"
/>
<div
class="absolute inset-y-0 end-0
pointer-events-none pe-3 pt-3"
x-show="errPasswords"
x-cloak
>
<svg
class="size-5 text-red"
width="16"
height="16"
fill="currentColor"
viewBox="0 0 16 16"
aria-hidden="true"
>
<path
d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0zM8
4a.905.905 0 0 0-.9.995l.35 3.507a.552.552 0 0
0 1.1 0l.35-3.507A.905.905 0 0 0 8 4zm.002 6a1
1 0 1 0 0 2 1 1 0 0 0 0-2z"
></path>
</svg>
</div>
</div>
<p
class="text-center text-xs text-red mt-2"
id="password-error"
x-show="errPasswords"
x-cloak
x-text="if (errPasswords) return errorMessage;"
></p>
</div>
<div class="flex items-center">
<div class="flex">
<input
id="remember-me"
name="remember-me"
type="checkbox"
class="shrink-0 mt-0.5 border-gray-200 rounded
text-blue focus:ring-blue-500"
/>
</div>
<div class="ms-3">
<label
for="remember-me"
class="text-sm"
>Remember me</label>
</div>
</div>
<button
x-bind:disabled="submitted"
x-text="buttontext"
type="submit"
class="w-full py-3 px-4 inline-flex justify-center items-center
gap-x-2 rounded-lg border border-transparent transition
bg-green hover:bg-green/75 text-mantle hover:cursor-pointer
disabled:bg-green/60 disabled:cursor-default"
></button>
</div>
</form>
}

42
view/page/register.templ Normal file
View File

@@ -0,0 +1,42 @@
package page
import "projectreshoot/view/layout"
import "projectreshoot/view/component/form"
// Returns the login page
templ Register() {
@layout.Global() {
<div class="max-w-100 mx-auto px-2">
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
<div class="p-4 sm:p-7">
<div class="text-center">
<h1
class="block text-2xl font-bold"
>Register</h1>
<p
class="mt-2 text-sm text-subtext0"
>
Already have an account?
<a
class="text-blue decoration-2 hover:underline
focus:outline-none focus:underline"
href="/login"
>
Login here
</a>
</p>
</div>
<div class="mt-5">
<div
class="py-3 flex items-center text-xs text-subtext0
uppercase before:flex-1 before:border-t
before:border-overlay1 before:me-6 after:flex-1
after:border-t after:border-overlay1 after:ms-6"
>Or</div>
@form.RegisterForm("")
</div>
</div>
</div>
</div>
}
}