Updated all code to use SafeConn and SafeTX

This commit is contained in:
2025-02-17 21:39:12 +11:00
parent 6faf168a6d
commit a8d112fdd5
17 changed files with 265 additions and 218 deletions

View File

@@ -20,7 +20,6 @@ tester:
go run . --port 3232 --test --loglevel trace
test:
rm -f **/.projectreshoot-test-database.db && \
go mod tidy && \
templ generate && \
go generate && \

View File

@@ -1,7 +1,7 @@
package db
import (
"database/sql"
"context"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
@@ -16,16 +16,16 @@ type User struct {
}
// Uses bcrypt to set the users Password_hash from the given password
func (user *User) SetPassword(conn *sql.DB, password string) error {
func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
}
user.Password_hash = string(hashedPassword)
query := `UPDATE users SET password_hash = ? WHERE id = ?`
_, err = conn.Exec(query, user.Password_hash, user.ID)
_, err = tx.Exec(ctx, query, user.Password_hash, user.ID)
if err != nil {
return errors.Wrap(err, "conn.Exec")
return errors.Wrap(err, "tx.Exec")
}
return nil
}
@@ -40,21 +40,21 @@ func (user *User) CheckPassword(password string) error {
}
// Change the user's username
func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error {
func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error {
query := `UPDATE users SET username = ? WHERE id = ?`
_, err := conn.Exec(query, newUsername, user.ID)
_, err := tx.Exec(ctx, query, newUsername, user.ID)
if err != nil {
return errors.Wrap(err, "conn.Exec")
return errors.Wrap(err, "tx.Exec")
}
return nil
}
// Change the user's bio
func (user *User) ChangeBio(conn *sql.DB, newBio string) error {
func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error {
query := `UPDATE users SET bio = ? WHERE id = ?`
_, err := conn.Exec(query, newBio, user.ID)
_, err := tx.Exec(ctx, query, newBio, user.ID)
if err != nil {
return errors.Wrap(err, "conn.Exec")
return errors.Wrap(err, "tx.Exec")
}
return nil
}

View File

@@ -9,43 +9,28 @@ import (
)
// Creates a new user in the database and returns a pointer
func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) {
func CreateNewUser(
ctx context.Context,
tx *SafeTX,
username string,
password string,
) (*User, error) {
query := `INSERT INTO users (username) VALUES (?)`
_, err := conn.Exec(query, username)
_, err := tx.Exec(ctx, query, username)
if err != nil {
return nil, errors.Wrap(err, "conn.Exec")
return nil, errors.Wrap(err, "tx.Exec")
}
user, err := GetUserFromUsername(conn, username)
user, err := GetUserFromUsername(ctx, tx, username)
if err != nil {
return nil, errors.Wrap(err, "GetUserFromUsername")
}
err = user.SetPassword(conn, password)
err = user.SetPassword(ctx, tx, password)
if err != nil {
return nil, errors.Wrap(err, "user.SetPassword")
}
return user, nil
}
// Fetches data from the users table using "WHERE column = 'value'"
func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
query := fmt.Sprintf(
`SELECT
id,
username,
password_hash,
created_at,
bio
FROM users
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
column,
)
rows, err := conn.Query(query, value)
if err != nil {
return nil, errors.Wrap(err, "conn.Query")
}
return rows, nil
}
// Fetches data from the users table using "WHERE column = 'value'"
func fetchUserData(
ctx context.Context,
@@ -92,8 +77,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
// Queries the database for a user matching the given username.
// Query is case insensitive
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
rows, err := oldfetchUserData(conn, "username", username)
func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) {
rows, err := fetchUserData(ctx, tx, "username", username)
if err != nil {
return nil, errors.Wrap(err, "fetchUserData")
}
@@ -122,11 +107,11 @@ func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
}
// Checks if the given username is unique. Returns true if not taken
func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) {
func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) {
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
rows, err := conn.Query(query, username)
rows, err := tx.Query(ctx, query, username)
if err != nil {
return false, errors.Wrap(err, "conn.Query")
return false, errors.Wrap(err, "tx.Query")
}
defer rows.Close()
taken := rows.Next()

9
go.mod
View File

@@ -7,18 +7,25 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/mattn/go-sqlite3 v1.14.24
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.10.0
golang.org/x/crypto v0.33.0
modernc.org/sqlite v1.35.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect
golang.org/x/sys v0.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.61.13 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.8.2 // indirect
)

42
go.sum
View File

@@ -3,11 +3,15 @@ github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YY
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
@@ -19,12 +23,14 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
@@ -32,12 +38,44 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo=
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo=
modernc.org/ccgo/v4 v4.23.16/go.mod h1:nNma8goMTY7aQZQNTyN9AIoJfxav4nvTnvKThAeMDdo=
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
modernc.org/gc/v2 v2.6.3 h1:aJVhcqAte49LF+mGveZ5KPlsp4tdGdAOT4sipJXADjw=
modernc.org/gc/v2 v2.6.3/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/libc v1.61.13 h1:3LRd6ZO1ezsFiX1y+bHd1ipyEHIJKvuprv0sLTBwLW8=
modernc.org/libc v1.61.13/go.mod h1:8F/uJWL/3nNil0Lgt1Dpz+GgkApWh04N3el3hxJcA6E=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.8.2 h1:cL9L4bcoAObu4NkxOlKWBWtNHIsnnACGF/TbqQ6sbcI=
modernc.org/memory v1.8.2/go.mod h1:ZbjSvMO5NQ1A2i3bWeDiVMxIorXwdClKE/0SZ+BMotU=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.35.0 h1:yQps4fegMnZFdphtzlfQTCNBWtS0CZv48pRpW3RFHRw=
modernc.org/sqlite v1.35.0/go.mod h1:9cr2sicr7jIaWTBKQmAxQLfBv9LL0su4ZTEV+utt3ic=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

View File

@@ -1,7 +1,7 @@
package handlers
import (
"database/sql"
"context"
"net/http"
"projectreshoot/contexts"
@@ -43,63 +43,81 @@ func HandleAccountSubpage() http.Handler {
// Handles a request to change the users username
func HandleChangeUsername(
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
WithTransaction(w, r, logger, conn,
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
r.ParseForm()
newUsername := r.FormValue("username")
unique, err := db.CheckUsernameUnique(conn, newUsername)
unique, err := db.CheckUsernameUnique(ctx, tx, newUsername)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error updating username")
w.WriteHeader(http.StatusInternalServerError)
return
}
if !unique {
tx.Rollback()
account.ChangeUsername("Username is taken", newUsername).
Render(r.Context(), w)
return
}
user := contexts.GetUser(r.Context())
err = user.ChangeUsername(conn, newUsername)
err = user.ChangeUsername(ctx, tx, newUsername)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error updating username")
w.WriteHeader(http.StatusInternalServerError)
return
}
tx.Commit()
w.Header().Set("HX-Refresh", "true")
},
)
},
)
}
// Handles a request to change the users bio
func HandleChangeBio(
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
WithTransaction(w, r, logger, conn,
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
r.ParseForm()
newBio := r.FormValue("bio")
leng := len([]rune(newBio))
if leng > 128 {
tx.Rollback()
account.ChangeBio("Bio limited to 128 characters", newBio).
Render(r.Context(), w)
return
}
user := contexts.GetUser(r.Context())
err := user.ChangeBio(conn, newBio)
err := user.ChangeBio(ctx, tx, newBio)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error updating bio")
w.WriteHeader(http.StatusInternalServerError)
return
}
tx.Commit()
w.Header().Set("HX-Refresh", "true")
},
)
},
)
}
func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) {
func validateChangePassword(
ctx context.Context,
tx *db.SafeTX,
r *http.Request,
) (string, error) {
r.ParseForm()
formPassword := r.FormValue("password")
formConfirmPassword := r.FormValue("confirm-password")
@@ -115,23 +133,30 @@ func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) {
// Handles a request to change the users password
func HandleChangePassword(
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
newPass, err := validateChangePassword(conn, r)
WithTransaction(w, r, logger, conn,
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
newPass, err := validateChangePassword(ctx, tx, r)
if err != nil {
tx.Rollback()
account.ChangePassword(err.Error()).Render(r.Context(), w)
return
}
user := contexts.GetUser(r.Context())
err = user.SetPassword(conn, newPass)
err = user.SetPassword(ctx, tx, newPass)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error updating password")
w.WriteHeader(http.StatusInternalServerError)
return
}
tx.Commit()
w.Header().Set("HX-Refresh", "true")
},
)
},
)
}

View File

@@ -1,7 +1,7 @@
package handlers
import (
"database/sql"
"context"
"net/http"
"projectreshoot/config"
@@ -16,10 +16,14 @@ import (
// Validates the username matches a user in the database and the password
// is correct. Returns the corresponding user
func validateLogin(conn *sql.DB, r *http.Request) (*db.User, error) {
func validateLogin(
ctx context.Context,
tx *db.SafeTX,
r *http.Request,
) (*db.User, error) {
formUsername := r.FormValue("username")
formPassword := r.FormValue("password")
user, err := db.GetUserFromUsername(conn, formUsername)
user, err := db.GetUserFromUsername(ctx, tx, formUsername)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromUsername")
}
@@ -47,13 +51,16 @@ func checkRememberMe(r *http.Request) bool {
func HandleLoginRequest(
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
WithTransaction(w, r, logger, conn,
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
r.ParseForm()
user, err := validateLogin(conn, r)
user, err := validateLogin(ctx, tx, r)
if err != nil {
tx.Rollback()
if err.Error() != "Username or password incorrect" {
logger.Warn().Caller().Err(err).Msg("Login request failed")
w.WriteHeader(http.StatusInternalServerError)
@@ -66,12 +73,16 @@ func HandleLoginRequest(
rememberMe := checkRememberMe(r)
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
if err != nil {
tx.Rollback()
w.WriteHeader(http.StatusInternalServerError)
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
return
}
tx.Commit()
pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom)
})
},
)
}

View File

@@ -1,7 +1,7 @@
package handlers
import (
"database/sql"
"context"
"net/http"
"projectreshoot/config"
@@ -14,11 +14,15 @@ import (
"github.com/rs/zerolog"
)
func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
func validateRegistration(
ctx context.Context,
tx *db.SafeTX,
r *http.Request,
) (*db.User, error) {
formUsername := r.FormValue("username")
formPassword := r.FormValue("password")
formConfirmPassword := r.FormValue("confirm-password")
unique, err := db.CheckUsernameUnique(conn, formUsername)
unique, err := db.CheckUsernameUnique(ctx, tx, formUsername)
if err != nil {
return nil, errors.Wrap(err, "db.CheckUsernameUnique")
}
@@ -31,7 +35,7 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
if len(formPassword) > 72 {
return nil, errors.New("Password exceeds maximum length of 72 bytes")
}
user, err := db.CreateNewUser(conn, formUsername, formPassword)
user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword)
if err != nil {
return nil, errors.Wrap(err, "db.CreateNewUser")
}
@@ -42,13 +46,16 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
func HandleRegisterRequest(
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
WithTransaction(w, r, logger, conn,
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
r.ParseForm()
user, err := validateRegistration(conn, r)
user, err := validateRegistration(ctx, tx, r)
if err != nil {
tx.Rollback()
if err.Error() != "Username is taken" &&
err.Error() != "Passwords do not match" &&
err.Error() != "Password exceeds maximum length of 72 bytes" {
@@ -63,14 +70,18 @@ func HandleRegisterRequest(
rememberMe := checkRememberMe(r)
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
if err != nil {
tx.Rollback()
w.WriteHeader(http.StatusInternalServerError)
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
return
}
tx.Commit()
pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom)
},
)
},
)
}
// Handles a request to view the login page. Will attempt to set "pagefrom"

View File

@@ -77,11 +77,6 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return errors.Wrap(err, "logging.GetLogger")
}
oldconn, err := db.OldConnectToDatabase(config.DBName)
if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase")
}
defer oldconn.Close()
conn, err := db.ConnectToDatabase(config.DBName)
if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase")
@@ -93,7 +88,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return errors.Wrap(err, "getStaticFiles")
}
srv := server.NewServer(config, logger, oldconn, conn, &staticFS)
srv := server.NewServer(config, logger, conn, &staticFS)
httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, config.Port),
Handler: srv,
@@ -104,7 +99,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
// Runs function for testing in dev if --test flag true
if args["test"] == "true" {
test(config, logger, oldconn, httpServer)
test(config, logger, conn, httpServer)
return nil
}

View File

@@ -107,7 +107,6 @@ func Authentication(
}
handlers.WithTransaction(w, r, logger, conn,
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
tx, err := conn.Begin(ctx)
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
if err != nil {
tx.Rollback()

View File

@@ -8,6 +8,7 @@ import (
"testing"
"projectreshoot/contexts"
"projectreshoot/db"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
@@ -16,13 +17,14 @@ import (
func TestAuthenticationMiddleware(t *testing.T) {
// Basic setup
conn, err := tests.SetupTestDB()
require.NoError(t, err)
sconn := db.MakeSafe(conn)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
conn, err := tests.SetupTestDB(t.Context())
require.NoError(t, err)
require.NotNil(t, conn)
defer tests.DeleteTestDB()
logger := tests.DebugLogger(t)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -38,7 +40,7 @@ func TestAuthenticationMiddleware(t *testing.T) {
})
// Add the middleware and create the server
authHandler := Authentication(logger, cfg, conn, testHandler)
authHandler := Authentication(logger, cfg, sconn, testHandler)
require.NoError(t, err)
server := httptest.NewServer(authHandler)
defer server.Close()

View File

@@ -5,6 +5,7 @@ import (
"net/http/httptest"
"testing"
"projectreshoot/db"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
@@ -13,13 +14,14 @@ import (
func TestPageLoginRequired(t *testing.T) {
// Basic setup
conn, err := tests.SetupTestDB()
require.NoError(t, err)
sconn := db.MakeSafe(conn)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
conn, err := tests.SetupTestDB(t.Context())
require.NoError(t, err)
require.NotNil(t, conn)
defer tests.DeleteTestDB()
logger := tests.DebugLogger(t)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -28,7 +30,7 @@ func TestPageLoginRequired(t *testing.T) {
// Add the middleware and create the server
loginRequiredHandler := RequiresLogin(testHandler)
authHandler := Authentication(logger, cfg, conn, loginRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler)
server := httptest.NewServer(authHandler)
defer server.Close()

View File

@@ -5,21 +5,23 @@ import (
"net/http/httptest"
"testing"
"projectreshoot/db"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestActionReauthRequired(t *testing.T) {
func TestReauthRequired(t *testing.T) {
// Basic setup
conn, err := tests.SetupTestDB()
require.NoError(t, err)
sconn := db.MakeSafe(conn)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
conn, err := tests.SetupTestDB(t.Context())
require.NoError(t, err)
require.NotNil(t, conn)
defer tests.DeleteTestDB()
logger := tests.DebugLogger(t)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -29,7 +31,7 @@ func TestActionReauthRequired(t *testing.T) {
// Add the middleware and create the server
reauthRequiredHandler := RequiresFresh(testHandler)
loginRequiredHandler := RequiresLogin(reauthRequiredHandler)
authHandler := Authentication(logger, cfg, conn, loginRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler)
server := httptest.NewServer(authHandler)
defer server.Close()

View File

@@ -1,7 +1,6 @@
package server
import (
"database/sql"
"net/http"
"projectreshoot/config"
@@ -18,7 +17,6 @@ func addRoutes(
mux *http.ServeMux,
logger *zerolog.Logger,
config *config.Config,
oldconn *sql.DB,
conn *db.SafeConn,
staticFS *http.FileSystem,
) {
@@ -44,7 +42,7 @@ func addRoutes(
handlers.HandleLoginRequest(
config,
logger,
oldconn,
conn,
)))
// Register page and handlers
@@ -57,7 +55,7 @@ func addRoutes(
handlers.HandleRegisterRequest(
config,
logger,
oldconn,
conn,
)))
// Logout
@@ -87,17 +85,17 @@ func addRoutes(
mux.Handle("POST /change-username",
middleware.RequiresLogin(
middleware.RequiresFresh(
handlers.HandleChangeUsername(logger, oldconn),
handlers.HandleChangeUsername(logger, conn),
),
))
mux.Handle("POST /change-bio",
middleware.RequiresLogin(
handlers.HandleChangeBio(logger, oldconn),
handlers.HandleChangeBio(logger, conn),
))
mux.Handle("POST /change-password",
middleware.RequiresLogin(
middleware.RequiresFresh(
handlers.HandleChangePassword(logger, oldconn),
handlers.HandleChangePassword(logger, conn),
),
))
}

View File

@@ -1,7 +1,6 @@
package server
import (
"database/sql"
"net/http"
"projectreshoot/config"
@@ -15,7 +14,6 @@ import (
func NewServer(
config *config.Config,
logger *zerolog.Logger,
oldconn *sql.DB,
conn *db.SafeConn,
staticFS *http.FileSystem,
) http.Handler {
@@ -24,7 +22,6 @@ func NewServer(
mux,
logger,
config,
oldconn,
conn,
staticFS,
)

View File

@@ -1,10 +1,10 @@
package main
import (
"database/sql"
"net/http"
"projectreshoot/config"
"projectreshoot/db"
"github.com/rs/zerolog"
)
@@ -18,7 +18,7 @@ import (
func test(
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
srv *http.Server,
) {
}

View File

@@ -1,16 +1,14 @@
package tests
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"projectreshoot/db"
"github.com/pkg/errors"
_ "github.com/mattn/go-sqlite3"
_ "modernc.org/sqlite"
)
func findSQLFile(filename string) (string, error) {
@@ -33,17 +31,11 @@ func findSQLFile(filename string) (string, error) {
}
// SetupTestDB initializes a test SQLite database with mock data
// Make sure to call DeleteTestDB when finished to cleanup
func SetupTestDB(ctx context.Context) (*db.SafeConn, error) {
dbfile, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
func SetupTestDB() (*sql.DB, error) {
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
conn := db.MakeSafe(dbfile)
tx, err := conn.Begin(ctx)
if err != nil {
return nil, errors.Wrap(err, "conn.Begin")
}
// Setup the test database
schemaPath, err := findSQLFile("schema.sql")
if err != nil {
@@ -56,9 +48,8 @@ func SetupTestDB(ctx context.Context) (*db.SafeConn, error) {
}
schemaSQL := string(sqlBytes)
_, err = tx.Exec(ctx, schemaSQL)
_, err = conn.Exec(schemaSQL)
if err != nil {
tx.Rollback()
return nil, errors.Wrap(err, "tx.Exec")
}
// Load the test data
@@ -72,24 +63,9 @@ func SetupTestDB(ctx context.Context) (*db.SafeConn, error) {
}
dataSQL := string(sqlBytes)
_, err = tx.Exec(ctx, dataSQL)
_, err = conn.Exec(dataSQL)
if err != nil {
tx.Rollback()
return nil, errors.Wrap(err, "tx.Exec")
}
tx.Commit()
return conn, nil
}
// Deletes the test database from disk
func DeleteTestDB() error {
fileName := ".projectreshoot-test-database.db"
// Attempt to remove the file
err := os.Remove(fileName)
if err != nil {
return errors.Wrap(err, "os.Remove")
}
return nil
}