Update authentication, reauth, logout to use new transactions

This commit is contained in:
2025-02-17 18:58:34 +11:00
parent 417daf0028
commit 2c61cec55c
17 changed files with 306 additions and 121 deletions

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
@@ -26,7 +27,7 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error
}
// Fetches data from the users table using "WHERE column = 'value'"
func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
query := fmt.Sprintf(
`SELECT
id,
@@ -45,6 +46,31 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e
return rows, nil
}
// Fetches data from the users table using "WHERE column = 'value'"
func fetchUserData(
ctx context.Context,
tx *SafeTX,
column string,
value interface{},
) (*sql.Rows, error) {
query := fmt.Sprintf(
`SELECT
id,
username,
password_hash,
created_at,
bio
FROM users
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
column,
)
rows, err := tx.Query(ctx, query, value)
if err != nil {
return nil, errors.Wrap(err, "tx.Query")
}
return rows, nil
}
// Scan the next row into the provided user pointer. Calls rows.Next() and
// assumes only row in the result. Providing a rows object with more than 1
// row may result in undefined behaviour.
@@ -67,7 +93,7 @@ func scanUserRow(user *User, rows *sql.Rows) error {
// Queries the database for a user matching the given username.
// Query is case insensitive
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
rows, err := fetchUserData(conn, "username", username)
rows, err := oldfetchUserData(conn, "username", username)
if err != nil {
return nil, errors.Wrap(err, "fetchUserData")
}
@@ -81,8 +107,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
}
// Queries the database for a user matching the given ID.
func GetUserFromID(conn *sql.DB, id int) (*User, error) {
rows, err := fetchUserData(conn, "id", id)
func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
rows, err := fetchUserData(ctx, tx, "id", id)
if err != nil {
return nil, errors.Wrap(err, "fetchUserData")
}