Updated all code to use SafeConn and SafeTX
This commit is contained in:
1
Makefile
1
Makefile
@@ -20,7 +20,6 @@ tester:
|
||||
go run . --port 3232 --test --loglevel trace
|
||||
|
||||
test:
|
||||
rm -f **/.projectreshoot-test-database.db && \
|
||||
go mod tidy && \
|
||||
templ generate && \
|
||||
go generate && \
|
||||
|
||||
20
db/user.go
20
db/user.go
@@ -1,7 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -16,16 +16,16 @@ type User struct {
|
||||
}
|
||||
|
||||
// Uses bcrypt to set the users Password_hash from the given password
|
||||
func (user *User) SetPassword(conn *sql.DB, password string) error {
|
||||
func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||
}
|
||||
user.Password_hash = string(hashedPassword)
|
||||
query := `UPDATE users SET password_hash = ? WHERE id = ?`
|
||||
_, err = conn.Exec(query, user.Password_hash, user.ID)
|
||||
_, err = tx.Exec(ctx, query, user.Password_hash, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -40,21 +40,21 @@ func (user *User) CheckPassword(password string) error {
|
||||
}
|
||||
|
||||
// Change the user's username
|
||||
func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error {
|
||||
func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error {
|
||||
query := `UPDATE users SET username = ? WHERE id = ?`
|
||||
_, err := conn.Exec(query, newUsername, user.ID)
|
||||
_, err := tx.Exec(ctx, query, newUsername, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the user's bio
|
||||
func (user *User) ChangeBio(conn *sql.DB, newBio string) error {
|
||||
func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error {
|
||||
query := `UPDATE users SET bio = ? WHERE id = ?`
|
||||
_, err := conn.Exec(query, newBio, user.ID)
|
||||
_, err := tx.Exec(ctx, query, newBio, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,43 +9,28 @@ import (
|
||||
)
|
||||
|
||||
// Creates a new user in the database and returns a pointer
|
||||
func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) {
|
||||
func CreateNewUser(
|
||||
ctx context.Context,
|
||||
tx *SafeTX,
|
||||
username string,
|
||||
password string,
|
||||
) (*User, error) {
|
||||
query := `INSERT INTO users (username) VALUES (?)`
|
||||
_, err := conn.Exec(query, username)
|
||||
_, err := tx.Exec(ctx, query, username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Exec")
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
user, err := GetUserFromUsername(conn, username)
|
||||
user, err := GetUserFromUsername(ctx, tx, username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "GetUserFromUsername")
|
||||
}
|
||||
err = user.SetPassword(conn, password)
|
||||
err = user.SetPassword(ctx, tx, password)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "user.SetPassword")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT
|
||||
id,
|
||||
username,
|
||||
password_hash,
|
||||
created_at,
|
||||
bio
|
||||
FROM users
|
||||
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
||||
column,
|
||||
)
|
||||
rows, err := conn.Query(query, value)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Query")
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func fetchUserData(
|
||||
ctx context.Context,
|
||||
@@ -92,8 +77,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
|
||||
|
||||
// Queries the database for a user matching the given username.
|
||||
// Query is case insensitive
|
||||
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
||||
rows, err := oldfetchUserData(conn, "username", username)
|
||||
func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) {
|
||||
rows, err := fetchUserData(ctx, tx, "username", username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetchUserData")
|
||||
}
|
||||
@@ -122,11 +107,11 @@ func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
||||
}
|
||||
|
||||
// Checks if the given username is unique. Returns true if not taken
|
||||
func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) {
|
||||
func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) {
|
||||
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
|
||||
rows, err := conn.Query(query, username)
|
||||
rows, err := tx.Query(ctx, query, username)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "conn.Query")
|
||||
return false, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
taken := rows.Next()
|
||||
|
||||
9
go.mod
9
go.mod
@@ -7,18 +7,25 @@ require (
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/crypto v0.33.0
|
||||
modernc.org/sqlite v1.35.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.61.13 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.8.2 // indirect
|
||||
)
|
||||
|
||||
42
go.sum
42
go.sum
@@ -3,11 +3,15 @@ github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YY
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
@@ -19,12 +23,14 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
@@ -32,12 +38,44 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo=
|
||||
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
||||
modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo=
|
||||
modernc.org/ccgo/v4 v4.23.16/go.mod h1:nNma8goMTY7aQZQNTyN9AIoJfxav4nvTnvKThAeMDdo=
|
||||
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
|
||||
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
|
||||
modernc.org/gc/v2 v2.6.3 h1:aJVhcqAte49LF+mGveZ5KPlsp4tdGdAOT4sipJXADjw=
|
||||
modernc.org/gc/v2 v2.6.3/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/libc v1.61.13 h1:3LRd6ZO1ezsFiX1y+bHd1ipyEHIJKvuprv0sLTBwLW8=
|
||||
modernc.org/libc v1.61.13/go.mod h1:8F/uJWL/3nNil0Lgt1Dpz+GgkApWh04N3el3hxJcA6E=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.8.2 h1:cL9L4bcoAObu4NkxOlKWBWtNHIsnnACGF/TbqQ6sbcI=
|
||||
modernc.org/memory v1.8.2/go.mod h1:ZbjSvMO5NQ1A2i3bWeDiVMxIorXwdClKE/0SZ+BMotU=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.35.0 h1:yQps4fegMnZFdphtzlfQTCNBWtS0CZv48pRpW3RFHRw=
|
||||
modernc.org/sqlite v1.35.0/go.mod h1:9cr2sicr7jIaWTBKQmAxQLfBv9LL0su4ZTEV+utt3ic=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/contexts"
|
||||
@@ -43,32 +43,39 @@ func HandleAccountSubpage() http.Handler {
|
||||
// Handles a request to change the users username
|
||||
func HandleChangeUsername(
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
newUsername := r.FormValue("username")
|
||||
|
||||
unique, err := db.CheckUsernameUnique(conn, newUsername)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Error updating username")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !unique {
|
||||
account.ChangeUsername("Username is taken", newUsername).
|
||||
Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
user := contexts.GetUser(r.Context())
|
||||
err = user.ChangeUsername(conn, newUsername)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Error updating username")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("HX-Refresh", "true")
|
||||
WithTransaction(w, r, logger, conn,
|
||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
newUsername := r.FormValue("username")
|
||||
unique, err := db.CheckUsernameUnique(ctx, tx, newUsername)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error updating username")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !unique {
|
||||
tx.Rollback()
|
||||
account.ChangeUsername("Username is taken", newUsername).
|
||||
Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
user := contexts.GetUser(r.Context())
|
||||
err = user.ChangeUsername(ctx, tx, newUsername)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error updating username")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.Header().Set("HX-Refresh", "true")
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -76,30 +83,41 @@ func HandleChangeUsername(
|
||||
// Handles a request to change the users bio
|
||||
func HandleChangeBio(
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
newBio := r.FormValue("bio")
|
||||
leng := len([]rune(newBio))
|
||||
if leng > 128 {
|
||||
account.ChangeBio("Bio limited to 128 characters", newBio).
|
||||
Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
user := contexts.GetUser(r.Context())
|
||||
err := user.ChangeBio(conn, newBio)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Error updating bio")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("HX-Refresh", "true")
|
||||
WithTransaction(w, r, logger, conn,
|
||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
newBio := r.FormValue("bio")
|
||||
leng := len([]rune(newBio))
|
||||
if leng > 128 {
|
||||
tx.Rollback()
|
||||
account.ChangeBio("Bio limited to 128 characters", newBio).
|
||||
Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
user := contexts.GetUser(r.Context())
|
||||
err := user.ChangeBio(ctx, tx, newBio)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error updating bio")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.Header().Set("HX-Refresh", "true")
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) {
|
||||
func validateChangePassword(
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) (string, error) {
|
||||
r.ParseForm()
|
||||
formPassword := r.FormValue("password")
|
||||
formConfirmPassword := r.FormValue("confirm-password")
|
||||
@@ -115,23 +133,30 @@ func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) {
|
||||
// Handles a request to change the users password
|
||||
func HandleChangePassword(
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
newPass, err := validateChangePassword(conn, r)
|
||||
if err != nil {
|
||||
account.ChangePassword(err.Error()).Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
user := contexts.GetUser(r.Context())
|
||||
err = user.SetPassword(conn, newPass)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Error updating password")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("HX-Refresh", "true")
|
||||
WithTransaction(w, r, logger, conn,
|
||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||
newPass, err := validateChangePassword(ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
account.ChangePassword(err.Error()).Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
user := contexts.GetUser(r.Context())
|
||||
err = user.SetPassword(ctx, tx, newPass)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error updating password")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.Header().Set("HX-Refresh", "true")
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
@@ -16,10 +16,14 @@ import (
|
||||
|
||||
// Validates the username matches a user in the database and the password
|
||||
// is correct. Returns the corresponding user
|
||||
func validateLogin(conn *sql.DB, r *http.Request) (*db.User, error) {
|
||||
func validateLogin(
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) (*db.User, error) {
|
||||
formUsername := r.FormValue("username")
|
||||
formPassword := r.FormValue("password")
|
||||
user, err := db.GetUserFromUsername(conn, formUsername)
|
||||
user, err := db.GetUserFromUsername(ctx, tx, formUsername)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromUsername")
|
||||
}
|
||||
@@ -47,31 +51,38 @@ func checkRememberMe(r *http.Request) bool {
|
||||
func HandleLoginRequest(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
user, err := validateLogin(conn, r)
|
||||
if err != nil {
|
||||
if err.Error() != "Username or password incorrect" {
|
||||
logger.Warn().Caller().Err(err).Msg("Login request failed")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
form.LoginForm(err.Error()).Render(r.Context(), w)
|
||||
}
|
||||
return
|
||||
}
|
||||
WithTransaction(w, r, logger, conn,
|
||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
user, err := validateLogin(ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
if err.Error() != "Username or password incorrect" {
|
||||
logger.Warn().Caller().Err(err).Msg("Login request failed")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
form.LoginForm(err.Error()).Render(r.Context(), w)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rememberMe := checkRememberMe(r)
|
||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
||||
}
|
||||
rememberMe := checkRememberMe(r)
|
||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
||||
return
|
||||
}
|
||||
|
||||
pageFrom := cookies.CheckPageFrom(w, r)
|
||||
w.Header().Set("HX-Redirect", pageFrom)
|
||||
tx.Commit()
|
||||
pageFrom := cookies.CheckPageFrom(w, r)
|
||||
w.Header().Set("HX-Redirect", pageFrom)
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
@@ -14,11 +14,15 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
|
||||
func validateRegistration(
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) (*db.User, error) {
|
||||
formUsername := r.FormValue("username")
|
||||
formPassword := r.FormValue("password")
|
||||
formConfirmPassword := r.FormValue("confirm-password")
|
||||
unique, err := db.CheckUsernameUnique(conn, formUsername)
|
||||
unique, err := db.CheckUsernameUnique(ctx, tx, formUsername)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.CheckUsernameUnique")
|
||||
}
|
||||
@@ -31,7 +35,7 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
|
||||
if len(formPassword) > 72 {
|
||||
return nil, errors.New("Password exceeds maximum length of 72 bytes")
|
||||
}
|
||||
user, err := db.CreateNewUser(conn, formUsername, formPassword)
|
||||
user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.CreateNewUser")
|
||||
}
|
||||
@@ -42,33 +46,40 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
|
||||
func HandleRegisterRequest(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
user, err := validateRegistration(conn, r)
|
||||
if err != nil {
|
||||
if err.Error() != "Username is taken" &&
|
||||
err.Error() != "Passwords do not match" &&
|
||||
err.Error() != "Password exceeds maximum length of 72 bytes" {
|
||||
logger.Warn().Caller().Err(err).Msg("Registration request failed")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
form.RegisterForm(err.Error()).Render(r.Context(), w)
|
||||
}
|
||||
return
|
||||
}
|
||||
WithTransaction(w, r, logger, conn,
|
||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
user, err := validateRegistration(ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
if err.Error() != "Username is taken" &&
|
||||
err.Error() != "Passwords do not match" &&
|
||||
err.Error() != "Password exceeds maximum length of 72 bytes" {
|
||||
logger.Warn().Caller().Err(err).Msg("Registration request failed")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
form.RegisterForm(err.Error()).Render(r.Context(), w)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rememberMe := checkRememberMe(r)
|
||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
||||
}
|
||||
|
||||
pageFrom := cookies.CheckPageFrom(w, r)
|
||||
w.Header().Set("HX-Redirect", pageFrom)
|
||||
rememberMe := checkRememberMe(r)
|
||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
pageFrom := cookies.CheckPageFrom(w, r)
|
||||
w.Header().Set("HX-Redirect", pageFrom)
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
9
main.go
9
main.go
@@ -77,11 +77,6 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
return errors.Wrap(err, "logging.GetLogger")
|
||||
}
|
||||
|
||||
oldconn, err := db.OldConnectToDatabase(config.DBName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
||||
}
|
||||
defer oldconn.Close()
|
||||
conn, err := db.ConnectToDatabase(config.DBName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
||||
@@ -93,7 +88,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
return errors.Wrap(err, "getStaticFiles")
|
||||
}
|
||||
|
||||
srv := server.NewServer(config, logger, oldconn, conn, &staticFS)
|
||||
srv := server.NewServer(config, logger, conn, &staticFS)
|
||||
httpServer := &http.Server{
|
||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
||||
Handler: srv,
|
||||
@@ -104,7 +99,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
|
||||
// Runs function for testing in dev if --test flag true
|
||||
if args["test"] == "true" {
|
||||
test(config, logger, oldconn, httpServer)
|
||||
test(config, logger, conn, httpServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -107,7 +107,6 @@ func Authentication(
|
||||
}
|
||||
handlers.WithTransaction(w, r, logger, conn,
|
||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||
tx, err := conn.Begin(ctx)
|
||||
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -16,13 +17,14 @@ import (
|
||||
|
||||
func TestAuthenticationMiddleware(t *testing.T) {
|
||||
// Basic setup
|
||||
conn, err := tests.SetupTestDB()
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn)
|
||||
defer sconn.Close()
|
||||
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB(t.Context())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
logger := tests.DebugLogger(t)
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -38,7 +40,7 @@ func TestAuthenticationMiddleware(t *testing.T) {
|
||||
})
|
||||
|
||||
// Add the middleware and create the server
|
||||
authHandler := Authentication(logger, cfg, conn, testHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, testHandler)
|
||||
require.NoError(t, err)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -13,13 +14,14 @@ import (
|
||||
|
||||
func TestPageLoginRequired(t *testing.T) {
|
||||
// Basic setup
|
||||
conn, err := tests.SetupTestDB()
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn)
|
||||
defer sconn.Close()
|
||||
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB(t.Context())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
logger := tests.DebugLogger(t)
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -28,7 +30,7 @@ func TestPageLoginRequired(t *testing.T) {
|
||||
|
||||
// Add the middleware and create the server
|
||||
loginRequiredHandler := RequiresLogin(testHandler)
|
||||
authHandler := Authentication(logger, cfg, conn, loginRequiredHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
|
||||
@@ -5,21 +5,23 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestActionReauthRequired(t *testing.T) {
|
||||
func TestReauthRequired(t *testing.T) {
|
||||
// Basic setup
|
||||
conn, err := tests.SetupTestDB()
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn)
|
||||
defer sconn.Close()
|
||||
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB(t.Context())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
logger := tests.DebugLogger(t)
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -29,7 +31,7 @@ func TestActionReauthRequired(t *testing.T) {
|
||||
// Add the middleware and create the server
|
||||
reauthRequiredHandler := RequiresFresh(testHandler)
|
||||
loginRequiredHandler := RequiresLogin(reauthRequiredHandler)
|
||||
authHandler := Authentication(logger, cfg, conn, loginRequiredHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
@@ -18,7 +17,6 @@ func addRoutes(
|
||||
mux *http.ServeMux,
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
oldconn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
staticFS *http.FileSystem,
|
||||
) {
|
||||
@@ -44,7 +42,7 @@ func addRoutes(
|
||||
handlers.HandleLoginRequest(
|
||||
config,
|
||||
logger,
|
||||
oldconn,
|
||||
conn,
|
||||
)))
|
||||
|
||||
// Register page and handlers
|
||||
@@ -57,7 +55,7 @@ func addRoutes(
|
||||
handlers.HandleRegisterRequest(
|
||||
config,
|
||||
logger,
|
||||
oldconn,
|
||||
conn,
|
||||
)))
|
||||
|
||||
// Logout
|
||||
@@ -87,17 +85,17 @@ func addRoutes(
|
||||
mux.Handle("POST /change-username",
|
||||
middleware.RequiresLogin(
|
||||
middleware.RequiresFresh(
|
||||
handlers.HandleChangeUsername(logger, oldconn),
|
||||
handlers.HandleChangeUsername(logger, conn),
|
||||
),
|
||||
))
|
||||
mux.Handle("POST /change-bio",
|
||||
middleware.RequiresLogin(
|
||||
handlers.HandleChangeBio(logger, oldconn),
|
||||
handlers.HandleChangeBio(logger, conn),
|
||||
))
|
||||
mux.Handle("POST /change-password",
|
||||
middleware.RequiresLogin(
|
||||
middleware.RequiresFresh(
|
||||
handlers.HandleChangePassword(logger, oldconn),
|
||||
handlers.HandleChangePassword(logger, conn),
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
@@ -15,7 +14,6 @@ import (
|
||||
func NewServer(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
oldconn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
staticFS *http.FileSystem,
|
||||
) http.Handler {
|
||||
@@ -24,7 +22,6 @@ func NewServer(
|
||||
mux,
|
||||
logger,
|
||||
config,
|
||||
oldconn,
|
||||
conn,
|
||||
staticFS,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
func test(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
srv *http.Server,
|
||||
) {
|
||||
}
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func findSQLFile(filename string) (string, error) {
|
||||
@@ -33,17 +31,11 @@ func findSQLFile(filename string) (string, error) {
|
||||
}
|
||||
|
||||
// SetupTestDB initializes a test SQLite database with mock data
|
||||
// Make sure to call DeleteTestDB when finished to cleanup
|
||||
func SetupTestDB(ctx context.Context) (*db.SafeConn, error) {
|
||||
dbfile, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
|
||||
func SetupTestDB() (*sql.DB, error) {
|
||||
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
conn := db.MakeSafe(dbfile)
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Begin")
|
||||
}
|
||||
// Setup the test database
|
||||
schemaPath, err := findSQLFile("schema.sql")
|
||||
if err != nil {
|
||||
@@ -56,9 +48,8 @@ func SetupTestDB(ctx context.Context) (*db.SafeConn, error) {
|
||||
}
|
||||
schemaSQL := string(sqlBytes)
|
||||
|
||||
_, err = tx.Exec(ctx, schemaSQL)
|
||||
_, err = conn.Exec(schemaSQL)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
// Load the test data
|
||||
@@ -72,24 +63,9 @@ func SetupTestDB(ctx context.Context) (*db.SafeConn, error) {
|
||||
}
|
||||
dataSQL := string(sqlBytes)
|
||||
|
||||
_, err = tx.Exec(ctx, dataSQL)
|
||||
_, err = conn.Exec(dataSQL)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
tx.Commit()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Deletes the test database from disk
|
||||
func DeleteTestDB() error {
|
||||
fileName := ".projectreshoot-test-database.db"
|
||||
|
||||
// Attempt to remove the file
|
||||
err := os.Remove(fileName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "os.Remove")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user