From 049698100f04a618b7bd3fd87e17053238e72522 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 00:24:14 +1100 Subject: [PATCH 01/47] Added hook to prevent pushing to staging and master --- .githooks/pre-push | 11 +++++++++++ setup-hooks.sh | 14 ++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 .githooks/pre-push create mode 100644 setup-hooks.sh diff --git a/.githooks/pre-push b/.githooks/pre-push new file mode 100644 index 0000000..316c82a --- /dev/null +++ b/.githooks/pre-push @@ -0,0 +1,11 @@ +#!/bin/sh +protected_branches=("master" "staging") +current_branch=$(git rev-parse --abbrev-ref HEAD) + +for branch in "${protected_branches[@]}"; do + if [ "$current_branch" = "$branch" ]; then + echo "Direct pushes to '$branch' are not allowed. Use a pull request instead." + exit 1 + fi +done +exit 0 diff --git a/setup-hooks.sh b/setup-hooks.sh new file mode 100644 index 0000000..fbf4509 --- /dev/null +++ b/setup-hooks.sh @@ -0,0 +1,14 @@ +#!/bin/sh +HOOKS_DIR=".githooks" +GIT_HOOKS_DIR=".git/hooks" + +mkdir -p "$GIT_HOOKS_DIR" + +for hook in "$HOOKS_DIR"/*; do + hook_name=$(basename "$hook") + cp "$hook" "$GIT_HOOKS_DIR/$hook_name" + chmod +x "$GIT_HOOKS_DIR/$hook_name" +done + +echo "Git hooks installed!" + From 4e5a5cb33e570b4fbb302e02e60791da63f1bbef Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 00:27:51 +1100 Subject: [PATCH 02/47] Removed test message from index page --- view/page/index.templ | 1 - 1 file changed, 1 deletion(-) diff --git a/view/page/index.templ b/view/page/index.templ index 4fa1163..e4e78a6 100644 --- a/view/page/index.templ +++ b/view/page/index.templ @@ -8,7 +8,6 @@ templ Index() {
Project Reshoot
A better way to discover and rate films
-
If you're seeing this text, you're my favourite :)
} } From 1edff49425b61916e9d1be8d81c2fc3456de7e93 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 11:44:55 +1100 Subject: [PATCH 03/47] Made new safe connection and transaction structs to enable db locking --- db/connection.go | 103 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/db/connection.go b/db/connection.go index d1415c5..1057dc9 100644 --- a/db/connection.go +++ b/db/connection.go @@ -1,14 +1,117 @@ package db import ( + "context" "database/sql" "fmt" + "sync" + "time" "github.com/pkg/errors" _ "github.com/mattn/go-sqlite3" ) +// Wraps the database handle, providing a mutex to safely manage transactions +type SafeConn struct { + db *sql.DB + mux sync.RWMutex +} + +// Extends sql.Tx for use with SafeConn +type SafeTX struct { + tx *sql.Tx + sc *SafeConn +} + +// Starts a new transaction, waiting up to 10 seconds if the database is locked +func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + lockAcquired := make(chan struct{}) + go func() { + conn.mux.RLock() + close(lockAcquired) + }() + + select { + case <-lockAcquired: + tx, err := conn.db.BeginTx(ctx, nil) + if err != nil { + conn.mux.RUnlock() + return nil, err + } + return &SafeTX{tx: tx, sc: conn}, nil + case <-ctx.Done(): + return nil, errors.New("Transaction time out due to database lock") + } +} + +// Query the database inside the transaction +func (stx *SafeTX) Query( + ctx context.Context, + query string, + args ...interface{}, +) (*sql.Rows, error) { + if stx.tx == nil { + return nil, errors.New("Cannot query without a transaction") + } + return stx.tx.QueryContext(ctx, query, args...) +} + +// Exec a statement on the database inside the transaction +func (stx *SafeTX) Exec( + ctx context.Context, + query string, + args ...interface{}, +) (sql.Result, error) { + if stx.tx == nil { + return nil, errors.New("Cannot exec without a transaction") + } + return stx.tx.ExecContext(ctx, query, args...) +} + +// Commit commits the transaction and releases the lock. +func (stx *SafeTX) Commit() error { + if stx.tx == nil { + return errors.New("Cannot commit without a transaction") + } + err := stx.tx.Commit() + stx.tx = nil + + stx.releaseLock() + return err +} + +// Rollback aborts the transaction. +func (stx *SafeTX) Rollback() error { + if stx.tx == nil { + return errors.New("Cannot rollback without a transaction") + } + err := stx.tx.Rollback() + stx.tx = nil + stx.releaseLock() + return err +} + +// Release the read lock for the transaction +func (stx *SafeTX) releaseLock() { + if stx.sc != nil { + stx.sc.mux.RUnlock() + } +} + +// Pause blocks new transactions for a backup. +func (conn *SafeConn) Pause() { + conn.mux.Lock() // Blocks all new transactions +} + +// Resume allows transactions to proceed. +func (conn *SafeConn) Resume() { + conn.mux.Unlock() +} + // Returns a database connection handle for the Turso DB func ConnectToDatabase(dbName string) (*sql.DB, error) { file := fmt.Sprintf("file:%s.db", dbName) From 417daf0028e5e3bb301ba349810c6b49ceba08b6 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 11:53:20 +1100 Subject: [PATCH 04/47] Added close method to SafeConn --- db/connection.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/db/connection.go b/db/connection.go index 1057dc9..421d4e6 100644 --- a/db/connection.go +++ b/db/connection.go @@ -112,13 +112,22 @@ func (conn *SafeConn) Resume() { conn.mux.Unlock() } +// Close the database connection +func (conn *SafeConn) Close() error { + conn.mux.Lock() + defer conn.mux.Unlock() + return conn.db.Close() +} + // Returns a database connection handle for the Turso DB -func ConnectToDatabase(dbName string) (*sql.DB, error) { +func ConnectToDatabase(dbName string) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) db, err := sql.Open("sqlite3", file) - if err != nil { return nil, errors.Wrap(err, "sql.Open") } - return db, nil + + conn := &SafeConn{db: db} + + return conn, nil } From 2c61cec55c7b24f68cbd1e059c7c362de42a479e Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 18:58:34 +1100 Subject: [PATCH 05/47] Update authentication, reauth, logout to use new transactions --- db/connection.go | 26 ++++++--- db/user_functions.go | 34 +++++++++-- handlers/logout.go | 91 +++++++++++++++++++++-------- handlers/reauthenticatate.go | 58 ++++++++++-------- handlers/withtransaction.go | 47 +++++++++++++++ jwt/parse.go | 13 +++-- jwt/revoke.go | 18 +++--- jwt/tokens.go | 12 ++-- main.go | 9 ++- middleware/authentication.go | 65 +++++++++++++-------- middleware/authentication_test.go | 2 +- middleware/logging.go | 7 ++- middleware/pageprotection_test.go | 2 +- middleware/reauthentication_test.go | 2 +- server/routes.go | 14 +++-- server/server.go | 5 +- tests/database.go | 22 +++++-- 17 files changed, 306 insertions(+), 121 deletions(-) create mode 100644 handlers/withtransaction.go diff --git a/db/connection.go b/db/connection.go index 421d4e6..bcd5895 100644 --- a/db/connection.go +++ b/db/connection.go @@ -5,7 +5,6 @@ import ( "database/sql" "fmt" "sync" - "time" "github.com/pkg/errors" @@ -18,17 +17,19 @@ type SafeConn struct { mux sync.RWMutex } +func MakeSafe(db *sql.DB) *SafeConn { + return &SafeConn{db: db} +} + // Extends sql.Tx for use with SafeConn type SafeTX struct { tx *sql.Tx sc *SafeConn } -// Starts a new transaction, waiting up to 10 seconds if the database is locked +// Starts a new transaction based on the current context. Will cancel if +// the context is closed/cancelled/done func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - lockAcquired := make(chan struct{}) go func() { conn.mux.RLock() @@ -119,7 +120,18 @@ func (conn *SafeConn) Close() error { return conn.db.Close() } -// Returns a database connection handle for the Turso DB +// Returns a database connection handle for the DB +func OldConnectToDatabase(dbName string) (*sql.DB, error) { + file := fmt.Sprintf("file:%s.db", dbName) + db, err := sql.Open("sqlite3", file) + if err != nil { + return nil, errors.Wrap(err, "sql.Open") + } + + return db, nil +} + +// Returns a database connection handle for the DB func ConnectToDatabase(dbName string) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) db, err := sql.Open("sqlite3", file) @@ -127,7 +139,7 @@ func ConnectToDatabase(dbName string) (*SafeConn, error) { return nil, errors.Wrap(err, "sql.Open") } - conn := &SafeConn{db: db} + conn := MakeSafe(db) return conn, nil } diff --git a/db/user_functions.go b/db/user_functions.go index 3c3623e..9c1100a 100644 --- a/db/user_functions.go +++ b/db/user_functions.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "fmt" @@ -26,7 +27,7 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error } // Fetches data from the users table using "WHERE column = 'value'" -func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { +func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { query := fmt.Sprintf( `SELECT id, @@ -45,6 +46,31 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e return rows, nil } +// Fetches data from the users table using "WHERE column = 'value'" +func fetchUserData( + ctx context.Context, + tx *SafeTX, + column string, + value interface{}, +) (*sql.Rows, error) { + query := fmt.Sprintf( + `SELECT + id, + username, + password_hash, + created_at, + bio + FROM users + WHERE %s = ? COLLATE NOCASE LIMIT 1`, + column, + ) + rows, err := tx.Query(ctx, query, value) + if err != nil { + return nil, errors.Wrap(err, "tx.Query") + } + return rows, nil +} + // Scan the next row into the provided user pointer. Calls rows.Next() and // assumes only row in the result. Providing a rows object with more than 1 // row may result in undefined behaviour. @@ -67,7 +93,7 @@ func scanUserRow(user *User, rows *sql.Rows) error { // Queries the database for a user matching the given username. // Query is case insensitive func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { - rows, err := fetchUserData(conn, "username", username) + rows, err := oldfetchUserData(conn, "username", username) if err != nil { return nil, errors.Wrap(err, "fetchUserData") } @@ -81,8 +107,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { } // Queries the database for a user matching the given ID. -func GetUserFromID(conn *sql.DB, id int) (*User, error) { - rows, err := fetchUserData(conn, "id", id) +func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) { + rows, err := fetchUserData(ctx, tx, "id", id) if err != nil { return nil, errors.Wrap(err, "fetchUserData") } diff --git a/handlers/logout.go b/handlers/logout.go index c8999a2..da78925 100644 --- a/handlers/logout.go +++ b/handlers/logout.go @@ -1,41 +1,79 @@ package handlers import ( - "database/sql" + "context" "net/http" + "strings" + "projectreshoot/config" "projectreshoot/cookies" + "projectreshoot/db" "projectreshoot/jwt" "github.com/pkg/errors" "github.com/rs/zerolog" ) +func revokeAccess( + config *config.Config, + ctx context.Context, + tx *db.SafeTX, + atStr string, +) error { + aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) + if err != nil { + if strings.Contains(err.Error(), "Token is expired") || + strings.Contains(err.Error(), "Token has been revoked") { + return nil // Token is expired, dont need to revoke it + } + return errors.Wrap(err, "jwt.ParseAccessToken") + } + err = jwt.RevokeToken(ctx, tx, aT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + return nil +} + +func revokeRefresh( + config *config.Config, + ctx context.Context, + tx *db.SafeTX, + rtStr string, +) error { + rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) + if err != nil { + if strings.Contains(err.Error(), "Token is expired") || + strings.Contains(err.Error(), "Token has been revoked") { + return nil // Token is expired, dont need to revoke it + } + return errors.Wrap(err, "jwt.ParseRefreshToken") + } + err = jwt.RevokeToken(ctx, tx, rT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + return nil +} + // Retrieve and revoke the user's tokens func revokeTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, r *http.Request, ) error { // get the tokens from the cookies atStr, rtStr := cookies.GetTokenStrings(r) - aT, err := jwt.ParseAccessToken(config, conn, atStr) - if err != nil { - return errors.Wrap(err, "jwt.ParseAccessToken") - } - rT, err := jwt.ParseRefreshToken(config, conn, rtStr) - if err != nil { - return errors.Wrap(err, "jwt.ParseRefreshToken") - } // revoke the refresh token first as the access token expires quicker // only matters if there is an error revoking the tokens - err = jwt.RevokeToken(conn, rT) + err := revokeRefresh(config, ctx, tx, rtStr) if err != nil { - return errors.Wrap(err, "jwt.RevokeToken") + return errors.Wrap(err, "revokeRefresh") } - err = jwt.RevokeToken(conn, aT) + err = revokeAccess(config, ctx, tx, atStr) if err != nil { - return errors.Wrap(err, "jwt.RevokeToken") + return errors.Wrap(err, "revokeAccess") } return nil } @@ -44,19 +82,24 @@ func revokeTokens( func HandleLogout( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - err := revokeTokens(config, conn, r) - if err != nil { - logger.Error().Err(err).Msg("Error occured on user logout") - w.WriteHeader(http.StatusInternalServerError) - return - } - cookies.DeleteCookie(w, "access", "/") - cookies.DeleteCookie(w, "refresh", "/") - w.Header().Set("HX-Redirect", "/login") + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + err := revokeTokens(config, ctx, tx, r) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error occured on user logout") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + cookies.DeleteCookie(w, "access", "/") + cookies.DeleteCookie(w, "refresh", "/") + w.Header().Set("HX-Redirect", "/login") + }) }, ) } diff --git a/handlers/reauthenticatate.go b/handlers/reauthenticatate.go index 9e188d8..87a0928 100644 --- a/handlers/reauthenticatate.go +++ b/handlers/reauthenticatate.go @@ -1,12 +1,13 @@ package handlers import ( - "database/sql" + "context" "net/http" "projectreshoot/config" "projectreshoot/contexts" "projectreshoot/cookies" + "projectreshoot/db" "projectreshoot/jwt" "projectreshoot/view/component/form" @@ -17,16 +18,17 @@ import ( // Get the tokens from the request func getTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, r *http.Request, ) (*jwt.AccessToken, *jwt.RefreshToken, error) { // get the existing tokens from the cookies atStr, rtStr := cookies.GetTokenStrings(r) - aT, err := jwt.ParseAccessToken(config, conn, atStr) + aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) if err != nil { return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken") } - rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) if err != nil { return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken") } @@ -35,15 +37,16 @@ func getTokens( // Revoke the given token pair func revokeTokenPair( - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, aT *jwt.AccessToken, rT *jwt.RefreshToken, ) error { - err := jwt.RevokeToken(conn, aT) + err := jwt.RevokeToken(ctx, tx, aT) if err != nil { return errors.Wrap(err, "jwt.RevokeToken") } - err = jwt.RevokeToken(conn, rT) + err = jwt.RevokeToken(ctx, tx, rT) if err != nil { return errors.Wrap(err, "jwt.RevokeToken") } @@ -53,11 +56,12 @@ func revokeTokenPair( // Issue new tokens for the user, invalidating the old ones func refreshTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, w http.ResponseWriter, r *http.Request, ) error { - aT, rT, err := getTokens(config, conn, r) + aT, rT, err := getTokens(config, ctx, tx, r) if err != nil { return errors.Wrap(err, "getTokens") } @@ -71,7 +75,7 @@ func refreshTokens( if err != nil { return errors.Wrap(err, "cookies.SetTokenCookies") } - err = revokeTokenPair(conn, aT, rT) + err = revokeTokenPair(ctx, tx, aT, rT) if err != nil { return errors.Wrap(err, "revokeTokenPair") } @@ -97,23 +101,29 @@ func validatePassword( func HandleReauthenticate( logger *zerolog.Logger, config *config.Config, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - err := validatePassword(r) - if err != nil { - w.WriteHeader(445) - form.ConfirmPassword("Incorrect password").Render(r.Context(), w) - return - } - err = refreshTokens(config, conn, w, r) - if err != nil { - logger.Error().Err(err).Msg("Failed to refresh user tokens") - w.WriteHeader(http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + err := validatePassword(r) + if err != nil { + tx.Rollback() + w.WriteHeader(445) + form.ConfirmPassword("Incorrect password").Render(r.Context(), w) + return + } + err = refreshTokens(config, ctx, tx, w, r) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Failed to refresh user tokens") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.WriteHeader(http.StatusOK) + }) }, ) } diff --git a/handlers/withtransaction.go b/handlers/withtransaction.go new file mode 100644 index 0000000..d81884f --- /dev/null +++ b/handlers/withtransaction.go @@ -0,0 +1,47 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "projectreshoot/db" + "projectreshoot/view/page" + + "github.com/rs/zerolog" +) + +// A helper function to create a transaction with a cancellable context. +func WithTransaction( + w http.ResponseWriter, + r *http.Request, + logger *zerolog.Logger, + conn *db.SafeConn, + handler func( + ctx context.Context, + tx *db.SafeTX, + w http.ResponseWriter, + r *http.Request, + ), +) { + // Create a cancellable context from the request context + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + tx.Rollback() + logger.Warn().Err(err).Msg("Request failed to start a transaction") + w.WriteHeader(http.StatusServiceUnavailable) + page.Error( + "503", + http.StatusText(503), + "This service is currently unavailable. It could be down for maintenance"). + Render(r.Context(), w) + return + } + + // Pass the context and transaction to the handler + handler(ctx, tx, w, r) +} diff --git a/jwt/parse.go b/jwt/parse.go index 741cc59..0446e85 100644 --- a/jwt/parse.go +++ b/jwt/parse.go @@ -1,11 +1,12 @@ package jwt import ( - "database/sql" + "context" "fmt" "time" "projectreshoot/config" + "projectreshoot/db" "github.com/golang-jwt/jwt" "github.com/google/uuid" @@ -17,7 +18,8 @@ import ( // has the correct scope. func ParseAccessToken( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, tokenString string, ) (*AccessToken, error) { if tokenString == "" { @@ -74,7 +76,7 @@ func ParseAccessToken( Scope: scope, } - valid, err := CheckTokenNotRevoked(conn, token) + valid, err := CheckTokenNotRevoked(ctx, tx, token) if err != nil { return nil, errors.Wrap(err, "CheckTokenNotRevoked") } @@ -89,7 +91,8 @@ func ParseAccessToken( // has the correct scope. func ParseRefreshToken( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, tokenString string, ) (*RefreshToken, error) { if tokenString == "" { @@ -141,7 +144,7 @@ func ParseRefreshToken( Scope: scope, } - valid, err := CheckTokenNotRevoked(conn, token) + valid, err := CheckTokenNotRevoked(ctx, tx, token) if err != nil { return nil, errors.Wrap(err, "CheckTokenNotRevoked") } diff --git a/jwt/revoke.go b/jwt/revoke.go index ed2ec63..e988a4e 100644 --- a/jwt/revoke.go +++ b/jwt/revoke.go @@ -1,32 +1,34 @@ package jwt import ( - "database/sql" + "context" + "projectreshoot/db" "github.com/pkg/errors" ) // Revoke a token by adding it to the database -func RevokeToken(conn *sql.DB, t Token) error { +func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error { jti := t.GetJTI() exp := t.GetEXP() query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)` - _, err := conn.Exec(query, jti, exp) + _, err := tx.Exec(ctx, query, jti, exp) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } // Check if a token has been revoked. Returns true if not revoked. -func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) { +func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) { jti := t.GetJTI() query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` - rows, err := conn.Query(query, jti) - defer rows.Close() + rows, err := tx.Query(ctx, query, jti) if err != nil { - return false, errors.Wrap(err, "conn.Exec") + return false, errors.Wrap(err, "tx.Query") } + // NOTE: rows.Close() + defer rows.Close() revoked := rows.Next() return !revoked, nil } diff --git a/jwt/tokens.go b/jwt/tokens.go index d76e952..ae5d97a 100644 --- a/jwt/tokens.go +++ b/jwt/tokens.go @@ -1,7 +1,7 @@ package jwt import ( - "database/sql" + "context" "projectreshoot/db" "github.com/google/uuid" @@ -12,7 +12,7 @@ type Token interface { GetJTI() uuid.UUID GetEXP() int64 GetScope() string - GetUser(conn *sql.DB) (*db.User, error) + GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) } // Access token @@ -38,15 +38,15 @@ type RefreshToken struct { Scope string // Should be "refresh" } -func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) { - user, err := db.GetUserFromID(conn, a.SUB) +func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) { + user, err := db.GetUserFromID(ctx, tx, a.SUB) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromID") } return user, nil } -func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) { - user, err := db.GetUserFromID(conn, r.SUB) +func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) { + user, err := db.GetUserFromID(ctx, tx, r.SUB) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromID") } diff --git a/main.go b/main.go index 84ed2fd..6e3032e 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "logging.GetLogger") } + oldconn, err := db.OldConnectToDatabase(config.DBName) + if err != nil { + return errors.Wrap(err, "db.ConnectToDatabase") + } + defer oldconn.Close() conn, err := db.ConnectToDatabase(config.DBName) if err != nil { return errors.Wrap(err, "db.ConnectToDatabase") @@ -88,7 +93,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "getStaticFiles") } - srv := server.NewServer(config, logger, conn, &staticFS) + srv := server.NewServer(config, logger, oldconn, conn, &staticFS) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, @@ -99,7 +104,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { // Runs function for testing in dev if --test flag true if args["test"] == "true" { - test(config, logger, conn, httpServer) + test(config, logger, oldconn, httpServer) return nil } diff --git a/middleware/authentication.go b/middleware/authentication.go index 23f40f0..3d31c94 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -1,7 +1,7 @@ package middleware import ( - "database/sql" + "context" "net/http" "time" @@ -9,6 +9,7 @@ import ( "projectreshoot/contexts" "projectreshoot/cookies" "projectreshoot/db" + "projectreshoot/handlers" "projectreshoot/jwt" "github.com/pkg/errors" @@ -18,14 +19,15 @@ import ( // Attempt to use a valid refresh token to generate a new token pair func refreshAuthTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, w http.ResponseWriter, req *http.Request, ref *jwt.RefreshToken, ) (*db.User, error) { - user, err := ref.GetUser(conn) + user, err := ref.GetUser(ctx, tx) if err != nil { - return nil, errors.Wrap(err, "rT.GetUser") + return nil, errors.Wrap(err, "ref.GetUser") } rememberMe := map[string]bool{ @@ -39,7 +41,7 @@ func refreshAuthTokens( return nil, errors.Wrap(err, "cookies.SetTokenCookies") } // New tokens sent, revoke the used refresh token - err = jwt.RevokeToken(conn, ref) + err = jwt.RevokeToken(ctx, tx, ref) if err != nil { return nil, errors.Wrap(err, "jwt.RevokeToken") } @@ -50,22 +52,23 @@ func refreshAuthTokens( // Check the cookies for token strings and attempt to authenticate them func getAuthenticatedUser( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, w http.ResponseWriter, r *http.Request, ) (*contexts.AuthenticatedUser, error) { // Get token strings from cookies atStr, rtStr := cookies.GetTokenStrings(r) // Attempt to parse the access token - aT, err := jwt.ParseAccessToken(config, conn, atStr) + aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) if err != nil { // Access token invalid, attempt to parse refresh token - rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) if err != nil { return nil, errors.Wrap(err, "jwt.ParseRefreshToken") } // Refresh token valid, attempt to get a new token pair - user, err := refreshAuthTokens(config, conn, w, r, rT) + user, err := refreshAuthTokens(config, ctx, tx, w, r, rT) if err != nil { return nil, errors.Wrap(err, "refreshAuthTokens") } @@ -77,9 +80,9 @@ func getAuthenticatedUser( return &authUser, nil } // Access token valid - user, err := aT.GetUser(conn) + user, err := aT.GetUser(ctx, tx) if err != nil { - return nil, errors.Wrap(err, "rT.GetUser") + return nil, errors.Wrap(err, "aT.GetUser") } authUser := contexts.AuthenticatedUser{ User: user, @@ -93,22 +96,36 @@ func getAuthenticatedUser( func Authentication( logger *zerolog.Logger, config *config.Config, - conn *sql.DB, + conn *db.SafeConn, next http.Handler, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user, err := getAuthenticatedUser(config, conn, w, r) - if err != nil { - // User auth failed, delete the cookies to avoid repeat requests - cookies.DeleteCookie(w, "access", "/") - cookies.DeleteCookie(w, "refresh", "/") - logger.Debug(). - Str("remote_addr", r.RemoteAddr). - Err(err). - Msg("Failed to authenticate user") + if r.URL.Path == "/static/css/output.css" || + r.URL.Path == "/static/favicon.ico" { + next.ServeHTTP(w, r) + return } - ctx := contexts.SetUser(r.Context(), user) - newReq := r.WithContext(ctx) - next.ServeHTTP(w, newReq) + handlers.WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + tx, err := conn.Begin(ctx) + user, err := getAuthenticatedUser(config, ctx, tx, w, r) + if err != nil { + tx.Rollback() + // User auth failed, delete the cookies to avoid repeat requests + cookies.DeleteCookie(w, "access", "/") + cookies.DeleteCookie(w, "refresh", "/") + logger.Debug(). + Str("remote_addr", r.RemoteAddr). + Err(err). + Msg("Failed to authenticate user") + next.ServeHTTP(w, r) + return + } + tx.Commit() + uctx := contexts.SetUser(r.Context(), user) + newReq := r.WithContext(uctx) + next.ServeHTTP(w, newReq) + }, + ) }) } diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index a5143c8..172bc03 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -19,7 +19,7 @@ func TestAuthenticationMiddleware(t *testing.T) { cfg, err := tests.TestConfig() require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + conn, err := tests.SetupTestDB(t.Context()) require.NoError(t, err) require.NotNil(t, conn) defer tests.DeleteTestDB() diff --git a/middleware/logging.go b/middleware/logging.go index abae797..de42258 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -23,9 +23,14 @@ func (w *wrappedWriter) WriteHeader(statusCode int) { // Middleware to add logs to console with details of the request func Logging(logger *zerolog.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/static/css/output.css" || + r.URL.Path == "/static/favicon.ico" { + next.ServeHTTP(w, r) + return + } start, err := contexts.GetStartTime(r.Context()) if err != nil { - // Handle failure here. internal server error maybe + // TODO: Handle failure here. internal server error maybe return } wrapped := &wrappedWriter{ diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go index de79975..03926b7 100644 --- a/middleware/pageprotection_test.go +++ b/middleware/pageprotection_test.go @@ -16,7 +16,7 @@ func TestPageLoginRequired(t *testing.T) { cfg, err := tests.TestConfig() require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + conn, err := tests.SetupTestDB(t.Context()) require.NoError(t, err) require.NotNil(t, conn) defer tests.DeleteTestDB() diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go index 63017cb..0f20840 100644 --- a/middleware/reauthentication_test.go +++ b/middleware/reauthentication_test.go @@ -16,7 +16,7 @@ func TestActionReauthRequired(t *testing.T) { cfg, err := tests.TestConfig() require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + conn, err := tests.SetupTestDB(t.Context()) require.NoError(t, err) require.NotNil(t, conn) defer tests.DeleteTestDB() diff --git a/server/routes.go b/server/routes.go index a92885f..5a9d8c9 100644 --- a/server/routes.go +++ b/server/routes.go @@ -5,6 +5,7 @@ import ( "net/http" "projectreshoot/config" + "projectreshoot/db" "projectreshoot/handlers" "projectreshoot/middleware" "projectreshoot/view/page" @@ -17,7 +18,8 @@ func addRoutes( mux *http.ServeMux, logger *zerolog.Logger, config *config.Config, - conn *sql.DB, + oldconn *sql.DB, + conn *db.SafeConn, staticFS *http.FileSystem, ) { // Health check @@ -42,7 +44,7 @@ func addRoutes( handlers.HandleLoginRequest( config, logger, - conn, + oldconn, ))) // Register page and handlers @@ -55,7 +57,7 @@ func addRoutes( handlers.HandleRegisterRequest( config, logger, - conn, + oldconn, ))) // Logout @@ -85,17 +87,17 @@ func addRoutes( mux.Handle("POST /change-username", middleware.RequiresLogin( middleware.RequiresFresh( - handlers.HandleChangeUsername(logger, conn), + handlers.HandleChangeUsername(logger, oldconn), ), )) mux.Handle("POST /change-bio", middleware.RequiresLogin( - handlers.HandleChangeBio(logger, conn), + handlers.HandleChangeBio(logger, oldconn), )) mux.Handle("POST /change-password", middleware.RequiresLogin( middleware.RequiresFresh( - handlers.HandleChangePassword(logger, conn), + handlers.HandleChangePassword(logger, oldconn), ), )) } diff --git a/server/server.go b/server/server.go index 648082f..c64a8cb 100644 --- a/server/server.go +++ b/server/server.go @@ -5,6 +5,7 @@ import ( "net/http" "projectreshoot/config" + "projectreshoot/db" "projectreshoot/middleware" "github.com/rs/zerolog" @@ -14,7 +15,8 @@ import ( func NewServer( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + oldconn *sql.DB, + conn *db.SafeConn, staticFS *http.FileSystem, ) http.Handler { mux := http.NewServeMux() @@ -22,6 +24,7 @@ func NewServer( mux, logger, config, + oldconn, conn, staticFS, ) diff --git a/tests/database.go b/tests/database.go index 7c606c5..a7fd26b 100644 --- a/tests/database.go +++ b/tests/database.go @@ -1,10 +1,12 @@ package tests import ( + "context" "database/sql" "fmt" "os" "path/filepath" + "projectreshoot/db" "github.com/pkg/errors" @@ -32,11 +34,16 @@ func findSQLFile(filename string) (string, error) { // SetupTestDB initializes a test SQLite database with mock data // Make sure to call DeleteTestDB when finished to cleanup -func SetupTestDB() (*sql.DB, error) { - conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db") +func SetupTestDB(ctx context.Context) (*db.SafeConn, error) { + dbfile, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db") if err != nil { return nil, errors.Wrap(err, "sql.Open") } + conn := db.MakeSafe(dbfile) + tx, err := conn.Begin(ctx) + if err != nil { + return nil, errors.Wrap(err, "conn.Begin") + } // Setup the test database schemaPath, err := findSQLFile("schema.sql") if err != nil { @@ -49,9 +56,10 @@ func SetupTestDB() (*sql.DB, error) { } schemaSQL := string(sqlBytes) - _, err = conn.Exec(schemaSQL) + _, err = tx.Exec(ctx, schemaSQL) if err != nil { - return nil, errors.Wrap(err, "conn.Exec") + tx.Rollback() + return nil, errors.Wrap(err, "tx.Exec") } // Load the test data dataPath, err := findSQLFile("testdata.sql") @@ -64,10 +72,12 @@ func SetupTestDB() (*sql.DB, error) { } dataSQL := string(sqlBytes) - _, err = conn.Exec(dataSQL) + _, err = tx.Exec(ctx, dataSQL) if err != nil { - return nil, errors.Wrap(err, "conn.Exec") + tx.Rollback() + return nil, errors.Wrap(err, "tx.Exec") } + tx.Commit() return conn, nil } From 6faf168a6d1ee41e5fcacbbe01617e865888a26e Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 21:38:29 +1100 Subject: [PATCH 06/47] Updated all code to use new SafeConn and SafeTX --- config/config.go | 5 +---- db/connection.go | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/config/config.go b/config/config.go index 25654d6..f4e60d7 100644 --- a/config/config.go +++ b/config/config.go @@ -33,10 +33,7 @@ type Config struct { // Load the application configuration and get a pointer to the Config object func GetConfig(args map[string]string) (*Config, error) { - err := godotenv.Load(".env") - if err != nil { - fmt.Println(err) - } + godotenv.Load(".env") var ( host string port string diff --git a/db/connection.go b/db/connection.go index bcd5895..a7c15a6 100644 --- a/db/connection.go +++ b/db/connection.go @@ -8,7 +8,7 @@ import ( "github.com/pkg/errors" - _ "github.com/mattn/go-sqlite3" + _ "modernc.org/sqlite" ) // Wraps the database handle, providing a mutex to safely manage transactions @@ -123,7 +123,7 @@ func (conn *SafeConn) Close() error { // Returns a database connection handle for the DB func OldConnectToDatabase(dbName string) (*sql.DB, error) { file := fmt.Sprintf("file:%s.db", dbName) - db, err := sql.Open("sqlite3", file) + db, err := sql.Open("sqlite", file) if err != nil { return nil, errors.Wrap(err, "sql.Open") } @@ -134,7 +134,7 @@ func OldConnectToDatabase(dbName string) (*sql.DB, error) { // Returns a database connection handle for the DB func ConnectToDatabase(dbName string) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) - db, err := sql.Open("sqlite3", file) + db, err := sql.Open("sqlite", file) if err != nil { return nil, errors.Wrap(err, "sql.Open") } From a8d112fdd52eccd94df31bc73102ad9b47fc9314 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 21:39:12 +1100 Subject: [PATCH 07/47] Updated all code to use SafeConn and SafeTX --- Makefile | 1 - db/user.go | 20 ++-- db/user_functions.go | 45 +++------ go.mod | 9 +- go.sum | 42 ++++++++- handlers/account.go | 137 ++++++++++++++++------------ handlers/login.go | 57 +++++++----- handlers/register.go | 65 +++++++------ main.go | 9 +- middleware/authentication.go | 1 - middleware/authentication_test.go | 14 +-- middleware/pageprotection_test.go | 14 +-- middleware/reauthentication_test.go | 16 ++-- server/routes.go | 12 +-- server/server.go | 3 - tester.go | 4 +- tests/database.go | 34 +------ 17 files changed, 265 insertions(+), 218 deletions(-) diff --git a/Makefile b/Makefile index 72a5804..66dd5b8 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,6 @@ tester: go run . --port 3232 --test --loglevel trace test: - rm -f **/.projectreshoot-test-database.db && \ go mod tidy && \ templ generate && \ go generate && \ diff --git a/db/user.go b/db/user.go index fe62349..a2daa26 100644 --- a/db/user.go +++ b/db/user.go @@ -1,7 +1,7 @@ package db import ( - "database/sql" + "context" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" @@ -16,16 +16,16 @@ type User struct { } // Uses bcrypt to set the users Password_hash from the given password -func (user *User) SetPassword(conn *sql.DB, password string) error { +func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return errors.Wrap(err, "bcrypt.GenerateFromPassword") } user.Password_hash = string(hashedPassword) query := `UPDATE users SET password_hash = ? WHERE id = ?` - _, err = conn.Exec(query, user.Password_hash, user.ID) + _, err = tx.Exec(ctx, query, user.Password_hash, user.ID) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } @@ -40,21 +40,21 @@ func (user *User) CheckPassword(password string) error { } // Change the user's username -func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error { +func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error { query := `UPDATE users SET username = ? WHERE id = ?` - _, err := conn.Exec(query, newUsername, user.ID) + _, err := tx.Exec(ctx, query, newUsername, user.ID) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } // Change the user's bio -func (user *User) ChangeBio(conn *sql.DB, newBio string) error { +func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error { query := `UPDATE users SET bio = ? WHERE id = ?` - _, err := conn.Exec(query, newBio, user.ID) + _, err := tx.Exec(ctx, query, newBio, user.ID) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } diff --git a/db/user_functions.go b/db/user_functions.go index 9c1100a..28d30ad 100644 --- a/db/user_functions.go +++ b/db/user_functions.go @@ -9,43 +9,28 @@ import ( ) // Creates a new user in the database and returns a pointer -func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { +func CreateNewUser( + ctx context.Context, + tx *SafeTX, + username string, + password string, +) (*User, error) { query := `INSERT INTO users (username) VALUES (?)` - _, err := conn.Exec(query, username) + _, err := tx.Exec(ctx, query, username) if err != nil { - return nil, errors.Wrap(err, "conn.Exec") + return nil, errors.Wrap(err, "tx.Exec") } - user, err := GetUserFromUsername(conn, username) + user, err := GetUserFromUsername(ctx, tx, username) if err != nil { return nil, errors.Wrap(err, "GetUserFromUsername") } - err = user.SetPassword(conn, password) + err = user.SetPassword(ctx, tx, password) if err != nil { return nil, errors.Wrap(err, "user.SetPassword") } return user, nil } -// Fetches data from the users table using "WHERE column = 'value'" -func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { - query := fmt.Sprintf( - `SELECT - id, - username, - password_hash, - created_at, - bio - FROM users - WHERE %s = ? COLLATE NOCASE LIMIT 1`, - column, - ) - rows, err := conn.Query(query, value) - if err != nil { - return nil, errors.Wrap(err, "conn.Query") - } - return rows, nil -} - // Fetches data from the users table using "WHERE column = 'value'" func fetchUserData( ctx context.Context, @@ -92,8 +77,8 @@ func scanUserRow(user *User, rows *sql.Rows) error { // Queries the database for a user matching the given username. // Query is case insensitive -func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { - rows, err := oldfetchUserData(conn, "username", username) +func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) { + rows, err := fetchUserData(ctx, tx, "username", username) if err != nil { return nil, errors.Wrap(err, "fetchUserData") } @@ -122,11 +107,11 @@ func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) { } // Checks if the given username is unique. Returns true if not taken -func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { +func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) { query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` - rows, err := conn.Query(query, username) + rows, err := tx.Query(ctx, query, username) if err != nil { - return false, errors.Wrap(err, "conn.Query") + return false, errors.Wrap(err, "tx.Query") } defer rows.Close() taken := rows.Next() diff --git a/go.mod b/go.mod index 97dfb00..98e11af 100644 --- a/go.mod +++ b/go.mod @@ -7,18 +7,25 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 - github.com/mattn/go-sqlite3 v1.14.24 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.33.0 + modernc.org/sqlite v1.35.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect golang.org/x/sys v0.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.61.13 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.8.2 // indirect ) diff --git a/go.sum b/go.sum index 328d060..baa5432 100644 --- a/go.sum +++ b/go.sum @@ -3,11 +3,15 @@ github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YY github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -19,12 +23,14 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= -github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= @@ -32,12 +38,44 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo= +golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= +golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= +modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo= +modernc.org/ccgo/v4 v4.23.16/go.mod h1:nNma8goMTY7aQZQNTyN9AIoJfxav4nvTnvKThAeMDdo= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.6.3 h1:aJVhcqAte49LF+mGveZ5KPlsp4tdGdAOT4sipJXADjw= +modernc.org/gc/v2 v2.6.3/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/libc v1.61.13 h1:3LRd6ZO1ezsFiX1y+bHd1ipyEHIJKvuprv0sLTBwLW8= +modernc.org/libc v1.61.13/go.mod h1:8F/uJWL/3nNil0Lgt1Dpz+GgkApWh04N3el3hxJcA6E= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.8.2 h1:cL9L4bcoAObu4NkxOlKWBWtNHIsnnACGF/TbqQ6sbcI= +modernc.org/memory v1.8.2/go.mod h1:ZbjSvMO5NQ1A2i3bWeDiVMxIorXwdClKE/0SZ+BMotU= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.35.0 h1:yQps4fegMnZFdphtzlfQTCNBWtS0CZv48pRpW3RFHRw= +modernc.org/sqlite v1.35.0/go.mod h1:9cr2sicr7jIaWTBKQmAxQLfBv9LL0su4ZTEV+utt3ic= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/handlers/account.go b/handlers/account.go index 365a283..dac0a3b 100644 --- a/handlers/account.go +++ b/handlers/account.go @@ -1,7 +1,7 @@ package handlers import ( - "database/sql" + "context" "net/http" "projectreshoot/contexts" @@ -43,32 +43,39 @@ func HandleAccountSubpage() http.Handler { // Handles a request to change the users username func HandleChangeUsername( logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - newUsername := r.FormValue("username") - - unique, err := db.CheckUsernameUnique(conn, newUsername) - if err != nil { - logger.Error().Err(err).Msg("Error updating username") - w.WriteHeader(http.StatusInternalServerError) - return - } - if !unique { - account.ChangeUsername("Username is taken", newUsername). - Render(r.Context(), w) - return - } - user := contexts.GetUser(r.Context()) - err = user.ChangeUsername(conn, newUsername) - if err != nil { - logger.Error().Err(err).Msg("Error updating username") - w.WriteHeader(http.StatusInternalServerError) - return - } - w.Header().Set("HX-Refresh", "true") + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + newUsername := r.FormValue("username") + unique, err := db.CheckUsernameUnique(ctx, tx, newUsername) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + if !unique { + tx.Rollback() + account.ChangeUsername("Username is taken", newUsername). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.ChangeUsername(ctx, tx, newUsername) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.Header().Set("HX-Refresh", "true") + }, + ) }, ) } @@ -76,30 +83,41 @@ func HandleChangeUsername( // Handles a request to change the users bio func HandleChangeBio( logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - newBio := r.FormValue("bio") - leng := len([]rune(newBio)) - if leng > 128 { - account.ChangeBio("Bio limited to 128 characters", newBio). - Render(r.Context(), w) - return - } - user := contexts.GetUser(r.Context()) - err := user.ChangeBio(conn, newBio) - if err != nil { - logger.Error().Err(err).Msg("Error updating bio") - w.WriteHeader(http.StatusInternalServerError) - return - } - w.Header().Set("HX-Refresh", "true") + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + newBio := r.FormValue("bio") + leng := len([]rune(newBio)) + if leng > 128 { + tx.Rollback() + account.ChangeBio("Bio limited to 128 characters", newBio). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err := user.ChangeBio(ctx, tx, newBio) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating bio") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.Header().Set("HX-Refresh", "true") + }, + ) }, ) } -func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) { +func validateChangePassword( + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) (string, error) { r.ParseForm() formPassword := r.FormValue("password") formConfirmPassword := r.FormValue("confirm-password") @@ -115,23 +133,30 @@ func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) { // Handles a request to change the users password func HandleChangePassword( logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - newPass, err := validateChangePassword(conn, r) - if err != nil { - account.ChangePassword(err.Error()).Render(r.Context(), w) - return - } - user := contexts.GetUser(r.Context()) - err = user.SetPassword(conn, newPass) - if err != nil { - logger.Error().Err(err).Msg("Error updating password") - w.WriteHeader(http.StatusInternalServerError) - return - } - w.Header().Set("HX-Refresh", "true") + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + newPass, err := validateChangePassword(ctx, tx, r) + if err != nil { + tx.Rollback() + account.ChangePassword(err.Error()).Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.SetPassword(ctx, tx, newPass) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating password") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.Header().Set("HX-Refresh", "true") + }, + ) }, ) } diff --git a/handlers/login.go b/handlers/login.go index 7af3901..8788b01 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -1,7 +1,7 @@ package handlers import ( - "database/sql" + "context" "net/http" "projectreshoot/config" @@ -16,10 +16,14 @@ import ( // Validates the username matches a user in the database and the password // is correct. Returns the corresponding user -func validateLogin(conn *sql.DB, r *http.Request) (*db.User, error) { +func validateLogin( + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) (*db.User, error) { formUsername := r.FormValue("username") formPassword := r.FormValue("password") - user, err := db.GetUserFromUsername(conn, formUsername) + user, err := db.GetUserFromUsername(ctx, tx, formUsername) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromUsername") } @@ -47,31 +51,38 @@ func checkRememberMe(r *http.Request) bool { func HandleLoginRequest( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - user, err := validateLogin(conn, r) - if err != nil { - if err.Error() != "Username or password incorrect" { - logger.Warn().Caller().Err(err).Msg("Login request failed") - w.WriteHeader(http.StatusInternalServerError) - } else { - form.LoginForm(err.Error()).Render(r.Context(), w) - } - return - } + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + user, err := validateLogin(ctx, tx, r) + if err != nil { + tx.Rollback() + if err.Error() != "Username or password incorrect" { + logger.Warn().Caller().Err(err).Msg("Login request failed") + w.WriteHeader(http.StatusInternalServerError) + } else { + form.LoginForm(err.Error()).Render(r.Context(), w) + } + return + } - rememberMe := checkRememberMe(r) - err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") - } + rememberMe := checkRememberMe(r) + err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) + if err != nil { + tx.Rollback() + w.WriteHeader(http.StatusInternalServerError) + logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") + return + } - pageFrom := cookies.CheckPageFrom(w, r) - w.Header().Set("HX-Redirect", pageFrom) + tx.Commit() + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) + }) }, ) } diff --git a/handlers/register.go b/handlers/register.go index 895ab67..605b02c 100644 --- a/handlers/register.go +++ b/handlers/register.go @@ -1,7 +1,7 @@ package handlers import ( - "database/sql" + "context" "net/http" "projectreshoot/config" @@ -14,11 +14,15 @@ import ( "github.com/rs/zerolog" ) -func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { +func validateRegistration( + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) (*db.User, error) { formUsername := r.FormValue("username") formPassword := r.FormValue("password") formConfirmPassword := r.FormValue("confirm-password") - unique, err := db.CheckUsernameUnique(conn, formUsername) + unique, err := db.CheckUsernameUnique(ctx, tx, formUsername) if err != nil { return nil, errors.Wrap(err, "db.CheckUsernameUnique") } @@ -31,7 +35,7 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { if len(formPassword) > 72 { return nil, errors.New("Password exceeds maximum length of 72 bytes") } - user, err := db.CreateNewUser(conn, formUsername, formPassword) + user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword) if err != nil { return nil, errors.Wrap(err, "db.CreateNewUser") } @@ -42,33 +46,40 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { func HandleRegisterRequest( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - user, err := validateRegistration(conn, r) - if err != nil { - if err.Error() != "Username is taken" && - err.Error() != "Passwords do not match" && - err.Error() != "Password exceeds maximum length of 72 bytes" { - logger.Warn().Caller().Err(err).Msg("Registration request failed") - w.WriteHeader(http.StatusInternalServerError) - } else { - form.RegisterForm(err.Error()).Render(r.Context(), w) - } - return - } + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + user, err := validateRegistration(ctx, tx, r) + if err != nil { + tx.Rollback() + if err.Error() != "Username is taken" && + err.Error() != "Passwords do not match" && + err.Error() != "Password exceeds maximum length of 72 bytes" { + logger.Warn().Caller().Err(err).Msg("Registration request failed") + w.WriteHeader(http.StatusInternalServerError) + } else { + form.RegisterForm(err.Error()).Render(r.Context(), w) + } + return + } - rememberMe := checkRememberMe(r) - err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") - } - - pageFrom := cookies.CheckPageFrom(w, r) - w.Header().Set("HX-Redirect", pageFrom) + rememberMe := checkRememberMe(r) + err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) + if err != nil { + tx.Rollback() + w.WriteHeader(http.StatusInternalServerError) + logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") + return + } + tx.Commit() + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) + }, + ) }, ) } diff --git a/main.go b/main.go index 6e3032e..84ed2fd 100644 --- a/main.go +++ b/main.go @@ -77,11 +77,6 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "logging.GetLogger") } - oldconn, err := db.OldConnectToDatabase(config.DBName) - if err != nil { - return errors.Wrap(err, "db.ConnectToDatabase") - } - defer oldconn.Close() conn, err := db.ConnectToDatabase(config.DBName) if err != nil { return errors.Wrap(err, "db.ConnectToDatabase") @@ -93,7 +88,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "getStaticFiles") } - srv := server.NewServer(config, logger, oldconn, conn, &staticFS) + srv := server.NewServer(config, logger, conn, &staticFS) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, @@ -104,7 +99,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { // Runs function for testing in dev if --test flag true if args["test"] == "true" { - test(config, logger, oldconn, httpServer) + test(config, logger, conn, httpServer) return nil } diff --git a/middleware/authentication.go b/middleware/authentication.go index 3d31c94..e444da8 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -107,7 +107,6 @@ func Authentication( } handlers.WithTransaction(w, r, logger, conn, func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - tx, err := conn.Begin(ctx) user, err := getAuthenticatedUser(config, ctx, tx, w, r) if err != nil { tx.Rollback() diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index 172bc03..95583af 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -8,6 +8,7 @@ import ( "testing" "projectreshoot/contexts" + "projectreshoot/db" "projectreshoot/tests" "github.com/stretchr/testify/assert" @@ -16,13 +17,14 @@ import ( func TestAuthenticationMiddleware(t *testing.T) { // Basic setup + conn, err := tests.SetupTestDB() + require.NoError(t, err) + sconn := db.MakeSafe(conn) + defer sconn.Close() + cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.NilLogger() - conn, err := tests.SetupTestDB(t.Context()) - require.NoError(t, err) - require.NotNil(t, conn) - defer tests.DeleteTestDB() + logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -38,7 +40,7 @@ func TestAuthenticationMiddleware(t *testing.T) { }) // Add the middleware and create the server - authHandler := Authentication(logger, cfg, conn, testHandler) + authHandler := Authentication(logger, cfg, sconn, testHandler) require.NoError(t, err) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go index 03926b7..80b0a15 100644 --- a/middleware/pageprotection_test.go +++ b/middleware/pageprotection_test.go @@ -5,6 +5,7 @@ import ( "net/http/httptest" "testing" + "projectreshoot/db" "projectreshoot/tests" "github.com/stretchr/testify/assert" @@ -13,13 +14,14 @@ import ( func TestPageLoginRequired(t *testing.T) { // Basic setup + conn, err := tests.SetupTestDB() + require.NoError(t, err) + sconn := db.MakeSafe(conn) + defer sconn.Close() + cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.NilLogger() - conn, err := tests.SetupTestDB(t.Context()) - require.NoError(t, err) - require.NotNil(t, conn) - defer tests.DeleteTestDB() + logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -28,7 +30,7 @@ func TestPageLoginRequired(t *testing.T) { // Add the middleware and create the server loginRequiredHandler := RequiresLogin(testHandler) - authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go index 0f20840..595e4e7 100644 --- a/middleware/reauthentication_test.go +++ b/middleware/reauthentication_test.go @@ -5,21 +5,23 @@ import ( "net/http/httptest" "testing" + "projectreshoot/db" "projectreshoot/tests" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestActionReauthRequired(t *testing.T) { +func TestReauthRequired(t *testing.T) { // Basic setup + conn, err := tests.SetupTestDB() + require.NoError(t, err) + sconn := db.MakeSafe(conn) + defer sconn.Close() + cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.NilLogger() - conn, err := tests.SetupTestDB(t.Context()) - require.NoError(t, err) - require.NotNil(t, conn) - defer tests.DeleteTestDB() + logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -29,7 +31,7 @@ func TestActionReauthRequired(t *testing.T) { // Add the middleware and create the server reauthRequiredHandler := RequiresFresh(testHandler) loginRequiredHandler := RequiresLogin(reauthRequiredHandler) - authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/server/routes.go b/server/routes.go index 5a9d8c9..eac1e17 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,7 +1,6 @@ package server import ( - "database/sql" "net/http" "projectreshoot/config" @@ -18,7 +17,6 @@ func addRoutes( mux *http.ServeMux, logger *zerolog.Logger, config *config.Config, - oldconn *sql.DB, conn *db.SafeConn, staticFS *http.FileSystem, ) { @@ -44,7 +42,7 @@ func addRoutes( handlers.HandleLoginRequest( config, logger, - oldconn, + conn, ))) // Register page and handlers @@ -57,7 +55,7 @@ func addRoutes( handlers.HandleRegisterRequest( config, logger, - oldconn, + conn, ))) // Logout @@ -87,17 +85,17 @@ func addRoutes( mux.Handle("POST /change-username", middleware.RequiresLogin( middleware.RequiresFresh( - handlers.HandleChangeUsername(logger, oldconn), + handlers.HandleChangeUsername(logger, conn), ), )) mux.Handle("POST /change-bio", middleware.RequiresLogin( - handlers.HandleChangeBio(logger, oldconn), + handlers.HandleChangeBio(logger, conn), )) mux.Handle("POST /change-password", middleware.RequiresLogin( middleware.RequiresFresh( - handlers.HandleChangePassword(logger, oldconn), + handlers.HandleChangePassword(logger, conn), ), )) } diff --git a/server/server.go b/server/server.go index c64a8cb..fa75b0a 100644 --- a/server/server.go +++ b/server/server.go @@ -1,7 +1,6 @@ package server import ( - "database/sql" "net/http" "projectreshoot/config" @@ -15,7 +14,6 @@ import ( func NewServer( config *config.Config, logger *zerolog.Logger, - oldconn *sql.DB, conn *db.SafeConn, staticFS *http.FileSystem, ) http.Handler { @@ -24,7 +22,6 @@ func NewServer( mux, logger, config, - oldconn, conn, staticFS, ) diff --git a/tester.go b/tester.go index bfd8981..e474d82 100644 --- a/tester.go +++ b/tester.go @@ -1,10 +1,10 @@ package main import ( - "database/sql" "net/http" "projectreshoot/config" + "projectreshoot/db" "github.com/rs/zerolog" ) @@ -18,7 +18,7 @@ import ( func test( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, srv *http.Server, ) { } diff --git a/tests/database.go b/tests/database.go index a7fd26b..549db2b 100644 --- a/tests/database.go +++ b/tests/database.go @@ -1,16 +1,14 @@ package tests import ( - "context" "database/sql" "fmt" "os" "path/filepath" - "projectreshoot/db" "github.com/pkg/errors" - _ "github.com/mattn/go-sqlite3" + _ "modernc.org/sqlite" ) func findSQLFile(filename string) (string, error) { @@ -33,17 +31,11 @@ func findSQLFile(filename string) (string, error) { } // SetupTestDB initializes a test SQLite database with mock data -// Make sure to call DeleteTestDB when finished to cleanup -func SetupTestDB(ctx context.Context) (*db.SafeConn, error) { - dbfile, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db") +func SetupTestDB() (*sql.DB, error) { + conn, err := sql.Open("sqlite", "file::memory:?cache=shared") if err != nil { return nil, errors.Wrap(err, "sql.Open") } - conn := db.MakeSafe(dbfile) - tx, err := conn.Begin(ctx) - if err != nil { - return nil, errors.Wrap(err, "conn.Begin") - } // Setup the test database schemaPath, err := findSQLFile("schema.sql") if err != nil { @@ -56,9 +48,8 @@ func SetupTestDB(ctx context.Context) (*db.SafeConn, error) { } schemaSQL := string(sqlBytes) - _, err = tx.Exec(ctx, schemaSQL) + _, err = conn.Exec(schemaSQL) if err != nil { - tx.Rollback() return nil, errors.Wrap(err, "tx.Exec") } // Load the test data @@ -72,24 +63,9 @@ func SetupTestDB(ctx context.Context) (*db.SafeConn, error) { } dataSQL := string(sqlBytes) - _, err = tx.Exec(ctx, dataSQL) + _, err = conn.Exec(dataSQL) if err != nil { - tx.Rollback() return nil, errors.Wrap(err, "tx.Exec") } - tx.Commit() return conn, nil } - -// Deletes the test database from disk -func DeleteTestDB() error { - fileName := ".projectreshoot-test-database.db" - - // Attempt to remove the file - err := os.Remove(fileName) - if err != nil { - return errors.Wrap(err, "os.Remove") - } - - return nil -} From 19f26d62a3a0f49478964f12aa40f56397961248 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 21:51:09 +1100 Subject: [PATCH 08/47] Removed OldConnectToDatabase --- db/connection.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/db/connection.go b/db/connection.go index a7c15a6..17311d3 100644 --- a/db/connection.go +++ b/db/connection.go @@ -120,17 +120,6 @@ func (conn *SafeConn) Close() error { return conn.db.Close() } -// Returns a database connection handle for the DB -func OldConnectToDatabase(dbName string) (*sql.DB, error) { - file := fmt.Sprintf("file:%s.db", dbName) - db, err := sql.Open("sqlite", file) - if err != nil { - return nil, errors.Wrap(err, "sql.Open") - } - - return db, nil -} - // Returns a database connection handle for the DB func ConnectToDatabase(dbName string) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) From 556c93fc49edddc9ba05f0b64594d2ae4d0bb468 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 23:14:13 +1100 Subject: [PATCH 09/47] Changed to using atomics as mutex was causing deadlock --- db/connection.go | 78 +++++++++++++++++++++++++++---------- handlers/withtransaction.go | 3 +- main.go | 37 ++++++++++++++++++ 3 files changed, 96 insertions(+), 22 deletions(-) diff --git a/db/connection.go b/db/connection.go index 17311d3..432fe1e 100644 --- a/db/connection.go +++ b/db/connection.go @@ -4,17 +4,17 @@ import ( "context" "database/sql" "fmt" - "sync" + "sync/atomic" "github.com/pkg/errors" _ "modernc.org/sqlite" ) -// Wraps the database handle, providing a mutex to safely manage transactions type SafeConn struct { - db *sql.DB - mux sync.RWMutex + db *sql.DB + readLockCount int32 + globalLockStatus int32 } func MakeSafe(db *sql.DB) *SafeConn { @@ -27,24 +27,63 @@ type SafeTX struct { sc *SafeConn } +func (conn *SafeConn) acquireGlobalLock() bool { + if atomic.LoadInt32(&conn.readLockCount) > 0 || atomic.LoadInt32(&conn.globalLockStatus) == 1 { + return false + } + atomic.StoreInt32(&conn.globalLockStatus, 1) + fmt.Println("=====================GLOBAL LOCK ACQUIRED==================") + return true +} + +func (conn *SafeConn) releaseGlobalLock() { + atomic.StoreInt32(&conn.globalLockStatus, 0) + fmt.Println("=====================GLOBAL LOCK RELEASED==================") +} + +func (conn *SafeConn) acquireReadLock() bool { + if atomic.LoadInt32(&conn.globalLockStatus) == 1 { + return false + } + atomic.AddInt32(&conn.readLockCount, 1) + fmt.Println("=====================READ LOCK ACQUIRED==================") + return true +} + +func (conn *SafeConn) releaseReadLock() { + atomic.AddInt32(&conn.readLockCount, -1) + fmt.Println("=====================READ LOCK RELEASED==================") +} + // Starts a new transaction based on the current context. Will cancel if // the context is closed/cancelled/done func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { lockAcquired := make(chan struct{}) + lockCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { - conn.mux.RLock() - close(lockAcquired) + select { + case <-lockCtx.Done(): + fmt.Println("=====================READ LOCK ABANDONED==================") + return + default: + if conn.acquireReadLock() { + close(lockAcquired) // Lock acquired + } + } }() select { case <-lockAcquired: tx, err := conn.db.BeginTx(ctx, nil) if err != nil { - conn.mux.RUnlock() + conn.releaseReadLock() return nil, err } return &SafeTX{tx: tx, sc: conn}, nil case <-ctx.Done(): + cancel() return nil, errors.New("Transaction time out due to database lock") } } @@ -81,7 +120,7 @@ func (stx *SafeTX) Commit() error { err := stx.tx.Commit() stx.tx = nil - stx.releaseLock() + stx.sc.releaseReadLock() return err } @@ -92,31 +131,30 @@ func (stx *SafeTX) Rollback() error { } err := stx.tx.Rollback() stx.tx = nil - stx.releaseLock() + stx.sc.releaseReadLock() return err } -// Release the read lock for the transaction -func (stx *SafeTX) releaseLock() { - if stx.sc != nil { - stx.sc.mux.RUnlock() - } -} - // Pause blocks new transactions for a backup. func (conn *SafeConn) Pause() { - conn.mux.Lock() // Blocks all new transactions + for !conn.acquireGlobalLock() { + // TODO: add a timeout? + } + fmt.Println("Global database lock acquired") } // Resume allows transactions to proceed. func (conn *SafeConn) Resume() { - conn.mux.Unlock() + conn.releaseGlobalLock() + fmt.Println("Global database lock released") } // Close the database connection func (conn *SafeConn) Close() error { - conn.mux.Lock() - defer conn.mux.Unlock() + fmt.Println("=====================DB LOCKING FOR SHUTDOWN==================") + conn.acquireGlobalLock() + defer conn.releaseGlobalLock() + fmt.Println("=====================DB LOCKED FOR SHUTDOWN==================") return conn.db.Close() } diff --git a/handlers/withtransaction.go b/handlers/withtransaction.go index d81884f..db963fc 100644 --- a/handlers/withtransaction.go +++ b/handlers/withtransaction.go @@ -25,13 +25,12 @@ func WithTransaction( ), ) { // Create a cancellable context from the request context - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second) defer cancel() // Start the transaction tx, err := conn.Begin(ctx) if err != nil { - tx.Rollback() logger.Warn().Err(err).Msg("Request failed to start a transaction") w.WriteHeader(http.StatusServiceUnavailable) page.Error( diff --git a/main.go b/main.go index 84ed2fd..03406c6 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,8 @@ import ( "os/signal" "strconv" "sync" + "sync/atomic" + "syscall" "time" "projectreshoot/config" @@ -43,6 +45,36 @@ func getStaticFiles() (http.FileSystem, error) { } } +var maint uint32 // atomic: 1 if in maintenance mode + +func handleMaintSignals(conn *db.SafeConn, srv *http.Server) { + ch := make(chan os.Signal, 1) + srv.RegisterOnShutdown(func() { + close(ch) + }) + go func() { + for sig := range ch { + switch sig { + case syscall.SIGUSR1: + if atomic.LoadUint32(&maint) != 1 { + atomic.StoreUint32(&maint, 1) + fmt.Println("Signal received: Starting maintenance") + fmt.Println("Attempting to acquire database lock") + conn.Pause() + } + case syscall.SIGUSR2: + if atomic.LoadUint32(&maint) != 0 { + fmt.Println("Signal received: Maintenance over") + fmt.Println("Releasing database lock") + conn.Resume() + atomic.StoreUint32(&maint, 0) + } + } + } + }() + signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2) +} + // Initializes and runs the server func run(ctx context.Context, w io.Writer, args map[string]string) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) @@ -103,6 +135,10 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return nil } + // Setups a channel to listen for os.Signal + handleMaintSignals(conn, httpServer) + + // Runs the http server go func() { fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr) if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { @@ -110,6 +146,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } }() + // Handles graceful shutdown var wg sync.WaitGroup wg.Add(1) go func() { From 9ea58b096112754eba800bbf7fbf561030aaae74 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 23:31:09 +1100 Subject: [PATCH 10/47] Made auth middleware turn skip timeout if maintenance mode is on --- handlers/withtransaction.go | 5 +--- main.go | 2 +- middleware/authentication.go | 56 ++++++++++++++++++++++-------------- server/server.go | 3 +- 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/handlers/withtransaction.go b/handlers/withtransaction.go index db963fc..a097719 100644 --- a/handlers/withtransaction.go +++ b/handlers/withtransaction.go @@ -11,7 +11,6 @@ import ( "github.com/rs/zerolog" ) -// A helper function to create a transaction with a cancellable context. func WithTransaction( w http.ResponseWriter, r *http.Request, @@ -24,8 +23,7 @@ func WithTransaction( r *http.Request, ), ) { - // Create a cancellable context from the request context - ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) defer cancel() // Start the transaction @@ -41,6 +39,5 @@ func WithTransaction( return } - // Pass the context and transaction to the handler handler(ctx, tx, w, r) } diff --git a/main.go b/main.go index 03406c6..64b8af2 100644 --- a/main.go +++ b/main.go @@ -120,7 +120,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "getStaticFiles") } - srv := server.NewServer(config, logger, conn, &staticFS) + srv := server.NewServer(config, logger, conn, &staticFS, &maint) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, diff --git a/middleware/authentication.go b/middleware/authentication.go index e444da8..354a936 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -3,13 +3,13 @@ package middleware import ( "context" "net/http" + "sync/atomic" "time" "projectreshoot/config" "projectreshoot/contexts" "projectreshoot/cookies" "projectreshoot/db" - "projectreshoot/handlers" "projectreshoot/jwt" "github.com/pkg/errors" @@ -98,6 +98,7 @@ func Authentication( config *config.Config, conn *db.SafeConn, next http.Handler, + maint *uint32, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/static/css/output.css" || @@ -105,26 +106,37 @@ func Authentication( next.ServeHTTP(w, r) return } - handlers.WithTransaction(w, r, logger, conn, - func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - user, err := getAuthenticatedUser(config, ctx, tx, w, r) - if err != nil { - tx.Rollback() - // User auth failed, delete the cookies to avoid repeat requests - cookies.DeleteCookie(w, "access", "/") - cookies.DeleteCookie(w, "refresh", "/") - logger.Debug(). - Str("remote_addr", r.RemoteAddr). - Err(err). - Msg("Failed to authenticate user") - next.ServeHTTP(w, r) - return - } - tx.Commit() - uctx := contexts.SetUser(r.Context(), user) - newReq := r.WithContext(uctx) - next.ServeHTTP(w, newReq) - }, - ) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + if atomic.LoadUint32(maint) == 1 { + cancel() + } + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + // Failed to start transaction, warn the user they cant login right now + logger.Warn().Err(err).Msg("Request failed to start a transaction") + w.WriteHeader(http.StatusServiceUnavailable) + next.ServeHTTP(w, r) + return + } + user, err := getAuthenticatedUser(config, ctx, tx, w, r) + if err != nil { + tx.Rollback() + // User auth failed, delete the cookies to avoid repeat requests + cookies.DeleteCookie(w, "access", "/") + cookies.DeleteCookie(w, "refresh", "/") + logger.Debug(). + Str("remote_addr", r.RemoteAddr). + Err(err). + Msg("Failed to authenticate user") + next.ServeHTTP(w, r) + return + } + tx.Commit() + uctx := contexts.SetUser(r.Context(), user) + newReq := r.WithContext(uctx) + next.ServeHTTP(w, newReq) }) } diff --git a/server/server.go b/server/server.go index fa75b0a..8f189ad 100644 --- a/server/server.go +++ b/server/server.go @@ -16,6 +16,7 @@ func NewServer( logger *zerolog.Logger, conn *db.SafeConn, staticFS *http.FileSystem, + maint *uint32, ) http.Handler { mux := http.NewServeMux() addRoutes( @@ -29,7 +30,7 @@ func NewServer( // Add middleware here, must be added in reverse order of execution // i.e. First in list will get executed last during the request handling handler = middleware.Logging(logger, handler) - handler = middleware.Authentication(logger, config, conn, handler) + handler = middleware.Authentication(logger, config, conn, handler, maint) // Gzip handler = middleware.Gzip(handler, config.GZIP) From b4b57c14cb6c0e31d3387ad4b0ea69f751d5817e Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 17 Feb 2025 23:32:48 +1100 Subject: [PATCH 11/47] Removed debugging stdout prints --- db/connection.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/db/connection.go b/db/connection.go index 432fe1e..199417f 100644 --- a/db/connection.go +++ b/db/connection.go @@ -32,13 +32,11 @@ func (conn *SafeConn) acquireGlobalLock() bool { return false } atomic.StoreInt32(&conn.globalLockStatus, 1) - fmt.Println("=====================GLOBAL LOCK ACQUIRED==================") return true } func (conn *SafeConn) releaseGlobalLock() { atomic.StoreInt32(&conn.globalLockStatus, 0) - fmt.Println("=====================GLOBAL LOCK RELEASED==================") } func (conn *SafeConn) acquireReadLock() bool { @@ -46,13 +44,11 @@ func (conn *SafeConn) acquireReadLock() bool { return false } atomic.AddInt32(&conn.readLockCount, 1) - fmt.Println("=====================READ LOCK ACQUIRED==================") return true } func (conn *SafeConn) releaseReadLock() { atomic.AddInt32(&conn.readLockCount, -1) - fmt.Println("=====================READ LOCK RELEASED==================") } // Starts a new transaction based on the current context. Will cancel if @@ -65,7 +61,6 @@ func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { go func() { select { case <-lockCtx.Done(): - fmt.Println("=====================READ LOCK ABANDONED==================") return default: if conn.acquireReadLock() { @@ -151,10 +146,8 @@ func (conn *SafeConn) Resume() { // Close the database connection func (conn *SafeConn) Close() error { - fmt.Println("=====================DB LOCKING FOR SHUTDOWN==================") conn.acquireGlobalLock() defer conn.releaseGlobalLock() - fmt.Println("=====================DB LOCKED FOR SHUTDOWN==================") return conn.db.Close() } From c2de8d254a8e9222e5a8f1d207cb76624ef424cf Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 18 Feb 2025 10:02:42 +1100 Subject: [PATCH 12/47] Added proper debug logging to safeconn methods --- .gitignore | 1 + db/connection.go | 48 ++++++++++++++++++++--------- go.mod | 2 ++ go.sum | 12 ++++++-- main.go | 33 +++++++++++--------- middleware/authentication.go | 4 +-- middleware/authentication_test.go | 10 +++--- middleware/pageprotection_test.go | 9 ++++-- middleware/reauthentication_test.go | 9 ++++-- 9 files changed, 83 insertions(+), 45 deletions(-) diff --git a/.gitignore b/.gitignore index 37839ec..5b22f92 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ query.sql *.db .logs/ +server.log tmp/ projectreshoot static/css/output.css diff --git a/db/connection.go b/db/connection.go index 199417f..2a88505 100644 --- a/db/connection.go +++ b/db/connection.go @@ -4,21 +4,23 @@ import ( "context" "database/sql" "fmt" - "sync/atomic" + "os" "github.com/pkg/errors" + "github.com/rs/zerolog" _ "modernc.org/sqlite" ) type SafeConn struct { db *sql.DB - readLockCount int32 - globalLockStatus int32 + readLockCount uint32 + globalLockStatus uint32 + logger *zerolog.Logger } -func MakeSafe(db *sql.DB) *SafeConn { - return &SafeConn{db: db} +func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn { + return &SafeConn{db: db, logger: logger} } // Extends sql.Tx for use with SafeConn @@ -28,27 +30,35 @@ type SafeTX struct { } func (conn *SafeConn) acquireGlobalLock() bool { - if atomic.LoadInt32(&conn.readLockCount) > 0 || atomic.LoadInt32(&conn.globalLockStatus) == 1 { + if conn.readLockCount > 0 || conn.globalLockStatus == 1 { return false } - atomic.StoreInt32(&conn.globalLockStatus, 1) + conn.globalLockStatus = 1 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock acquired") return true } func (conn *SafeConn) releaseGlobalLock() { - atomic.StoreInt32(&conn.globalLockStatus, 0) + conn.globalLockStatus = 0 + conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus). + Msg("Global lock released") } func (conn *SafeConn) acquireReadLock() bool { - if atomic.LoadInt32(&conn.globalLockStatus) == 1 { + if conn.globalLockStatus == 1 { return false } - atomic.AddInt32(&conn.readLockCount, 1) + conn.readLockCount += 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock acquired") return true } func (conn *SafeConn) releaseReadLock() { - atomic.AddInt32(&conn.readLockCount, -1) + conn.readLockCount -= 1 + conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount). + Msg("Read lock released") } // Starts a new transaction based on the current context. Will cancel if @@ -134,32 +144,40 @@ func (stx *SafeTX) Rollback() error { func (conn *SafeConn) Pause() { for !conn.acquireGlobalLock() { // TODO: add a timeout? + // TODO: failed to acquire lock: print info with readLockCount + // every second, or update it dynamically } - fmt.Println("Global database lock acquired") + // force logger to log to Stdout + log := conn.logger.With().Logger().Output(os.Stdout) + log.Info().Msg("Global database lock acquired") } // Resume allows transactions to proceed. func (conn *SafeConn) Resume() { conn.releaseGlobalLock() - fmt.Println("Global database lock released") + // force logger to log to Stdout + log := conn.logger.With().Logger().Output(os.Stdout) + log.Info().Msg("Global database lock released") } // Close the database connection func (conn *SafeConn) Close() error { + conn.logger.Debug().Msg("Acquiring global lock for connection close") conn.acquireGlobalLock() defer conn.releaseGlobalLock() + conn.logger.Debug().Msg("Closing database connection") return conn.db.Close() } // Returns a database connection handle for the DB -func ConnectToDatabase(dbName string) (*SafeConn, error) { +func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) db, err := sql.Open("sqlite", file) if err != nil { return nil, errors.Wrap(err, "sql.Open") } - conn := MakeSafe(db) + conn := MakeSafe(db, logger) return conn, nil } diff --git a/go.mod b/go.mod index 98e11af..0852a78 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,9 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect + golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.61.13 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index baa5432..9206c36 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,11 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -42,8 +47,8 @@ golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtD golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -51,8 +56,9 @@ golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= diff --git a/main.go b/main.go index 64b8af2..baa7350 100644 --- a/main.go +++ b/main.go @@ -23,20 +23,21 @@ import ( "projectreshoot/server" "github.com/pkg/errors" + "github.com/rs/zerolog" ) //go:embed static/* var embeddedStatic embed.FS // Gets the static files -func getStaticFiles() (http.FileSystem, error) { +func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) { if _, err := os.Stat("static"); err == nil { // Use actual filesystem in development - fmt.Println("Using filesystem for static files") + logger.Debug().Msg("Using filesystem for static files") return http.Dir("static"), nil } else { // Use embedded filesystem in production - fmt.Println("Using embedded static files") + logger.Debug().Msg("Using embedded static files") subFS, err := fs.Sub(embeddedStatic, "static") if err != nil { return nil, errors.Wrap(err, "fs.Sub") @@ -47,7 +48,7 @@ func getStaticFiles() (http.FileSystem, error) { var maint uint32 // atomic: 1 if in maintenance mode -func handleMaintSignals(conn *db.SafeConn, srv *http.Server) { +func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Logger) { ch := make(chan os.Signal, 1) srv.RegisterOnShutdown(func() { close(ch) @@ -58,14 +59,16 @@ func handleMaintSignals(conn *db.SafeConn, srv *http.Server) { case syscall.SIGUSR1: if atomic.LoadUint32(&maint) != 1 { atomic.StoreUint32(&maint, 1) - fmt.Println("Signal received: Starting maintenance") - fmt.Println("Attempting to acquire database lock") + log := logger.With().Logger().Output(os.Stdout) + log.Info().Msg("Signal received: Starting maintenance") + log.Info().Msg("Attempting to acquire database lock") conn.Pause() } case syscall.SIGUSR2: if atomic.LoadUint32(&maint) != 0 { - fmt.Println("Signal received: Maintenance over") - fmt.Println("Releasing database lock") + log := logger.With().Logger().Output(os.Stdout) + log.Info().Msg("Signal received: Maintenance over") + log.Info().Msg("Releasing database lock") conn.Resume() atomic.StoreUint32(&maint, 0) } @@ -109,13 +112,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "logging.GetLogger") } - conn, err := db.ConnectToDatabase(config.DBName) + conn, err := db.ConnectToDatabase(config.DBName, logger) if err != nil { return errors.Wrap(err, "db.ConnectToDatabase") } defer conn.Close() - staticFS, err := getStaticFiles() + staticFS, err := getStaticFiles(logger) if err != nil { return errors.Wrap(err, "getStaticFiles") } @@ -136,13 +139,13 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } // Setups a channel to listen for os.Signal - handleMaintSignals(conn, httpServer) + handleMaintSignals(conn, httpServer, logger) // Runs the http server go func() { - fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr) + logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests") if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - fmt.Fprintf(os.Stderr, "Error listening and serving: %s\n", err) + logger.Error().Err(err).Msg("Error listening and serving") } }() @@ -156,11 +159,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second) defer cancel() if err := httpServer.Shutdown(shutdownCtx); err != nil { - fmt.Fprintf(os.Stderr, "Error shutting down http server: %s\n", err) + logger.Error().Err(err).Msg("Error shutting down server") } }() wg.Wait() - fmt.Fprintln(w, "Shutting down") + logger.Info().Msg("Shutting down") return nil } diff --git a/middleware/authentication.go b/middleware/authentication.go index 354a936..92837e5 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -115,8 +115,8 @@ func Authentication( // Start the transaction tx, err := conn.Begin(ctx) if err != nil { - // Failed to start transaction, warn the user they cant login right now - logger.Warn().Err(err).Msg("Request failed to start a transaction") + // Failed to start transaction, send 503 code to client + logger.Warn().Err(err).Msg("Skipping Auth - unable to start a transaction") w.WriteHeader(http.StatusServiceUnavailable) next.ServeHTTP(w, r) return diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index 95583af..6ce807c 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "strconv" + "sync/atomic" "testing" "projectreshoot/contexts" @@ -16,15 +17,15 @@ import ( ) func TestAuthenticationMiddleware(t *testing.T) { + logger := tests.NilLogger() // Basic setup conn, err := tests.SetupTestDB() require.NoError(t, err) - sconn := db.MakeSafe(conn) + sconn := db.MakeSafe(conn, logger) defer sconn.Close() cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -38,9 +39,10 @@ func TestAuthenticationMiddleware(t *testing.T) { w.Write([]byte(strconv.Itoa(user.ID))) } }) - + var maint uint32 + atomic.StoreUint32(&maint, 0) // Add the middleware and create the server - authHandler := Authentication(logger, cfg, sconn, testHandler) + authHandler := Authentication(logger, cfg, sconn, testHandler, &maint) require.NoError(t, err) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go index 80b0a15..c6efcba 100644 --- a/middleware/pageprotection_test.go +++ b/middleware/pageprotection_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "sync/atomic" "testing" "projectreshoot/db" @@ -13,24 +14,26 @@ import ( ) func TestPageLoginRequired(t *testing.T) { + logger := tests.NilLogger() // Basic setup conn, err := tests.SetupTestDB() require.NoError(t, err) - sconn := db.MakeSafe(conn) + sconn := db.MakeSafe(conn, logger) defer sconn.Close() cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + var maint uint32 + atomic.StoreUint32(&maint, 0) // Add the middleware and create the server loginRequiredHandler := RequiresLogin(testHandler) - authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go index 595e4e7..a1f2083 100644 --- a/middleware/reauthentication_test.go +++ b/middleware/reauthentication_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "sync/atomic" "testing" "projectreshoot/db" @@ -13,25 +14,27 @@ import ( ) func TestReauthRequired(t *testing.T) { + logger := tests.NilLogger() // Basic setup conn, err := tests.SetupTestDB() require.NoError(t, err) - sconn := db.MakeSafe(conn) + sconn := db.MakeSafe(conn, logger) defer sconn.Close() cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + var maint uint32 + atomic.StoreUint32(&maint, 0) // Add the middleware and create the server reauthRequiredHandler := RequiresFresh(testHandler) loginRequiredHandler := RequiresLogin(reauthRequiredHandler) - authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint) server := httptest.NewServer(authHandler) defer server.Close() From 38d47cdf638b49efd96722370b19f041ccbee398 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 18 Feb 2025 10:37:57 +1100 Subject: [PATCH 13/47] Added check for attempting to acquire global lock --- db/connection.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/db/connection.go b/db/connection.go index 2a88505..9bb39bd 100644 --- a/db/connection.go +++ b/db/connection.go @@ -13,10 +13,11 @@ import ( ) type SafeConn struct { - db *sql.DB - readLockCount uint32 - globalLockStatus uint32 - logger *zerolog.Logger + db *sql.DB + readLockCount uint32 + globalLockStatus uint32 + globalLockRequested uint32 + logger *zerolog.Logger } func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn { @@ -46,7 +47,7 @@ func (conn *SafeConn) releaseGlobalLock() { } func (conn *SafeConn) acquireReadLock() bool { - if conn.globalLockStatus == 1 { + if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 { return false } conn.readLockCount += 1 @@ -142,6 +143,7 @@ func (stx *SafeTX) Rollback() error { // Pause blocks new transactions for a backup. func (conn *SafeConn) Pause() { + conn.globalLockRequested = 1 for !conn.acquireGlobalLock() { // TODO: add a timeout? // TODO: failed to acquire lock: print info with readLockCount @@ -150,6 +152,7 @@ func (conn *SafeConn) Pause() { // force logger to log to Stdout log := conn.logger.With().Logger().Output(os.Stdout) log.Info().Msg("Global database lock acquired") + conn.globalLockRequested = 0 } // Resume allows transactions to proceed. From c0bbe687f90b735ccd2ba5b22e5762f1fef630de Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 18 Feb 2025 19:57:30 +1100 Subject: [PATCH 14/47] Added timeout for acquiring global lock --- config/config.go | 2 ++ db/connection.go | 43 ++++++++++++++++++++++++++----------------- handlers/login.go | 1 + main.go | 11 ++++++++--- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/config/config.go b/config/config.go index f4e60d7..66a17a8 100644 --- a/config/config.go +++ b/config/config.go @@ -22,6 +22,7 @@ type Config struct { WriteTimeout time.Duration // Timeout for writing requests in seconds IdleTimeout time.Duration // Timeout for idle connections in seconds DBName string // Filename of the db (doesnt include file extension) + DBLockTimeout time.Duration // Timeout for acquiring database lock SecretKey string // Secret key for signing tokens AccessTokenExpiry int64 // Access token expiry in minutes RefreshTokenExpiry int64 // Refresh token expiry in minutes @@ -87,6 +88,7 @@ func GetConfig(args map[string]string) (*Config, error) { WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10), IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120), DBName: GetEnvDefault("DB_NAME", "projectreshoot"), + DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60), SecretKey: os.Getenv("SECRET_KEY"), AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5), RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day diff --git a/db/connection.go b/db/connection.go index 9bb39bd..7a19106 100644 --- a/db/connection.go +++ b/db/connection.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "os" + "time" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -75,7 +76,7 @@ func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { return default: if conn.acquireReadLock() { - close(lockAcquired) // Lock acquired + close(lockAcquired) } } }() @@ -118,7 +119,7 @@ func (stx *SafeTX) Exec( return stx.tx.ExecContext(ctx, query, args...) } -// Commit commits the transaction and releases the lock. +// Commit the current transaction and release the read lock func (stx *SafeTX) Commit() error { if stx.tx == nil { return errors.New("Cannot commit without a transaction") @@ -130,7 +131,7 @@ func (stx *SafeTX) Commit() error { return err } -// Rollback aborts the transaction. +// Abort the current transaction, releasing the read lock func (stx *SafeTX) Rollback() error { if stx.tx == nil { return errors.New("Cannot rollback without a transaction") @@ -141,21 +142,31 @@ func (stx *SafeTX) Rollback() error { return err } -// Pause blocks new transactions for a backup. -func (conn *SafeConn) Pause() { - conn.globalLockRequested = 1 - for !conn.acquireGlobalLock() { - // TODO: add a timeout? - // TODO: failed to acquire lock: print info with readLockCount - // every second, or update it dynamically - } - // force logger to log to Stdout +// Acquire a global lock, preventing all transactions +func (conn *SafeConn) Pause(timeoutAfter time.Duration) { + // force logger to log to Stdout so the signalling process can check log := conn.logger.With().Logger().Output(os.Stdout) - log.Info().Msg("Global database lock acquired") - conn.globalLockRequested = 0 + log.Info().Msg("Attempting to acquire global database lock") + conn.globalLockRequested = 1 + defer func() { conn.globalLockRequested = 0 }() + timeout := time.After(timeoutAfter) + attempt := 0 + for { + if conn.acquireGlobalLock() { + log.Info().Msg("Global database lock acquired") + return + } + select { + case <-timeout: + log.Info().Msg("Timeout: Global database lock abandoned") + return + case <-time.After(100 * time.Millisecond): + attempt++ + } + } } -// Resume allows transactions to proceed. +// Release the global lock func (conn *SafeConn) Resume() { conn.releaseGlobalLock() // force logger to log to Stdout @@ -179,8 +190,6 @@ func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) if err != nil { return nil, errors.Wrap(err, "sql.Open") } - conn := MakeSafe(db, logger) - return conn, nil } diff --git a/handlers/login.go b/handlers/login.go index 8788b01..85f6c1c 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -3,6 +3,7 @@ package handlers import ( "context" "net/http" + "time" "projectreshoot/config" "projectreshoot/cookies" diff --git a/main.go b/main.go index baa7350..70c1dc9 100644 --- a/main.go +++ b/main.go @@ -48,7 +48,12 @@ func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) { var maint uint32 // atomic: 1 if in maintenance mode -func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Logger) { +func handleMaintSignals( + conn *db.SafeConn, + srv *http.Server, + logger *zerolog.Logger, + config *config.Config, +) { ch := make(chan os.Signal, 1) srv.RegisterOnShutdown(func() { close(ch) @@ -62,7 +67,7 @@ func handleMaintSignals(conn *db.SafeConn, srv *http.Server, logger *zerolog.Log log := logger.With().Logger().Output(os.Stdout) log.Info().Msg("Signal received: Starting maintenance") log.Info().Msg("Attempting to acquire database lock") - conn.Pause() + conn.Pause(config.DBLockTimeout * time.Second) } case syscall.SIGUSR2: if atomic.LoadUint32(&maint) != 0 { @@ -139,7 +144,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } // Setups a channel to listen for os.Signal - handleMaintSignals(conn, httpServer, logger) + handleMaintSignals(conn, httpServer, logger, config) // Runs the http server go func() { From 0ece08726d840f5d7fdca5538da8a89b0142f8dd Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 18 Feb 2025 20:05:41 +1100 Subject: [PATCH 15/47] Added debug logging to run function --- handlers/login.go | 1 - main.go | 9 +++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/handlers/login.go b/handlers/login.go index 85f6c1c..8788b01 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -3,7 +3,6 @@ package handlers import ( "context" "net/http" - "time" "projectreshoot/config" "projectreshoot/cookies" diff --git a/main.go b/main.go index 70c1dc9..9fdb6c0 100644 --- a/main.go +++ b/main.go @@ -48,14 +48,17 @@ func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) { var maint uint32 // atomic: 1 if in maintenance mode +// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode func handleMaintSignals( conn *db.SafeConn, srv *http.Server, logger *zerolog.Logger, config *config.Config, ) { + logger.Debug().Msg("Starting signal listener") ch := make(chan os.Signal, 1) srv.RegisterOnShutdown(func() { + logger.Debug().Msg("Shutting down signal listener") close(ch) }) go func() { @@ -117,17 +120,21 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "logging.GetLogger") } + logger.Debug().Msg("Config loaded and logger started") + logger.Debug().Msg("Connecting to database") conn, err := db.ConnectToDatabase(config.DBName, logger) if err != nil { return errors.Wrap(err, "db.ConnectToDatabase") } defer conn.Close() + logger.Debug().Msg("Getting static files") staticFS, err := getStaticFiles(logger) if err != nil { return errors.Wrap(err, "getStaticFiles") } + logger.Debug().Msg("Setting up HTTP server") srv := server.NewServer(config, logger, conn, &staticFS, &maint) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), @@ -139,6 +146,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { // Runs function for testing in dev if --test flag true if args["test"] == "true" { + logger.Debug().Msg("Running tester function") test(config, logger, conn, httpServer) return nil } @@ -147,6 +155,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { handleMaintSignals(conn, httpServer, logger, config) // Runs the http server + logger.Debug().Msg("Starting up the HTTP server") go func() { logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests") if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { From 1f7a9e08e6ad3468d22bb12061b957f592fd2322 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 18 Feb 2025 20:51:22 +1100 Subject: [PATCH 16/47] Added error 503 popup --- handlers/account.go | 163 ++++++++++-------- handlers/login.go | 60 ++++--- handlers/logout.go | 36 ++-- handlers/reauthenticatate.go | 46 +++-- handlers/register.go | 65 +++---- handlers/withtransaction.go | 12 +- .../{errorPopup.templ => error500Popup.templ} | 6 +- view/component/popup/error503Popup.templ | 63 +++++++ view/layout/global.templ | 17 +- 9 files changed, 291 insertions(+), 177 deletions(-) rename view/component/popup/{errorPopup.templ => error500Popup.templ} (95%) create mode 100644 view/component/popup/error503Popup.templ diff --git a/handlers/account.go b/handlers/account.go index dac0a3b..f9d60fc 100644 --- a/handlers/account.go +++ b/handlers/account.go @@ -3,6 +3,7 @@ package handlers import ( "context" "net/http" + "time" "projectreshoot/contexts" "projectreshoot/cookies" @@ -47,35 +48,41 @@ func HandleChangeUsername( ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - WithTransaction(w, r, logger, conn, - func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - r.ParseForm() - newUsername := r.FormValue("username") - unique, err := db.CheckUsernameUnique(ctx, tx, newUsername) - if err != nil { - tx.Rollback() - logger.Error().Err(err).Msg("Error updating username") - w.WriteHeader(http.StatusInternalServerError) - return - } - if !unique { - tx.Rollback() - account.ChangeUsername("Username is taken", newUsername). - Render(r.Context(), w) - return - } - user := contexts.GetUser(r.Context()) - err = user.ChangeUsername(ctx, tx, newUsername) - if err != nil { - tx.Rollback() - logger.Error().Err(err).Msg("Error updating username") - w.WriteHeader(http.StatusInternalServerError) - return - } - tx.Commit() - w.Header().Set("HX-Refresh", "true") - }, - ) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + r.ParseForm() + newUsername := r.FormValue("username") + unique, err := db.CheckUsernameUnique(ctx, tx, newUsername) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + if !unique { + tx.Rollback() + account.ChangeUsername("Username is taken", newUsername). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.ChangeUsername(ctx, tx, newUsername) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.Header().Set("HX-Refresh", "true") }, ) } @@ -87,29 +94,35 @@ func HandleChangeBio( ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - WithTransaction(w, r, logger, conn, - func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - r.ParseForm() - newBio := r.FormValue("bio") - leng := len([]rune(newBio)) - if leng > 128 { - tx.Rollback() - account.ChangeBio("Bio limited to 128 characters", newBio). - Render(r.Context(), w) - return - } - user := contexts.GetUser(r.Context()) - err := user.ChangeBio(ctx, tx, newBio) - if err != nil { - tx.Rollback() - logger.Error().Err(err).Msg("Error updating bio") - w.WriteHeader(http.StatusInternalServerError) - return - } - tx.Commit() - w.Header().Set("HX-Refresh", "true") - }, - ) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Error updating bio") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + r.ParseForm() + newBio := r.FormValue("bio") + leng := len([]rune(newBio)) + if leng > 128 { + tx.Rollback() + account.ChangeBio("Bio limited to 128 characters", newBio). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.ChangeBio(ctx, tx, newBio) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating bio") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.Header().Set("HX-Refresh", "true") }, ) } @@ -137,26 +150,32 @@ func HandleChangePassword( ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - WithTransaction(w, r, logger, conn, - func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - newPass, err := validateChangePassword(ctx, tx, r) - if err != nil { - tx.Rollback() - account.ChangePassword(err.Error()).Render(r.Context(), w) - return - } - user := contexts.GetUser(r.Context()) - err = user.SetPassword(ctx, tx, newPass) - if err != nil { - tx.Rollback() - logger.Error().Err(err).Msg("Error updating password") - w.WriteHeader(http.StatusInternalServerError) - return - } - tx.Commit() - w.Header().Set("HX-Refresh", "true") - }, - ) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Error updating password") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + newPass, err := validateChangePassword(ctx, tx, r) + if err != nil { + tx.Rollback() + account.ChangePassword(err.Error()).Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.SetPassword(ctx, tx, newPass) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating password") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.Header().Set("HX-Refresh", "true") }, ) } diff --git a/handlers/login.go b/handlers/login.go index 8788b01..b646e7e 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -3,6 +3,7 @@ package handlers import ( "context" "net/http" + "time" "projectreshoot/config" "projectreshoot/cookies" @@ -55,34 +56,41 @@ func HandleLoginRequest( ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - WithTransaction(w, r, logger, conn, - func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - r.ParseForm() - user, err := validateLogin(ctx, tx, r) - if err != nil { - tx.Rollback() - if err.Error() != "Username or password incorrect" { - logger.Warn().Caller().Err(err).Msg("Login request failed") - w.WriteHeader(http.StatusInternalServerError) - } else { - form.LoginForm(err.Error()).Render(r.Context(), w) - } - return - } + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() - rememberMe := checkRememberMe(r) - err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) - if err != nil { - tx.Rollback() - w.WriteHeader(http.StatusInternalServerError) - logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") - return - } + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Failed to set token cookies") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + r.ParseForm() + user, err := validateLogin(ctx, tx, r) + if err != nil { + tx.Rollback() + if err.Error() != "Username or password incorrect" { + logger.Warn().Caller().Err(err).Msg("Login request failed") + w.WriteHeader(http.StatusInternalServerError) + } else { + form.LoginForm(err.Error()).Render(r.Context(), w) + } + return + } - tx.Commit() - pageFrom := cookies.CheckPageFrom(w, r) - w.Header().Set("HX-Redirect", pageFrom) - }) + rememberMe := checkRememberMe(r) + err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) + if err != nil { + tx.Rollback() + w.WriteHeader(http.StatusInternalServerError) + logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") + return + } + + tx.Commit() + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) }, ) } diff --git a/handlers/logout.go b/handlers/logout.go index da78925..b93db43 100644 --- a/handlers/logout.go +++ b/handlers/logout.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "strings" + "time" "projectreshoot/config" "projectreshoot/cookies" @@ -86,20 +87,27 @@ func HandleLogout( ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - WithTransaction(w, r, logger, conn, - func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - err := revokeTokens(config, ctx, tx, r) - if err != nil { - tx.Rollback() - logger.Error().Err(err).Msg("Error occured on user logout") - w.WriteHeader(http.StatusInternalServerError) - return - } - tx.Commit() - cookies.DeleteCookie(w, "access", "/") - cookies.DeleteCookie(w, "refresh", "/") - w.Header().Set("HX-Redirect", "/login") - }) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Error occured on user logout") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + err = revokeTokens(config, ctx, tx, r) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error occured on user logout") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + cookies.DeleteCookie(w, "access", "/") + cookies.DeleteCookie(w, "refresh", "/") + w.Header().Set("HX-Redirect", "/login") }, ) } diff --git a/handlers/reauthenticatate.go b/handlers/reauthenticatate.go index 87a0928..6adb3f2 100644 --- a/handlers/reauthenticatate.go +++ b/handlers/reauthenticatate.go @@ -3,6 +3,7 @@ package handlers import ( "context" "net/http" + "time" "projectreshoot/config" "projectreshoot/contexts" @@ -105,25 +106,32 @@ func HandleReauthenticate( ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - WithTransaction(w, r, logger, conn, - func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - err := validatePassword(r) - if err != nil { - tx.Rollback() - w.WriteHeader(445) - form.ConfirmPassword("Incorrect password").Render(r.Context(), w) - return - } - err = refreshTokens(config, ctx, tx, w, r) - if err != nil { - tx.Rollback() - logger.Error().Err(err).Msg("Failed to refresh user tokens") - w.WriteHeader(http.StatusInternalServerError) - return - } - tx.Commit() - w.WriteHeader(http.StatusOK) - }) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Failed to refresh user tokens") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + err = validatePassword(r) + if err != nil { + tx.Rollback() + w.WriteHeader(445) + form.ConfirmPassword("Incorrect password").Render(r.Context(), w) + return + } + err = refreshTokens(config, ctx, tx, w, r) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Failed to refresh user tokens") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.WriteHeader(http.StatusOK) }, ) } diff --git a/handlers/register.go b/handlers/register.go index 605b02c..dc4e856 100644 --- a/handlers/register.go +++ b/handlers/register.go @@ -3,6 +3,7 @@ package handlers import ( "context" "net/http" + "time" "projectreshoot/config" "projectreshoot/cookies" @@ -50,36 +51,42 @@ func HandleRegisterRequest( ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - WithTransaction(w, r, logger, conn, - func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - r.ParseForm() - user, err := validateRegistration(ctx, tx, r) - if err != nil { - tx.Rollback() - if err.Error() != "Username is taken" && - err.Error() != "Passwords do not match" && - err.Error() != "Password exceeds maximum length of 72 bytes" { - logger.Warn().Caller().Err(err).Msg("Registration request failed") - w.WriteHeader(http.StatusInternalServerError) - } else { - form.RegisterForm(err.Error()).Render(r.Context(), w) - } - return - } + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() - rememberMe := checkRememberMe(r) - err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) - if err != nil { - tx.Rollback() - w.WriteHeader(http.StatusInternalServerError) - logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") - return - } - tx.Commit() - pageFrom := cookies.CheckPageFrom(w, r) - w.Header().Set("HX-Redirect", pageFrom) - }, - ) + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + logger.Warn().Err(err).Msg("Failed to set token cookies") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + r.ParseForm() + user, err := validateRegistration(ctx, tx, r) + if err != nil { + tx.Rollback() + if err.Error() != "Username is taken" && + err.Error() != "Passwords do not match" && + err.Error() != "Password exceeds maximum length of 72 bytes" { + logger.Warn().Caller().Err(err).Msg("Registration request failed") + w.WriteHeader(http.StatusInternalServerError) + } else { + form.RegisterForm(err.Error()).Render(r.Context(), w) + } + return + } + + rememberMe := checkRememberMe(r) + err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) + if err != nil { + tx.Rollback() + w.WriteHeader(http.StatusInternalServerError) + logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") + return + } + tx.Commit() + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) }, ) } diff --git a/handlers/withtransaction.go b/handlers/withtransaction.go index a097719..c5e23d0 100644 --- a/handlers/withtransaction.go +++ b/handlers/withtransaction.go @@ -6,12 +6,11 @@ import ( "time" "projectreshoot/db" - "projectreshoot/view/page" "github.com/rs/zerolog" ) -func WithTransaction( +func removeme( w http.ResponseWriter, r *http.Request, logger *zerolog.Logger, @@ -22,6 +21,7 @@ func WithTransaction( w http.ResponseWriter, r *http.Request, ), + onfail func(err error), ) { ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) defer cancel() @@ -29,13 +29,7 @@ func WithTransaction( // Start the transaction tx, err := conn.Begin(ctx) if err != nil { - logger.Warn().Err(err).Msg("Request failed to start a transaction") - w.WriteHeader(http.StatusServiceUnavailable) - page.Error( - "503", - http.StatusText(503), - "This service is currently unavailable. It could be down for maintenance"). - Render(r.Context(), w) + onfail(err) return } diff --git a/view/component/popup/errorPopup.templ b/view/component/popup/error500Popup.templ similarity index 95% rename from view/component/popup/errorPopup.templ rename to view/component/popup/error500Popup.templ index f809230..45573e0 100644 --- a/view/component/popup/errorPopup.templ +++ b/view/component/popup/error500Popup.templ @@ -1,9 +1,9 @@ package popup -templ ErrorPopup() { +templ Error500Popup() {
+ +
+} diff --git a/view/layout/global.templ b/view/layout/global.templ index 17f58eb..2469f00 100644 --- a/view/layout/global.templ +++ b/view/layout/global.templ @@ -41,11 +41,12 @@ templ Global() {