diff --git a/Makefile b/Makefile index 72a5804..66dd5b8 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,6 @@ tester: go run . --port 3232 --test --loglevel trace test: - rm -f **/.projectreshoot-test-database.db && \ go mod tidy && \ templ generate && \ go generate && \ diff --git a/db/user.go b/db/user.go index fe62349..a2daa26 100644 --- a/db/user.go +++ b/db/user.go @@ -1,7 +1,7 @@ package db import ( - "database/sql" + "context" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" @@ -16,16 +16,16 @@ type User struct { } // Uses bcrypt to set the users Password_hash from the given password -func (user *User) SetPassword(conn *sql.DB, password string) error { +func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return errors.Wrap(err, "bcrypt.GenerateFromPassword") } user.Password_hash = string(hashedPassword) query := `UPDATE users SET password_hash = ? WHERE id = ?` - _, err = conn.Exec(query, user.Password_hash, user.ID) + _, err = tx.Exec(ctx, query, user.Password_hash, user.ID) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } @@ -40,21 +40,21 @@ func (user *User) CheckPassword(password string) error { } // Change the user's username -func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error { +func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error { query := `UPDATE users SET username = ? WHERE id = ?` - _, err := conn.Exec(query, newUsername, user.ID) + _, err := tx.Exec(ctx, query, newUsername, user.ID) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } // Change the user's bio -func (user *User) ChangeBio(conn *sql.DB, newBio string) error { +func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error { query := `UPDATE users SET bio = ? WHERE id = ?` - _, err := conn.Exec(query, newBio, user.ID) + _, err := tx.Exec(ctx, query, newBio, user.ID) if err != nil { - return errors.Wrap(err, "conn.Exec") + return errors.Wrap(err, "tx.Exec") } return nil } diff --git a/db/user_functions.go b/db/user_functions.go index 9c1100a..28d30ad 100644 --- a/db/user_functions.go +++ b/db/user_functions.go @@ -9,43 +9,28 @@ import ( ) // Creates a new user in the database and returns a pointer -func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { +func CreateNewUser( + ctx context.Context, + tx *SafeTX, + username string, + password string, +) (*User, error) { query := `INSERT INTO users (username) VALUES (?)` - _, err := conn.Exec(query, username) + _, err := tx.Exec(ctx, query, username) if err != nil { - return nil, errors.Wrap(err, "conn.Exec") + return nil, errors.Wrap(err, "tx.Exec") } - user, err := GetUserFromUsername(conn, username) + user, err := GetUserFromUsername(ctx, tx, username) if err != nil { return nil, errors.Wrap(err, "GetUserFromUsername") } - err = user.SetPassword(conn, password) + err = user.SetPassword(ctx, tx, password) if err != nil { return nil, errors.Wrap(err, "user.SetPassword") } return user, nil } -// Fetches data from the users table using "WHERE column = 'value'" -func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { - query := fmt.Sprintf( - `SELECT - id, - username, - password_hash, - created_at, - bio - FROM users - WHERE %s = ? COLLATE NOCASE LIMIT 1`, - column, - ) - rows, err := conn.Query(query, value) - if err != nil { - return nil, errors.Wrap(err, "conn.Query") - } - return rows, nil -} - // Fetches data from the users table using "WHERE column = 'value'" func fetchUserData( ctx context.Context, @@ -92,8 +77,8 @@ func scanUserRow(user *User, rows *sql.Rows) error { // Queries the database for a user matching the given username. // Query is case insensitive -func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { - rows, err := oldfetchUserData(conn, "username", username) +func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) { + rows, err := fetchUserData(ctx, tx, "username", username) if err != nil { return nil, errors.Wrap(err, "fetchUserData") } @@ -122,11 +107,11 @@ func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) { } // Checks if the given username is unique. Returns true if not taken -func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { +func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) { query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` - rows, err := conn.Query(query, username) + rows, err := tx.Query(ctx, query, username) if err != nil { - return false, errors.Wrap(err, "conn.Query") + return false, errors.Wrap(err, "tx.Query") } defer rows.Close() taken := rows.Next() diff --git a/go.mod b/go.mod index 97dfb00..98e11af 100644 --- a/go.mod +++ b/go.mod @@ -7,18 +7,25 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 - github.com/mattn/go-sqlite3 v1.14.24 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.33.0 + modernc.org/sqlite v1.35.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect golang.org/x/sys v0.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.61.13 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.8.2 // indirect ) diff --git a/go.sum b/go.sum index 328d060..baa5432 100644 --- a/go.sum +++ b/go.sum @@ -3,11 +3,15 @@ github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YY github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -19,12 +23,14 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= -github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= @@ -32,12 +38,44 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo= +golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= +golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= +modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo= +modernc.org/ccgo/v4 v4.23.16/go.mod h1:nNma8goMTY7aQZQNTyN9AIoJfxav4nvTnvKThAeMDdo= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.6.3 h1:aJVhcqAte49LF+mGveZ5KPlsp4tdGdAOT4sipJXADjw= +modernc.org/gc/v2 v2.6.3/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/libc v1.61.13 h1:3LRd6ZO1ezsFiX1y+bHd1ipyEHIJKvuprv0sLTBwLW8= +modernc.org/libc v1.61.13/go.mod h1:8F/uJWL/3nNil0Lgt1Dpz+GgkApWh04N3el3hxJcA6E= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.8.2 h1:cL9L4bcoAObu4NkxOlKWBWtNHIsnnACGF/TbqQ6sbcI= +modernc.org/memory v1.8.2/go.mod h1:ZbjSvMO5NQ1A2i3bWeDiVMxIorXwdClKE/0SZ+BMotU= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.35.0 h1:yQps4fegMnZFdphtzlfQTCNBWtS0CZv48pRpW3RFHRw= +modernc.org/sqlite v1.35.0/go.mod h1:9cr2sicr7jIaWTBKQmAxQLfBv9LL0su4ZTEV+utt3ic= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/handlers/account.go b/handlers/account.go index 365a283..dac0a3b 100644 --- a/handlers/account.go +++ b/handlers/account.go @@ -1,7 +1,7 @@ package handlers import ( - "database/sql" + "context" "net/http" "projectreshoot/contexts" @@ -43,32 +43,39 @@ func HandleAccountSubpage() http.Handler { // Handles a request to change the users username func HandleChangeUsername( logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - newUsername := r.FormValue("username") - - unique, err := db.CheckUsernameUnique(conn, newUsername) - if err != nil { - logger.Error().Err(err).Msg("Error updating username") - w.WriteHeader(http.StatusInternalServerError) - return - } - if !unique { - account.ChangeUsername("Username is taken", newUsername). - Render(r.Context(), w) - return - } - user := contexts.GetUser(r.Context()) - err = user.ChangeUsername(conn, newUsername) - if err != nil { - logger.Error().Err(err).Msg("Error updating username") - w.WriteHeader(http.StatusInternalServerError) - return - } - w.Header().Set("HX-Refresh", "true") + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + newUsername := r.FormValue("username") + unique, err := db.CheckUsernameUnique(ctx, tx, newUsername) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + if !unique { + tx.Rollback() + account.ChangeUsername("Username is taken", newUsername). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.ChangeUsername(ctx, tx, newUsername) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating username") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.Header().Set("HX-Refresh", "true") + }, + ) }, ) } @@ -76,30 +83,41 @@ func HandleChangeUsername( // Handles a request to change the users bio func HandleChangeBio( logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - newBio := r.FormValue("bio") - leng := len([]rune(newBio)) - if leng > 128 { - account.ChangeBio("Bio limited to 128 characters", newBio). - Render(r.Context(), w) - return - } - user := contexts.GetUser(r.Context()) - err := user.ChangeBio(conn, newBio) - if err != nil { - logger.Error().Err(err).Msg("Error updating bio") - w.WriteHeader(http.StatusInternalServerError) - return - } - w.Header().Set("HX-Refresh", "true") + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + newBio := r.FormValue("bio") + leng := len([]rune(newBio)) + if leng > 128 { + tx.Rollback() + account.ChangeBio("Bio limited to 128 characters", newBio). + Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err := user.ChangeBio(ctx, tx, newBio) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating bio") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.Header().Set("HX-Refresh", "true") + }, + ) }, ) } -func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) { +func validateChangePassword( + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) (string, error) { r.ParseForm() formPassword := r.FormValue("password") formConfirmPassword := r.FormValue("confirm-password") @@ -115,23 +133,30 @@ func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) { // Handles a request to change the users password func HandleChangePassword( logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - newPass, err := validateChangePassword(conn, r) - if err != nil { - account.ChangePassword(err.Error()).Render(r.Context(), w) - return - } - user := contexts.GetUser(r.Context()) - err = user.SetPassword(conn, newPass) - if err != nil { - logger.Error().Err(err).Msg("Error updating password") - w.WriteHeader(http.StatusInternalServerError) - return - } - w.Header().Set("HX-Refresh", "true") + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + newPass, err := validateChangePassword(ctx, tx, r) + if err != nil { + tx.Rollback() + account.ChangePassword(err.Error()).Render(r.Context(), w) + return + } + user := contexts.GetUser(r.Context()) + err = user.SetPassword(ctx, tx, newPass) + if err != nil { + tx.Rollback() + logger.Error().Err(err).Msg("Error updating password") + w.WriteHeader(http.StatusInternalServerError) + return + } + tx.Commit() + w.Header().Set("HX-Refresh", "true") + }, + ) }, ) } diff --git a/handlers/login.go b/handlers/login.go index 7af3901..8788b01 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -1,7 +1,7 @@ package handlers import ( - "database/sql" + "context" "net/http" "projectreshoot/config" @@ -16,10 +16,14 @@ import ( // Validates the username matches a user in the database and the password // is correct. Returns the corresponding user -func validateLogin(conn *sql.DB, r *http.Request) (*db.User, error) { +func validateLogin( + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) (*db.User, error) { formUsername := r.FormValue("username") formPassword := r.FormValue("password") - user, err := db.GetUserFromUsername(conn, formUsername) + user, err := db.GetUserFromUsername(ctx, tx, formUsername) if err != nil { return nil, errors.Wrap(err, "db.GetUserFromUsername") } @@ -47,31 +51,38 @@ func checkRememberMe(r *http.Request) bool { func HandleLoginRequest( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - user, err := validateLogin(conn, r) - if err != nil { - if err.Error() != "Username or password incorrect" { - logger.Warn().Caller().Err(err).Msg("Login request failed") - w.WriteHeader(http.StatusInternalServerError) - } else { - form.LoginForm(err.Error()).Render(r.Context(), w) - } - return - } + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + user, err := validateLogin(ctx, tx, r) + if err != nil { + tx.Rollback() + if err.Error() != "Username or password incorrect" { + logger.Warn().Caller().Err(err).Msg("Login request failed") + w.WriteHeader(http.StatusInternalServerError) + } else { + form.LoginForm(err.Error()).Render(r.Context(), w) + } + return + } - rememberMe := checkRememberMe(r) - err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") - } + rememberMe := checkRememberMe(r) + err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) + if err != nil { + tx.Rollback() + w.WriteHeader(http.StatusInternalServerError) + logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") + return + } - pageFrom := cookies.CheckPageFrom(w, r) - w.Header().Set("HX-Redirect", pageFrom) + tx.Commit() + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) + }) }, ) } diff --git a/handlers/register.go b/handlers/register.go index 895ab67..605b02c 100644 --- a/handlers/register.go +++ b/handlers/register.go @@ -1,7 +1,7 @@ package handlers import ( - "database/sql" + "context" "net/http" "projectreshoot/config" @@ -14,11 +14,15 @@ import ( "github.com/rs/zerolog" ) -func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { +func validateRegistration( + ctx context.Context, + tx *db.SafeTX, + r *http.Request, +) (*db.User, error) { formUsername := r.FormValue("username") formPassword := r.FormValue("password") formConfirmPassword := r.FormValue("confirm-password") - unique, err := db.CheckUsernameUnique(conn, formUsername) + unique, err := db.CheckUsernameUnique(ctx, tx, formUsername) if err != nil { return nil, errors.Wrap(err, "db.CheckUsernameUnique") } @@ -31,7 +35,7 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { if len(formPassword) > 72 { return nil, errors.New("Password exceeds maximum length of 72 bytes") } - user, err := db.CreateNewUser(conn, formUsername, formPassword) + user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword) if err != nil { return nil, errors.Wrap(err, "db.CreateNewUser") } @@ -42,33 +46,40 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) { func HandleRegisterRequest( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - user, err := validateRegistration(conn, r) - if err != nil { - if err.Error() != "Username is taken" && - err.Error() != "Passwords do not match" && - err.Error() != "Password exceeds maximum length of 72 bytes" { - logger.Warn().Caller().Err(err).Msg("Registration request failed") - w.WriteHeader(http.StatusInternalServerError) - } else { - form.RegisterForm(err.Error()).Render(r.Context(), w) - } - return - } + WithTransaction(w, r, logger, conn, + func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + user, err := validateRegistration(ctx, tx, r) + if err != nil { + tx.Rollback() + if err.Error() != "Username is taken" && + err.Error() != "Passwords do not match" && + err.Error() != "Password exceeds maximum length of 72 bytes" { + logger.Warn().Caller().Err(err).Msg("Registration request failed") + w.WriteHeader(http.StatusInternalServerError) + } else { + form.RegisterForm(err.Error()).Render(r.Context(), w) + } + return + } - rememberMe := checkRememberMe(r) - err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") - } - - pageFrom := cookies.CheckPageFrom(w, r) - w.Header().Set("HX-Redirect", pageFrom) + rememberMe := checkRememberMe(r) + err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) + if err != nil { + tx.Rollback() + w.WriteHeader(http.StatusInternalServerError) + logger.Warn().Caller().Err(err).Msg("Failed to set token cookies") + return + } + tx.Commit() + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) + }, + ) }, ) } diff --git a/main.go b/main.go index 6e3032e..84ed2fd 100644 --- a/main.go +++ b/main.go @@ -77,11 +77,6 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "logging.GetLogger") } - oldconn, err := db.OldConnectToDatabase(config.DBName) - if err != nil { - return errors.Wrap(err, "db.ConnectToDatabase") - } - defer oldconn.Close() conn, err := db.ConnectToDatabase(config.DBName) if err != nil { return errors.Wrap(err, "db.ConnectToDatabase") @@ -93,7 +88,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return errors.Wrap(err, "getStaticFiles") } - srv := server.NewServer(config, logger, oldconn, conn, &staticFS) + srv := server.NewServer(config, logger, conn, &staticFS) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, @@ -104,7 +99,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { // Runs function for testing in dev if --test flag true if args["test"] == "true" { - test(config, logger, oldconn, httpServer) + test(config, logger, conn, httpServer) return nil } diff --git a/middleware/authentication.go b/middleware/authentication.go index 3d31c94..e444da8 100644 --- a/middleware/authentication.go +++ b/middleware/authentication.go @@ -107,7 +107,6 @@ func Authentication( } handlers.WithTransaction(w, r, logger, conn, func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) { - tx, err := conn.Begin(ctx) user, err := getAuthenticatedUser(config, ctx, tx, w, r) if err != nil { tx.Rollback() diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go index 172bc03..95583af 100644 --- a/middleware/authentication_test.go +++ b/middleware/authentication_test.go @@ -8,6 +8,7 @@ import ( "testing" "projectreshoot/contexts" + "projectreshoot/db" "projectreshoot/tests" "github.com/stretchr/testify/assert" @@ -16,13 +17,14 @@ import ( func TestAuthenticationMiddleware(t *testing.T) { // Basic setup + conn, err := tests.SetupTestDB() + require.NoError(t, err) + sconn := db.MakeSafe(conn) + defer sconn.Close() + cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.NilLogger() - conn, err := tests.SetupTestDB(t.Context()) - require.NoError(t, err) - require.NotNil(t, conn) - defer tests.DeleteTestDB() + logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -38,7 +40,7 @@ func TestAuthenticationMiddleware(t *testing.T) { }) // Add the middleware and create the server - authHandler := Authentication(logger, cfg, conn, testHandler) + authHandler := Authentication(logger, cfg, sconn, testHandler) require.NoError(t, err) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go index 03926b7..80b0a15 100644 --- a/middleware/pageprotection_test.go +++ b/middleware/pageprotection_test.go @@ -5,6 +5,7 @@ import ( "net/http/httptest" "testing" + "projectreshoot/db" "projectreshoot/tests" "github.com/stretchr/testify/assert" @@ -13,13 +14,14 @@ import ( func TestPageLoginRequired(t *testing.T) { // Basic setup + conn, err := tests.SetupTestDB() + require.NoError(t, err) + sconn := db.MakeSafe(conn) + defer sconn.Close() + cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.NilLogger() - conn, err := tests.SetupTestDB(t.Context()) - require.NoError(t, err) - require.NotNil(t, conn) - defer tests.DeleteTestDB() + logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -28,7 +30,7 @@ func TestPageLoginRequired(t *testing.T) { // Add the middleware and create the server loginRequiredHandler := RequiresLogin(testHandler) - authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go index 0f20840..595e4e7 100644 --- a/middleware/reauthentication_test.go +++ b/middleware/reauthentication_test.go @@ -5,21 +5,23 @@ import ( "net/http/httptest" "testing" + "projectreshoot/db" "projectreshoot/tests" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestActionReauthRequired(t *testing.T) { +func TestReauthRequired(t *testing.T) { // Basic setup + conn, err := tests.SetupTestDB() + require.NoError(t, err) + sconn := db.MakeSafe(conn) + defer sconn.Close() + cfg, err := tests.TestConfig() require.NoError(t, err) - logger := tests.NilLogger() - conn, err := tests.SetupTestDB(t.Context()) - require.NoError(t, err) - require.NotNil(t, conn) - defer tests.DeleteTestDB() + logger := tests.DebugLogger(t) // Handler to check outcome of Authentication middleware testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -29,7 +31,7 @@ func TestActionReauthRequired(t *testing.T) { // Add the middleware and create the server reauthRequiredHandler := RequiresFresh(testHandler) loginRequiredHandler := RequiresLogin(reauthRequiredHandler) - authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler) server := httptest.NewServer(authHandler) defer server.Close() diff --git a/server/routes.go b/server/routes.go index 5a9d8c9..eac1e17 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,7 +1,6 @@ package server import ( - "database/sql" "net/http" "projectreshoot/config" @@ -18,7 +17,6 @@ func addRoutes( mux *http.ServeMux, logger *zerolog.Logger, config *config.Config, - oldconn *sql.DB, conn *db.SafeConn, staticFS *http.FileSystem, ) { @@ -44,7 +42,7 @@ func addRoutes( handlers.HandleLoginRequest( config, logger, - oldconn, + conn, ))) // Register page and handlers @@ -57,7 +55,7 @@ func addRoutes( handlers.HandleRegisterRequest( config, logger, - oldconn, + conn, ))) // Logout @@ -87,17 +85,17 @@ func addRoutes( mux.Handle("POST /change-username", middleware.RequiresLogin( middleware.RequiresFresh( - handlers.HandleChangeUsername(logger, oldconn), + handlers.HandleChangeUsername(logger, conn), ), )) mux.Handle("POST /change-bio", middleware.RequiresLogin( - handlers.HandleChangeBio(logger, oldconn), + handlers.HandleChangeBio(logger, conn), )) mux.Handle("POST /change-password", middleware.RequiresLogin( middleware.RequiresFresh( - handlers.HandleChangePassword(logger, oldconn), + handlers.HandleChangePassword(logger, conn), ), )) } diff --git a/server/server.go b/server/server.go index c64a8cb..fa75b0a 100644 --- a/server/server.go +++ b/server/server.go @@ -1,7 +1,6 @@ package server import ( - "database/sql" "net/http" "projectreshoot/config" @@ -15,7 +14,6 @@ import ( func NewServer( config *config.Config, logger *zerolog.Logger, - oldconn *sql.DB, conn *db.SafeConn, staticFS *http.FileSystem, ) http.Handler { @@ -24,7 +22,6 @@ func NewServer( mux, logger, config, - oldconn, conn, staticFS, ) diff --git a/tester.go b/tester.go index bfd8981..e474d82 100644 --- a/tester.go +++ b/tester.go @@ -1,10 +1,10 @@ package main import ( - "database/sql" "net/http" "projectreshoot/config" + "projectreshoot/db" "github.com/rs/zerolog" ) @@ -18,7 +18,7 @@ import ( func test( config *config.Config, logger *zerolog.Logger, - conn *sql.DB, + conn *db.SafeConn, srv *http.Server, ) { } diff --git a/tests/database.go b/tests/database.go index a7fd26b..549db2b 100644 --- a/tests/database.go +++ b/tests/database.go @@ -1,16 +1,14 @@ package tests import ( - "context" "database/sql" "fmt" "os" "path/filepath" - "projectreshoot/db" "github.com/pkg/errors" - _ "github.com/mattn/go-sqlite3" + _ "modernc.org/sqlite" ) func findSQLFile(filename string) (string, error) { @@ -33,17 +31,11 @@ func findSQLFile(filename string) (string, error) { } // SetupTestDB initializes a test SQLite database with mock data -// Make sure to call DeleteTestDB when finished to cleanup -func SetupTestDB(ctx context.Context) (*db.SafeConn, error) { - dbfile, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db") +func SetupTestDB() (*sql.DB, error) { + conn, err := sql.Open("sqlite", "file::memory:?cache=shared") if err != nil { return nil, errors.Wrap(err, "sql.Open") } - conn := db.MakeSafe(dbfile) - tx, err := conn.Begin(ctx) - if err != nil { - return nil, errors.Wrap(err, "conn.Begin") - } // Setup the test database schemaPath, err := findSQLFile("schema.sql") if err != nil { @@ -56,9 +48,8 @@ func SetupTestDB(ctx context.Context) (*db.SafeConn, error) { } schemaSQL := string(sqlBytes) - _, err = tx.Exec(ctx, schemaSQL) + _, err = conn.Exec(schemaSQL) if err != nil { - tx.Rollback() return nil, errors.Wrap(err, "tx.Exec") } // Load the test data @@ -72,24 +63,9 @@ func SetupTestDB(ctx context.Context) (*db.SafeConn, error) { } dataSQL := string(sqlBytes) - _, err = tx.Exec(ctx, dataSQL) + _, err = conn.Exec(dataSQL) if err != nil { - tx.Rollback() return nil, errors.Wrap(err, "tx.Exec") } - tx.Commit() return conn, nil } - -// Deletes the test database from disk -func DeleteTestDB() error { - fileName := ".projectreshoot-test-database.db" - - // Attempt to remove the file - err := os.Remove(fileName) - if err != nil { - return errors.Wrap(err, "os.Remove") - } - - return nil -}