118 lines
2.7 KiB
Go
118 lines
2.7 KiB
Go
package models
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"git.haelnorr.com/h/golib/hwsauth"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// Creates a new user in the database and returns a pointer
|
|
func CreateNewUser(
|
|
tx *sql.Tx,
|
|
username string,
|
|
password string,
|
|
) (*User, error) {
|
|
query := `INSERT INTO users (username) VALUES (?)`
|
|
_, err := tx.Exec(query, username)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "tx.Exec")
|
|
}
|
|
user, err := GetUserFromUsername(tx, username)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "GetUserFromUsername")
|
|
}
|
|
err = user.SetPassword(tx, password)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "user.SetPassword")
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
// Fetches data from the users table using "WHERE column = 'value'"
|
|
func fetchUserData(
|
|
tx interface {
|
|
Query(query string, args ...any) (*sql.Rows, error)
|
|
},
|
|
column string,
|
|
value any,
|
|
) (*sql.Rows, error) {
|
|
query := fmt.Sprintf(
|
|
`SELECT
|
|
id,
|
|
username,
|
|
created_at,
|
|
bio
|
|
FROM users
|
|
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
|
column,
|
|
)
|
|
rows, err := tx.Query(query, value)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "tx.Query")
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
// Calls rows.Next() and scans the row into the provided user pointer.
|
|
// Will error if no row available
|
|
func scanUserRow(user *User, rows *sql.Rows) error {
|
|
if !rows.Next() {
|
|
return errors.New("User not found")
|
|
}
|
|
err := rows.Scan(
|
|
&user.id,
|
|
&user.Username,
|
|
&user.Created_at,
|
|
&user.Bio,
|
|
)
|
|
if err != nil {
|
|
return errors.Wrap(err, "rows.Scan")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Queries the database for a user matching the given username.
|
|
// Query is case insensitive
|
|
func GetUserFromUsername(tx *sql.Tx, username string) (*User, error) {
|
|
rows, err := fetchUserData(tx, "username", username)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "fetchUserData")
|
|
}
|
|
defer rows.Close()
|
|
var user User
|
|
err = scanUserRow(&user, rows)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "scanUserRow")
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// Queries the database for a user matching the given ID.
|
|
func GetUserFromID(tx hwsauth.DBTransaction, id int) (*User, error) {
|
|
rows, err := fetchUserData(tx, "id", id)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "fetchUserData")
|
|
}
|
|
defer rows.Close()
|
|
var user User
|
|
err = scanUserRow(&user, rows)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "scanUserRow")
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// Checks if the given username is unique. Returns true if not taken
|
|
func CheckUsernameUnique(tx *sql.Tx, username string) (bool, error) {
|
|
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
|
|
rows, err := tx.Query(query, username)
|
|
if err != nil {
|
|
return false, errors.Wrap(err, "tx.Query")
|
|
}
|
|
defer rows.Close()
|
|
taken := rows.Next()
|
|
return !taken, nil
|
|
}
|