Updated all code to use SafeConn and SafeTX

This commit is contained in:
2025-02-17 21:39:12 +11:00
parent 6faf168a6d
commit a8d112fdd5
17 changed files with 265 additions and 218 deletions

View File

@@ -9,43 +9,28 @@ import (
)
// Creates a new user in the database and returns a pointer
func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) {
func CreateNewUser(
ctx context.Context,
tx *SafeTX,
username string,
password string,
) (*User, error) {
query := `INSERT INTO users (username) VALUES (?)`
_, err := conn.Exec(query, username)
_, err := tx.Exec(ctx, query, username)
if err != nil {
return nil, errors.Wrap(err, "conn.Exec")
return nil, errors.Wrap(err, "tx.Exec")
}
user, err := GetUserFromUsername(conn, username)
user, err := GetUserFromUsername(ctx, tx, username)
if err != nil {
return nil, errors.Wrap(err, "GetUserFromUsername")
}
err = user.SetPassword(conn, password)
err = user.SetPassword(ctx, tx, password)
if err != nil {
return nil, errors.Wrap(err, "user.SetPassword")
}
return user, nil
}
// Fetches data from the users table using "WHERE column = 'value'"
func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
query := fmt.Sprintf(
`SELECT
id,
username,
password_hash,
created_at,
bio
FROM users
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
column,
)
rows, err := conn.Query(query, value)
if err != nil {
return nil, errors.Wrap(err, "conn.Query")
}
return rows, nil
}
// Fetches data from the users table using "WHERE column = 'value'"
func fetchUserData(
ctx context.Context,
@@ -92,8 +77,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
// Queries the database for a user matching the given username.
// Query is case insensitive
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
rows, err := oldfetchUserData(conn, "username", username)
func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) {
rows, err := fetchUserData(ctx, tx, "username", username)
if err != nil {
return nil, errors.Wrap(err, "fetchUserData")
}
@@ -122,11 +107,11 @@ func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
}
// Checks if the given username is unique. Returns true if not taken
func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) {
func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) {
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
rows, err := conn.Query(query, username)
rows, err := tx.Query(ctx, query, username)
if err != nil {
return false, errors.Wrap(err, "conn.Query")
return false, errors.Wrap(err, "tx.Query")
}
defer rows.Close()
taken := rows.Next()