Update authentication, reauth, logout to use new transactions
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@@ -18,17 +17,19 @@ type SafeConn struct {
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func MakeSafe(db *sql.DB) *SafeConn {
|
||||
return &SafeConn{db: db}
|
||||
}
|
||||
|
||||
// Extends sql.Tx for use with SafeConn
|
||||
type SafeTX struct {
|
||||
tx *sql.Tx
|
||||
sc *SafeConn
|
||||
}
|
||||
|
||||
// Starts a new transaction, waiting up to 10 seconds if the database is locked
|
||||
// Starts a new transaction based on the current context. Will cancel if
|
||||
// the context is closed/cancelled/done
|
||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
lockAcquired := make(chan struct{})
|
||||
go func() {
|
||||
conn.mux.RLock()
|
||||
@@ -119,7 +120,18 @@ func (conn *SafeConn) Close() error {
|
||||
return conn.db.Close()
|
||||
}
|
||||
|
||||
// Returns a database connection handle for the Turso DB
|
||||
// Returns a database connection handle for the DB
|
||||
func OldConnectToDatabase(dbName string) (*sql.DB, error) {
|
||||
file := fmt.Sprintf("file:%s.db", dbName)
|
||||
db, err := sql.Open("sqlite3", file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Returns a database connection handle for the DB
|
||||
func ConnectToDatabase(dbName string) (*SafeConn, error) {
|
||||
file := fmt.Sprintf("file:%s.db", dbName)
|
||||
db, err := sql.Open("sqlite3", file)
|
||||
@@ -127,7 +139,7 @@ func ConnectToDatabase(dbName string) (*SafeConn, error) {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
|
||||
conn := &SafeConn{db: db}
|
||||
conn := MakeSafe(db)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
@@ -26,7 +27,7 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error
|
||||
}
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
|
||||
func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT
|
||||
id,
|
||||
@@ -45,6 +46,31 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func fetchUserData(
|
||||
ctx context.Context,
|
||||
tx *SafeTX,
|
||||
column string,
|
||||
value interface{},
|
||||
) (*sql.Rows, error) {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT
|
||||
id,
|
||||
username,
|
||||
password_hash,
|
||||
created_at,
|
||||
bio
|
||||
FROM users
|
||||
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
||||
column,
|
||||
)
|
||||
rows, err := tx.Query(ctx, query, value)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Scan the next row into the provided user pointer. Calls rows.Next() and
|
||||
// assumes only row in the result. Providing a rows object with more than 1
|
||||
// row may result in undefined behaviour.
|
||||
@@ -67,7 +93,7 @@ func scanUserRow(user *User, rows *sql.Rows) error {
|
||||
// Queries the database for a user matching the given username.
|
||||
// Query is case insensitive
|
||||
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
||||
rows, err := fetchUserData(conn, "username", username)
|
||||
rows, err := oldfetchUserData(conn, "username", username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetchUserData")
|
||||
}
|
||||
@@ -81,8 +107,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
||||
}
|
||||
|
||||
// Queries the database for a user matching the given ID.
|
||||
func GetUserFromID(conn *sql.DB, id int) (*User, error) {
|
||||
rows, err := fetchUserData(conn, "id", id)
|
||||
func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
||||
rows, err := fetchUserData(ctx, tx, "id", id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetchUserData")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user