diff --git a/.githooks/pre-push b/.githooks/pre-push new file mode 100644 index 0000000..316c82a --- /dev/null +++ b/.githooks/pre-push @@ -0,0 +1,11 @@ +#!/bin/sh +protected_branches=("master" "staging") +current_branch=$(git rev-parse --abbrev-ref HEAD) + +for branch in "${protected_branches[@]}"; do + if [ "$current_branch" = "$branch" ]; then + echo "Direct pushes to '$branch' are not allowed. Use a pull request instead." + exit 1 + fi +done +exit 0 diff --git a/.github/workflows/deploy_production.yaml b/.github/workflows/deploy_production.yaml index aafd99f..058d43d 100644 --- a/.github/workflows/deploy_production.yaml +++ b/.github/workflows/deploy_production.yaml @@ -33,11 +33,15 @@ jobs: - name: Build the binary run: make build SUFFIX=-production-$GITHUB_SHA + - name: Build the migration binary + run: make migrate SUFFIX=-production-$GITHUB_SHA + - name: Deploy to Server env: USER: deploy HOST: projectreshoot.com DIR: /home/deploy/releases/production + MIG_DIR: /home/deploy/migration-bin DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }} run: | mkdir -p ~/.ssh @@ -49,7 +53,13 @@ jobs: echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR - scp -i ~/.ssh/id_ed25519 projectreshoot-production-${GITHUB_SHA} $USER@$HOST:$DIR - ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_production.sh $GITHUB_SHA + ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR + scp -i ~/.ssh/id_ed25519 prmigrate-production-${GITHUB_SHA} $USER@$HOST:$MIG_DIR + + scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR + scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR + scp -i ~/.ssh/id_ed25519 ./deploy/db/migrationcleanup.sh $USER@$HOST:$MIG_DIR + + ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy.sh $GITHUB_SHA production diff --git a/.github/workflows/deploy_staging.yaml b/.github/workflows/deploy_staging.yaml index aab9831..539097e 100644 --- a/.github/workflows/deploy_staging.yaml +++ b/.github/workflows/deploy_staging.yaml @@ -33,11 +33,15 @@ jobs: - name: Build the binary run: make build SUFFIX=-staging-$GITHUB_SHA + - name: Build the migration binary + run: make migrate SUFFIX=-staging-$GITHUB_SHA + - name: Deploy to Server env: USER: deploy HOST: projectreshoot.com DIR: /home/deploy/releases/staging + MIG_DIR: /home/deploy/migration-bin DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }} run: | mkdir -p ~/.ssh @@ -49,7 +53,13 @@ jobs: echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR - scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR - ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_staging.sh $GITHUB_SHA + ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR + scp -i ~/.ssh/id_ed25519 prmigrate-staging-${GITHUB_SHA} $USER@$HOST:$MIG_DIR + + scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR + scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR + scp -i ~/.ssh/id_ed25519 ./deploy/db/migrationcleanup.sh $USER@$HOST:$MIG_DIR + + ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy.sh $GITHUB_SHA staging diff --git a/.gitignore b/.gitignore index 37839ec..9e3f92b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,9 @@ query.sql *.db .logs/ +server.log tmp/ +prmigrate projectreshoot static/css/output.css view/**/*_templ.go diff --git a/Makefile b/Makefile index 72a5804..acabebf 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ # Makefile .PHONY: build +.PHONY: migrate BINARY_NAME=projectreshoot @@ -17,14 +18,20 @@ dev: tester: go mod tidy && \ - go run . --port 3232 --test --loglevel trace + go run . --port 3232 --tester --loglevel trace test: - rm -f **/.projectreshoot-test-database.db && \ go mod tidy && \ templ generate && \ go generate && \ + go test . + go test ./db go test ./middleware clean: go clean + +migrate: + go mod tidy && \ + go generate && \ + go build -ldflags="-w -s" -o prmigrate${SUFFIX} ./migrate diff --git a/config/config.go b/config/config.go index 25654d6..0b9bc90 100644 --- a/config/config.go +++ b/config/config.go @@ -21,7 +21,8 @@ type Config struct { ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds WriteTimeout time.Duration // Timeout for writing requests in seconds IdleTimeout time.Duration // Timeout for idle connections in seconds - DBName string // Filename of the db (doesnt include file extension) + DBName string // Filename of the db - hardcoded and doubles as DB version + DBLockTimeout time.Duration // Timeout for acquiring database lock SecretKey string // Secret key for signing tokens AccessTokenExpiry int64 // Access token expiry in minutes RefreshTokenExpiry int64 // Refresh token expiry in minutes @@ -33,10 +34,7 @@ type Config struct { // Load the application configuration and get a pointer to the Config object func GetConfig(args map[string]string) (*Config, error) { - err := godotenv.Load(".env") - if err != nil { - fmt.Println(err) - } + godotenv.Load(".env") var ( host string port string @@ -89,7 +87,8 @@ func GetConfig(args map[string]string) (*Config, error) { ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2), WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10), IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120), - DBName: GetEnvDefault("DB_NAME", "projectreshoot"), + DBName: "00001", + DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60), SecretKey: os.Getenv("SECRET_KEY"), AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5), RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day @@ -99,7 +98,7 @@ func GetConfig(args map[string]string) (*Config, error) { LogDir: GetEnvDefault("LOG_DIR", ""), } - if config.SecretKey == "" { + if config.SecretKey == "" && args["dbver"] != "true" { return nil, errors.New("Envar not set: SECRET_KEY") } diff --git a/db/connection.go b/db/connection.go index d1415c5..a156590 100644 --- a/db/connection.go +++ b/db/connection.go @@ -3,19 +3,56 @@ package db import ( "database/sql" "fmt" + "strconv" "github.com/pkg/errors" + "github.com/rs/zerolog" - _ "github.com/mattn/go-sqlite3" + _ "modernc.org/sqlite" ) -// Returns a database connection handle for the Turso DB -func ConnectToDatabase(dbName string) (*sql.DB, error) { +// Returns a database connection handle for the DB +func ConnectToDatabase( + dbName string, + logger *zerolog.Logger, +) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) - db, err := sql.Open("sqlite3", file) - + db, err := sql.Open("sqlite", file) if err != nil { return nil, errors.Wrap(err, "sql.Open") } - return db, nil + version, err := strconv.Atoi(dbName) + if err != nil { + return nil, errors.Wrap(err, "strconv.Atoi") + } + err = checkDBVersion(db, version) + if err != nil { + return nil, errors.Wrap(err, "checkDBVersion") + } + conn := MakeSafe(db, logger) + return conn, nil +} + +// Check the database version +func checkDBVersion(db *sql.DB, expectVer int) error { + query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1 + ORDER BY version_id DESC LIMIT 1` + rows, err := db.Query(query) + if err != nil { + return errors.Wrap(err, "checkDBVersion") + } + defer rows.Close() + if rows.Next() { + var version int + err = rows.Scan(&version) + if err != nil { + return errors.Wrap(err, "rows.Scan") + } + if version != expectVer { + return errors.New("Version mismatch") + } + } else { + return errors.New("No version found") + } + return nil } diff --git a/db/safeconn.go b/db/safeconn.go new file mode 100644 index 0000000..4d66f70 --- /dev/null +++ b/db/safeconn.go @@ -0,0 +1,129 @@ +package db + +import ( + "context" + "database/sql" + "time" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +type SafeConn struct { + db *sql.DB + readLockCount uint32 + globalLockStatus uint32 + globalLockRequested uint32 + logger *zerolog.Logger +} + +// Make the provided db handle safe and attach a logger to it +func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn { + return &SafeConn{db: db, logger: logger} +} + +// Attempts to acquire a global lock on the database connection +func (conn *SafeConn) acquireGlobalLock() bool { + if conn.readLockCount > 0 || conn.globalLockStatus == 1 { + return false + } + conn.globalLockStatus = 1 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock acquired") + return true +} + +// Releases a global lock on the database connection +func (conn *SafeConn) releaseGlobalLock() { + conn.globalLockStatus = 0 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock released") +} + +// Acquire a read lock on the connection. Multiple read locks can be acquired +// at the same time +func (conn *SafeConn) acquireReadLock() bool { + if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 { + return false + } + conn.readLockCount += 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock acquired") + return true +} + +// Release a read lock. Decrements read lock count by 1 +func (conn *SafeConn) releaseReadLock() { + conn.readLockCount -= 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock released") +} + +// Starts a new transaction based on the current context. Will cancel if +// the context is closed/cancelled/done +func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { + lockAcquired := make(chan struct{}) + lockCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + select { + case <-lockCtx.Done(): + return + default: + if conn.acquireReadLock() { + close(lockAcquired) + } + } + }() + + select { + case <-lockAcquired: + tx, err := conn.db.BeginTx(ctx, nil) + if err != nil { + conn.releaseReadLock() + return nil, err + } + return &SafeTX{tx: tx, sc: conn}, nil + case <-ctx.Done(): + cancel() + return nil, errors.New("Transaction time out due to database lock") + } +} + +// Acquire a global lock, preventing all transactions +func (conn *SafeConn) Pause(timeoutAfter time.Duration) { + conn.logger.Info().Msg("Attempting to acquire global database lock") + conn.globalLockRequested = 1 + defer func() { conn.globalLockRequested = 0 }() + timeout := time.After(timeoutAfter) + attempt := 0 + for { + if conn.acquireGlobalLock() { + conn.logger.Info().Msg("Global database lock acquired") + return + } + select { + case <-timeout: + conn.logger.Info().Msg("Timeout: Global database lock abandoned") + return + case <-time.After(100 * time.Millisecond): + attempt++ + } + } +} + +// Release the global lock +func (conn *SafeConn) Resume() { + conn.releaseGlobalLock() + conn.logger.Info().Msg("Global database lock released") +} + +// Close the database connection +func (conn *SafeConn) Close() error { + conn.logger.Debug().Msg("Acquiring global lock for connection close") + conn.acquireGlobalLock() + defer conn.releaseGlobalLock() + conn.logger.Debug().Msg("Closing database connection") + return conn.db.Close() +} diff --git a/db/safeconntx_test.go b/db/safeconntx_test.go new file mode 100644 index 0000000..b1816d7 --- /dev/null +++ b/db/safeconntx_test.go @@ -0,0 +1,143 @@ +package db + +import ( + "context" + "projectreshoot/tests" + "strconv" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSafeConn(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) + logger := tests.NilLogger() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) + require.NoError(t, err) + sconn := MakeSafe(conn, logger) + defer sconn.Close() + + t.Run("Global lock waits for read locks to finish", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + var requested sync.WaitGroup + var engaged sync.WaitGroup + requested.Add(1) + engaged.Add(1) + go func() { + requested.Done() + sconn.Pause(5 * time.Second) + engaged.Done() + }() + requested.Wait() + assert.Equal(t, uint32(0), sconn.globalLockStatus) + assert.Equal(t, uint32(1), sconn.globalLockRequested) + tx.Commit() + engaged.Wait() + assert.Equal(t, uint32(1), sconn.globalLockStatus) + assert.Equal(t, uint32(0), sconn.globalLockRequested) + sconn.Resume() + }) + t.Run("Lock abandons after timeout", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + sconn.Pause(250 * time.Millisecond) + assert.Equal(t, uint32(0), sconn.globalLockStatus) + assert.Equal(t, uint32(0), sconn.globalLockRequested) + tx.Commit() + }) + t.Run("Pause blocks transactions and resume allows", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + var requested sync.WaitGroup + var engaged sync.WaitGroup + requested.Add(1) + engaged.Add(1) + go func() { + requested.Done() + sconn.Pause(5 * time.Second) + engaged.Done() + }() + requested.Wait() + assert.Equal(t, uint32(0), sconn.globalLockStatus) + assert.Equal(t, uint32(1), sconn.globalLockRequested) + ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond) + defer cancel() + _, err = sconn.Begin(ctx) + require.Error(t, err) + tx.Commit() + engaged.Wait() + _, err = sconn.Begin(ctx) + require.Error(t, err) + sconn.Resume() + tx, err = sconn.Begin(t.Context()) + require.NoError(t, err) + tx.Commit() + }) +} +func TestSafeTX(t *testing.T) { + cfg, err := tests.TestConfig() + require.NoError(t, err) + logger := tests.NilLogger() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) + require.NoError(t, err) + conn, err := tests.SetupTestDB(ver) + require.NoError(t, err) + sconn := MakeSafe(conn, logger) + defer sconn.Close() + + t.Run("Commit releases lock", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + assert.Equal(t, uint32(1), sconn.readLockCount) + tx.Commit() + assert.Equal(t, uint32(0), sconn.readLockCount) + }) + t.Run("Rollback releases lock", func(t *testing.T) { + tx, err := sconn.Begin(t.Context()) + require.NoError(t, err) + assert.Equal(t, uint32(1), sconn.readLockCount) + tx.Rollback() + assert.Equal(t, uint32(0), sconn.readLockCount) + }) + t.Run("Multiple TX can gain read lock", func(t *testing.T) { + tx1, err := sconn.Begin(t.Context()) + require.NoError(t, err) + tx2, err := sconn.Begin(t.Context()) + require.NoError(t, err) + tx3, err := sconn.Begin(t.Context()) + require.NoError(t, err) + tx1.Commit() + tx2.Commit() + tx3.Commit() + }) + t.Run("Lock acquiring times out after timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond) + defer cancel() + sconn.acquireGlobalLock() + defer sconn.releaseGlobalLock() + _, err := sconn.Begin(ctx) + require.Error(t, err) + }) + t.Run("Lock acquires if lock released", func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond) + defer cancel() + sconn.acquireGlobalLock() + var wg sync.WaitGroup + wg.Add(1) + go func() { + tx, err := sconn.Begin(ctx) + require.NoError(t, err) + tx.Commit() + wg.Done() + }() + sconn.releaseGlobalLock() + wg.Wait() + }) +} diff --git a/db/safetx.go b/db/safetx.go new file mode 100644 index 0000000..5a3f06e --- /dev/null +++ b/db/safetx.go @@ -0,0 +1,61 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/pkg/errors" +) + +// Extends sql.Tx for use with SafeConn +type SafeTX struct { + tx *sql.Tx + sc *SafeConn +} + +// Query the database inside the transaction +func (stx *SafeTX) Query( + ctx context.Context, + query string, + args ...interface{}, +) (*sql.Rows, error) { + if stx.tx == nil { + return nil, errors.New("Cannot query without a transaction") + } + return stx.tx.QueryContext(ctx, query, args...) +} + +// Exec a statement on the database inside the transaction +func (stx *SafeTX) Exec( + ctx context.Context, + query string, + args ...interface{}, +) (sql.Result, error) { + if stx.tx == nil { + return nil, errors.New("Cannot exec without a transaction") + } + return stx.tx.ExecContext(ctx, query, args...) +} + +// Commit the current transaction and release the read lock +func (stx *SafeTX) Commit() error { + if stx.tx == nil { + return errors.New("Cannot commit without a transaction") + } + err := stx.tx.Commit() + stx.tx = nil + + stx.sc.releaseReadLock() + return err +} + +// Abort the current transaction, releasing the read lock +func (stx *SafeTX) Rollback() error { + if stx.tx == nil { + return errors.New("Cannot rollback without a transaction") + } + err := stx.tx.Rollback() + stx.tx = nil + stx.sc.releaseReadLock() + return err +} diff --git a/db/user.go b/db/user.go index fe62349..a2daa26 100644 --- a/db/user.go +++ b/db/user.go @@ -1,7 +1,7 @@ package db import ( - "database/sql" + "context" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" @@ -16,16 +16,16 @@ type User struct { } // Uses bcrypt to set the users Password_hash from the given password -func (user *User) SetPassword(conn *sql.DB, password string) error { +func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return errors.Wrap(err, "bcrypt.GenerateFromPassword") } user.Password_hash = string(hashedPassword) query := `UPDATE users SET password_hash = ? WHERE id = ?` - _, err = conn.Exec(query, user.Password_hash, user.ID) + _, err = tx.Exec(ctx, query, user.Password_hash, user.ID) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } @@ -40,21 +40,21 @@ func (user *User) CheckPassword(password string) error { } // Change the user's username -func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error { +func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error { query := `UPDATE users SET username = ? WHERE id = ?` - _, err := conn.Exec(query, newUsername, user.ID) + _, err := tx.Exec(ctx, query, newUsername, user.ID) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } // Change the user's bio -func (user *User) ChangeBio(conn *sql.DB, newBio string) error { +func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error { query := `UPDATE users SET bio = ? WHERE id = ?` - _, err := conn.Exec(query, newBio, user.ID) + _, err := tx.Exec(ctx, query, newBio, user.ID) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } diff --git a/db/user_functions.go b/db/user_functions.go index 3c3623e..6dd76d4 100644 --- a/db/user_functions.go +++ b/db/user_functions.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "fmt" @@ -8,17 +9,22 @@ import ( ) // Creates a new user in the database and returns a pointer -func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { +func CreateNewUser( + ctx context.Context, + tx *SafeTX, + username string, + password string, +) (*User, error) { query := `INSERT INTO users (username) VALUES (?)` - _, err := conn.Exec(query, username) + _, err := tx.Exec(ctx, query, username) if err != nil { - return nil, errors.Wrap(err, "conn.Exec") + return nil, errors.Wrap(err, "tx.Exec") } - user, err := GetUserFromUsername(conn, username) + user, err := GetUserFromUsername(ctx, tx, username) if err != nil { return nil, errors.Wrap(err, "GetUserFromUsername") } - err = user.SetPassword(conn, password) + err = user.SetPassword(ctx, tx, password) if err != nil { return nil, errors.Wrap(err, "user.SetPassword") } @@ -26,7 +32,12 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error } // Fetches data from the users table using "WHERE column = 'value'" -func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { +func fetchUserData( + ctx context.Context, + tx *SafeTX, + column string, + value interface{}, +) (*sql.Rows, error) { query := fmt.Sprintf( `SELECT id, @@ -38,36 +49,36 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e WHERE %s = ? COLLATE NOCASE LIMIT 1`, column, ) - rows, err := conn.Query(query, value) + rows, err := tx.Query(ctx, query, value) if err != nil { - return nil, errors.Wrap(err, "conn.Query") + return nil, errors.Wrap(err, "tx.Query") } return rows, nil } -// Scan the next row into the provided user pointer. Calls rows.Next() and -// assumes only row in the result. Providing a rows object with more than 1 -// row may result in undefined behaviour. +// Calls rows.Next() and scans the row into the provided user pointer. +// Will error if no row available func scanUserRow(user *User, rows *sql.Rows) error { - for rows.Next() { - err := rows.Scan( - &user.ID, - &user.Username, - &user.Password_hash, - &user.Created_at, - &user.Bio, - ) - if err != nil { - return errors.Wrap(err, "rows.Scan") - } + if !rows.Next() { + return errors.New("User not found") + } + err := rows.Scan( + &user.ID, + &user.Username, + &user.Password_hash, + &user.Created_at, + &user.Bio, + ) + if err != nil { + return errors.Wrap(err, "rows.Scan") } return nil } // Queries the database for a user matching the given username. // Query is case insensitive -func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { - rows, err := fetchUserData(conn, "username", username) +func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) { + rows, err := fetchUserData(ctx, tx, "username", username) if err != nil { return nil, errors.Wrap(err, "fetchUserData") } @@ -81,8 +92,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { } // Queries the database for a user matching the given ID. -func GetUserFromID(conn *sql.DB, id int) (*User, error) { - rows, err := fetchUserData(conn, "id", id) +func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) { + rows, err := fetchUserData(ctx, tx, "id", id) if err != nil { return nil, errors.Wrap(err, "fetchUserData") } @@ -96,11 +107,11 @@ func GetUserFromID(conn *sql.DB, id int) (*User, error) { } // Checks if the given username is unique. Returns true if not taken -func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { +func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) { query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` - rows, err := conn.Query(query, username) + rows, err := tx.Query(ctx, query, username) if err != nil { - return false, errors.Wrap(err, "conn.Query") + return false, errors.Wrap(err, "tx.Query") } defer rows.Close() taken := rows.Next() diff --git a/deploy/caddy/Caddyfile b/deploy/caddy/Caddyfile index 7f29089..b0f4573 100644 --- a/deploy/caddy/Caddyfile +++ b/deploy/caddy/Caddyfile @@ -1,12 +1,58 @@ projectreshoot.com { - reverse_proxy localhost:3000 localhost:3001 localhost:3002 { - health_uri /healthz - fail_duration 30s - } + rate_limit { + zone auth { + match { + method POST + path /login /register + } + key {remote_host} + events 4 + window 1m + } + zone client { + key {remote_host} + events 100 + window 1m + } + } + reverse_proxy localhost:3000 localhost:3001 localhost:3002 { + transport http { + max_conns_per_host 10 + } + health_uri /healthz + fail_duration 30s + } + log { + output file /var/log/caddy/access.log + } } + staging.projectreshoot.com { - reverse_proxy localhost:3005 localhost:3006 localhost:3007 { - health_uri /healthz - fail_duration 30s - } + rate_limit { + zone auth { + match { + method POST + path /login /register + } + key {remote_host} + events 4 + window 1m + } + zone client { + key {remote_host} + events 100 + window 1m + } + } + reverse_proxy localhost:3005 localhost:3006 localhost:3007 { + transport http { + max_conns_per_host 10 + } + health_uri /healthz + fail_duration 30s + } + log { + output file /var/log/caddy/access-staging.log + } } + diff --git a/deploy/db/backup.sh b/deploy/db/backup.sh new file mode 100755 index 0000000..206c96c --- /dev/null +++ b/deploy/db/backup.sh @@ -0,0 +1,112 @@ +#!/bin/bash + +# Exit on error +set -e + +if [[ -z "$1" ]]; then + echo "Usage: $0 " + exit 1 +fi + +ENVR="$1" +if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then + echo "Error: environment must be 'production' or 'staging'." + exit 1 +fi +ACTIVE_DIR="/home/deploy/$ENVR" +DATA_DIR="/home/deploy/data/$ENVR" +BACKUP_DIR="/home/deploy/data/backups/$ENVR" +if [[ "$ENVR" == "production" ]]; then + SERVICE_NAME="projectreshoot" + declare -a PORTS=("3000" "3001" "3002") +else + SERVICE_NAME="$ENVR.projectreshoot" + declare -a PORTS=("3005" "3006" "3007") +fi + +# Send SIGUSR2 to release maintenance mode +release_maintenance() { + echo "Releasing maintenance mode..." + for PORT in "${PORTS[@]}"; do + sudo systemctl kill -s SIGUSR2 "$SERVICE_NAME@$PORT.service" + done +} + +shopt -s nullglob +DB_FILES=("$ACTIVE_DIR"/*.db) +DB_COUNT=${#DB_FILES[@]} + +if [[ $DB_COUNT -gt 1 ]]; then + echo "Error: More than one .db file found in $ACTIVE_DIR. Manual intervention required." + exit 1 +elif [[ $DB_COUNT -eq 0 ]]; then + echo "Error: No .db file found in $ACTIVE_DIR." + exit 1 +fi + +# Extract the filename without extension +DB_FILE="${DB_FILES[0]}" +DB_VER=$(basename "$DB_FILE" .db) + +# Send SIGUSR1 to trigger maintenance mode only for active services +declare -a ACTIVE_PORTS=() +for PORT in "${PORTS[@]}"; do + if systemctl is-active --quiet "$SERVICE_NAME@$PORT.service"; then + sudo systemctl kill -s SIGUSR1 "$SERVICE_NAME@$PORT.service" + ACTIVE_PORTS+=("$PORT") + fi +done +trap release_maintenance EXIT + +# Function to check logs for success or failure +check_logs() { + local port="$1" + local service="$SERVICE_NAME@$port.service" + + echo "Waiting for $service to enter maintenance mode..." + + # Check the last few lines first in case the message already appeared + if sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Global database lock acquired"; then + echo "$service successfully entered maintenance mode." + return 0 + elif sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Timeout: Global database lock abandoned"; then + echo "Error: $service failed to enter maintenance mode." + return 1 + fi + + # If not found, continuously watch logs until we get a success or failure message + sudo journalctl -u "$service" -f --no-pager | while read -r line; do + if echo "$line" | grep -q "Global database lock acquired"; then + echo "$service successfully entered maintenance mode." + pkill -P $$ journalctl # Kill journalctl process once we have success + return 0 + elif echo "$line" | grep -q "Timeout: Global database lock abandoned"; then + echo "Error: $service failed to enter maintenance mode." + pkill -P $$ journalctl # Kill journalctl process on failure + return 1 + fi + done +} + +# Check logs for each service +for PORT in "${ACTIVE_PORTS[@]}"; do + check_logs "$PORT" +done + +# Get current datetime in YYYY-MM-DD-HHMM format +TIMESTAMP=$(date +"%Y-%m-%d-%H%M") + +# Define source and destination paths +SOURCE_DB="$DATA_DIR/$DB_VER.db" +BACKUP_DB="$BACKUP_DIR/${DB_VER}-${TIMESTAMP}.db" + +# Copy the database file +if [[ -f "$SOURCE_DB" ]]; then + cp "$SOURCE_DB" "$BACKUP_DB" + echo "Backup created: $BACKUP_DB" +else + echo "Error: Source database file $SOURCE_DB not found." + exit 1 +fi + + diff --git a/deploy/db/migrate.sh b/deploy/db/migrate.sh new file mode 100755 index 0000000..dd2f03b --- /dev/null +++ b/deploy/db/migrate.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +if [[ -z "$1" ]]; then + echo "Usage: $0 " + exit 1 +fi +ENVR="$1" +if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then + echo "Error: environment must be 'production' or 'staging'." + exit 1 +fi +if [[ -z "$2" ]]; then + echo "Usage: $0 " + exit 1 +fi +TGT_VER="$2" +re='^[0-9]+$' +if ! [[ $TGT_VER =~ $re ]] ; then + echo "Error: version not a number" >&2 + exit 1 +fi +if [ -z "$3" ]; then + echo "Usage: $0 " + exit 1 +fi +COMMIT_HASH=$3 +MIGRATION_BIN="/home/deploy/migration-bin" +BACKUP_OUTPUT=$(/bin/bash ${MIGRATION_BIN}/backup.sh "$ENVR" 2>&1) +echo "$BACKUP_OUTPUT" +if [[ $? -ne 0 ]]; then + exit 1 +fi +BACKUP_FILE=$(echo "$BACKUP_OUTPUT" | grep -oP '(?<=Backup created: ).*') +if [[ -z "$BACKUP_FILE" ]]; then + echo "Error: backup failed" + exit 1 +fi + +FILE_NAME=${BACKUP_FILE##*/} +CUR_VER=${FILE_NAME%%-*} +if [[ $((+$TGT_VER)) == $((+$CUR_VER)) ]]; then + echo "Version same, skipping migration" + exit 0 +fi +if [[ $((+$TGT_VER)) > $((+$CUR_VER)) ]]; then + CMD="up-to" +fi +if [[ $((+$TGT_VER)) < $((+$CUR_VER)) ]]; then + CMD="down-to" +fi +TIMESTAMP=$(date +"%Y-%m-%d-%H%M") + +ACTIVE_DIR="/home/deploy/$ENVR" +DATA_DIR="/home/deploy/data/$ENVR" +BACKUP_DIR="/home/deploy/data/backups/$ENVR" +UPDATED_BACKUP="$BACKUP_DIR/${TGT_VER}-${TIMESTAMP}.db" +UPDATED_COPY="$DATA_DIR/${TGT_VER}.db" +UPDATED_LINK="$ACTIVE_DIR/${TGT_VER}.db" + +cp $BACKUP_FILE $UPDATED_BACKUP +failed_cleanup() { + rm $UPDATED_BACKUP +} +trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT + +echo "Migration in progress from $CUR_VER to $TGT_VER" +${MIGRATION_BIN}/prmigrate-${ENVR}-${COMMIT_HASH} $UPDATED_BACKUP $CMD $TGT_VER +if [ $? -ne 0 ]; then + echo "Migration failed" + exit 1 +fi +echo "Migration completed" + +cp $UPDATED_BACKUP $UPDATED_COPY +ln -s $UPDATED_COPY $UPDATED_LINK +echo "Upgraded database linked and ready for deploy" +exit 0 diff --git a/deploy/db/migrationcleanup.sh b/deploy/db/migrationcleanup.sh new file mode 100755 index 0000000..9de3c5a --- /dev/null +++ b/deploy/db/migrationcleanup.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Exit on error +set -e + +if [[ -z "$1" ]]; then + echo "Usage: $0 " + exit 1 +fi + +ENVR="$1" +if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then + echo "Error: environment must be 'production' or 'staging'." + exit 1 +fi +if [[ -z "$2" ]]; then + echo "Usage: $0 " + exit 1 +fi +TGT_VER="$2" +re='^[0-9]+$' +if ! [[ $TGT_VER =~ $re ]] ; then + echo "Error: version not a number" >&2 + exit 1 +fi +ACTIVE_DIR="/home/deploy/$ENVR" +find "$ACTIVE_DIR" -type l -name "*.db" ! -name "${TGT_VER}.db" -exec rm -v {} + diff --git a/deploy/deploy_production.sh b/deploy/deploy.sh similarity index 72% rename from deploy/deploy_production.sh rename to deploy/deploy.sh index bc47915..dd198e6 100644 --- a/deploy/deploy_production.sh +++ b/deploy/deploy.sh @@ -8,19 +8,37 @@ if [ -z "$1" ]; then echo "Usage: $0 " exit 1 fi - COMMIT_HASH=$1 -RELEASES_DIR="/home/deploy/releases/production" -DEPLOY_BIN="/home/deploy/production/projectreshoot" -SERVICE_NAME="projectreshoot" -BINARY_NAME="projectreshoot-production-${COMMIT_HASH}" +ENVR="$2" +if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then + echo "Error: environment must be 'production' or 'staging'." + exit 1 +fi + +RELEASES_DIR="/home/deploy/releases/$ENVR" +DEPLOY_BIN="/home/deploy/$ENVR/projectreshoot" +MIGRATION_BIN="/home/deploy/migration-bin" +BINARY_NAME="projectreshoot-$ENVR-${COMMIT_HASH}" declare -a PORTS=("3000" "3001" "3002") +if [[ "$ENVR" == "production" ]]; then + SERVICE_NAME="projectreshoot" + declare -a PORTS=("3000" "3001" "3002") +else + SERVICE_NAME="$ENVR.projectreshoot" + declare -a PORTS=("3005" "3006" "3007") +fi # Check if the binary exists if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}" exit 1 fi +DB_VER=$(${RELEASES_DIR}/${BINARY_NAME} --dbver | grep -oP '(?<=Database version: ).*') +${MIGRATION_BIN}/migrate.sh $ENVR $DB_VER $COMMIT_HASH +if [[ $? -ne 0 ]]; then + echo "Migration failed" + exit 1 +fi # Keep a reference to the previous binary from the symlink if [ -L "${DEPLOY_BIN}" ]; then @@ -92,3 +110,4 @@ for port in "${PORTS[@]}"; do done echo "Deployment completed successfully." +${MIGRATION_BIN}/migrationcleanup.sh $ENVR $DB_VER diff --git a/deploy/deploy_staging.sh b/deploy/deploy_staging.sh deleted file mode 100644 index 3ada4c5..0000000 --- a/deploy/deploy_staging.sh +++ /dev/null @@ -1,94 +0,0 @@ -#!/bin/bash - -# Exit on error -set -e - -# Check if commit hash is passed as an argument -if [ -z "$1" ]; then - echo "Usage: $0 " - exit 1 -fi - -COMMIT_HASH=$1 -RELEASES_DIR="/home/deploy/releases/staging" -DEPLOY_BIN="/home/deploy/staging/projectreshoot" -SERVICE_NAME="staging.projectreshoot" -BINARY_NAME="projectreshoot-staging-${COMMIT_HASH}" -declare -a PORTS=("3005" "3006" "3007") - -# Check if the binary exists -if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then - echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}" - exit 1 -fi - -# Keep a reference to the previous binary from the symlink -if [ -L "${DEPLOY_BIN}" ]; then - PREVIOUS=$(readlink -f $DEPLOY_BIN) - echo "Current binary is ${PREVIOUS}, saved for rollback." -else - echo "No symbolic link found, no previous binary to backup." - PREVIOUS="" -fi - -rollback_deployment() { - if [ -n "$PREVIOUS" ]; then - echo "Rolling back to previous binary: ${PREVIOUS}" - ln -sfn "${PREVIOUS}" "${DEPLOY_BIN}" - else - echo "No previous binary to roll back to." - fi - - # wait to restart the services - sleep 10 - - # Restart all services with the previous binary - for port in "${PORTS[@]}"; do - SERVICE="${SERVICE_NAME}@${port}.service" - echo "Restarting $SERVICE..." - sudo systemctl restart $SERVICE - done - - echo "Rollback completed." -} - -# Copy the binary to the deployment directory -echo "Promoting ${BINARY_NAME} to ${DEPLOY_BIN}..." -ln -sf "${RELEASES_DIR}/${BINARY_NAME}" "${DEPLOY_BIN}" - -WAIT_TIME=5 -restart_service() { - local port=$1 - local SERVICE="${SERVICE_NAME}@${port}.service" - echo "Restarting ${SERVICE}..." - - # Restart the service - if ! sudo systemctl restart "$SERVICE"; then - echo "Error: Failed to restart ${SERVICE}. Rolling back deployment." - - # Call the rollback function - rollback_deployment - exit 1 - fi - - # Wait a few seconds to allow the service to fully start - echo "Waiting for ${SERVICE} to fully start..." - sleep $WAIT_TIME - - # Check the status of the service - if ! systemctl is-active --quiet "${SERVICE}"; then - echo "Error: ${SERVICE} failed to start correctly. Rolling back deployment." - - # Call the rollback function - rollback_deployment - exit 1 - fi - - echo "${SERVICE}.service restarted successfully." -} - -for port in "${PORTS[@]}"; do - restart_service $port -done - -echo "Deployment completed successfully." diff --git a/go.mod b/go.mod index 97dfb00..f15482c 100644 --- a/go.mod +++ b/go.mod @@ -1,24 +1,37 @@ module projectreshoot -go 1.23.5 +go 1.24.0 require ( github.com/a-h/templ v0.3.833 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 - github.com/mattn/go-sqlite3 v1.14.24 github.com/pkg/errors v0.9.1 + github.com/pressly/goose/v3 v3.24.1 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.33.0 + modernc.org/sqlite v1.35.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mfridman/interpolate v0.0.2 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/sethvargo/go-retry v0.3.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect + golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.61.13 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.8.2 // indirect ) diff --git a/go.sum b/go.sum index 328d060..b4bb05c 100644 --- a/go.sum +++ b/go.sum @@ -3,15 +3,24 @@ github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YY github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -19,25 +28,68 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= -github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= +github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4LY= +github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= +github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= +golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= +golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= +modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo= +modernc.org/ccgo/v4 v4.23.16/go.mod h1:nNma8goMTY7aQZQNTyN9AIoJfxav4nvTnvKThAeMDdo= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.6.3 h1:aJVhcqAte49LF+mGveZ5KPlsp4tdGdAOT4sipJXADjw= +modernc.org/gc/v2 v2.6.3/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/libc v1.61.13 h1:3LRd6ZO1ezsFiX1y+bHd1ipyEHIJKvuprv0sLTBwLW8= +modernc.org/libc v1.61.13/go.mod h1:8F/uJWL/3nNil0Lgt1Dpz+GgkApWh04N3el3hxJcA6E= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.8.2 h1:cL9L4bcoAObu4NkxOlKWBWtNHIsnnACGF/TbqQ6sbcI= +modernc.org/memory v1.8.2/go.mod h1:ZbjSvMO5NQ1A2i3bWeDiVMxIorXwdClKE/0SZ+BMotU= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.35.0 h1:yQps4fegMnZFdphtzlfQTCNBWtS0CZv48pRpW3RFHRw= +modernc.org/sqlite v1.35.0/go.mod h1:9cr2sicr7jIaWTBKQmAxQLfBv9LL0su4ZTEV+utt3ic= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/handlers/account.go b/handler/account.go similarity index 65% rename from handlers/account.go rename to handler/account.go index 365a283..5236e1e 100644 --- a/handlers/account.go +++ b/handler/account.go @@ -1,8 +1,9 @@ -package handlers +package handler import ( - "database/sql" + "context" "net/http" + "time" "projectreshoot/contexts" "projectreshoot/cookies" @@ -15,7 +16,7 @@ import ( ) // Renders the account page on the 'General' subpage -func HandleAccountPage() http.Handler { +func AccountPage() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("subpage") @@ -29,7 +30,7 @@ func HandleAccountPage() http.Handler { } // Handles a request to change the subpage for the Accou/accountnt page -func HandleAccountSubpage() http.Handler { +func AccountSubpage() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { r.ParseForm() @@ -41,65 +42,95 @@ func HandleAccountSubpage() http.Handler { } // Handles a request to change the users username -func HandleChangeUsername( +func ChangeUsername( logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusServiceUnavailable) + return + } r.ParseForm() newUsername := r.FormValue("username") - - unique, err := db.CheckUsernameUnique(conn, newUsername) + unique, err := db.CheckUsernameUnique(ctx, tx, newUsername) if err != nil { + tx.Rollback() logger.Error().Err(err).Msg("Error updating username") w.WriteHeader(http.StatusInternalServerError) return } if !unique { + tx.Rollback() account.ChangeUsername("Username is taken", newUsername). Render(r.Context(), w) return } user := contexts.GetUser(r.Context()) - err = user.ChangeUsername(conn, newUsername) + err = user.ChangeUsername(ctx, tx, newUsername) if err != nil { + tx.Rollback() logger.Error().Err(err).Msg("Error updating username") w.WriteHeader(http.StatusInternalServerError) return } + tx.Commit() w.Header().Set("HX-Refresh", "true") }, ) } // Handles a request to change the users bio -func HandleChangeBio( +func ChangeBio( logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Error updating bio") + w.WriteHeader(http.StatusServiceUnavailable) + return + } r.ParseForm() newBio := r.FormValue("bio") leng := len([]rune(newBio)) if leng > 128 { + tx.Rollback() account.ChangeBio("Bio limited to 128 characters", newBio). Render(r.Context(), w) return } user := contexts.GetUser(r.Context()) - err := user.ChangeBio(conn, newBio) + err = user.ChangeBio(ctx, tx, newBio) if err != nil { + tx.Rollback() logger.Error().Err(err).Msg("Error updating bio") w.WriteHeader(http.StatusInternalServerError) return } + tx.Commit() w.Header().Set("HX-Refresh", "true") }, ) } -func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) { +func validateChangePassword( + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) (string, error) { r.ParseForm() formPassword := r.FormValue("password") formConfirmPassword := r.FormValue("confirm-password") @@ -113,24 +144,37 @@ func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) { } // Handles a request to change the users password -func HandleChangePassword( +func ChangePassword( logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - newPass, err := validateChangePassword(conn, r) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) if err != nil { + logger.Warn().Err(err).Msg("Error updating password") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + newPass, err := validateChangePassword(ctx, tx, r) + if err != nil { + tx.Rollback() account.ChangePassword(err.Error()).Render(r.Context(), w) return } user := contexts.GetUser(r.Context()) - err = user.SetPassword(conn, newPass) + err = user.SetPassword(ctx, tx, newPass) if err != nil { + tx.Rollback() logger.Error().Err(err).Msg("Error updating password") w.WriteHeader(http.StatusInternalServerError) return } + tx.Commit() w.Header().Set("HX-Refresh", "true") }, ) diff --git a/handler/errorpage.go b/handler/errorpage.go new file mode 100644 index 0000000..19aa760 --- /dev/null +++ b/handler/errorpage.go @@ -0,0 +1,24 @@ +package handler + +import ( + "net/http" + "projectreshoot/view/page" +) + +func ErrorPage( + errorCode int, + w http.ResponseWriter, + r *http.Request, +) { + message := map[int]string{ + 401: "You need to login to view this page.", + 403: "You do not have permission to view this page.", + 404: "The page or resource you have requested does not exist.", + 500: `An error occured on the server. Please try again, and if this + continues to happen contact an administrator.`, + 503: "The server is currently down for maintenance and should be back soon. =)", + } + w.WriteHeader(http.StatusUnauthorized) + page.Error(errorCode, http.StatusText(errorCode), message[errorCode]). + Render(r.Context(), w) +} diff --git a/handlers/index.go b/handler/index.go similarity index 62% rename from handlers/index.go rename to handler/index.go index 974b5be..f7e09bf 100644 --- a/handlers/index.go +++ b/handler/index.go @@ -1,4 +1,4 @@ -package handlers +package handler import ( "net/http" @@ -8,15 +8,11 @@ import ( // Handles responses to the / path. Also serves a 404 Page for paths that // don't have explicit handlers -func HandleRoot() http.Handler { +func Root() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - page.Error( - "404", - "Page not found", - "The page or resource you have requested does not exist", - ).Render(r.Context(), w) + ErrorPage(http.StatusNotFound, w, r) return } page.Index().Render(r.Context(), w) diff --git a/handlers/login.go b/handler/login.go similarity index 76% rename from handlers/login.go rename to handler/login.go index 7af3901..f366447 100644 --- a/handlers/login.go +++ b/handler/login.go @@ -1,8 +1,9 @@ -package handlers +package handler import ( - "database/sql" + "context" "net/http" + "time" "projectreshoot/config" "projectreshoot/cookies" @@ -16,10 +17,14 @@ import ( // Validates the username matches a user in the database and the password // is correct. Returns the corresponding user -func validateLogin(conn *sql.DB, r *http.Request) (*db.User, error) { +func validateLogin( + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) (*db.User, error) { formUsername := r.FormValue("username") formPassword := r.FormValue("password") - user, err := db.GetUserFromUsername(conn, formUsername) + user, err := db.GetUserFromUsername(ctx, tx, formUsername) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromUsername") } @@ -44,16 +49,27 @@ func checkRememberMe(r *http.Request) bool { // Handles an attempted login request. On success will return a HTMX redirect // and on fail will return the login form again, passing the error to the // template for user feedback -func HandleLoginRequest( +func LoginRequest( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - user, err := validateLogin(conn, r) + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) if err != nil { + logger.Warn().Err(err).Msg("Failed to set token cookies") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + r.ParseForm() + user, err := validateLogin(ctx, tx, r) + if err != nil { + tx.Rollback() if err.Error() != "Username or password incorrect" { logger.Warn().Caller().Err(err).Msg("Login request failed") w.WriteHeader(http.StatusInternalServerError) @@ -66,10 +82,13 @@ func HandleLoginRequest( rememberMe := checkRememberMe(r) err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) if err != nil { + tx.Rollback() w.WriteHeader(http.StatusInternalServerError) logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") + return } + tx.Commit() pageFrom := cookies.CheckPageFrom(w, r) w.Header().Set("HX-Redirect", pageFrom) }, @@ -78,7 +97,7 @@ func HandleLoginRequest( // Handles a request to view the login page. Will attempt to set "pagefrom" // cookie so a successful login can redirect the user to the page they came -func HandleLoginPage(trustedHost string) http.Handler { +func LoginPage(trustedHost string) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { cookies.SetPageFrom(w, r, trustedHost) diff --git a/handler/logout.go b/handler/logout.go new file mode 100644 index 0000000..7a14572 --- /dev/null +++ b/handler/logout.go @@ -0,0 +1,113 @@ +package handler + +import ( + "context" + "net/http" + "strings" + "time" + + "projectreshoot/config" + "projectreshoot/cookies" + "projectreshoot/db" + "projectreshoot/jwt" + + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +func revokeAccess( + config *config.Config, + ctx context.Context, + tx *db.SafeTX, + atStr string, +) error { + aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) + if err != nil { + if strings.Contains(err.Error(), "Token is expired") || + strings.Contains(err.Error(), "Token has been revoked") { + return nil // Token is expired, dont need to revoke it + } + return errors.Wrap(err, "jwt.ParseAccessToken") + } + err = jwt.RevokeToken(ctx, tx, aT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + return nil +} + +func revokeRefresh( + config *config.Config, + ctx context.Context, + tx *db.SafeTX, + rtStr string, +) error { + rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) + if err != nil { + if strings.Contains(err.Error(), "Token is expired") || + strings.Contains(err.Error(), "Token has been revoked") { + return nil // Token is expired, dont need to revoke it + } + return errors.Wrap(err, "jwt.ParseRefreshToken") + } + err = jwt.RevokeToken(ctx, tx, rT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + return nil +} + +// Retrieve and revoke the user's tokens +func revokeTokens( + config *config.Config, + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) error { + // get the tokens from the cookies + atStr, rtStr := cookies.GetTokenStrings(r) + // revoke the refresh token first as the access token expires quicker + // only matters if there is an error revoking the tokens + err := revokeRefresh(config, ctx, tx, rtStr) + if err != nil { + return errors.Wrap(err, "revokeRefresh") + } + err = revokeAccess(config, ctx, tx, atStr) + if err != nil { + return errors.Wrap(err, "revokeAccess") + } + return nil +} + +// Handle a logout request +func Logout( + config *config.Config, + logger *zerolog.Logger, + conn *db.SafeConn, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Error occured on user logout") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + err = revokeTokens(config, ctx, tx, r) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error occured on user logout") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + cookies.DeleteCookie(w, "access", "/") + cookies.DeleteCookie(w, "refresh", "/") + w.Header().Set("HX-Redirect", "/login") + }, + ) +} diff --git a/handlers/page.go b/handler/page.go similarity index 94% rename from handlers/page.go rename to handler/page.go index 223ff78..36d38eb 100644 --- a/handlers/page.go +++ b/handler/page.go @@ -1,4 +1,4 @@ -package handlers +package handler import ( "net/http" diff --git a/handlers/profile.go b/handler/profile.go similarity index 75% rename from handlers/profile.go rename to handler/profile.go index 91f381f..51763ea 100644 --- a/handlers/profile.go +++ b/handler/profile.go @@ -1,11 +1,11 @@ -package handlers +package handler import ( "net/http" "projectreshoot/view/page" ) -func HandleProfilePage() http.Handler { +func ProfilePage() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { page.Profile().Render(r.Context(), w) diff --git a/handlers/reauthenticatate.go b/handler/reauthenticatate.go similarity index 71% rename from handlers/reauthenticatate.go rename to handler/reauthenticatate.go index 9e188d8..6bf7317 100644 --- a/handlers/reauthenticatate.go +++ b/handler/reauthenticatate.go @@ -1,12 +1,14 @@ -package handlers +package handler import ( - "database/sql" + "context" "net/http" + "time" "projectreshoot/config" "projectreshoot/contexts" "projectreshoot/cookies" + "projectreshoot/db" "projectreshoot/jwt" "projectreshoot/view/component/form" @@ -17,16 +19,17 @@ import ( // Get the tokens from the request func getTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, r *http.Request, ) (*jwt.AccessToken, *jwt.RefreshToken, error) { // get the existing tokens from the cookies atStr, rtStr := cookies.GetTokenStrings(r) - aT, err := jwt.ParseAccessToken(config, conn, atStr) + aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) if err != nil { return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken") } - rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) if err != nil { return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken") } @@ -35,15 +38,16 @@ func getTokens( // Revoke the given token pair func revokeTokenPair( - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, aT *jwt.AccessToken, rT *jwt.RefreshToken, ) error { - err := jwt.RevokeToken(conn, aT) + err := jwt.RevokeToken(ctx, tx, aT) if err != nil { return errors.Wrap(err, "jwt.RevokeToken") } - err = jwt.RevokeToken(conn, rT) + err = jwt.RevokeToken(ctx, tx, rT) if err != nil { return errors.Wrap(err, "jwt.RevokeToken") } @@ -53,11 +57,12 @@ func revokeTokenPair( // Issue new tokens for the user, invalidating the old ones func refreshTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, w http.ResponseWriter, r *http.Request, ) error { - aT, rT, err := getTokens(config, conn, r) + aT, rT, err := getTokens(config, ctx, tx, r) if err != nil { return errors.Wrap(err, "getTokens") } @@ -71,7 +76,7 @@ func refreshTokens( if err != nil { return errors.Wrap(err, "cookies.SetTokenCookies") } - err = revokeTokenPair(conn, aT, rT) + err = revokeTokenPair(ctx, tx, aT, rT) if err != nil { return errors.Wrap(err, "revokeTokenPair") } @@ -94,25 +99,38 @@ func validatePassword( } // Handle request to reauthenticate (i.e. make token fresh again) -func HandleReauthenticate( +func Reauthenticate( logger *zerolog.Logger, config *config.Config, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - err := validatePassword(r) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) if err != nil { + logger.Warn().Err(err).Msg("Failed to refresh user tokens") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + err = validatePassword(r) + if err != nil { + tx.Rollback() w.WriteHeader(445) form.ConfirmPassword("Incorrect password").Render(r.Context(), w) return } - err = refreshTokens(config, conn, w, r) + err = refreshTokens(config, ctx, tx, w, r) if err != nil { + tx.Rollback() logger.Error().Err(err).Msg("Failed to refresh user tokens") w.WriteHeader(http.StatusInternalServerError) return } + tx.Commit() w.WriteHeader(http.StatusOK) }, ) diff --git a/handlers/register.go b/handler/register.go similarity index 72% rename from handlers/register.go rename to handler/register.go index 895ab67..2599ef5 100644 --- a/handlers/register.go +++ b/handler/register.go @@ -1,8 +1,9 @@ -package handlers +package handler import ( - "database/sql" + "context" "net/http" + "time" "projectreshoot/config" "projectreshoot/cookies" @@ -14,11 +15,15 @@ import ( "github.com/rs/zerolog" ) -func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { +func validateRegistration( + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) (*db.User, error) { formUsername := r.FormValue("username") formPassword := r.FormValue("password") formConfirmPassword := r.FormValue("confirm-password") - unique, err := db.CheckUsernameUnique(conn, formUsername) + unique, err := db.CheckUsernameUnique(ctx, tx, formUsername) if err != nil { return nil, errors.Wrap(err, "db.CheckUsernameUnique") } @@ -31,7 +36,7 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { if len(formPassword) > 72 { return nil, errors.New("Password exceeds maximum length of 72 bytes") } - user, err := db.CreateNewUser(conn, formUsername, formPassword) + user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword) if err != nil { return nil, errors.Wrap(err, "db.CreateNewUser") } @@ -39,16 +44,27 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { return user, nil } -func HandleRegisterRequest( +func RegisterRequest( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - user, err := validateRegistration(conn, r) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) if err != nil { + logger.Warn().Err(err).Msg("Failed to set token cookies") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + r.ParseForm() + user, err := validateRegistration(ctx, tx, r) + if err != nil { + tx.Rollback() if err.Error() != "Username is taken" && err.Error() != "Passwords do not match" && err.Error() != "Password exceeds maximum length of 72 bytes" { @@ -63,10 +79,12 @@ func HandleRegisterRequest( rememberMe := checkRememberMe(r) err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) if err != nil { + tx.Rollback() w.WriteHeader(http.StatusInternalServerError) logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") + return } - + tx.Commit() pageFrom := cookies.CheckPageFrom(w, r) w.Header().Set("HX-Redirect", pageFrom) }, @@ -75,7 +93,7 @@ func HandleRegisterRequest( // Handles a request to view the login page. Will attempt to set "pagefrom" // cookie so a successful login can redirect the user to the page they came -func HandleRegisterPage(trustedHost string) http.Handler { +func RegisterPage(trustedHost string) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { cookies.SetPageFrom(w, r, trustedHost) diff --git a/handlers/static.go b/handler/static.go similarity index 93% rename from handlers/static.go rename to handler/static.go index bc198dd..8b3c542 100644 --- a/handlers/static.go +++ b/handler/static.go @@ -1,4 +1,4 @@ -package handlers +package handler import ( "net/http" @@ -42,7 +42,7 @@ func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) { // Handles requests for static files, without allowing access to the // directory viewer and returning 404 if an exact file is not found -func HandleStatic(staticFS *http.FileSystem) http.Handler { +func StaticFS(staticFS *http.FileSystem) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { nfs := justFilesFilesystem{*staticFS} diff --git a/handler/withtransaction.go b/handler/withtransaction.go new file mode 100644 index 0000000..37b709e --- /dev/null +++ b/handler/withtransaction.go @@ -0,0 +1,37 @@ +package handler + +import ( + "context" + "net/http" + "time" + + "projectreshoot/db" + + "github.com/rs/zerolog" +) + +func removeme( + w http.ResponseWriter, + r *http.Request, + logger *zerolog.Logger, + conn *db.SafeConn, + handler func( + ctx context.Context, + tx *db.SafeTX, + w http.ResponseWriter, + r *http.Request, + ), + onfail func(err error), +) { + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + onfail(err) + return + } + + handler(ctx, tx, w, r) +} diff --git a/handlers/logout.go b/handlers/logout.go deleted file mode 100644 index c8999a2..0000000 --- a/handlers/logout.go +++ /dev/null @@ -1,62 +0,0 @@ -package handlers - -import ( - "database/sql" - "net/http" - "projectreshoot/config" - "projectreshoot/cookies" - "projectreshoot/jwt" - - "github.com/pkg/errors" - "github.com/rs/zerolog" -) - -// Retrieve and revoke the user's tokens -func revokeTokens( - config *config.Config, - conn *sql.DB, - r *http.Request, -) error { - // get the tokens from the cookies - atStr, rtStr := cookies.GetTokenStrings(r) - aT, err := jwt.ParseAccessToken(config, conn, atStr) - if err != nil { - return errors.Wrap(err, "jwt.ParseAccessToken") - } - rT, err := jwt.ParseRefreshToken(config, conn, rtStr) - if err != nil { - return errors.Wrap(err, "jwt.ParseRefreshToken") - } - // revoke the refresh token first as the access token expires quicker - // only matters if there is an error revoking the tokens - err = jwt.RevokeToken(conn, rT) - if err != nil { - return errors.Wrap(err, "jwt.RevokeToken") - } - err = jwt.RevokeToken(conn, aT) - if err != nil { - return errors.Wrap(err, "jwt.RevokeToken") - } - return nil -} - -// Handle a logout request -func HandleLogout( - config *config.Config, - logger *zerolog.Logger, - conn *sql.DB, -) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - err := revokeTokens(config, conn, r) - if err != nil { - logger.Error().Err(err).Msg("Error occured on user logout") - w.WriteHeader(http.StatusInternalServerError) - return - } - cookies.DeleteCookie(w, "access", "/") - cookies.DeleteCookie(w, "refresh", "/") - w.Header().Set("HX-Redirect", "/login") - }, - ) -} diff --git a/jwt/parse.go b/jwt/parse.go index 741cc59..0446e85 100644 --- a/jwt/parse.go +++ b/jwt/parse.go @@ -1,11 +1,12 @@ package jwt import ( - "database/sql" + "context" "fmt" "time" "projectreshoot/config" + "projectreshoot/db" "github.com/golang-jwt/jwt" "github.com/google/uuid" @@ -17,7 +18,8 @@ import ( // has the correct scope. func ParseAccessToken( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, tokenString string, ) (*AccessToken, error) { if tokenString == "" { @@ -74,7 +76,7 @@ func ParseAccessToken( Scope: scope, } - valid, err := CheckTokenNotRevoked(conn, token) + valid, err := CheckTokenNotRevoked(ctx, tx, token) if err != nil { return nil, errors.Wrap(err, "CheckTokenNotRevoked") } @@ -89,7 +91,8 @@ func ParseAccessToken( // has the correct scope. func ParseRefreshToken( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, tokenString string, ) (*RefreshToken, error) { if tokenString == "" { @@ -141,7 +144,7 @@ func ParseRefreshToken( Scope: scope, } - valid, err := CheckTokenNotRevoked(conn, token) + valid, err := CheckTokenNotRevoked(ctx, tx, token) if err != nil { return nil, errors.Wrap(err, "CheckTokenNotRevoked") } diff --git a/jwt/revoke.go b/jwt/revoke.go index ed2ec63..016f33e 100644 --- a/jwt/revoke.go +++ b/jwt/revoke.go @@ -1,32 +1,33 @@ package jwt import ( - "database/sql" + "context" + "projectreshoot/db" "github.com/pkg/errors" ) // Revoke a token by adding it to the database -func RevokeToken(conn *sql.DB, t Token) error { +func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error { jti := t.GetJTI() exp := t.GetEXP() query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)` - _, err := conn.Exec(query, jti, exp) + _, err := tx.Exec(ctx, query, jti, exp) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } // Check if a token has been revoked. Returns true if not revoked. -func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) { +func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) { jti := t.GetJTI() query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` - rows, err := conn.Query(query, jti) - defer rows.Close() + rows, err := tx.Query(ctx, query, jti) if err != nil { - return false, errors.Wrap(err, "conn.Exec") + return false, errors.Wrap(err, "tx.Query") } + defer rows.Close() revoked := rows.Next() return !revoked, nil } diff --git a/jwt/tokens.go b/jwt/tokens.go index d76e952..ae5d97a 100644 --- a/jwt/tokens.go +++ b/jwt/tokens.go @@ -1,7 +1,7 @@ package jwt import ( - "database/sql" + "context" "projectreshoot/db" "github.com/google/uuid" @@ -12,7 +12,7 @@ type Token interface { GetJTI() uuid.UUID GetEXP() int64 GetScope() string - GetUser(conn *sql.DB) (*db.User, error) + GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) } // Access token @@ -38,15 +38,15 @@ type RefreshToken struct { Scope string // Should be "refresh" } -func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) { - user, err := db.GetUserFromID(conn, a.SUB) +func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) { + user, err := db.GetUserFromID(ctx, tx, a.SUB) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromID") } return user, nil } -func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) { - user, err := db.GetUserFromID(conn, r.SUB) +func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) { + user, err := db.GetUserFromID(ctx, tx, r.SUB) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromID") } diff --git a/main.go b/main.go index 84ed2fd..257de55 100644 --- a/main.go +++ b/main.go @@ -13,28 +13,32 @@ import ( "os/signal" "strconv" "sync" + "sync/atomic" + "syscall" "time" "projectreshoot/config" "projectreshoot/db" "projectreshoot/logging" "projectreshoot/server" + "projectreshoot/tests" "github.com/pkg/errors" + "github.com/rs/zerolog" ) //go:embed static/* var embeddedStatic embed.FS // Gets the static files -func getStaticFiles() (http.FileSystem, error) { +func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) { if _, err := os.Stat("static"); err == nil { // Use actual filesystem in development - fmt.Println("Using filesystem for static files") + logger.Debug().Msg("Using filesystem for static files") return http.Dir("static"), nil } else { // Use embedded filesystem in production - fmt.Println("Using embedded static files") + logger.Debug().Msg("Using embedded static files") subFS, err := fs.Sub(embeddedStatic, "static") if err != nil { return nil, errors.Wrap(err, "fs.Sub") @@ -43,6 +47,44 @@ func getStaticFiles() (http.FileSystem, error) { } } +var maint uint32 // atomic: 1 if in maintenance mode + +// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode +func handleMaintSignals( + conn *db.SafeConn, + srv *http.Server, + logger *zerolog.Logger, + config *config.Config, +) { + logger.Debug().Msg("Starting signal listener") + ch := make(chan os.Signal, 1) + srv.RegisterOnShutdown(func() { + logger.Debug().Msg("Shutting down signal listener") + close(ch) + }) + go func() { + for sig := range ch { + switch sig { + case syscall.SIGUSR1: + if atomic.LoadUint32(&maint) != 1 { + atomic.StoreUint32(&maint, 1) + logger.Info().Msg("Signal received: Starting maintenance") + logger.Info().Msg("Attempting to acquire database lock") + conn.Pause(config.DBLockTimeout * time.Second) + } + case syscall.SIGUSR2: + if atomic.LoadUint32(&maint) != 0 { + logger.Info().Msg("Signal received: Maintenance over") + logger.Info().Msg("Releasing database lock") + conn.Resume() + atomic.StoreUint32(&maint, 0) + } + } + } + }() + signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2) +} + // Initializes and runs the server func run(ctx context.Context, w io.Writer, args map[string]string) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) @@ -53,6 +95,12 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "server.GetConfig") } + // Return the version of the database required + if args["dbver"] == "true" { + fmt.Fprintf(w, "Database version: %s\n", config.DBName) + return nil + } + var logfile *os.File = nil if config.LogOutput == "both" || config.LogOutput == "file" { logfile, err = logging.GetLogFile(config.LogDir) @@ -77,18 +125,36 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "logging.GetLogger") } - conn, err := db.ConnectToDatabase(config.DBName) - if err != nil { - return errors.Wrap(err, "db.ConnectToDatabase") + logger.Debug().Msg("Config loaded and logger started") + logger.Debug().Msg("Connecting to database") + var conn *db.SafeConn + if args["test"] == "true" { + logger.Debug().Msg("Server in test mode, using test database") + ver, err := strconv.ParseInt(config.DBName, 10, 0) + if err != nil { + return errors.Wrap(err, "strconv.ParseInt") + } + testconn, err := tests.SetupTestDB(ver) + if err != nil { + return errors.Wrap(err, "tests.SetupTestDB") + } + conn = db.MakeSafe(testconn, logger) + } else { + conn, err = db.ConnectToDatabase(config.DBName, logger) + if err != nil { + return errors.Wrap(err, "db.ConnectToDatabase") + } } defer conn.Close() - staticFS, err := getStaticFiles() + logger.Debug().Msg("Getting static files") + staticFS, err := getStaticFiles(logger) if err != nil { return errors.Wrap(err, "getStaticFiles") } - srv := server.NewServer(config, logger, conn, &staticFS) + logger.Debug().Msg("Setting up HTTP server") + srv := server.NewServer(config, logger, conn, &staticFS, &maint) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, @@ -98,18 +164,25 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } // Runs function for testing in dev if --test flag true - if args["test"] == "true" { + if args["tester"] == "true" { + logger.Debug().Msg("Running tester function") test(config, logger, conn, httpServer) return nil } + // Setups a channel to listen for os.Signal + handleMaintSignals(conn, httpServer, logger, config) + + // Runs the http server + logger.Debug().Msg("Starting up the HTTP server") go func() { - fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr) + logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests") if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - fmt.Fprintf(os.Stderr, "Error listening and serving: %s\n", err) + logger.Error().Err(err).Msg("Error listening and serving") } }() + // Handles graceful shutdown var wg sync.WaitGroup wg.Add(1) go func() { @@ -119,11 +192,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second) defer cancel() if err := httpServer.Shutdown(shutdownCtx); err != nil { - fmt.Fprintf(os.Stderr, "Error shutting down http server: %s\n", err) + logger.Error().Err(err).Msg("Error shutting down server") } }() wg.Wait() - fmt.Fprintln(w, "Shutting down") + logger.Info().Msg("Shutting down") return nil } @@ -133,7 +206,9 @@ func main() { // Parse commandline args host := flag.String("host", "", "Override host to listen on") port := flag.String("port", "", "Override port to listen on") - test := flag.Bool("test", false, "Run test function instead of main program") + test := flag.Bool("test", false, "Run server in test mode") + tester := flag.Bool("tester", false, "Run tester function instead of main program") + dbver := flag.Bool("dbver", false, "Get the version of the database required") loglevel := flag.String("loglevel", "", "Set log level") logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)") flag.Parse() @@ -143,6 +218,8 @@ func main() { "host": *host, "port": *port, "test": strconv.FormatBool(*test), + "tester": strconv.FormatBool(*tester), + "dbver": strconv.FormatBool(*dbver), "loglevel": *loglevel, "logoutput": *logoutput, } diff --git a/main_test.go b/main_test.go index ccb1d48..bc71d15 100644 --- a/main_test.go +++ b/main_test.go @@ -1,26 +1,102 @@ package main import ( + "bytes" "context" "fmt" "net/http" "os" + "strings" + "syscall" "testing" "time" + + "github.com/stretchr/testify/require" ) func Test_main(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) t.Cleanup(cancel) - args := map[string]string{} - go run(ctx, os.Stdout, args) + args := map[string]string{"test": "true"} + var stdout bytes.Buffer + os.Setenv("SECRET_KEY", ".") + os.Setenv("HOST", "127.0.0.1") + os.Setenv("PORT", "3232") + runSrvErr := make(chan error) + go func() { + if err := run(ctx, &stdout, args); err != nil { + runSrvErr <- err + return + } + }() - // wait for the server to become available - waitForReady(ctx, 10*time.Second, "http://localhost:3333/healthz") + go func() { + err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz") + if err != nil { + runSrvErr <- err + return + } + runSrvErr <- nil + }() + select { + case err := <-runSrvErr: + if err != nil { + t.Fatalf("Error starting test server: %s", err) + return + } + t.Log("Test server started") + } - // do tests - fmt.Println("Tests starting") + t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) { + done := make(chan bool) + go func() { + expected := "Global database lock acquired" + for { + if strings.Contains(stdout.String(), expected) { + done <- true + return + } + time.Sleep(100 * time.Millisecond) + } + }() + + proc, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + proc.Signal(syscall.SIGUSR1) + + select { + case <-done: + t.Log("found") + case <-time.After(250 * time.Millisecond): + t.Errorf("Not found") + } + }) + + t.Run("SIGUSR2 releases database global lock", func(t *testing.T) { + done := make(chan bool) + go func() { + expected := "Global database lock released" + for { + if strings.Contains(stdout.String(), expected) { + done <- true + return + } + time.Sleep(100 * time.Millisecond) + } + }() + + proc, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + proc.Signal(syscall.SIGUSR2) + + select { + case <-done: + t.Log("found") + case <-time.After(250 * time.Millisecond): + t.Errorf("Not found") + } + }) } func waitForReady( @@ -44,6 +120,7 @@ func waitForReady( resp, err := client.Do(req) if err != nil { fmt.Printf("Error making request: %s\n", err.Error()) + time.Sleep(250 * time.Millisecond) continue } if resp.StatusCode == http.StatusOK { diff --git a/middleware/authentication.go b/middleware/authentication.go index 23f40f0..6c192ea 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -1,14 +1,16 @@ package middleware import ( - "database/sql" + "context" "net/http" + "sync/atomic" "time" "projectreshoot/config" "projectreshoot/contexts" "projectreshoot/cookies" "projectreshoot/db" + "projectreshoot/handler" "projectreshoot/jwt" "github.com/pkg/errors" @@ -18,14 +20,15 @@ import ( // Attempt to use a valid refresh token to generate a new token pair func refreshAuthTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, w http.ResponseWriter, req *http.Request, ref *jwt.RefreshToken, ) (*db.User, error) { - user, err := ref.GetUser(conn) + user, err := ref.GetUser(ctx, tx) if err != nil { - return nil, errors.Wrap(err, "rT.GetUser") + return nil, errors.Wrap(err, "ref.GetUser") } rememberMe := map[string]bool{ @@ -39,7 +42,7 @@ func refreshAuthTokens( return nil, errors.Wrap(err, "cookies.SetTokenCookies") } // New tokens sent, revoke the used refresh token - err = jwt.RevokeToken(conn, ref) + err = jwt.RevokeToken(ctx, tx, ref) if err != nil { return nil, errors.Wrap(err, "jwt.RevokeToken") } @@ -50,22 +53,23 @@ func refreshAuthTokens( // Check the cookies for token strings and attempt to authenticate them func getAuthenticatedUser( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, w http.ResponseWriter, r *http.Request, ) (*contexts.AuthenticatedUser, error) { // Get token strings from cookies atStr, rtStr := cookies.GetTokenStrings(r) // Attempt to parse the access token - aT, err := jwt.ParseAccessToken(config, conn, atStr) + aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) if err != nil { // Access token invalid, attempt to parse refresh token - rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) if err != nil { return nil, errors.Wrap(err, "jwt.ParseRefreshToken") } // Refresh token valid, attempt to get a new token pair - user, err := refreshAuthTokens(config, conn, w, r, rT) + user, err := refreshAuthTokens(config, ctx, tx, w, r, rT) if err != nil { return nil, errors.Wrap(err, "refreshAuthTokens") } @@ -77,9 +81,9 @@ func getAuthenticatedUser( return &authUser, nil } // Access token valid - user, err := aT.GetUser(conn) + user, err := aT.GetUser(ctx, tx) if err != nil { - return nil, errors.Wrap(err, "rT.GetUser") + return nil, errors.Wrap(err, "aT.GetUser") } authUser := contexts.AuthenticatedUser{ User: user, @@ -93,12 +97,34 @@ func getAuthenticatedUser( func Authentication( logger *zerolog.Logger, config *config.Config, - conn *sql.DB, + conn *db.SafeConn, next http.Handler, + maint *uint32, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user, err := getAuthenticatedUser(config, conn, w, r) + if r.URL.Path == "/static/css/output.css" || + r.URL.Path == "/static/favicon.ico" { + next.ServeHTTP(w, r) + return + } + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + if atomic.LoadUint32(maint) == 1 { + cancel() + } + + // Start the transaction + tx, err := conn.Begin(ctx) if err != nil { + // Failed to start transaction, skip auth + logger.Warn().Err(err). + Msg("Skipping Auth - unable to start a transaction") + handler.ErrorPage(http.StatusServiceUnavailable, w, r) + return + } + user, err := getAuthenticatedUser(config, ctx, tx, w, r) + if err != nil { + tx.Rollback() // User auth failed, delete the cookies to avoid repeat requests cookies.DeleteCookie(w, "access", "/") cookies.DeleteCookie(w, "refresh", "/") @@ -106,9 +132,12 @@ func Authentication( Str("remote_addr", r.RemoteAddr). Err(err). Msg("Failed to authenticate user") + next.ServeHTTP(w, r) + return } - ctx := contexts.SetUser(r.Context(), user) - newReq := r.WithContext(ctx) + tx.Commit() + uctx := contexts.SetUser(r.Context(), user) + newReq := r.WithContext(uctx) next.ServeHTTP(w, newReq) }) } diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index a5143c8..bb3dceb 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -5,9 +5,11 @@ import ( "net/http" "net/http/httptest" "strconv" + "sync/atomic" "testing" "projectreshoot/contexts" + "projectreshoot/db" "projectreshoot/tests" "github.com/stretchr/testify/assert" @@ -15,14 +17,15 @@ import ( ) func TestAuthenticationMiddleware(t *testing.T) { - // Basic setup cfg, err := tests.TestConfig() require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) require.NoError(t, err) - require.NotNil(t, conn) - defer tests.DeleteTestDB() + conn, err := tests.SetupTestDB(ver) + require.NoError(t, err) + sconn := db.MakeSafe(conn, logger) + defer sconn.Close() // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -36,9 +39,10 @@ func TestAuthenticationMiddleware(t *testing.T) { w.Write([]byte(strconv.Itoa(user.ID))) } }) - + var maint uint32 + atomic.StoreUint32(&maint, 0) // Add the middleware and create the server - authHandler := Authentication(logger, cfg, conn, testHandler) + authHandler := Authentication(logger, cfg, sconn, testHandler, &maint) require.NoError(t, err) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/logging.go b/middleware/logging.go index abae797..ae6616b 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "projectreshoot/contexts" + "projectreshoot/handler" "time" "github.com/rs/zerolog" @@ -23,9 +24,14 @@ func (w *wrappedWriter) WriteHeader(statusCode int) { // Middleware to add logs to console with details of the request func Logging(logger *zerolog.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/static/css/output.css" || + r.URL.Path == "/static/favicon.ico" { + next.ServeHTTP(w, r) + return + } start, err := contexts.GetStartTime(r.Context()) if err != nil { - // Handle failure here. internal server error maybe + handler.ErrorPage(http.StatusInternalServerError, w, r) return } wrapped := &wrappedWriter{ @@ -38,7 +44,7 @@ func Logging(logger *zerolog.Logger, next http.Handler) http.Handler { Str("method", r.Method). Str("resource", r.URL.Path). Dur("time_elapsed", time.Since(start)). - Str("remote_addr", r.RemoteAddr). + Str("remote_addr", r.Header.Get("X-Forwarded-For")). Msg("Served") }) } diff --git a/middleware/pageprotection.go b/middleware/pageprotection.go index f5537b2..7f104c0 100644 --- a/middleware/pageprotection.go +++ b/middleware/pageprotection.go @@ -3,20 +3,15 @@ package middleware import ( "net/http" "projectreshoot/contexts" - "projectreshoot/view/page" + "projectreshoot/handler" ) // Checks if the user is set in the context and shows 401 page if not logged in -func RequiresLogin(next http.Handler) http.Handler { +func LoginReq(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := contexts.GetUser(r.Context()) if user == nil { - w.WriteHeader(http.StatusUnauthorized) - page.Error( - "401", - "Unauthorized", - "Please login to view this page", - ).Render(r.Context(), w) + handler.ErrorPage(http.StatusUnauthorized, w, r) return } next.ServeHTTP(w, r) @@ -25,7 +20,7 @@ func RequiresLogin(next http.Handler) http.Handler { // Checks if the user is set in the context and redirects them to profile if // they are logged in -func RequiresLogout(next http.Handler) http.Handler { +func LogoutReq(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := contexts.GetUser(r.Context()) if user != nil { diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go index de79975..93414f8 100644 --- a/middleware/pageprotection_test.go +++ b/middleware/pageprotection_test.go @@ -3,8 +3,11 @@ package middleware import ( "net/http" "net/http/httptest" + "strconv" + "sync/atomic" "testing" + "projectreshoot/db" "projectreshoot/tests" "github.com/stretchr/testify/assert" @@ -12,23 +15,26 @@ import ( ) func TestPageLoginRequired(t *testing.T) { - // Basic setup cfg, err := tests.TestConfig() require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) require.NoError(t, err) - require.NotNil(t, conn) - defer tests.DeleteTestDB() + conn, err := tests.SetupTestDB(ver) + require.NoError(t, err) + sconn := db.MakeSafe(conn, logger) + defer sconn.Close() // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + var maint uint32 + atomic.StoreUint32(&maint, 0) // Add the middleware and create the server - loginRequiredHandler := RequiresLogin(testHandler) - authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + loginRequiredHandler := LoginReq(testHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/reauthentication.go b/middleware/reauthentication.go index 41fad65..b1fdb93 100644 --- a/middleware/reauthentication.go +++ b/middleware/reauthentication.go @@ -6,7 +6,7 @@ import ( "time" ) -func RequiresFresh( +func FreshReq( next http.Handler, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go index 63017cb..646a557 100644 --- a/middleware/reauthentication_test.go +++ b/middleware/reauthentication_test.go @@ -3,33 +3,39 @@ package middleware import ( "net/http" "net/http/httptest" + "strconv" + "sync/atomic" "testing" + "projectreshoot/db" "projectreshoot/tests" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestActionReauthRequired(t *testing.T) { - // Basic setup +func TestReauthRequired(t *testing.T) { cfg, err := tests.TestConfig() require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + ver, err := strconv.ParseInt(cfg.DBName, 10, 0) require.NoError(t, err) - require.NotNil(t, conn) - defer tests.DeleteTestDB() + conn, err := tests.SetupTestDB(ver) + require.NoError(t, err) + sconn := db.MakeSafe(conn, logger) + defer sconn.Close() // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + var maint uint32 + atomic.StoreUint32(&maint, 0) // Add the middleware and create the server - reauthRequiredHandler := RequiresFresh(testHandler) - loginRequiredHandler := RequiresLogin(reauthRequiredHandler) - authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + reauthRequiredHandler := FreshReq(testHandler) + loginRequiredHandler := LoginReq(reauthRequiredHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/migrate/migrate.go b/migrate/migrate.go new file mode 100644 index 0000000..dd8a99e --- /dev/null +++ b/migrate/migrate.go @@ -0,0 +1,69 @@ +package main + +import ( + "context" + "database/sql" + "embed" + "fmt" + "io/fs" + "log" + "os" + "strconv" + + "github.com/pressly/goose/v3" + + _ "modernc.org/sqlite" +) + +//go:embed migrations +var migrationsFS embed.FS + +func main() { + if len(os.Args) != 4 { + fmt.Println("Usage: prmigrate up-to|down-to ") + os.Exit(1) + } + + filePath := os.Args[1] + direction := os.Args[2] + versionStr := os.Args[3] + + version, err := strconv.Atoi(versionStr) + if err != nil { + log.Fatalf("Invalid version number: %v", err) + } + if _, err := os.Stat(filePath); os.IsNotExist(err) { + log.Fatalf("Database file does not exist: %v", filePath) + } + db, err := sql.Open("sqlite", filePath) + if err != nil { + log.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + migrations, err := fs.Sub(migrationsFS, "migrations") + if err != nil { + log.Fatalf("Failed to get migrations from embedded filesystem") + } + provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations) + if err != nil { + log.Fatalf("Failed to create migration provider: %v", err) + } + + ctx := context.Background() + + switch direction { + case "up-to": + _, err = provider.UpTo(ctx, int64(version)) + case "down-to": + _, err = provider.DownTo(ctx, int64(version)) + default: + log.Fatalf("Invalid direction: use 'up-to' or 'down-to'") + } + + if err != nil { + log.Fatalf("Migration failed: %v", err) + } + + fmt.Println("Migration successful!") +} diff --git a/schema.sql b/migrate/migrations/00001_init.sql similarity index 62% rename from schema.sql rename to migrate/migrations/00001_init.sql index 986d312..ccdf336 100644 --- a/schema.sql +++ b/migrate/migrations/00001_init.sql @@ -1,5 +1,6 @@ +-- +goose Up +-- +goose StatementBegin PRAGMA foreign_keys=ON; -BEGIN TRANSACTION; CREATE TABLE IF NOT EXISTS jwtblacklist ( jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'), exp INTEGER NOT NULL @@ -11,9 +12,16 @@ CREATE TABLE IF NOT EXISTS "users" ( created_at INTEGER DEFAULT (unixepoch()), bio TEXT DEFAULT "" ) STRICT; -CREATE TRIGGER cleanup_expired_tokens +CREATE TRIGGER IF NOT EXISTS cleanup_expired_tokens AFTER INSERT ON jwtblacklist BEGIN DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now'); END; -COMMIT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP TRIGGER IF EXISTS cleanup_expired_tokens; +DROP TABLE IF EXISTS jwtblacklist; +DROP TABLE IF EXISTS users; +-- +goose StatementEnd diff --git a/server/routes.go b/server/routes.go index a92885f..8462e72 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,11 +1,11 @@ package server import ( - "database/sql" "net/http" "projectreshoot/config" - "projectreshoot/handlers" + "projectreshoot/db" + "projectreshoot/handler" "projectreshoot/middleware" "projectreshoot/view/page" @@ -17,85 +17,47 @@ func addRoutes( mux *http.ServeMux, logger *zerolog.Logger, config *config.Config, - conn *sql.DB, + conn *db.SafeConn, staticFS *http.FileSystem, ) { + route := mux.Handle + loggedIn := middleware.LoginReq + loggedOut := middleware.LogoutReq + fresh := middleware.FreshReq + // Health check mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {}) // Static files - mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic(staticFS))) + route("GET /static/", http.StripPrefix("/static/", handler.StaticFS(staticFS))) // Index page and unhandled catchall (404) - mux.Handle("GET /", handlers.HandleRoot()) + route("GET /", handler.Root()) // Static content, unprotected pages - mux.Handle("GET /about", handlers.HandlePage(page.About())) + route("GET /about", handler.HandlePage(page.About())) // Login page and handlers - mux.Handle("GET /login", - middleware.RequiresLogout( - handlers.HandleLoginPage(config.TrustedHost), - )) - mux.Handle("POST /login", - middleware.RequiresLogout( - handlers.HandleLoginRequest( - config, - logger, - conn, - ))) + route("GET /login", loggedOut(handler.LoginPage(config.TrustedHost))) + route("POST /login", loggedOut(handler.LoginRequest(config, logger, conn))) // Register page and handlers - mux.Handle("GET /register", - middleware.RequiresLogout( - handlers.HandleRegisterPage(config.TrustedHost), - )) - mux.Handle("POST /register", - middleware.RequiresLogout( - handlers.HandleRegisterRequest( - config, - logger, - conn, - ))) + route("GET /register", loggedOut(handler.RegisterPage(config.TrustedHost))) + route("POST /register", loggedOut(handler.RegisterRequest(config, logger, conn))) // Logout - mux.Handle("POST /logout", handlers.HandleLogout(config, logger, conn)) + route("POST /logout", handler.Logout(config, logger, conn)) // Reauthentication request - mux.Handle("POST /reauthenticate", - middleware.RequiresLogin( - handlers.HandleReauthenticate(logger, config, conn), - )) + route("POST /reauthenticate", loggedIn(handler.Reauthenticate(logger, config, conn))) // Profile page - mux.Handle("GET /profile", - middleware.RequiresLogin( - handlers.HandleProfilePage(), - )) + route("GET /profile", loggedIn(handler.ProfilePage())) // Account page - mux.Handle("GET /account", - middleware.RequiresLogin( - handlers.HandleAccountPage(), - )) - mux.Handle("POST /account-select-page", - middleware.RequiresLogin( - handlers.HandleAccountSubpage(), - )) - mux.Handle("POST /change-username", - middleware.RequiresLogin( - middleware.RequiresFresh( - handlers.HandleChangeUsername(logger, conn), - ), - )) - mux.Handle("POST /change-bio", - middleware.RequiresLogin( - handlers.HandleChangeBio(logger, conn), - )) - mux.Handle("POST /change-password", - middleware.RequiresLogin( - middleware.RequiresFresh( - handlers.HandleChangePassword(logger, conn), - ), - )) + route("GET /account", loggedIn(handler.AccountPage())) + route("POST /account-select-page", loggedIn(handler.AccountSubpage())) + route("POST /change-username", loggedIn(fresh(handler.ChangeUsername(logger, conn)))) + route("POST /change-bio", loggedIn(handler.ChangeBio(logger, conn))) + route("POST /change-password", loggedIn(fresh(handler.ChangePassword(logger, conn)))) } diff --git a/server/server.go b/server/server.go index 648082f..8f189ad 100644 --- a/server/server.go +++ b/server/server.go @@ -1,10 +1,10 @@ package server import ( - "database/sql" "net/http" "projectreshoot/config" + "projectreshoot/db" "projectreshoot/middleware" "github.com/rs/zerolog" @@ -14,8 +14,9 @@ import ( func NewServer( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, staticFS *http.FileSystem, + maint *uint32, ) http.Handler { mux := http.NewServeMux() addRoutes( @@ -29,7 +30,7 @@ func NewServer( // Add middleware here, must be added in reverse order of execution // i.e. First in list will get executed last during the request handling handler = middleware.Logging(logger, handler) - handler = middleware.Authentication(logger, config, conn, handler) + handler = middleware.Authentication(logger, config, conn, handler, maint) // Gzip handler = middleware.Gzip(handler, config.GZIP) diff --git a/setup-hooks.sh b/setup-hooks.sh new file mode 100644 index 0000000..fbf4509 --- /dev/null +++ b/setup-hooks.sh @@ -0,0 +1,14 @@ +#!/bin/sh +HOOKS_DIR=".githooks" +GIT_HOOKS_DIR=".git/hooks" + +mkdir -p "$GIT_HOOKS_DIR" + +for hook in "$HOOKS_DIR"/*; do + hook_name=$(basename "$hook") + cp "$hook" "$GIT_HOOKS_DIR/$hook_name" + chmod +x "$GIT_HOOKS_DIR/$hook_name" +done + +echo "Git hooks installed!" + diff --git a/tester.go b/tester.go index bfd8981..e474d82 100644 --- a/tester.go +++ b/tester.go @@ -1,10 +1,10 @@ package main import ( - "database/sql" "net/http" "projectreshoot/config" + "projectreshoot/db" "github.com/rs/zerolog" ) @@ -18,7 +18,7 @@ import ( func test( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, srv *http.Server, ) { } diff --git a/tests/database.go b/tests/database.go index 7c606c5..0010636 100644 --- a/tests/database.go +++ b/tests/database.go @@ -1,64 +1,83 @@ package tests import ( + "context" "database/sql" - "fmt" + "io/fs" "os" "path/filepath" "github.com/pkg/errors" + "github.com/pressly/goose/v3" - _ "github.com/mattn/go-sqlite3" + _ "modernc.org/sqlite" ) -func findSQLFile(filename string) (string, error) { +func findMigrations() (*fs.FS, error) { + dir, err := os.Getwd() + if err != nil { + return nil, err + } + + for { + if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil { + migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations")) + return &migrationsdir, nil + } + + parent := filepath.Dir(dir) + if parent == dir { // Reached root + return nil, errors.New("Unable to locate migrations directory") + } + dir = parent + } +} + +func findTestData() (string, error) { dir, err := os.Getwd() if err != nil { return "", err } for { - if _, err := os.Stat(filepath.Join(dir, filename)); err == nil { - return filepath.Join(dir, filename), nil + if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil { + return filepath.Join(dir, "tests", "testdata.sql"), nil } parent := filepath.Dir(dir) if parent == dir { // Reached root - return "", errors.New(fmt.Sprintf("Unable to locate %s", filename)) + return "", errors.New("Unable to locate test data") } dir = parent } } -// SetupTestDB initializes a test SQLite database with mock data -// Make sure to call DeleteTestDB when finished to cleanup -func SetupTestDB() (*sql.DB, error) { - conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db") +func SetupTestDB(version int64) (*sql.DB, error) { + conn, err := sql.Open("sqlite", "file::memory:?cache=shared") if err != nil { return nil, errors.Wrap(err, "sql.Open") } - // Setup the test database - schemaPath, err := findSQLFile("schema.sql") + + migrations, err := findMigrations() if err != nil { - return nil, errors.Wrap(err, "findSchema") + return nil, errors.Wrap(err, "findMigrations") + } + provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations) + if err != nil { + return nil, errors.Wrap(err, "goose.NewProvider") + } + ctx := context.Background() + if _, err := provider.UpTo(ctx, version); err != nil { + return nil, errors.Wrap(err, "provider.UpTo") } - sqlBytes, err := os.ReadFile(schemaPath) - if err != nil { - return nil, errors.Wrap(err, "os.ReadFile") - } - schemaSQL := string(sqlBytes) - - _, err = conn.Exec(schemaSQL) - if err != nil { - return nil, errors.Wrap(err, "conn.Exec") - } + // NOTE: ================================================== // Load the test data - dataPath, err := findSQLFile("testdata.sql") + dataPath, err := findTestData() if err != nil { return nil, errors.Wrap(err, "findSchema") } - sqlBytes, err = os.ReadFile(dataPath) + sqlBytes, err := os.ReadFile(dataPath) if err != nil { return nil, errors.Wrap(err, "os.ReadFile") } @@ -66,20 +85,7 @@ func SetupTestDB() (*sql.DB, error) { _, err = conn.Exec(dataSQL) if err != nil { - return nil, errors.Wrap(err, "conn.Exec") + return nil, errors.Wrap(err, "tx.Exec") } return conn, nil } - -// Deletes the test database from disk -func DeleteTestDB() error { - fileName := ".projectreshoot-test-database.db" - - // Attempt to remove the file - err := os.Remove(fileName) - if err != nil { - return errors.Wrap(err, "os.Remove") - } - - return nil -} diff --git a/tests/logger.go b/tests/logger.go index d8a0dd9..c3a9118 100644 --- a/tests/logger.go +++ b/tests/logger.go @@ -24,6 +24,10 @@ func NilLogger() *zerolog.Logger { // Return a logger that makes use of the T.Log method to enable debugging tests func DebugLogger(t *testing.T) *zerolog.Logger { - logger := zerolog.New(&TLogWriter{t: t}) + logger := zerolog.New(GetTLogWriter(t)) return &logger } + +func GetTLogWriter(t *testing.T) *TLogWriter { + return &TLogWriter{t: t} +} diff --git a/testdata.sql b/tests/testdata.sql similarity index 100% rename from testdata.sql rename to tests/testdata.sql diff --git a/view/component/form/loginform.templ b/view/component/form/loginform.templ index d6fa4a8..4407288 100644 --- a/view/component/form/loginform.templ +++ b/view/component/form/loginform.templ @@ -33,7 +33,7 @@ templ LoginForm(loginError string) {
diff --git a/view/component/form/registerform.templ b/view/component/form/registerform.templ index 7344b00..4775e2b 100644 --- a/view/component/form/registerform.templ +++ b/view/component/form/registerform.templ @@ -38,7 +38,7 @@ templ RegisterForm(registerError string) { >
diff --git a/view/component/popup/errorPopup.templ b/view/component/popup/error500Popup.templ similarity index 95% rename from view/component/popup/errorPopup.templ rename to view/component/popup/error500Popup.templ index f809230..45573e0 100644 --- a/view/component/popup/errorPopup.templ +++ b/view/component/popup/error500Popup.templ @@ -1,9 +1,9 @@ package popup -templ ErrorPopup() { +templ Error500Popup() {
+ +
+} diff --git a/view/layout/global.templ b/view/layout/global.templ index 17f58eb..2469f00 100644 --- a/view/layout/global.templ +++ b/view/layout/global.templ @@ -41,11 +41,12 @@ templ Global() {