Updated all code to use SafeConn and SafeTX
This commit is contained in:
20
db/user.go
20
db/user.go
@@ -1,7 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -16,16 +16,16 @@ type User struct {
|
||||
}
|
||||
|
||||
// Uses bcrypt to set the users Password_hash from the given password
|
||||
func (user *User) SetPassword(conn *sql.DB, password string) error {
|
||||
func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||
}
|
||||
user.Password_hash = string(hashedPassword)
|
||||
query := `UPDATE users SET password_hash = ? WHERE id = ?`
|
||||
_, err = conn.Exec(query, user.Password_hash, user.ID)
|
||||
_, err = tx.Exec(ctx, query, user.Password_hash, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -40,21 +40,21 @@ func (user *User) CheckPassword(password string) error {
|
||||
}
|
||||
|
||||
// Change the user's username
|
||||
func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error {
|
||||
func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error {
|
||||
query := `UPDATE users SET username = ? WHERE id = ?`
|
||||
_, err := conn.Exec(query, newUsername, user.ID)
|
||||
_, err := tx.Exec(ctx, query, newUsername, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the user's bio
|
||||
func (user *User) ChangeBio(conn *sql.DB, newBio string) error {
|
||||
func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error {
|
||||
query := `UPDATE users SET bio = ? WHERE id = ?`
|
||||
_, err := conn.Exec(query, newBio, user.ID)
|
||||
_, err := tx.Exec(ctx, query, newBio, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,43 +9,28 @@ import (
|
||||
)
|
||||
|
||||
// Creates a new user in the database and returns a pointer
|
||||
func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) {
|
||||
func CreateNewUser(
|
||||
ctx context.Context,
|
||||
tx *SafeTX,
|
||||
username string,
|
||||
password string,
|
||||
) (*User, error) {
|
||||
query := `INSERT INTO users (username) VALUES (?)`
|
||||
_, err := conn.Exec(query, username)
|
||||
_, err := tx.Exec(ctx, query, username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Exec")
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
user, err := GetUserFromUsername(conn, username)
|
||||
user, err := GetUserFromUsername(ctx, tx, username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "GetUserFromUsername")
|
||||
}
|
||||
err = user.SetPassword(conn, password)
|
||||
err = user.SetPassword(ctx, tx, password)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "user.SetPassword")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT
|
||||
id,
|
||||
username,
|
||||
password_hash,
|
||||
created_at,
|
||||
bio
|
||||
FROM users
|
||||
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
||||
column,
|
||||
)
|
||||
rows, err := conn.Query(query, value)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Query")
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func fetchUserData(
|
||||
ctx context.Context,
|
||||
@@ -92,8 +77,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
|
||||
|
||||
// Queries the database for a user matching the given username.
|
||||
// Query is case insensitive
|
||||
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
||||
rows, err := oldfetchUserData(conn, "username", username)
|
||||
func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) {
|
||||
rows, err := fetchUserData(ctx, tx, "username", username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetchUserData")
|
||||
}
|
||||
@@ -122,11 +107,11 @@ func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
||||
}
|
||||
|
||||
// Checks if the given username is unique. Returns true if not taken
|
||||
func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) {
|
||||
func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) {
|
||||
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
|
||||
rows, err := conn.Query(query, username)
|
||||
rows, err := tx.Query(ctx, query, username)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "conn.Query")
|
||||
return false, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
taken := rows.Next()
|
||||
|
||||
Reference in New Issue
Block a user