diff --git a/db/user.go b/db/user.go new file mode 100644 index 0000000..a330357 --- /dev/null +++ b/db/user.go @@ -0,0 +1,50 @@ +package db + +import ( + "database/sql" + + "github.com/pkg/errors" + "golang.org/x/crypto/bcrypt" +) + +type User struct { + ID int // Integer ID (index primary key) + Username string // Username (unique) + Password_hash string // Bcrypt password hash + Created_at int64 // Epoch timestamp when the user was added to the database + Bio string // Short byline set by the user +} + +// Uses bcrypt to set the users Password_hash from the given password +func (user *User) SetPassword(conn *sql.DB, password string) error { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return errors.Wrap(err, "bcrypt.GenerateFromPassword") + } + user.Password_hash = string(hashedPassword) + query := `UPDATE users SET password_hash = ? WHERE id = ?` + _, err = conn.Exec(query, user.Password_hash, user.ID) + if err != nil { + return errors.Wrap(err, "conn.Exec") + } + return nil +} + +// Uses bcrypt to check if the given password matches the users Password_hash +func (user *User) CheckPassword(password string) error { + err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password)) + if err != nil { + return errors.Wrap(err, "bcrypt.CompareHashAndPassword") + } + return nil +} + +// Change the user's username +func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error { + query := `UPDATE users SET username = ? WHERE id = ?` + _, err := conn.Exec(query, newUsername, user.ID) + if err != nil { + return errors.Wrap(err, "conn.Exec") + } + return nil +} diff --git a/db/user_functions.go b/db/user_functions.go new file mode 100644 index 0000000..3c3623e --- /dev/null +++ b/db/user_functions.go @@ -0,0 +1,108 @@ +package db + +import ( + "database/sql" + "fmt" + + "github.com/pkg/errors" +) + +// Creates a new user in the database and returns a pointer +func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { + query := `INSERT INTO users (username) VALUES (?)` + _, err := conn.Exec(query, username) + if err != nil { + return nil, errors.Wrap(err, "conn.Exec") + } + user, err := GetUserFromUsername(conn, username) + if err != nil { + return nil, errors.Wrap(err, "GetUserFromUsername") + } + err = user.SetPassword(conn, password) + if err != nil { + return nil, errors.Wrap(err, "user.SetPassword") + } + return user, nil +} + +// Fetches data from the users table using "WHERE column = 'value'" +func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { + query := fmt.Sprintf( + `SELECT + id, + username, + password_hash, + created_at, + bio + FROM users + WHERE %s = ? COLLATE NOCASE LIMIT 1`, + column, + ) + rows, err := conn.Query(query, value) + if err != nil { + return nil, errors.Wrap(err, "conn.Query") + } + return rows, nil +} + +// Scan the next row into the provided user pointer. Calls rows.Next() and +// assumes only row in the result. Providing a rows object with more than 1 +// row may result in undefined behaviour. +func scanUserRow(user *User, rows *sql.Rows) error { + for rows.Next() { + err := rows.Scan( + &user.ID, + &user.Username, + &user.Password_hash, + &user.Created_at, + &user.Bio, + ) + if err != nil { + return errors.Wrap(err, "rows.Scan") + } + } + return nil +} + +// Queries the database for a user matching the given username. +// Query is case insensitive +func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { + rows, err := fetchUserData(conn, "username", username) + if err != nil { + return nil, errors.Wrap(err, "fetchUserData") + } + defer rows.Close() + var user User + err = scanUserRow(&user, rows) + if err != nil { + return nil, errors.Wrap(err, "scanUserRow") + } + return &user, nil +} + +// Queries the database for a user matching the given ID. +func GetUserFromID(conn *sql.DB, id int) (*User, error) { + rows, err := fetchUserData(conn, "id", id) + if err != nil { + return nil, errors.Wrap(err, "fetchUserData") + } + defer rows.Close() + var user User + err = scanUserRow(&user, rows) + if err != nil { + return nil, errors.Wrap(err, "scanUserRow") + } + return &user, nil +} + +// Checks if the given username is unique. Returns true if not taken +func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { + query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` + rows, err := conn.Query(query, username) + if err != nil { + return false, errors.Wrap(err, "conn.Query") + } + defer rows.Close() + taken := rows.Next() + return !taken, nil +} diff --git a/db/users.go b/db/users.go deleted file mode 100644 index 23a207e..0000000 --- a/db/users.go +++ /dev/null @@ -1,125 +0,0 @@ -package db - -import ( - "database/sql" - - "github.com/pkg/errors" - "golang.org/x/crypto/bcrypt" -) - -type User struct { - ID int // Integer ID (index primary key) - Username string // Username (unique) - Password_hash string // Bcrypt password hash - Created_at int64 // Epoch timestamp when the user was added to the database -} - -// Uses bcrypt to set the users Password_hash from the given password -func (user *User) SetPassword(conn *sql.DB, password string) error { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "bcrypt.GenerateFromPassword") - } - user.Password_hash = string(hashedPassword) - query := `UPDATE users SET password_hash = ? WHERE id = ?` - result, err := conn.Exec(query, user.Password_hash, user.ID) - if err != nil { - return errors.Wrap(err, "conn.Exec") - } - ra, err := result.RowsAffected() - if err != nil { - return errors.Wrap(err, "result.RowsAffected") - } - if ra != 1 { - return errors.New("Password was not updated") - } - return nil -} - -// Uses bcrypt to check if the given password matches the users Password_hash -func (user *User) CheckPassword(password string) error { - err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password)) - if err != nil { - return errors.Wrap(err, "bcrypt.CompareHashAndPassword") - } - return nil -} - -// Creates a new user in the database and returns a pointer -func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) { - query := `INSERT INTO users (username) VALUES (?)` - _, err := conn.Exec(query, username) - if err != nil { - return nil, errors.Wrap(err, "conn.Exec") - } - user, err := GetUserFromUsername(conn, username) - if err != nil { - return nil, errors.Wrap(err, "GetUserFromUsername") - } - err = user.SetPassword(conn, password) - if err != nil { - return nil, errors.Wrap(err, "user.SetPassword") - } - return user, nil -} - -// Queries the database for a user matching the given username. -// Query is case insensitive -func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { - query := `SELECT id, username, password_hash, created_at FROM users - WHERE username = ? COLLATE NOCASE` - rows, err := conn.Query(query, username) - if err != nil { - return nil, errors.Wrap(err, "conn.Query") - } - defer rows.Close() - var user User - for rows.Next() { - err := rows.Scan( - &user.ID, - &user.Username, - &user.Password_hash, - &user.Created_at, - ) - if err != nil { - return nil, errors.Wrap(err, "rows.Scan") - } - } - return &user, nil -} - -// Queries the database for a user matching the given ID. -func GetUserFromID(conn *sql.DB, id int) (*User, error) { - query := `SELECT id, username, password_hash, created_at FROM users - WHERE id = ?` - rows, err := conn.Query(query, id) - if err != nil { - return nil, errors.Wrap(err, "conn.Query") - } - defer rows.Close() - var user User - for rows.Next() { - err := rows.Scan( - &user.ID, - &user.Username, - &user.Password_hash, - &user.Created_at, - ) - if err != nil { - return nil, errors.Wrap(err, "rows.Scan") - } - } - return &user, nil -} - -// Checks if the given username is unique. Returns true if not taken -func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) { - query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` - rows, err := conn.Query(query, username) - if err != nil { - return false, errors.Wrap(err, "conn.Query") - } - defer rows.Close() - taken := rows.Next() - return !taken, nil -} diff --git a/schema.sql b/schema.sql index 80f9970..986d312 100644 --- a/schema.sql +++ b/schema.sql @@ -8,7 +8,8 @@ CREATE TABLE IF NOT EXISTS "users" ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, password_hash TEXT DEFAULT "", - created_at INTEGER DEFAULT (unixepoch()) + created_at INTEGER DEFAULT (unixepoch()), + bio TEXT DEFAULT "" ) STRICT; CREATE TRIGGER cleanup_expired_tokens AFTER INSERT ON jwtblacklist