diff --git a/db/connection.go b/db/connection.go index 421d4e6..bcd5895 100644 --- a/db/connection.go +++ b/db/connection.go @@ -5,7 +5,6 @@ import ( "database/sql" "fmt" "sync" - "time" "github.com/pkg/errors" @@ -18,17 +17,19 @@ type SafeConn struct { mux sync.RWMutex } +func MakeSafe(db *sql.DB) *SafeConn { + return &SafeConn{db: db} +} + // Extends sql.Tx for use with SafeConn type SafeTX struct { tx *sql.Tx sc *SafeConn } -// Starts a new transaction, waiting up to 10 seconds if the database is locked +// Starts a new transaction based on the current context. Will cancel if +// the context is closed/cancelled/done func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - lockAcquired := make(chan struct{}) go func() { conn.mux.RLock() @@ -119,7 +120,18 @@ func (conn *SafeConn) Close() error { return conn.db.Close() } -// Returns a database connection handle for the Turso DB +// Returns a database connection handle for the DB +func OldConnectToDatabase(dbName string) (*sql.DB, error) { + file := fmt.Sprintf("file:%s.db", dbName) + db, err := sql.Open("sqlite3", file) + if err != nil { + return nil, errors.Wrap(err, "sql.Open") + } + + return db, nil +} + +// Returns a database connection handle for the DB func ConnectToDatabase(dbName string) (*SafeConn, error) { file := fmt.Sprintf("file:%s.db", dbName) db, err := sql.Open("sqlite3", file) @@ -127,7 +139,7 @@ func ConnectToDatabase(dbName string) (*SafeConn, error) { return nil, errors.Wrap(err, "sql.Open") } - conn := &SafeConn{db: db} + conn := MakeSafe(db) return conn, nil } diff --git a/db/user_functions.go b/db/user_functions.go index 3c3623e..9c1100a 100644 --- a/db/user_functions.go +++ b/db/user_functions.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "fmt" @@ -26,7 +27,7 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error } // Fetches data from the users table using "WHERE column = 'value'" -func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { +func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { query := fmt.Sprintf( `SELECT id, @@ -45,6 +46,31 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e return rows, nil } +// Fetches data from the users table using "WHERE column = 'value'" +func fetchUserData( + ctx context.Context, + tx *SafeTX, + column string, + value interface{}, +) (*sql.Rows, error) { + query := fmt.Sprintf( + `SELECT + id, + username, + password_hash, + created_at, + bio + FROM users + WHERE %s = ? COLLATE NOCASE LIMIT 1`, + column, + ) + rows, err := tx.Query(ctx, query, value) + if err != nil { + return nil, errors.Wrap(err, "tx.Query") + } + return rows, nil +} + // Scan the next row into the provided user pointer. Calls rows.Next() and // assumes only row in the result. Providing a rows object with more than 1 // row may result in undefined behaviour. @@ -67,7 +93,7 @@ func scanUserRow(user *User, rows *sql.Rows) error { // Queries the database for a user matching the given username. // Query is case insensitive func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { - rows, err := fetchUserData(conn, "username", username) + rows, err := oldfetchUserData(conn, "username", username) if err != nil { return nil, errors.Wrap(err, "fetchUserData") } @@ -81,8 +107,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { } // Queries the database for a user matching the given ID. -func GetUserFromID(conn *sql.DB, id int) (*User, error) { - rows, err := fetchUserData(conn, "id", id) +func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) { + rows, err := fetchUserData(ctx, tx, "id", id) if err != nil { return nil, errors.Wrap(err, "fetchUserData") } diff --git a/handlers/logout.go b/handlers/logout.go index c8999a2..da78925 100644 --- a/handlers/logout.go +++ b/handlers/logout.go @@ -1,41 +1,79 @@ package handlers import ( - "database/sql" + "context" "net/http" + "strings" + "projectreshoot/config" "projectreshoot/cookies" + "projectreshoot/db" "projectreshoot/jwt" "github.com/pkg/errors" "github.com/rs/zerolog" ) +func revokeAccess( + config *config.Config, + ctx context.Context, + tx *db.SafeTX, + atStr string, +) error { + aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) + if err != nil { + if strings.Contains(err.Error(), "Token is expired") || + strings.Contains(err.Error(), "Token has been revoked") { + return nil // Token is expired, dont need to revoke it + } + return errors.Wrap(err, "jwt.ParseAccessToken") + } + err = jwt.RevokeToken(ctx, tx, aT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + return nil +} + +func revokeRefresh( + config *config.Config, + ctx context.Context, + tx *db.SafeTX, + rtStr string, +) error { + rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) + if err != nil { + if strings.Contains(err.Error(), "Token is expired") || + strings.Contains(err.Error(), "Token has been revoked") { + return nil // Token is expired, dont need to revoke it + } + return errors.Wrap(err, "jwt.ParseRefreshToken") + } + err = jwt.RevokeToken(ctx, tx, rT) + if err != nil { + return errors.Wrap(err, "jwt.RevokeToken") + } + return nil +} + // Retrieve and revoke the user's tokens func revokeTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, r *http.Request, ) error { // get the tokens from the cookies atStr, rtStr := cookies.GetTokenStrings(r) - aT, err := jwt.ParseAccessToken(config, conn, atStr) - if err != nil { - return errors.Wrap(err, "jwt.ParseAccessToken") - } - rT, err := jwt.ParseRefreshToken(config, conn, rtStr) - if err != nil { - return errors.Wrap(err, "jwt.ParseRefreshToken") - } // revoke the refresh token first as the access token expires quicker // only matters if there is an error revoking the tokens - err = jwt.RevokeToken(conn, rT) + err := revokeRefresh(config, ctx, tx, rtStr) if err != nil { - return errors.Wrap(err, "jwt.RevokeToken") + return errors.Wrap(err, "revokeRefresh") } - err = jwt.RevokeToken(conn, aT) + err = revokeAccess(config, ctx, tx, atStr) if err != nil { - return errors.Wrap(err, "jwt.RevokeToken") + return errors.Wrap(err, "revokeAccess") } return nil } @@ -44,19 +82,24 @@ func revokeTokens( func HandleLogout( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - err := revokeTokens(config, conn, r) - if err != nil { - logger.Error().Err(err).Msg("Error occured on user logout") - w.WriteHeader(http.StatusInternalServerError) - return - } - cookies.DeleteCookie(w, "access", "/") - cookies.DeleteCookie(w, "refresh", "/") - w.Header().Set("HX-Redirect", "/login") + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + err := revokeTokens(config, ctx, tx, r) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error occured on user logout") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + cookies.DeleteCookie(w, "access", "/") + cookies.DeleteCookie(w, "refresh", "/") + w.Header().Set("HX-Redirect", "/login") + }) }, ) } diff --git a/handlers/reauthenticatate.go b/handlers/reauthenticatate.go index 9e188d8..87a0928 100644 --- a/handlers/reauthenticatate.go +++ b/handlers/reauthenticatate.go @@ -1,12 +1,13 @@ package handlers import ( - "database/sql" + "context" "net/http" "projectreshoot/config" "projectreshoot/contexts" "projectreshoot/cookies" + "projectreshoot/db" "projectreshoot/jwt" "projectreshoot/view/component/form" @@ -17,16 +18,17 @@ import ( // Get the tokens from the request func getTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, r *http.Request, ) (*jwt.AccessToken, *jwt.RefreshToken, error) { // get the existing tokens from the cookies atStr, rtStr := cookies.GetTokenStrings(r) - aT, err := jwt.ParseAccessToken(config, conn, atStr) + aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) if err != nil { return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken") } - rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) if err != nil { return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken") } @@ -35,15 +37,16 @@ func getTokens( // Revoke the given token pair func revokeTokenPair( - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, aT *jwt.AccessToken, rT *jwt.RefreshToken, ) error { - err := jwt.RevokeToken(conn, aT) + err := jwt.RevokeToken(ctx, tx, aT) if err != nil { return errors.Wrap(err, "jwt.RevokeToken") } - err = jwt.RevokeToken(conn, rT) + err = jwt.RevokeToken(ctx, tx, rT) if err != nil { return errors.Wrap(err, "jwt.RevokeToken") } @@ -53,11 +56,12 @@ func revokeTokenPair( // Issue new tokens for the user, invalidating the old ones func refreshTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, w http.ResponseWriter, r *http.Request, ) error { - aT, rT, err := getTokens(config, conn, r) + aT, rT, err := getTokens(config, ctx, tx, r) if err != nil { return errors.Wrap(err, "getTokens") } @@ -71,7 +75,7 @@ func refreshTokens( if err != nil { return errors.Wrap(err, "cookies.SetTokenCookies") } - err = revokeTokenPair(conn, aT, rT) + err = revokeTokenPair(ctx, tx, aT, rT) if err != nil { return errors.Wrap(err, "revokeTokenPair") } @@ -97,23 +101,29 @@ func validatePassword( func HandleReauthenticate( logger *zerolog.Logger, config *config.Config, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - err := validatePassword(r) - if err != nil { - w.WriteHeader(445) - form.ConfirmPassword("Incorrect password").Render(r.Context(), w) - return - } - err = refreshTokens(config, conn, w, r) - if err != nil { - logger.Error().Err(err).Msg("Failed to refresh user tokens") - w.WriteHeader(http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + err := validatePassword(r) + if err != nil { + tx.Rollback() + w.WriteHeader(445) + form.ConfirmPassword("Incorrect password").Render(r.Context(), w) + return + } + err = refreshTokens(config, ctx, tx, w, r) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Failed to refresh user tokens") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.WriteHeader(http.StatusOK) + }) }, ) } diff --git a/handlers/withtransaction.go b/handlers/withtransaction.go new file mode 100644 index 0000000..d81884f --- /dev/null +++ b/handlers/withtransaction.go @@ -0,0 +1,47 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "projectreshoot/db" + "projectreshoot/view/page" + + "github.com/rs/zerolog" +) + +// A helper function to create a transaction with a cancellable context. +func WithTransaction( + w http.ResponseWriter, + r *http.Request, + logger *zerolog.Logger, + conn *db.SafeConn, + handler func( + ctx context.Context, + tx *db.SafeTX, + w http.ResponseWriter, + r *http.Request, + ), +) { + // Create a cancellable context from the request context + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Start the transaction + tx, err := conn.Begin(ctx) + if err != nil { + tx.Rollback() + logger.Warn().Err(err).Msg("Request failed to start a transaction") + w.WriteHeader(http.StatusServiceUnavailable) + page.Error( + "503", + http.StatusText(503), + "This service is currently unavailable. It could be down for maintenance"). + Render(r.Context(), w) + return + } + + // Pass the context and transaction to the handler + handler(ctx, tx, w, r) +} diff --git a/jwt/parse.go b/jwt/parse.go index 741cc59..0446e85 100644 --- a/jwt/parse.go +++ b/jwt/parse.go @@ -1,11 +1,12 @@ package jwt import ( - "database/sql" + "context" "fmt" "time" "projectreshoot/config" + "projectreshoot/db" "github.com/golang-jwt/jwt" "github.com/google/uuid" @@ -17,7 +18,8 @@ import ( // has the correct scope. func ParseAccessToken( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, tokenString string, ) (*AccessToken, error) { if tokenString == "" { @@ -74,7 +76,7 @@ func ParseAccessToken( Scope: scope, } - valid, err := CheckTokenNotRevoked(conn, token) + valid, err := CheckTokenNotRevoked(ctx, tx, token) if err != nil { return nil, errors.Wrap(err, "CheckTokenNotRevoked") } @@ -89,7 +91,8 @@ func ParseAccessToken( // has the correct scope. func ParseRefreshToken( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, tokenString string, ) (*RefreshToken, error) { if tokenString == "" { @@ -141,7 +144,7 @@ func ParseRefreshToken( Scope: scope, } - valid, err := CheckTokenNotRevoked(conn, token) + valid, err := CheckTokenNotRevoked(ctx, tx, token) if err != nil { return nil, errors.Wrap(err, "CheckTokenNotRevoked") } diff --git a/jwt/revoke.go b/jwt/revoke.go index ed2ec63..e988a4e 100644 --- a/jwt/revoke.go +++ b/jwt/revoke.go @@ -1,32 +1,34 @@ package jwt import ( - "database/sql" + "context" + "projectreshoot/db" "github.com/pkg/errors" ) // Revoke a token by adding it to the database -func RevokeToken(conn *sql.DB, t Token) error { +func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error { jti := t.GetJTI() exp := t.GetEXP() query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)` - _, err := conn.Exec(query, jti, exp) + _, err := tx.Exec(ctx, query, jti, exp) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } // Check if a token has been revoked. Returns true if not revoked. -func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) { +func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) { jti := t.GetJTI() query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` - rows, err := conn.Query(query, jti) - defer rows.Close() + rows, err := tx.Query(ctx, query, jti) if err != nil { - return false, errors.Wrap(err, "conn.Exec") + return false, errors.Wrap(err, "tx.Query") } + // NOTE: rows.Close() + defer rows.Close() revoked := rows.Next() return !revoked, nil } diff --git a/jwt/tokens.go b/jwt/tokens.go index d76e952..ae5d97a 100644 --- a/jwt/tokens.go +++ b/jwt/tokens.go @@ -1,7 +1,7 @@ package jwt import ( - "database/sql" + "context" "projectreshoot/db" "github.com/google/uuid" @@ -12,7 +12,7 @@ type Token interface { GetJTI() uuid.UUID GetEXP() int64 GetScope() string - GetUser(conn *sql.DB) (*db.User, error) + GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) } // Access token @@ -38,15 +38,15 @@ type RefreshToken struct { Scope string // Should be "refresh" } -func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) { - user, err := db.GetUserFromID(conn, a.SUB) +func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) { + user, err := db.GetUserFromID(ctx, tx, a.SUB) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromID") } return user, nil } -func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) { - user, err := db.GetUserFromID(conn, r.SUB) +func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) { + user, err := db.GetUserFromID(ctx, tx, r.SUB) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromID") } diff --git a/main.go b/main.go index 84ed2fd..6e3032e 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "logging.GetLogger") } + oldconn, err := db.OldConnectToDatabase(config.DBName) + if err != nil { + return errors.Wrap(err, "db.ConnectToDatabase") + } + defer oldconn.Close() conn, err := db.ConnectToDatabase(config.DBName) if err != nil { return errors.Wrap(err, "db.ConnectToDatabase") @@ -88,7 +93,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "getStaticFiles") } - srv := server.NewServer(config, logger, conn, &staticFS) + srv := server.NewServer(config, logger, oldconn, conn, &staticFS) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, @@ -99,7 +104,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { // Runs function for testing in dev if --test flag true if args["test"] == "true" { - test(config, logger, conn, httpServer) + test(config, logger, oldconn, httpServer) return nil } diff --git a/middleware/authentication.go b/middleware/authentication.go index 23f40f0..3d31c94 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -1,7 +1,7 @@ package middleware import ( - "database/sql" + "context" "net/http" "time" @@ -9,6 +9,7 @@ import ( "projectreshoot/contexts" "projectreshoot/cookies" "projectreshoot/db" + "projectreshoot/handlers" "projectreshoot/jwt" "github.com/pkg/errors" @@ -18,14 +19,15 @@ import ( // Attempt to use a valid refresh token to generate a new token pair func refreshAuthTokens( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, w http.ResponseWriter, req *http.Request, ref *jwt.RefreshToken, ) (*db.User, error) { - user, err := ref.GetUser(conn) + user, err := ref.GetUser(ctx, tx) if err != nil { - return nil, errors.Wrap(err, "rT.GetUser") + return nil, errors.Wrap(err, "ref.GetUser") } rememberMe := map[string]bool{ @@ -39,7 +41,7 @@ func refreshAuthTokens( return nil, errors.Wrap(err, "cookies.SetTokenCookies") } // New tokens sent, revoke the used refresh token - err = jwt.RevokeToken(conn, ref) + err = jwt.RevokeToken(ctx, tx, ref) if err != nil { return nil, errors.Wrap(err, "jwt.RevokeToken") } @@ -50,22 +52,23 @@ func refreshAuthTokens( // Check the cookies for token strings and attempt to authenticate them func getAuthenticatedUser( config *config.Config, - conn *sql.DB, + ctx context.Context, + tx *db.SafeTX, w http.ResponseWriter, r *http.Request, ) (*contexts.AuthenticatedUser, error) { // Get token strings from cookies atStr, rtStr := cookies.GetTokenStrings(r) // Attempt to parse the access token - aT, err := jwt.ParseAccessToken(config, conn, atStr) + aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) if err != nil { // Access token invalid, attempt to parse refresh token - rT, err := jwt.ParseRefreshToken(config, conn, rtStr) + rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) if err != nil { return nil, errors.Wrap(err, "jwt.ParseRefreshToken") } // Refresh token valid, attempt to get a new token pair - user, err := refreshAuthTokens(config, conn, w, r, rT) + user, err := refreshAuthTokens(config, ctx, tx, w, r, rT) if err != nil { return nil, errors.Wrap(err, "refreshAuthTokens") } @@ -77,9 +80,9 @@ func getAuthenticatedUser( return &authUser, nil } // Access token valid - user, err := aT.GetUser(conn) + user, err := aT.GetUser(ctx, tx) if err != nil { - return nil, errors.Wrap(err, "rT.GetUser") + return nil, errors.Wrap(err, "aT.GetUser") } authUser := contexts.AuthenticatedUser{ User: user, @@ -93,22 +96,36 @@ func getAuthenticatedUser( func Authentication( logger *zerolog.Logger, config *config.Config, - conn *sql.DB, + conn *db.SafeConn, next http.Handler, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user, err := getAuthenticatedUser(config, conn, w, r) - if err != nil { - // User auth failed, delete the cookies to avoid repeat requests - cookies.DeleteCookie(w, "access", "/") - cookies.DeleteCookie(w, "refresh", "/") - logger.Debug(). - Str("remote_addr", r.RemoteAddr). - Err(err). - Msg("Failed to authenticate user") + if r.URL.Path == "/static/css/output.css" || + r.URL.Path == "/static/favicon.ico" { + next.ServeHTTP(w, r) + return } - ctx := contexts.SetUser(r.Context(), user) - newReq := r.WithContext(ctx) - next.ServeHTTP(w, newReq) + handlers.WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + tx, err := conn.Begin(ctx) + user, err := getAuthenticatedUser(config, ctx, tx, w, r) + if err != nil { + tx.Rollback() + // User auth failed, delete the cookies to avoid repeat requests + cookies.DeleteCookie(w, "access", "/") + cookies.DeleteCookie(w, "refresh", "/") + logger.Debug(). + Str("remote_addr", r.RemoteAddr). + Err(err). + Msg("Failed to authenticate user") + next.ServeHTTP(w, r) + return + } + tx.Commit() + uctx := contexts.SetUser(r.Context(), user) + newReq := r.WithContext(uctx) + next.ServeHTTP(w, newReq) + }, + ) }) } diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index a5143c8..172bc03 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -19,7 +19,7 @@ func TestAuthenticationMiddleware(t *testing.T) { cfg, err := tests.TestConfig() require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + conn, err := tests.SetupTestDB(t.Context()) require.NoError(t, err) require.NotNil(t, conn) defer tests.DeleteTestDB() diff --git a/middleware/logging.go b/middleware/logging.go index abae797..de42258 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -23,9 +23,14 @@ func (w *wrappedWriter) WriteHeader(statusCode int) { // Middleware to add logs to console with details of the request func Logging(logger *zerolog.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/static/css/output.css" || + r.URL.Path == "/static/favicon.ico" { + next.ServeHTTP(w, r) + return + } start, err := contexts.GetStartTime(r.Context()) if err != nil { - // Handle failure here. internal server error maybe + // TODO: Handle failure here. internal server error maybe return } wrapped := &wrappedWriter{ diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go index de79975..03926b7 100644 --- a/middleware/pageprotection_test.go +++ b/middleware/pageprotection_test.go @@ -16,7 +16,7 @@ func TestPageLoginRequired(t *testing.T) { cfg, err := tests.TestConfig() require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + conn, err := tests.SetupTestDB(t.Context()) require.NoError(t, err) require.NotNil(t, conn) defer tests.DeleteTestDB() diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go index 63017cb..0f20840 100644 --- a/middleware/reauthentication_test.go +++ b/middleware/reauthentication_test.go @@ -16,7 +16,7 @@ func TestActionReauthRequired(t *testing.T) { cfg, err := tests.TestConfig() require.NoError(t, err) logger := tests.NilLogger() - conn, err := tests.SetupTestDB() + conn, err := tests.SetupTestDB(t.Context()) require.NoError(t, err) require.NotNil(t, conn) defer tests.DeleteTestDB() diff --git a/server/routes.go b/server/routes.go index a92885f..5a9d8c9 100644 --- a/server/routes.go +++ b/server/routes.go @@ -5,6 +5,7 @@ import ( "net/http" "projectreshoot/config" + "projectreshoot/db" "projectreshoot/handlers" "projectreshoot/middleware" "projectreshoot/view/page" @@ -17,7 +18,8 @@ func addRoutes( mux *http.ServeMux, logger *zerolog.Logger, config *config.Config, - conn *sql.DB, + oldconn *sql.DB, + conn *db.SafeConn, staticFS *http.FileSystem, ) { // Health check @@ -42,7 +44,7 @@ func addRoutes( handlers.HandleLoginRequest( config, logger, - conn, + oldconn, ))) // Register page and handlers @@ -55,7 +57,7 @@ func addRoutes( handlers.HandleRegisterRequest( config, logger, - conn, + oldconn, ))) // Logout @@ -85,17 +87,17 @@ func addRoutes( mux.Handle("POST /change-username", middleware.RequiresLogin( middleware.RequiresFresh( - handlers.HandleChangeUsername(logger, conn), + handlers.HandleChangeUsername(logger, oldconn), ), )) mux.Handle("POST /change-bio", middleware.RequiresLogin( - handlers.HandleChangeBio(logger, conn), + handlers.HandleChangeBio(logger, oldconn), )) mux.Handle("POST /change-password", middleware.RequiresLogin( middleware.RequiresFresh( - handlers.HandleChangePassword(logger, conn), + handlers.HandleChangePassword(logger, oldconn), ), )) } diff --git a/server/server.go b/server/server.go index 648082f..c64a8cb 100644 --- a/server/server.go +++ b/server/server.go @@ -5,6 +5,7 @@ import ( "net/http" "projectreshoot/config" + "projectreshoot/db" "projectreshoot/middleware" "github.com/rs/zerolog" @@ -14,7 +15,8 @@ import ( func NewServer( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + oldconn *sql.DB, + conn *db.SafeConn, staticFS *http.FileSystem, ) http.Handler { mux := http.NewServeMux() @@ -22,6 +24,7 @@ func NewServer( mux, logger, config, + oldconn, conn, staticFS, ) diff --git a/tests/database.go b/tests/database.go index 7c606c5..a7fd26b 100644 --- a/tests/database.go +++ b/tests/database.go @@ -1,10 +1,12 @@ package tests import ( + "context" "database/sql" "fmt" "os" "path/filepath" + "projectreshoot/db" "github.com/pkg/errors" @@ -32,11 +34,16 @@ func findSQLFile(filename string) (string, error) { // SetupTestDB initializes a test SQLite database with mock data // Make sure to call DeleteTestDB when finished to cleanup -func SetupTestDB() (*sql.DB, error) { - conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db") +func SetupTestDB(ctx context.Context) (*db.SafeConn, error) { + dbfile, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db") if err != nil { return nil, errors.Wrap(err, "sql.Open") } + conn := db.MakeSafe(dbfile) + tx, err := conn.Begin(ctx) + if err != nil { + return nil, errors.Wrap(err, "conn.Begin") + } // Setup the test database schemaPath, err := findSQLFile("schema.sql") if err != nil { @@ -49,9 +56,10 @@ func SetupTestDB() (*sql.DB, error) { } schemaSQL := string(sqlBytes) - _, err = conn.Exec(schemaSQL) + _, err = tx.Exec(ctx, schemaSQL) if err != nil { - return nil, errors.Wrap(err, "conn.Exec") + tx.Rollback() + return nil, errors.Wrap(err, "tx.Exec") } // Load the test data dataPath, err := findSQLFile("testdata.sql") @@ -64,10 +72,12 @@ func SetupTestDB() (*sql.DB, error) { } dataSQL := string(sqlBytes) - _, err = conn.Exec(dataSQL) + _, err = tx.Exec(ctx, dataSQL) if err != nil { - return nil, errors.Wrap(err, "conn.Exec") + tx.Rollback() + return nil, errors.Wrap(err, "tx.Exec") } + tx.Commit() return conn, nil }