diff --git a/.github/workflows/deploy_production.yaml b/.github/workflows/deploy_production.yaml new file mode 100644 index 0000000..576ee4d --- /dev/null +++ b/.github/workflows/deploy_production.yaml @@ -0,0 +1,55 @@ +name: Deploy Staging to Server + +on: + push: + branches: + - master + +jobs: + build-and-deploy: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.24.x' + + - name: Install Templ + run: go install github.com/a-h/templ/cmd/templ@latest + + - name: Install tailwindcsscli + run: | + curl -fsSL -o tailwindcss https://github.com/tailwindlabs/tailwindcss/releases/latest/download/tailwindcss-linux-x64 + chmod +x tailwindcss + sudo mv tailwindcss /usr/local/bin/ + + - name: Run tests + run: make test + + - name: Build the binary + run: make build SUFFIX=-production-$GITHUB_SHA + + - name: Deploy to Server + env: + USER: deploy + HOST: projectreshoot.com + DIR: /home/deploy/releases/production + DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }} + run: | + mkdir -p ~/.ssh + echo "$DEPLOY_SSH_PRIVATE_KEY" > ~/.ssh/id_ed25519 + chmod 600 ~/.ssh/id_ed25519 + + echo "Host *" > ~/.ssh/config + echo " StrictHostKeyChecking no" >> ~/.ssh/config + echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config + + ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR + + scp -i ~/.ssh/id_ed25519 projectreshoot-production-${GITHUB_SHA} $USER@$HOST:$DIR + + ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_production.sh $GITHUB_SHA diff --git a/.github/workflows/deploy_staging.yaml b/.github/workflows/deploy_staging.yaml new file mode 100644 index 0000000..aab9831 --- /dev/null +++ b/.github/workflows/deploy_staging.yaml @@ -0,0 +1,55 @@ +name: Deploy Staging to Server + +on: + push: + branches: + - staging + +jobs: + build-and-deploy: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.24.x' + + - name: Install Templ + run: go install github.com/a-h/templ/cmd/templ@latest + + - name: Install tailwindcsscli + run: | + curl -fsSL -o tailwindcss https://github.com/tailwindlabs/tailwindcss/releases/latest/download/tailwindcss-linux-x64 + chmod +x tailwindcss + sudo mv tailwindcss /usr/local/bin/ + + - name: Run tests + run: make test + + - name: Build the binary + run: make build SUFFIX=-staging-$GITHUB_SHA + + - name: Deploy to Server + env: + USER: deploy + HOST: projectreshoot.com + DIR: /home/deploy/releases/staging + DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }} + run: | + mkdir -p ~/.ssh + echo "$DEPLOY_SSH_PRIVATE_KEY" > ~/.ssh/id_ed25519 + chmod 600 ~/.ssh/id_ed25519 + + echo "Host *" > ~/.ssh/config + echo " StrictHostKeyChecking no" >> ~/.ssh/config + echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config + + ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR + + scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR + + ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_staging.sh $GITHUB_SHA diff --git a/Makefile b/Makefile index c056648..72a5804 100644 --- a/Makefile +++ b/Makefile @@ -4,10 +4,11 @@ BINARY_NAME=projectreshoot build: + tailwindcss -i ./static/css/input.css -o ./static/css/output.css && \ go mod tidy && \ templ generate && \ go generate && \ - go build -ldflags="-w -s" -o ${BINARY_NAME} + go build -ldflags="-w -s" -o ${BINARY_NAME}${SUFFIX} dev: templ generate --watch &\ @@ -19,9 +20,11 @@ tester: go run . --port 3232 --test --loglevel trace test: + rm -f **/.projectreshoot-test-database.db && \ go mod tidy && \ - go test . -v - go test ./middleware -v + templ generate && \ + go generate && \ + go test ./middleware clean: go clean diff --git a/config/config.go b/config/config.go index da62d0d..25654d6 100644 --- a/config/config.go +++ b/config/config.go @@ -53,7 +53,7 @@ func GetConfig(args map[string]string) (*Config, error) { if args["port"] != "" { port = args["port"] } else { - port = GetEnvDefault("PORT", "3333") + port = GetEnvDefault("PORT", "3010") } if args["loglevel"] != "" { logLevel = logging.GetLogLevel(args["loglevel"]) diff --git a/contexts/user.go b/contexts/user.go index 3a752a5..0ef041f 100644 --- a/contexts/user.go +++ b/contexts/user.go @@ -5,14 +5,19 @@ import ( "projectreshoot/db" ) +type AuthenticatedUser struct { + *db.User + Fresh int64 +} + // Return a new context with the user added in -func SetUser(ctx context.Context, u *db.User) context.Context { +func SetUser(ctx context.Context, u *AuthenticatedUser) context.Context { return context.WithValue(ctx, contextKeyAuthorizedUser, u) } // Retrieve a user from the given context. Returns nil if not set -func GetUser(ctx context.Context) *db.User { - user, ok := ctx.Value(contextKeyAuthorizedUser).(*db.User) +func GetUser(ctx context.Context) *AuthenticatedUser { + user, ok := ctx.Value(contextKeyAuthorizedUser).(*AuthenticatedUser) if !ok { return nil } diff --git a/cookies/delete.go b/cookies/delete.go deleted file mode 100644 index d847b65..0000000 --- a/cookies/delete.go +++ /dev/null @@ -1,18 +0,0 @@ -package cookies - -import ( - "net/http" - "time" -) - -// Tell the browser to delete the cookie matching the name provided -// Path must match the original set cookie for it to delete -func DeleteCookie(w http.ResponseWriter, name string, path string) { - http.SetCookie(w, &http.Cookie{ - Name: name, - Value: "", - Path: path, - Expires: time.Unix(0, 0), // Expire in the past - MaxAge: -1, // Immediately expire - }) -} diff --git a/cookies/functions.go b/cookies/functions.go new file mode 100644 index 0000000..8e33212 --- /dev/null +++ b/cookies/functions.go @@ -0,0 +1,37 @@ +package cookies + +import ( + "net/http" + "time" +) + +// Tell the browser to delete the cookie matching the name provided +// Path must match the original set cookie for it to delete +func DeleteCookie(w http.ResponseWriter, name string, path string) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: "", + Path: path, + Expires: time.Unix(0, 0), // Expire in the past + MaxAge: -1, // Immediately expire + HttpOnly: true, + }) +} + +// Set a cookie with the given name, path and value. maxAge directly relates +// to cookie MaxAge (0 for no max age, >0 for TTL in seconds) +func SetCookie( + w http.ResponseWriter, + name string, + path string, + value string, + maxAge int, +) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: value, + Path: path, + HttpOnly: true, + MaxAge: maxAge, + }) +} diff --git a/cookies/pagefrom.go b/cookies/pagefrom.go index 906db61..fa4cb0b 100644 --- a/cookies/pagefrom.go +++ b/cookies/pagefrom.go @@ -32,6 +32,5 @@ func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) { } else { pageFrom = parsedURL.Path } - pageFromCookie := &http.Cookie{Name: "pagefrom", Value: pageFrom, Path: "/"} - http.SetCookie(w, pageFromCookie) + SetCookie(w, "pagefrom", "/", pageFrom, 0) } diff --git a/db/user.go b/db/user.go new file mode 100644 index 0000000..fe62349 --- /dev/null +++ b/db/user.go @@ -0,0 +1,60 @@ +package db + +import ( + "database/sql" + + "github.com/pkg/errors" + "golang.org/x/crypto/bcrypt" +) + +type User struct { + ID int // Integer ID (index primary key) + Username string // Username (unique) + Password_hash string // Bcrypt password hash + Created_at int64 // Epoch timestamp when the user was added to the database + Bio string // Short byline set by the user +} + +// Uses bcrypt to set the users Password_hash from the given password +func (user *User) SetPassword(conn *sql.DB, password string) error { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return errors.Wrap(err, "bcrypt.GenerateFromPassword") + } + user.Password_hash = string(hashedPassword) + query := `UPDATE users SET password_hash = ? WHERE id = ?` + _, err = conn.Exec(query, user.Password_hash, user.ID) + if err != nil { + return errors.Wrap(err, "conn.Exec") + } + return nil +} + +// Uses bcrypt to check if the given password matches the users Password_hash +func (user *User) CheckPassword(password string) error { + err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password)) + if err != nil { + return errors.Wrap(err, "bcrypt.CompareHashAndPassword") + } + return nil +} + +// Change the user's username +func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error { + query := `UPDATE users SET username = ? WHERE id = ?` + _, err := conn.Exec(query, newUsername, user.ID) + if err != nil { + return errors.Wrap(err, "conn.Exec") + } + return nil +} + +// Change the user's bio +func (user *User) ChangeBio(conn *sql.DB, newBio string) error { + query := `UPDATE users SET bio = ? WHERE id = ?` + _, err := conn.Exec(query, newBio, user.ID) + if err != nil { + return errors.Wrap(err, "conn.Exec") + } + return nil +} diff --git a/db/user_functions.go b/db/user_functions.go new file mode 100644 index 0000000..3c3623e --- /dev/null +++ b/db/user_functions.go @@ -0,0 +1,108 @@ +package db + +import ( + "database/sql" + "fmt" + + "github.com/pkg/errors" +) + +// Creates a new user in the database and returns a pointer +func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { + query := `INSERT INTO users (username) VALUES (?)` + _, err := conn.Exec(query, username) + if err != nil { + return nil, errors.Wrap(err, "conn.Exec") + } + user, err := GetUserFromUsername(conn, username) + if err != nil { + return nil, errors.Wrap(err, "GetUserFromUsername") + } + err = user.SetPassword(conn, password) + if err != nil { + return nil, errors.Wrap(err, "user.SetPassword") + } + return user, nil +} + +// Fetches data from the users table using "WHERE column = 'value'" +func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { + query := fmt.Sprintf( + `SELECT + id, + username, + password_hash, + created_at, + bio + FROM users + WHERE %s = ? COLLATE NOCASE LIMIT 1`, + column, + ) + rows, err := conn.Query(query, value) + if err != nil { + return nil, errors.Wrap(err, "conn.Query") + } + return rows, nil +} + +// Scan the next row into the provided user pointer. Calls rows.Next() and +// assumes only row in the result. Providing a rows object with more than 1 +// row may result in undefined behaviour. +func scanUserRow(user *User, rows *sql.Rows) error { + for rows.Next() { + err := rows.Scan( + &user.ID, + &user.Username, + &user.Password_hash, + &user.Created_at, + &user.Bio, + ) + if err != nil { + return errors.Wrap(err, "rows.Scan") + } + } + return nil +} + +// Queries the database for a user matching the given username. +// Query is case insensitive +func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { + rows, err := fetchUserData(conn, "username", username) + if err != nil { + return nil, errors.Wrap(err, "fetchUserData") + } + defer rows.Close() + var user User + err = scanUserRow(&user, rows) + if err != nil { + return nil, errors.Wrap(err, "scanUserRow") + } + return &user, nil +} + +// Queries the database for a user matching the given ID. +func GetUserFromID(conn *sql.DB, id int) (*User, error) { + rows, err := fetchUserData(conn, "id", id) + if err != nil { + return nil, errors.Wrap(err, "fetchUserData") + } + defer rows.Close() + var user User + err = scanUserRow(&user, rows) + if err != nil { + return nil, errors.Wrap(err, "scanUserRow") + } + return &user, nil +} + +// Checks if the given username is unique. Returns true if not taken +func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { + query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` + rows, err := conn.Query(query, username) + if err != nil { + return false, errors.Wrap(err, "conn.Query") + } + defer rows.Close() + taken := rows.Next() + return !taken, nil +} diff --git a/db/users.go b/db/users.go deleted file mode 100644 index 23a207e..0000000 --- a/db/users.go +++ /dev/null @@ -1,125 +0,0 @@ -package db - -import ( - "database/sql" - - "github.com/pkg/errors" - "golang.org/x/crypto/bcrypt" -) - -type User struct { - ID int // Integer ID (index primary key) - Username string // Username (unique) - Password_hash string // Bcrypt password hash - Created_at int64 // Epoch timestamp when the user was added to the database -} - -// Uses bcrypt to set the users Password_hash from the given password -func (user *User) SetPassword(conn *sql.DB, password string) error { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "bcrypt.GenerateFromPassword") - } - user.Password_hash = string(hashedPassword) - query := `UPDATE users SET password_hash = ? WHERE id = ?` - result, err := conn.Exec(query, user.Password_hash, user.ID) - if err != nil { - return errors.Wrap(err, "conn.Exec") - } - ra, err := result.RowsAffected() - if err != nil { - return errors.Wrap(err, "result.RowsAffected") - } - if ra != 1 { - return errors.New("Password was not updated") - } - return nil -} - -// Uses bcrypt to check if the given password matches the users Password_hash -func (user *User) CheckPassword(password string) error { - err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password)) - if err != nil { - return errors.Wrap(err, "bcrypt.CompareHashAndPassword") - } - return nil -} - -// Creates a new user in the database and returns a pointer -func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { - query := `INSERT INTO users (username) VALUES (?)` - _, err := conn.Exec(query, username) - if err != nil { - return nil, errors.Wrap(err, "conn.Exec") - } - user, err := GetUserFromUsername(conn, username) - if err != nil { - return nil, errors.Wrap(err, "GetUserFromUsername") - } - err = user.SetPassword(conn, password) - if err != nil { - return nil, errors.Wrap(err, "user.SetPassword") - } - return user, nil -} - -// Queries the database for a user matching the given username. -// Query is case insensitive -func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { - query := `SELECT id, username, password_hash, created_at FROM users - WHERE username = ? COLLATE NOCASE` - rows, err := conn.Query(query, username) - if err != nil { - return nil, errors.Wrap(err, "conn.Query") - } - defer rows.Close() - var user User - for rows.Next() { - err := rows.Scan( - &user.ID, - &user.Username, - &user.Password_hash, - &user.Created_at, - ) - if err != nil { - return nil, errors.Wrap(err, "rows.Scan") - } - } - return &user, nil -} - -// Queries the database for a user matching the given ID. -func GetUserFromID(conn *sql.DB, id int) (*User, error) { - query := `SELECT id, username, password_hash, created_at FROM users - WHERE id = ?` - rows, err := conn.Query(query, id) - if err != nil { - return nil, errors.Wrap(err, "conn.Query") - } - defer rows.Close() - var user User - for rows.Next() { - err := rows.Scan( - &user.ID, - &user.Username, - &user.Password_hash, - &user.Created_at, - ) - if err != nil { - return nil, errors.Wrap(err, "rows.Scan") - } - } - return &user, nil -} - -// Checks if the given username is unique. Returns true if not taken -func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { - query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` - rows, err := conn.Query(query, username) - if err != nil { - return false, errors.Wrap(err, "conn.Query") - } - defer rows.Close() - taken := rows.Next() - return !taken, nil -} diff --git a/deploy/caddy/Caddyfile b/deploy/caddy/Caddyfile new file mode 100644 index 0000000..7f29089 --- /dev/null +++ b/deploy/caddy/Caddyfile @@ -0,0 +1,12 @@ +projectreshoot.com { + reverse_proxy localhost:3000 localhost:3001 localhost:3002 { + health_uri /healthz + fail_duration 30s + } +} +staging.projectreshoot.com { + reverse_proxy localhost:3005 localhost:3006 localhost:3007 { + health_uri /healthz + fail_duration 30s + } +} diff --git a/deploy/deploy_production.sh b/deploy/deploy_production.sh new file mode 100644 index 0000000..bc47915 --- /dev/null +++ b/deploy/deploy_production.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +# Exit on error +set -e + +# Check if commit hash is passed as an argument +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +COMMIT_HASH=$1 +RELEASES_DIR="/home/deploy/releases/production" +DEPLOY_BIN="/home/deploy/production/projectreshoot" +SERVICE_NAME="projectreshoot" +BINARY_NAME="projectreshoot-production-${COMMIT_HASH}" +declare -a PORTS=("3000" "3001" "3002") + +# Check if the binary exists +if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then + echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}" + exit 1 +fi + +# Keep a reference to the previous binary from the symlink +if [ -L "${DEPLOY_BIN}" ]; then + PREVIOUS=$(readlink -f $DEPLOY_BIN) + echo "Current binary is ${PREVIOUS}, saved for rollback." +else + echo "No symbolic link found, no previous binary to backup." + PREVIOUS="" +fi + +rollback_deployment() { + if [ -n "$PREVIOUS" ]; then + echo "Rolling back to previous binary: ${PREVIOUS}" + ln -sfn "${PREVIOUS}" "${DEPLOY_BIN}" + else + echo "No previous binary to roll back to." + fi + + # wait to restart the services + sleep 10 + + # Restart all services with the previous binary + for port in "${PORTS[@]}"; do + SERVICE="${SERVICE_NAME}@${port}.service" + echo "Restarting $SERVICE..." + sudo systemctl restart $SERVICE + done + + echo "Rollback completed." +} + +# Copy the binary to the deployment directory +echo "Promoting ${BINARY_NAME} to ${DEPLOY_BIN}..." +ln -sf "${RELEASES_DIR}/${BINARY_NAME}" "${DEPLOY_BIN}" + +WAIT_TIME=5 +restart_service() { + local port=$1 + local SERVICE="${SERVICE_NAME}@${port}.service" + echo "Restarting ${SERVICE}..." + + # Restart the service + if ! sudo systemctl restart "$SERVICE"; then + echo "Error: Failed to restart ${SERVICE}. Rolling back deployment." + + # Call the rollback function + rollback_deployment + exit 1 + fi + + # Wait a few seconds to allow the service to fully start + echo "Waiting for ${SERVICE} to fully start..." + sleep $WAIT_TIME + + # Check the status of the service + if ! systemctl is-active --quiet "${SERVICE}"; then + echo "Error: ${SERVICE} failed to start correctly. Rolling back deployment." + + # Call the rollback function + rollback_deployment + exit 1 + fi + + echo "${SERVICE}.service restarted successfully." +} + +for port in "${PORTS[@]}"; do + restart_service $port +done + +echo "Deployment completed successfully." diff --git a/deploy/deploy_staging.sh b/deploy/deploy_staging.sh new file mode 100644 index 0000000..3ada4c5 --- /dev/null +++ b/deploy/deploy_staging.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +# Exit on error +set -e + +# Check if commit hash is passed as an argument +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +COMMIT_HASH=$1 +RELEASES_DIR="/home/deploy/releases/staging" +DEPLOY_BIN="/home/deploy/staging/projectreshoot" +SERVICE_NAME="staging.projectreshoot" +BINARY_NAME="projectreshoot-staging-${COMMIT_HASH}" +declare -a PORTS=("3005" "3006" "3007") + +# Check if the binary exists +if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then + echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}" + exit 1 +fi + +# Keep a reference to the previous binary from the symlink +if [ -L "${DEPLOY_BIN}" ]; then + PREVIOUS=$(readlink -f $DEPLOY_BIN) + echo "Current binary is ${PREVIOUS}, saved for rollback." +else + echo "No symbolic link found, no previous binary to backup." + PREVIOUS="" +fi + +rollback_deployment() { + if [ -n "$PREVIOUS" ]; then + echo "Rolling back to previous binary: ${PREVIOUS}" + ln -sfn "${PREVIOUS}" "${DEPLOY_BIN}" + else + echo "No previous binary to roll back to." + fi + + # wait to restart the services + sleep 10 + + # Restart all services with the previous binary + for port in "${PORTS[@]}"; do + SERVICE="${SERVICE_NAME}@${port}.service" + echo "Restarting $SERVICE..." + sudo systemctl restart $SERVICE + done + + echo "Rollback completed." +} + +# Copy the binary to the deployment directory +echo "Promoting ${BINARY_NAME} to ${DEPLOY_BIN}..." +ln -sf "${RELEASES_DIR}/${BINARY_NAME}" "${DEPLOY_BIN}" + +WAIT_TIME=5 +restart_service() { + local port=$1 + local SERVICE="${SERVICE_NAME}@${port}.service" + echo "Restarting ${SERVICE}..." + + # Restart the service + if ! sudo systemctl restart "$SERVICE"; then + echo "Error: Failed to restart ${SERVICE}. Rolling back deployment." + + # Call the rollback function + rollback_deployment + exit 1 + fi + + # Wait a few seconds to allow the service to fully start + echo "Waiting for ${SERVICE} to fully start..." + sleep $WAIT_TIME + + # Check the status of the service + if ! systemctl is-active --quiet "${SERVICE}"; then + echo "Error: ${SERVICE} failed to start correctly. Rolling back deployment." + + # Call the rollback function + rollback_deployment + exit 1 + fi + + echo "${SERVICE}.service restarted successfully." +} + +for port in "${PORTS[@]}"; do + restart_service $port +done + +echo "Deployment completed successfully." diff --git a/deploy/systemd/production@.service b/deploy/systemd/production@.service new file mode 100644 index 0000000..7a49b82 --- /dev/null +++ b/deploy/systemd/production@.service @@ -0,0 +1,27 @@ +[Unit] +Description=Project Reshoot %i +After=network.target + +[Service] +ExecStart=/home/deploy/production/projectreshoot +WorkingDirectory=/home/deploy/production +User=deploy +Group=deploy +EnvironmentFile=/etc/env/projectreshoot.env +Environment="HOST=127.0.0.1" +Environment="PORT=%i" +Environment="TRUSTED_HOST=projectreshoot.com" +Environment="SSL=true" +Environment="GZIP=true" +Environment="LOG_LEVEL=info" +Environment="LOG_OUTPUT=file" +Environment="LOG_DIR=/home/deploy/production/logs" +LimitNOFILE=65536 +Restart=on-failure +TimeoutSec=30 +PrivateTmp=true +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/deploy/systemd/staging@.service b/deploy/systemd/staging@.service new file mode 100644 index 0000000..8dad36f --- /dev/null +++ b/deploy/systemd/staging@.service @@ -0,0 +1,29 @@ +[Unit] +Description=Project Reshoot Staging %i +After=network.target + +[Service] +ExecStart=/home/deploy/staging/projectreshoot +WorkingDirectory=/home/deploy/staging +User=deploy +Group=deploy +EnvironmentFile=/etc/env/staging.projectreshoot.env +Environment="HOST=127.0.0.1" +Environment="PORT=%i" +Environment="TRUSTED_HOST=staging.projectreshoot.com" +Environment="SSL=true" +Environment="GZIP=true" +Environment="LOG_LEVEL=debug" +Environment="LOG_OUTPUT=both" +Environment="LOG_DIR=/home/deploy/staging/logs" +LimitNOFILE=65536 +Restart=on-failure +TimeoutSec=30 +PrivateTmp=true +ProtectSystem=full +ProtectHome=yes +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/handlers/account.go b/handlers/account.go new file mode 100644 index 0000000..365a283 --- /dev/null +++ b/handlers/account.go @@ -0,0 +1,137 @@ +package handlers + +import ( + "database/sql" + "net/http" + + "projectreshoot/contexts" + "projectreshoot/cookies" + "projectreshoot/db" + "projectreshoot/view/component/account" + "projectreshoot/view/page" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +// Renders the account page on the 'General' subpage +func HandleAccountPage() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("subpage") + subpage := "General" + if err == nil { + subpage = cookie.Value + } + page.Account(subpage).Render(r.Context(), w) + }, + ) +} + +// Handles a request to change the subpage for the Accou/accountnt page +func HandleAccountSubpage() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + subpage := r.FormValue("subpage") + cookies.SetCookie(w, "subpage", "/account", subpage, 300) + account.AccountContainer(subpage).Render(r.Context(), w) + }, + ) +} + +// Handles a request to change the users username +func HandleChangeUsername( + logger *zerolog.Logger, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + newUsername := r.FormValue("username") + + unique, err := db.CheckUsernameUnique(conn, newUsername) + if err != nil { + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + if !unique { + account.ChangeUsername("Username is taken", newUsername). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.ChangeUsername(conn, newUsername) + if err != nil { + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("HX-Refresh", "true") + }, + ) +} + +// Handles a request to change the users bio +func HandleChangeBio( + logger *zerolog.Logger, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + newBio := r.FormValue("bio") + leng := len([]rune(newBio)) + if leng > 128 { + account.ChangeBio("Bio limited to 128 characters", newBio). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err := user.ChangeBio(conn, newBio) + if err != nil { + logger.Error().Err(err).Msg("Error updating bio") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("HX-Refresh", "true") + }, + ) +} +func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) { + r.ParseForm() + formPassword := r.FormValue("password") + formConfirmPassword := r.FormValue("confirm-password") + if formPassword != formConfirmPassword { + return "", errors.New("Passwords do not match") + } + if len(formPassword) > 72 { + return "", errors.New("Password exceeds maximum length of 72 bytes") + } + return formPassword, nil +} + +// Handles a request to change the users password +func HandleChangePassword( + logger *zerolog.Logger, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + newPass, err := validateChangePassword(conn, r) + if err != nil { + account.ChangePassword(err.Error()).Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.SetPassword(conn, newPass) + if err != nil { + logger.Error().Err(err).Msg("Error updating password") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("HX-Refresh", "true") + }, + ) +} diff --git a/handlers/profile.go b/handlers/profile.go index 91d21b2..91f381f 100644 --- a/handlers/profile.go +++ b/handlers/profile.go @@ -5,7 +5,7 @@ import ( "projectreshoot/view/page" ) -func HandleProfile() http.Handler { +func HandleProfilePage() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { page.Profile().Render(r.Context(), w) diff --git a/handlers/reauthenticatate.go b/handlers/reauthenticatate.go new file mode 100644 index 0000000..9e188d8 --- /dev/null +++ b/handlers/reauthenticatate.go @@ -0,0 +1,119 @@ +package handlers + +import ( + "database/sql" + "net/http" + + "projectreshoot/config" + "projectreshoot/contexts" + "projectreshoot/cookies" + "projectreshoot/jwt" + "projectreshoot/view/component/form" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +// Get the tokens from the request +func getTokens( + config *config.Config, + conn *sql.DB, + r *http.Request, +) (*jwt.AccessToken, *jwt.RefreshToken, error) { + // get the existing tokens from the cookies + atStr, rtStr := cookies.GetTokenStrings(r) + aT, err := jwt.ParseAccessToken(config, conn, atStr) + if err != nil { + return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken") + } + rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + if err != nil { + return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken") + } + return aT, rT, nil +} + +// Revoke the given token pair +func revokeTokenPair( + conn *sql.DB, + aT *jwt.AccessToken, + rT *jwt.RefreshToken, +) error { + err := jwt.RevokeToken(conn, aT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + err = jwt.RevokeToken(conn, rT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + return nil +} + +// Issue new tokens for the user, invalidating the old ones +func refreshTokens( + config *config.Config, + conn *sql.DB, + w http.ResponseWriter, + r *http.Request, +) error { + aT, rT, err := getTokens(config, conn, r) + if err != nil { + return errors.Wrap(err, "getTokens") + } + rememberMe := map[string]bool{ + "session": false, + "exp": true, + }[aT.TTL] + // issue new tokens for the user + user := contexts.GetUser(r.Context()) + err = cookies.SetTokenCookies(w, r, config, user.User, true, rememberMe) + if err != nil { + return errors.Wrap(err, "cookies.SetTokenCookies") + } + err = revokeTokenPair(conn, aT, rT) + if err != nil { + return errors.Wrap(err, "revokeTokenPair") + } + + return nil +} + +// Validate the provided password +func validatePassword( + r *http.Request, +) error { + r.ParseForm() + password := r.FormValue("password") + user := contexts.GetUser(r.Context()) + err := user.CheckPassword(password) + if err != nil { + return errors.Wrap(err, "user.CheckPassword") + } + return nil +} + +// Handle request to reauthenticate (i.e. make token fresh again) +func HandleReauthenticate( + logger *zerolog.Logger, + config *config.Config, + conn *sql.DB, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + err := validatePassword(r) + if err != nil { + w.WriteHeader(445) + form.ConfirmPassword("Incorrect password").Render(r.Context(), w) + return + } + err = refreshTokens(config, conn, w, r) + if err != nil { + logger.Error().Err(err).Msg("Failed to refresh user tokens") + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + }, + ) +} diff --git a/handlers/static.go b/handlers/static.go index 768e6e1..bc198dd 100644 --- a/handlers/static.go +++ b/handlers/static.go @@ -42,10 +42,10 @@ func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) { // Handles requests for static files, without allowing access to the // directory viewer and returning 404 if an exact file is not found -func HandleStatic() http.Handler { +func HandleStatic(staticFS *http.FileSystem) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - nfs := justFilesFilesystem{http.Dir("static")} + nfs := justFilesFilesystem{*staticFS} fs := http.FileServer(nfs) fs.ServeHTTP(w, r) }, diff --git a/main.go b/main.go index 1bcf240..84ed2fd 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" "io" + "io/fs" "net" "net/http" "os" @@ -22,6 +23,26 @@ import ( "github.com/pkg/errors" ) +//go:embed static/* +var embeddedStatic embed.FS + +// Gets the static files +func getStaticFiles() (http.FileSystem, error) { + if _, err := os.Stat("static"); err == nil { + // Use actual filesystem in development + fmt.Println("Using filesystem for static files") + return http.Dir("static"), nil + } else { + // Use embedded filesystem in production + fmt.Println("Using embedded static files") + subFS, err := fs.Sub(embeddedStatic, "static") + if err != nil { + return nil, errors.Wrap(err, "fs.Sub") + } + return http.FS(subFS), nil + } +} + // Initializes and runs the server func run(ctx context.Context, w io.Writer, args map[string]string) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) @@ -62,7 +83,12 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } defer conn.Close() - srv := server.NewServer(config, logger, conn) + staticFS, err := getStaticFiles() + if err != nil { + return errors.Wrap(err, "getStaticFiles") + } + + srv := server.NewServer(config, logger, conn, &staticFS) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, @@ -101,9 +127,6 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return nil } -//go:embed static/* -var static embed.FS - // Start of runtime. Parse commandline arguments & flags, Initializes context // and starts the server func main() { diff --git a/middleware/authentication.go b/middleware/authentication.go index 38cabbe..23f40f0 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -3,6 +3,7 @@ package middleware import ( "database/sql" "net/http" + "time" "projectreshoot/config" "projectreshoot/contexts" @@ -52,7 +53,7 @@ func getAuthenticatedUser( conn *sql.DB, w http.ResponseWriter, r *http.Request, -) (*db.User, error) { +) (*contexts.AuthenticatedUser, error) { // Get token strings from cookies atStr, rtStr := cookies.GetTokenStrings(r) // Attempt to parse the access token @@ -69,14 +70,22 @@ func getAuthenticatedUser( return nil, errors.Wrap(err, "refreshAuthTokens") } // New token pair sent, return the authorized user - return user, nil + authUser := contexts.AuthenticatedUser{ + User: user, + Fresh: time.Now().Unix(), + } + return &authUser, nil } // Access token valid user, err := aT.GetUser(conn) if err != nil { return nil, errors.Wrap(err, "rT.GetUser") } - return user, nil + authUser := contexts.AuthenticatedUser{ + User: user, + Fresh: aT.Fresh, + } + return &authUser, nil } // Attempt to authenticate the user and add their account details diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index 8465ca0..a5143c8 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -8,8 +8,6 @@ import ( "testing" "projectreshoot/contexts" - "projectreshoot/db" - "projectreshoot/jwt" "projectreshoot/tests" "github.com/stretchr/testify/assert" @@ -45,27 +43,7 @@ func TestAuthenticationMiddleware(t *testing.T) { server := httptest.NewServer(authHandler) defer server.Close() - // Setup the user and tokens to test with - user, err := db.GetUserFromID(conn, 1) - require.NoError(t, err) - - // Good tokens - atStr, _, err := jwt.GenerateAccessToken(cfg, user, false, false) - require.NoError(t, err) - rtStr, _, err := jwt.GenerateRefreshToken(cfg, user, false) - require.NoError(t, err) - - // Create a token and revoke it for testing - expStr, _, err := jwt.GenerateAccessToken(cfg, user, false, false) - require.NoError(t, err) - expT, err := jwt.ParseAccessToken(cfg, conn, expStr) - require.NoError(t, err) - err = jwt.RevokeToken(conn, expT) - require.NoError(t, err) - - // Make sure it actually got revoked - expT, err = jwt.ParseAccessToken(cfg, conn, expStr) - require.Error(t, err) + tokens := getTokens() tests := []struct { name string @@ -75,29 +53,48 @@ func TestAuthenticationMiddleware(t *testing.T) { expectedCode int }{ { - name: "Valid Access Token", + name: "Valid Access Token (Fresh)", id: 1, - accessToken: atStr, + accessToken: tokens["accessFresh"], refreshToken: "", expectedCode: http.StatusOK, }, + { + name: "Valid Access Token (Unfresh)", + id: 1, + accessToken: tokens["accessUnfresh"], + refreshToken: tokens["refreshExpired"], + expectedCode: http.StatusOK, + }, { name: "Valid Refresh Token (Triggers Refresh)", id: 1, - accessToken: expStr, - refreshToken: rtStr, + accessToken: tokens["accessExpired"], + refreshToken: tokens["refreshValid"], expectedCode: http.StatusOK, }, { - name: "Refresh token revoked (after refresh)", - accessToken: expStr, - refreshToken: rtStr, + name: "Both tokens expired", + accessToken: tokens["accessExpired"], + refreshToken: tokens["refreshExpired"], + expectedCode: http.StatusUnauthorized, + }, + { + name: "Access token revoked", + accessToken: tokens["accessRevoked"], + refreshToken: "", + expectedCode: http.StatusUnauthorized, + }, + { + name: "Refresh token revoked", + accessToken: "", + refreshToken: tokens["refreshRevoked"], expectedCode: http.StatusUnauthorized, }, { name: "Invalid Tokens", - accessToken: expStr, - refreshToken: expStr, + accessToken: tokens["invalid"], + refreshToken: tokens["invalid"], expectedCode: http.StatusUnauthorized, }, { @@ -130,3 +127,18 @@ func TestAuthenticationMiddleware(t *testing.T) { }) } } + +// get the tokens to test with +func getTokens() map[string]string { + tokens := map[string]string{ + "accessFresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzIyMTAsImZyZXNoIjo0ODk1NjcyMjEwLCJpYXQiOjE3Mzk2NzIyMTAsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImE4Njk2YWM4LTg3OWMtNDdkNC1iZWM2LTRlY2Y4MTRiZThiZiIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.6nAquDY0JBLPdaJ9q_sMpKj1ISG4Vt2U05J57aoPue8", + "accessUnfresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJmcmVzaCI6MTczOTY3NTY3MSwiaWF0IjoxNzM5Njc1NjcxLCJpc3MiOiIxMjcuMC4wLjEiLCJqdGkiOiJjOGNhZmFjNy0yODkzLTQzNzMtOTI4ZS03MGUwODJkYmM2MGIiLCJzY29wZSI6ImFjY2VzcyIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.plWQVFwHlhXUYI5utS7ny1JfXjJSFrigkq-PnTHD5VY", + "accessExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImZyZXNoIjoxNzM5NjcyMjQ4LCJpYXQiOjE3Mzk2NzIyNDgsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjgxYzA1YzBjLTJhOGItNGQ2MC04Yzc4LWY2ZTQxODYxZDFmNCIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.iI1f17kKTuFDEMEYltJRIwRYgYQ-_nF9Wsn0KR6x77Q", + "refreshValid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImlhdCI6MTczOTY3MTkyMiwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTUxMTY3ZWEtNDA3OS00ZTczLTkzZDQtNTgwZDMzODRjZDU4Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.tvtqQ8Z4WrYWHHb0MaEPdsU2FT2KLRE1zHOv3ipoFyc", + "refreshExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImlhdCI6MTczOTY3MjI0OCwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTg5YTc5MTYtZGEzYi00YmJhLWI3ZDMtOWI1N2ViNjRhMmU0Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.rH_fytC7Duxo598xacu820pQKF9ELbG8674h_bK_c4I", + "accessRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImZyZXNoIjoxNzM5NjcxOTIyLCJpYXQiOjE3Mzk2NzE5MjIsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjBhNmIzMzhlLTkzMGEtNDNmZS04ZjcwLTFhNmRhZWQyNTZmYSIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.mZLuCp9amcm2_CqYvbHPlk86nfiuy_Or8TlntUCw4Qs", + "refreshRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJpYXQiOjE3Mzk2NzU2NzEsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImI3ZmE1MWRjLTg1MzItNDJlMS04NzU2LTVkMjViZmIyMDAzYSIsInNjb3BlIjoicmVmcmVzaCIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.5Q9yDZN5FubfCWHclUUZEkJPOUHcOEpVpgcUK-ameHo", + "invalid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0ODUxNDA5ODQsImlhdCI6MTQ4NTEzNzM4NCwiaXNzIjoiYWNtZS5jb20iLCJzdWIiOiIyOWFjMGMxOC0wYjRhLTQyY2YtODJmYy0wM2Q1NzAzMThhMWQiLCJhcHBsaWNhdGlvbklkIjoiNzkxMDM3MzQtOTdhYi00ZDFhLWFmMzctZTAwNmQwNWQyOTUyIiwicm9sZXMiOltdfQ.Mp0Pcwsz5VECK11Kf2ZZNF_SMKu5CgBeLN9ZOP04kZo", + } + return tokens +} diff --git a/middleware/excluded.go b/middleware/excluded.go deleted file mode 100644 index cc31749..0000000 --- a/middleware/excluded.go +++ /dev/null @@ -1,25 +0,0 @@ -package middleware - -import ( - "net/http" - "strings" -) - -var excludedFiles = map[string]bool{ - "/static/css/output.css": true, -} - -// Checks is path requested if for an excluded file and returns the file -// instead of passing the request onto the next middleware -func ExcludedFiles(next http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if excludedFiles[r.URL.Path] { - filePath := strings.TrimPrefix(r.URL.Path, "/") - http.ServeFile(w, r, filePath) - } else { - next.ServeHTTP(w, r) - } - }, - ) -} diff --git a/middleware/favicon.go b/middleware/favicon.go deleted file mode 100644 index 41385fa..0000000 --- a/middleware/favicon.go +++ /dev/null @@ -1,17 +0,0 @@ -package middleware - -import ( - "net/http" -) - -func Favicon(next http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/favicon.ico" { - http.ServeFile(w, r, "static/favicon.ico") - } else { - next.ServeHTTP(w, r) - } - }, - ) -} diff --git a/middleware/pageprotection.go b/middleware/pageprotection.go index 64ef4da..f5537b2 100644 --- a/middleware/pageprotection.go +++ b/middleware/pageprotection.go @@ -11,6 +11,7 @@ func RequiresLogin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := contexts.GetUser(r.Context()) if user == nil { + w.WriteHeader(http.StatusUnauthorized) page.Error( "401", "Unauthorized", diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go new file mode 100644 index 0000000..de79975 --- /dev/null +++ b/middleware/pageprotection_test.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "projectreshoot/tests" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPageLoginRequired(t *testing.T) { + // Basic setup + cfg, err := tests.TestConfig() + require.NoError(t, err) + logger := tests.NilLogger() + conn, err := tests.SetupTestDB() + require.NoError(t, err) + require.NotNil(t, conn) + defer tests.DeleteTestDB() + + // Handler to check outcome of Authentication middleware + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Add the middleware and create the server + loginRequiredHandler := RequiresLogin(testHandler) + authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + server := httptest.NewServer(authHandler) + defer server.Close() + + tokens := getTokens() + + tests := []struct { + name string + accessToken string + refreshToken string + expectedCode int + }{ + { + name: "Valid Login", + accessToken: tokens["accessFresh"], + refreshToken: "", + expectedCode: http.StatusOK, + }, + { + name: "Expired login", + accessToken: tokens["accessExpired"], + refreshToken: tokens["refreshExpired"], + expectedCode: http.StatusUnauthorized, + }, + { + name: "No login", + accessToken: "", + refreshToken: "", + expectedCode: http.StatusUnauthorized, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &http.Client{} + + req, _ := http.NewRequest(http.MethodGet, server.URL, nil) + + // Add cookies if provided + if tt.accessToken != "" { + req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken}) + } + if tt.refreshToken != "" { + req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken}) + } + + resp, err := client.Do(req) + assert.NoError(t, err) + assert.Equal(t, tt.expectedCode, resp.StatusCode) + }) + } +} diff --git a/middleware/reauthentication.go b/middleware/reauthentication.go new file mode 100644 index 0000000..41fad65 --- /dev/null +++ b/middleware/reauthentication.go @@ -0,0 +1,21 @@ +package middleware + +import ( + "net/http" + "projectreshoot/contexts" + "time" +) + +func RequiresFresh( + next http.Handler, +) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := contexts.GetUser(r.Context()) + isFresh := time.Now().Before(time.Unix(user.Fresh, 0)) + if !isFresh { + w.WriteHeader(444) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go new file mode 100644 index 0000000..63017cb --- /dev/null +++ b/middleware/reauthentication_test.go @@ -0,0 +1,88 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "projectreshoot/tests" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestActionReauthRequired(t *testing.T) { + // Basic setup + cfg, err := tests.TestConfig() + require.NoError(t, err) + logger := tests.NilLogger() + conn, err := tests.SetupTestDB() + require.NoError(t, err) + require.NotNil(t, conn) + defer tests.DeleteTestDB() + + // Handler to check outcome of Authentication middleware + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Add the middleware and create the server + reauthRequiredHandler := RequiresFresh(testHandler) + loginRequiredHandler := RequiresLogin(reauthRequiredHandler) + authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + server := httptest.NewServer(authHandler) + defer server.Close() + + tokens := getTokens() + + tests := []struct { + name string + accessToken string + refreshToken string + expectedCode int + }{ + { + name: "Fresh Login", + accessToken: tokens["accessFresh"], + refreshToken: "", + expectedCode: http.StatusOK, + }, + { + name: "Unfresh Login", + accessToken: tokens["accessUnfresh"], + refreshToken: "", + expectedCode: 444, + }, + { + name: "Expired login", + accessToken: tokens["accessExpired"], + refreshToken: tokens["refreshExpired"], + expectedCode: http.StatusUnauthorized, + }, + { + name: "No login", + accessToken: "", + refreshToken: "", + expectedCode: http.StatusUnauthorized, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &http.Client{} + + req, _ := http.NewRequest(http.MethodGet, server.URL, nil) + + // Add cookies if provided + if tt.accessToken != "" { + req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken}) + } + if tt.refreshToken != "" { + req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken}) + } + + resp, err := client.Do(req) + assert.NoError(t, err) + assert.Equal(t, tt.expectedCode, resp.StatusCode) + }) + } +} diff --git a/schema.sql b/schema.sql index 80f9970..986d312 100644 --- a/schema.sql +++ b/schema.sql @@ -8,7 +8,8 @@ CREATE TABLE IF NOT EXISTS "users" ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT DEFAULT "", - created_at INTEGER DEFAULT (unixepoch()) + created_at INTEGER DEFAULT (unixepoch()), + bio TEXT DEFAULT "" ) STRICT; CREATE TRIGGER cleanup_expired_tokens AFTER INSERT ON jwtblacklist diff --git a/server/routes.go b/server/routes.go index 19c87f7..a92885f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -18,12 +18,13 @@ func addRoutes( logger *zerolog.Logger, config *config.Config, conn *sql.DB, + staticFS *http.FileSystem, ) { // Health check mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {}) // Static files - mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic())) + mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic(staticFS))) // Index page and unhandled catchall (404) mux.Handle("GET /", handlers.HandleRoot()) @@ -60,9 +61,41 @@ func addRoutes( // Logout mux.Handle("POST /logout", handlers.HandleLogout(config, logger, conn)) + // Reauthentication request + mux.Handle("POST /reauthenticate", + middleware.RequiresLogin( + handlers.HandleReauthenticate(logger, config, conn), + )) + // Profile page mux.Handle("GET /profile", middleware.RequiresLogin( - handlers.HandleProfile(), + handlers.HandleProfilePage(), + )) + + // Account page + mux.Handle("GET /account", + middleware.RequiresLogin( + handlers.HandleAccountPage(), + )) + mux.Handle("POST /account-select-page", + middleware.RequiresLogin( + handlers.HandleAccountSubpage(), + )) + mux.Handle("POST /change-username", + middleware.RequiresLogin( + middleware.RequiresFresh( + handlers.HandleChangeUsername(logger, conn), + ), + )) + mux.Handle("POST /change-bio", + middleware.RequiresLogin( + handlers.HandleChangeBio(logger, conn), + )) + mux.Handle("POST /change-password", + middleware.RequiresLogin( + middleware.RequiresFresh( + handlers.HandleChangePassword(logger, conn), + ), )) } diff --git a/server/server.go b/server/server.go index e460546..648082f 100644 --- a/server/server.go +++ b/server/server.go @@ -15,6 +15,7 @@ func NewServer( config *config.Config, logger *zerolog.Logger, conn *sql.DB, + staticFS *http.FileSystem, ) http.Handler { mux := http.NewServeMux() addRoutes( @@ -22,6 +23,7 @@ func NewServer( logger, config, conn, + staticFS, ) var handler http.Handler = mux // Add middleware here, must be added in reverse order of execution @@ -29,10 +31,6 @@ func NewServer( handler = middleware.Logging(logger, handler) handler = middleware.Authentication(logger, config, conn, handler) - // Serve the favicon and exluded files before any middleware is added - handler = middleware.ExcludedFiles(handler) - handler = middleware.Favicon(handler) - // Gzip handler = middleware.Gzip(handler, config.GZIP) diff --git a/testdata.sql b/testdata.sql index 6b05c97..be1ee04 100644 --- a/testdata.sql +++ b/testdata.sql @@ -1 +1,3 @@ -INSERT INTO users VALUES(1,'testuser','hashedpassword',1738995274); +INSERT INTO users VALUES(1,'testuser','hashedpassword',1738995274, 'bio'); +INSERT INTO jwtblacklist VALUES('0a6b338e-930a-43fe-8f70-1a6daed256fa', 33299675344); +INSERT INTO jwtblacklist VALUES('b7fa51dc-8532-42e1-8756-5d25bfb2003a', 33299675344); diff --git a/view/component/account/changebio.templ b/view/component/account/changebio.templ new file mode 100644 index 0000000..cb85013 --- /dev/null +++ b/view/component/account/changebio.templ @@ -0,0 +1,117 @@ +package account + +import "projectreshoot/contexts" + +templ ChangeBio(err string, bio string) { + {{ + user := contexts.GetUser(ctx) + if bio == "" { + bio = user.Bio + } + }} +
+ +
+
+ +
+ + +
+
+
+ + +
+
+

+
+} diff --git a/view/component/account/changepassword.templ b/view/component/account/changepassword.templ new file mode 100644 index 0000000..6c368d3 --- /dev/null +++ b/view/component/account/changepassword.templ @@ -0,0 +1,141 @@ +package account + +templ ChangePassword(err string) { +
+ +
+
+ + +
+ +
+
+
+ + +
+ +
+
+
+ + +
+
+

+
+} diff --git a/view/component/account/changeusername.templ b/view/component/account/changeusername.templ new file mode 100644 index 0000000..25cc151 --- /dev/null +++ b/view/component/account/changeusername.templ @@ -0,0 +1,108 @@ +package account + +import "projectreshoot/contexts" + +templ ChangeUsername(err string, username string) { + {{ + user := contexts.GetUser(ctx) + if username == "" { + username = user.Username + } + }} +
+ +
+
+ + +
+ +
+
+
+ + +
+
+

+
+} diff --git a/view/component/account/container.templ b/view/component/account/container.templ new file mode 100644 index 0000000..f065949 --- /dev/null +++ b/view/component/account/container.templ @@ -0,0 +1,26 @@ +package account + +templ AccountContainer(subpage string) { +
+ @SelectMenu(subpage) +
+
+ { subpage } +
+ switch subpage { + case "General": + @AccountGeneral() + case "Security": + @AccountSecurity() + } +
+
+} diff --git a/view/component/account/general.templ b/view/component/account/general.templ new file mode 100644 index 0000000..c8946c9 --- /dev/null +++ b/view/component/account/general.templ @@ -0,0 +1,8 @@ +package account + +templ AccountGeneral() { +
+ @ChangeUsername("", "") + @ChangeBio("", "") +
+} diff --git a/view/component/account/security.templ b/view/component/account/security.templ new file mode 100644 index 0000000..e61ead8 --- /dev/null +++ b/view/component/account/security.templ @@ -0,0 +1,7 @@ +package account + +templ AccountSecurity() { +
+ @ChangePassword("") +
+} diff --git a/view/component/account/selectmenu.templ b/view/component/account/selectmenu.templ new file mode 100644 index 0000000..a4221ce --- /dev/null +++ b/view/component/account/selectmenu.templ @@ -0,0 +1,91 @@ +package account + +import "fmt" + +type MenuItem struct { + name string + href string +} + +func getMenuItems() []MenuItem { + return []MenuItem{ + { + name: "General", + href: "general", + }, + { + name: "Security", + href: "security", + }, + { + name: "Preferences", + href: "preferences", + }, + } +} + +templ SelectMenu(activePage string) { + {{ + menuItems := getMenuItems() + page := fmt.Sprintf("{page:'%s'}", activePage) + }} +
+
+
+ +
+
+
    + for _, item := range menuItems { + {{ + activebind := fmt.Sprintf("page === '%s' && 'bg-mantle'", item.name) + }} +
  • + +
  • + } +
+
+
+
+} diff --git a/view/component/form/confirmpass.templ b/view/component/form/confirmpass.templ new file mode 100644 index 0000000..a19e9a6 --- /dev/null +++ b/view/component/form/confirmpass.templ @@ -0,0 +1,90 @@ +package form + +templ ConfirmPassword(err string) { +
+ +
+
+
+ +
+ +
+
+

+
+ + +
+
+} diff --git a/view/component/form/loginform.templ b/view/component/form/loginform.templ index 3380320..d6fa4a8 100644 --- a/view/component/form/loginform.templ +++ b/view/component/form/loginform.templ @@ -1,31 +1,34 @@ package form -import "fmt" - // Login Form. If loginError is not an empty string, it will display the // contents of loginError to the user. // If loginError is "Username or password incorrect" it will also show // error icons on the username and password field templ LoginForm(loginError string) { - {{ - errCreds := "false" - if loginError == "Username or password incorrect" { - errCreds = "true" - } - xdata := fmt.Sprintf( - "{credentialError: %s, errorMessage: '%s'}", - errCreds, - loginError, - ) - }} + {{ credErr := "Username or password incorrect" }}
+
@@ -44,6 +47,7 @@ templ LoginForm(loginError string) { disabled:pointer-events-none" required aria-describedby="username-error" + @input="resetErr()" />
+
-