Update authentication, reauth, logout to use new transactions
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
@@ -18,17 +17,19 @@ type SafeConn struct {
|
|||||||
mux sync.RWMutex
|
mux sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MakeSafe(db *sql.DB) *SafeConn {
|
||||||
|
return &SafeConn{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
// Extends sql.Tx for use with SafeConn
|
// Extends sql.Tx for use with SafeConn
|
||||||
type SafeTX struct {
|
type SafeTX struct {
|
||||||
tx *sql.Tx
|
tx *sql.Tx
|
||||||
sc *SafeConn
|
sc *SafeConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Starts a new transaction, waiting up to 10 seconds if the database is locked
|
// Starts a new transaction based on the current context. Will cancel if
|
||||||
|
// the context is closed/cancelled/done
|
||||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
lockAcquired := make(chan struct{})
|
lockAcquired := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
conn.mux.RLock()
|
conn.mux.RLock()
|
||||||
@@ -119,7 +120,18 @@ func (conn *SafeConn) Close() error {
|
|||||||
return conn.db.Close()
|
return conn.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a database connection handle for the Turso DB
|
// Returns a database connection handle for the DB
|
||||||
|
func OldConnectToDatabase(dbName string) (*sql.DB, error) {
|
||||||
|
file := fmt.Sprintf("file:%s.db", dbName)
|
||||||
|
db, err := sql.Open("sqlite3", file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "sql.Open")
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a database connection handle for the DB
|
||||||
func ConnectToDatabase(dbName string) (*SafeConn, error) {
|
func ConnectToDatabase(dbName string) (*SafeConn, error) {
|
||||||
file := fmt.Sprintf("file:%s.db", dbName)
|
file := fmt.Sprintf("file:%s.db", dbName)
|
||||||
db, err := sql.Open("sqlite3", file)
|
db, err := sql.Open("sqlite3", file)
|
||||||
@@ -127,7 +139,7 @@ func ConnectToDatabase(dbName string) (*SafeConn, error) {
|
|||||||
return nil, errors.Wrap(err, "sql.Open")
|
return nil, errors.Wrap(err, "sql.Open")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := &SafeConn{db: db}
|
conn := MakeSafe(db)
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
@@ -26,7 +27,7 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fetches data from the users table using "WHERE column = 'value'"
|
// Fetches data from the users table using "WHERE column = 'value'"
|
||||||
func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
|
func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
|
||||||
query := fmt.Sprintf(
|
query := fmt.Sprintf(
|
||||||
`SELECT
|
`SELECT
|
||||||
id,
|
id,
|
||||||
@@ -45,6 +46,31 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e
|
|||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetches data from the users table using "WHERE column = 'value'"
|
||||||
|
func fetchUserData(
|
||||||
|
ctx context.Context,
|
||||||
|
tx *SafeTX,
|
||||||
|
column string,
|
||||||
|
value interface{},
|
||||||
|
) (*sql.Rows, error) {
|
||||||
|
query := fmt.Sprintf(
|
||||||
|
`SELECT
|
||||||
|
id,
|
||||||
|
username,
|
||||||
|
password_hash,
|
||||||
|
created_at,
|
||||||
|
bio
|
||||||
|
FROM users
|
||||||
|
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
||||||
|
column,
|
||||||
|
)
|
||||||
|
rows, err := tx.Query(ctx, query, value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tx.Query")
|
||||||
|
}
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Scan the next row into the provided user pointer. Calls rows.Next() and
|
// Scan the next row into the provided user pointer. Calls rows.Next() and
|
||||||
// assumes only row in the result. Providing a rows object with more than 1
|
// assumes only row in the result. Providing a rows object with more than 1
|
||||||
// row may result in undefined behaviour.
|
// row may result in undefined behaviour.
|
||||||
@@ -67,7 +93,7 @@ func scanUserRow(user *User, rows *sql.Rows) error {
|
|||||||
// Queries the database for a user matching the given username.
|
// Queries the database for a user matching the given username.
|
||||||
// Query is case insensitive
|
// Query is case insensitive
|
||||||
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
||||||
rows, err := fetchUserData(conn, "username", username)
|
rows, err := oldfetchUserData(conn, "username", username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "fetchUserData")
|
return nil, errors.Wrap(err, "fetchUserData")
|
||||||
}
|
}
|
||||||
@@ -81,8 +107,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Queries the database for a user matching the given ID.
|
// Queries the database for a user matching the given ID.
|
||||||
func GetUserFromID(conn *sql.DB, id int) (*User, error) {
|
func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
||||||
rows, err := fetchUserData(conn, "id", id)
|
rows, err := fetchUserData(ctx, tx, "id", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "fetchUserData")
|
return nil, errors.Wrap(err, "fetchUserData")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,41 +1,79 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/config"
|
||||||
"projectreshoot/cookies"
|
"projectreshoot/cookies"
|
||||||
|
"projectreshoot/db"
|
||||||
"projectreshoot/jwt"
|
"projectreshoot/jwt"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func revokeAccess(
|
||||||
|
config *config.Config,
|
||||||
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
|
atStr string,
|
||||||
|
) error {
|
||||||
|
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "Token is expired") ||
|
||||||
|
strings.Contains(err.Error(), "Token has been revoked") {
|
||||||
|
return nil // Token is expired, dont need to revoke it
|
||||||
|
}
|
||||||
|
return errors.Wrap(err, "jwt.ParseAccessToken")
|
||||||
|
}
|
||||||
|
err = jwt.RevokeToken(ctx, tx, aT)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "jwt.RevokeToken")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func revokeRefresh(
|
||||||
|
config *config.Config,
|
||||||
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
|
rtStr string,
|
||||||
|
) error {
|
||||||
|
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "Token is expired") ||
|
||||||
|
strings.Contains(err.Error(), "Token has been revoked") {
|
||||||
|
return nil // Token is expired, dont need to revoke it
|
||||||
|
}
|
||||||
|
return errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||||
|
}
|
||||||
|
err = jwt.RevokeToken(ctx, tx, rT)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "jwt.RevokeToken")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Retrieve and revoke the user's tokens
|
// Retrieve and revoke the user's tokens
|
||||||
func revokeTokens(
|
func revokeTokens(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) error {
|
) error {
|
||||||
// get the tokens from the cookies
|
// get the tokens from the cookies
|
||||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||||
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.ParseAccessToken")
|
|
||||||
}
|
|
||||||
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.ParseRefreshToken")
|
|
||||||
}
|
|
||||||
// revoke the refresh token first as the access token expires quicker
|
// revoke the refresh token first as the access token expires quicker
|
||||||
// only matters if there is an error revoking the tokens
|
// only matters if there is an error revoking the tokens
|
||||||
err = jwt.RevokeToken(conn, rT)
|
err := revokeRefresh(config, ctx, tx, rtStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
return errors.Wrap(err, "revokeRefresh")
|
||||||
}
|
}
|
||||||
err = jwt.RevokeToken(conn, aT)
|
err = revokeAccess(config, ctx, tx, atStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
return errors.Wrap(err, "revokeAccess")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -44,19 +82,24 @@ func revokeTokens(
|
|||||||
func HandleLogout(
|
func HandleLogout(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
conn *sql.DB,
|
conn *db.SafeConn,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
err := revokeTokens(config, conn, r)
|
WithTransaction(w, r, logger, conn,
|
||||||
if err != nil {
|
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||||
logger.Error().Err(err).Msg("Error occured on user logout")
|
err := revokeTokens(config, ctx, tx, r)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
if err != nil {
|
||||||
return
|
tx.Rollback()
|
||||||
}
|
logger.Error().Err(err).Msg("Error occured on user logout")
|
||||||
cookies.DeleteCookie(w, "access", "/")
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
cookies.DeleteCookie(w, "refresh", "/")
|
return
|
||||||
w.Header().Set("HX-Redirect", "/login")
|
}
|
||||||
|
tx.Commit()
|
||||||
|
cookies.DeleteCookie(w, "access", "/")
|
||||||
|
cookies.DeleteCookie(w, "refresh", "/")
|
||||||
|
w.Header().Set("HX-Redirect", "/login")
|
||||||
|
})
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/config"
|
||||||
"projectreshoot/contexts"
|
"projectreshoot/contexts"
|
||||||
"projectreshoot/cookies"
|
"projectreshoot/cookies"
|
||||||
|
"projectreshoot/db"
|
||||||
"projectreshoot/jwt"
|
"projectreshoot/jwt"
|
||||||
"projectreshoot/view/component/form"
|
"projectreshoot/view/component/form"
|
||||||
|
|
||||||
@@ -17,16 +18,17 @@ import (
|
|||||||
// Get the tokens from the request
|
// Get the tokens from the request
|
||||||
func getTokens(
|
func getTokens(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
||||||
// get the existing tokens from the cookies
|
// get the existing tokens from the cookies
|
||||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||||
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken")
|
return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken")
|
||||||
}
|
}
|
||||||
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||||
}
|
}
|
||||||
@@ -35,15 +37,16 @@ func getTokens(
|
|||||||
|
|
||||||
// Revoke the given token pair
|
// Revoke the given token pair
|
||||||
func revokeTokenPair(
|
func revokeTokenPair(
|
||||||
conn *sql.DB,
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
aT *jwt.AccessToken,
|
aT *jwt.AccessToken,
|
||||||
rT *jwt.RefreshToken,
|
rT *jwt.RefreshToken,
|
||||||
) error {
|
) error {
|
||||||
err := jwt.RevokeToken(conn, aT)
|
err := jwt.RevokeToken(ctx, tx, aT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
return errors.Wrap(err, "jwt.RevokeToken")
|
||||||
}
|
}
|
||||||
err = jwt.RevokeToken(conn, rT)
|
err = jwt.RevokeToken(ctx, tx, rT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
return errors.Wrap(err, "jwt.RevokeToken")
|
||||||
}
|
}
|
||||||
@@ -53,11 +56,12 @@ func revokeTokenPair(
|
|||||||
// Issue new tokens for the user, invalidating the old ones
|
// Issue new tokens for the user, invalidating the old ones
|
||||||
func refreshTokens(
|
func refreshTokens(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) error {
|
) error {
|
||||||
aT, rT, err := getTokens(config, conn, r)
|
aT, rT, err := getTokens(config, ctx, tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "getTokens")
|
return errors.Wrap(err, "getTokens")
|
||||||
}
|
}
|
||||||
@@ -71,7 +75,7 @@ func refreshTokens(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "cookies.SetTokenCookies")
|
return errors.Wrap(err, "cookies.SetTokenCookies")
|
||||||
}
|
}
|
||||||
err = revokeTokenPair(conn, aT, rT)
|
err = revokeTokenPair(ctx, tx, aT, rT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "revokeTokenPair")
|
return errors.Wrap(err, "revokeTokenPair")
|
||||||
}
|
}
|
||||||
@@ -97,23 +101,29 @@ func validatePassword(
|
|||||||
func HandleReauthenticate(
|
func HandleReauthenticate(
|
||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
conn *db.SafeConn,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
err := validatePassword(r)
|
WithTransaction(w, r, logger, conn,
|
||||||
if err != nil {
|
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(445)
|
err := validatePassword(r)
|
||||||
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
|
if err != nil {
|
||||||
return
|
tx.Rollback()
|
||||||
}
|
w.WriteHeader(445)
|
||||||
err = refreshTokens(config, conn, w, r)
|
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
|
||||||
if err != nil {
|
return
|
||||||
logger.Error().Err(err).Msg("Failed to refresh user tokens")
|
}
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
err = refreshTokens(config, ctx, tx, w, r)
|
||||||
return
|
if err != nil {
|
||||||
}
|
tx.Rollback()
|
||||||
w.WriteHeader(http.StatusOK)
|
logger.Error().Err(err).Msg("Failed to refresh user tokens")
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
47
handlers/withtransaction.go
Normal file
47
handlers/withtransaction.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"projectreshoot/db"
|
||||||
|
"projectreshoot/view/page"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A helper function to create a transaction with a cancellable context.
|
||||||
|
func WithTransaction(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
conn *db.SafeConn,
|
||||||
|
handler func(
|
||||||
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
),
|
||||||
|
) {
|
||||||
|
// Create a cancellable context from the request context
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Start the transaction
|
||||||
|
tx, err := conn.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
logger.Warn().Err(err).Msg("Request failed to start a transaction")
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
page.Error(
|
||||||
|
"503",
|
||||||
|
http.StatusText(503),
|
||||||
|
"This service is currently unavailable. It could be down for maintenance").
|
||||||
|
Render(r.Context(), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass the context and transaction to the handler
|
||||||
|
handler(ctx, tx, w, r)
|
||||||
|
}
|
||||||
13
jwt/parse.go
13
jwt/parse.go
@@ -1,11 +1,12 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/config"
|
||||||
|
"projectreshoot/db"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -17,7 +18,8 @@ import (
|
|||||||
// has the correct scope.
|
// has the correct scope.
|
||||||
func ParseAccessToken(
|
func ParseAccessToken(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
tokenString string,
|
tokenString string,
|
||||||
) (*AccessToken, error) {
|
) (*AccessToken, error) {
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
@@ -74,7 +76,7 @@ func ParseAccessToken(
|
|||||||
Scope: scope,
|
Scope: scope,
|
||||||
}
|
}
|
||||||
|
|
||||||
valid, err := CheckTokenNotRevoked(conn, token)
|
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||||
}
|
}
|
||||||
@@ -89,7 +91,8 @@ func ParseAccessToken(
|
|||||||
// has the correct scope.
|
// has the correct scope.
|
||||||
func ParseRefreshToken(
|
func ParseRefreshToken(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
tokenString string,
|
tokenString string,
|
||||||
) (*RefreshToken, error) {
|
) (*RefreshToken, error) {
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
@@ -141,7 +144,7 @@ func ParseRefreshToken(
|
|||||||
Scope: scope,
|
Scope: scope,
|
||||||
}
|
}
|
||||||
|
|
||||||
valid, err := CheckTokenNotRevoked(conn, token)
|
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,32 +1,34 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
|
"projectreshoot/db"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Revoke a token by adding it to the database
|
// Revoke a token by adding it to the database
|
||||||
func RevokeToken(conn *sql.DB, t Token) error {
|
func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error {
|
||||||
jti := t.GetJTI()
|
jti := t.GetJTI()
|
||||||
exp := t.GetEXP()
|
exp := t.GetEXP()
|
||||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||||
_, err := conn.Exec(query, jti, exp)
|
_, err := tx.Exec(ctx, query, jti, exp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "conn.Exec")
|
return errors.Wrap(err, "tx.Exec")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if a token has been revoked. Returns true if not revoked.
|
// Check if a token has been revoked. Returns true if not revoked.
|
||||||
func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
|
func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) {
|
||||||
jti := t.GetJTI()
|
jti := t.GetJTI()
|
||||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||||
rows, err := conn.Query(query, jti)
|
rows, err := tx.Query(ctx, query, jti)
|
||||||
defer rows.Close()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "conn.Exec")
|
return false, errors.Wrap(err, "tx.Query")
|
||||||
}
|
}
|
||||||
|
// NOTE: rows.Close()
|
||||||
|
defer rows.Close()
|
||||||
revoked := rows.Next()
|
revoked := rows.Next()
|
||||||
return !revoked, nil
|
return !revoked, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"projectreshoot/db"
|
"projectreshoot/db"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -12,7 +12,7 @@ type Token interface {
|
|||||||
GetJTI() uuid.UUID
|
GetJTI() uuid.UUID
|
||||||
GetEXP() int64
|
GetEXP() int64
|
||||||
GetScope() string
|
GetScope() string
|
||||||
GetUser(conn *sql.DB) (*db.User, error)
|
GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Access token
|
// Access token
|
||||||
@@ -38,15 +38,15 @@ type RefreshToken struct {
|
|||||||
Scope string // Should be "refresh"
|
Scope string // Should be "refresh"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) {
|
func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
||||||
user, err := db.GetUserFromID(conn, a.SUB)
|
user, err := db.GetUserFromID(ctx, tx, a.SUB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||||
}
|
}
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) {
|
func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
||||||
user, err := db.GetUserFromID(conn, r.SUB)
|
user, err := db.GetUserFromID(ctx, tx, r.SUB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||||
}
|
}
|
||||||
|
|||||||
9
main.go
9
main.go
@@ -77,6 +77,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|||||||
return errors.Wrap(err, "logging.GetLogger")
|
return errors.Wrap(err, "logging.GetLogger")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldconn, err := db.OldConnectToDatabase(config.DBName)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "db.ConnectToDatabase")
|
||||||
|
}
|
||||||
|
defer oldconn.Close()
|
||||||
conn, err := db.ConnectToDatabase(config.DBName)
|
conn, err := db.ConnectToDatabase(config.DBName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
return errors.Wrap(err, "db.ConnectToDatabase")
|
||||||
@@ -88,7 +93,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|||||||
return errors.Wrap(err, "getStaticFiles")
|
return errors.Wrap(err, "getStaticFiles")
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := server.NewServer(config, logger, conn, &staticFS)
|
srv := server.NewServer(config, logger, oldconn, conn, &staticFS)
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
Addr: net.JoinHostPort(config.Host, config.Port),
|
||||||
Handler: srv,
|
Handler: srv,
|
||||||
@@ -99,7 +104,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|||||||
|
|
||||||
// Runs function for testing in dev if --test flag true
|
// Runs function for testing in dev if --test flag true
|
||||||
if args["test"] == "true" {
|
if args["test"] == "true" {
|
||||||
test(config, logger, conn, httpServer)
|
test(config, logger, oldconn, httpServer)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"projectreshoot/contexts"
|
"projectreshoot/contexts"
|
||||||
"projectreshoot/cookies"
|
"projectreshoot/cookies"
|
||||||
"projectreshoot/db"
|
"projectreshoot/db"
|
||||||
|
"projectreshoot/handlers"
|
||||||
"projectreshoot/jwt"
|
"projectreshoot/jwt"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -18,14 +19,15 @@ import (
|
|||||||
// Attempt to use a valid refresh token to generate a new token pair
|
// Attempt to use a valid refresh token to generate a new token pair
|
||||||
func refreshAuthTokens(
|
func refreshAuthTokens(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
ref *jwt.RefreshToken,
|
ref *jwt.RefreshToken,
|
||||||
) (*db.User, error) {
|
) (*db.User, error) {
|
||||||
user, err := ref.GetUser(conn)
|
user, err := ref.GetUser(ctx, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "rT.GetUser")
|
return nil, errors.Wrap(err, "ref.GetUser")
|
||||||
}
|
}
|
||||||
|
|
||||||
rememberMe := map[string]bool{
|
rememberMe := map[string]bool{
|
||||||
@@ -39,7 +41,7 @@ func refreshAuthTokens(
|
|||||||
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
|
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
|
||||||
}
|
}
|
||||||
// New tokens sent, revoke the used refresh token
|
// New tokens sent, revoke the used refresh token
|
||||||
err = jwt.RevokeToken(conn, ref)
|
err = jwt.RevokeToken(ctx, tx, ref)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "jwt.RevokeToken")
|
return nil, errors.Wrap(err, "jwt.RevokeToken")
|
||||||
}
|
}
|
||||||
@@ -50,22 +52,23 @@ func refreshAuthTokens(
|
|||||||
// Check the cookies for token strings and attempt to authenticate them
|
// Check the cookies for token strings and attempt to authenticate them
|
||||||
func getAuthenticatedUser(
|
func getAuthenticatedUser(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
ctx context.Context,
|
||||||
|
tx *db.SafeTX,
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (*contexts.AuthenticatedUser, error) {
|
) (*contexts.AuthenticatedUser, error) {
|
||||||
// Get token strings from cookies
|
// Get token strings from cookies
|
||||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||||
// Attempt to parse the access token
|
// Attempt to parse the access token
|
||||||
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Access token invalid, attempt to parse refresh token
|
// Access token invalid, attempt to parse refresh token
|
||||||
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||||
}
|
}
|
||||||
// Refresh token valid, attempt to get a new token pair
|
// Refresh token valid, attempt to get a new token pair
|
||||||
user, err := refreshAuthTokens(config, conn, w, r, rT)
|
user, err := refreshAuthTokens(config, ctx, tx, w, r, rT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "refreshAuthTokens")
|
return nil, errors.Wrap(err, "refreshAuthTokens")
|
||||||
}
|
}
|
||||||
@@ -77,9 +80,9 @@ func getAuthenticatedUser(
|
|||||||
return &authUser, nil
|
return &authUser, nil
|
||||||
}
|
}
|
||||||
// Access token valid
|
// Access token valid
|
||||||
user, err := aT.GetUser(conn)
|
user, err := aT.GetUser(ctx, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "rT.GetUser")
|
return nil, errors.Wrap(err, "aT.GetUser")
|
||||||
}
|
}
|
||||||
authUser := contexts.AuthenticatedUser{
|
authUser := contexts.AuthenticatedUser{
|
||||||
User: user,
|
User: user,
|
||||||
@@ -93,22 +96,36 @@ func getAuthenticatedUser(
|
|||||||
func Authentication(
|
func Authentication(
|
||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
conn *db.SafeConn,
|
||||||
next http.Handler,
|
next http.Handler,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
user, err := getAuthenticatedUser(config, conn, w, r)
|
if r.URL.Path == "/static/css/output.css" ||
|
||||||
if err != nil {
|
r.URL.Path == "/static/favicon.ico" {
|
||||||
// User auth failed, delete the cookies to avoid repeat requests
|
next.ServeHTTP(w, r)
|
||||||
cookies.DeleteCookie(w, "access", "/")
|
return
|
||||||
cookies.DeleteCookie(w, "refresh", "/")
|
|
||||||
logger.Debug().
|
|
||||||
Str("remote_addr", r.RemoteAddr).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to authenticate user")
|
|
||||||
}
|
}
|
||||||
ctx := contexts.SetUser(r.Context(), user)
|
handlers.WithTransaction(w, r, logger, conn,
|
||||||
newReq := r.WithContext(ctx)
|
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||||
next.ServeHTTP(w, newReq)
|
tx, err := conn.Begin(ctx)
|
||||||
|
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
// User auth failed, delete the cookies to avoid repeat requests
|
||||||
|
cookies.DeleteCookie(w, "access", "/")
|
||||||
|
cookies.DeleteCookie(w, "refresh", "/")
|
||||||
|
logger.Debug().
|
||||||
|
Str("remote_addr", r.RemoteAddr).
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to authenticate user")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
uctx := contexts.SetUser(r.Context(), user)
|
||||||
|
newReq := r.WithContext(uctx)
|
||||||
|
next.ServeHTTP(w, newReq)
|
||||||
|
},
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func TestAuthenticationMiddleware(t *testing.T) {
|
|||||||
cfg, err := tests.TestConfig()
|
cfg, err := tests.TestConfig()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
logger := tests.NilLogger()
|
logger := tests.NilLogger()
|
||||||
conn, err := tests.SetupTestDB()
|
conn, err := tests.SetupTestDB(t.Context())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
defer tests.DeleteTestDB()
|
defer tests.DeleteTestDB()
|
||||||
|
|||||||
@@ -23,9 +23,14 @@ func (w *wrappedWriter) WriteHeader(statusCode int) {
|
|||||||
// Middleware to add logs to console with details of the request
|
// Middleware to add logs to console with details of the request
|
||||||
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
|
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/static/css/output.css" ||
|
||||||
|
r.URL.Path == "/static/favicon.ico" {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
start, err := contexts.GetStartTime(r.Context())
|
start, err := contexts.GetStartTime(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Handle failure here. internal server error maybe
|
// TODO: Handle failure here. internal server error maybe
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
wrapped := &wrappedWriter{
|
wrapped := &wrappedWriter{
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func TestPageLoginRequired(t *testing.T) {
|
|||||||
cfg, err := tests.TestConfig()
|
cfg, err := tests.TestConfig()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
logger := tests.NilLogger()
|
logger := tests.NilLogger()
|
||||||
conn, err := tests.SetupTestDB()
|
conn, err := tests.SetupTestDB(t.Context())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
defer tests.DeleteTestDB()
|
defer tests.DeleteTestDB()
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func TestActionReauthRequired(t *testing.T) {
|
|||||||
cfg, err := tests.TestConfig()
|
cfg, err := tests.TestConfig()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
logger := tests.NilLogger()
|
logger := tests.NilLogger()
|
||||||
conn, err := tests.SetupTestDB()
|
conn, err := tests.SetupTestDB(t.Context())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
defer tests.DeleteTestDB()
|
defer tests.DeleteTestDB()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/config"
|
||||||
|
"projectreshoot/db"
|
||||||
"projectreshoot/handlers"
|
"projectreshoot/handlers"
|
||||||
"projectreshoot/middleware"
|
"projectreshoot/middleware"
|
||||||
"projectreshoot/view/page"
|
"projectreshoot/view/page"
|
||||||
@@ -17,7 +18,8 @@ func addRoutes(
|
|||||||
mux *http.ServeMux,
|
mux *http.ServeMux,
|
||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
conn *sql.DB,
|
oldconn *sql.DB,
|
||||||
|
conn *db.SafeConn,
|
||||||
staticFS *http.FileSystem,
|
staticFS *http.FileSystem,
|
||||||
) {
|
) {
|
||||||
// Health check
|
// Health check
|
||||||
@@ -42,7 +44,7 @@ func addRoutes(
|
|||||||
handlers.HandleLoginRequest(
|
handlers.HandleLoginRequest(
|
||||||
config,
|
config,
|
||||||
logger,
|
logger,
|
||||||
conn,
|
oldconn,
|
||||||
)))
|
)))
|
||||||
|
|
||||||
// Register page and handlers
|
// Register page and handlers
|
||||||
@@ -55,7 +57,7 @@ func addRoutes(
|
|||||||
handlers.HandleRegisterRequest(
|
handlers.HandleRegisterRequest(
|
||||||
config,
|
config,
|
||||||
logger,
|
logger,
|
||||||
conn,
|
oldconn,
|
||||||
)))
|
)))
|
||||||
|
|
||||||
// Logout
|
// Logout
|
||||||
@@ -85,17 +87,17 @@ func addRoutes(
|
|||||||
mux.Handle("POST /change-username",
|
mux.Handle("POST /change-username",
|
||||||
middleware.RequiresLogin(
|
middleware.RequiresLogin(
|
||||||
middleware.RequiresFresh(
|
middleware.RequiresFresh(
|
||||||
handlers.HandleChangeUsername(logger, conn),
|
handlers.HandleChangeUsername(logger, oldconn),
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
mux.Handle("POST /change-bio",
|
mux.Handle("POST /change-bio",
|
||||||
middleware.RequiresLogin(
|
middleware.RequiresLogin(
|
||||||
handlers.HandleChangeBio(logger, conn),
|
handlers.HandleChangeBio(logger, oldconn),
|
||||||
))
|
))
|
||||||
mux.Handle("POST /change-password",
|
mux.Handle("POST /change-password",
|
||||||
middleware.RequiresLogin(
|
middleware.RequiresLogin(
|
||||||
middleware.RequiresFresh(
|
middleware.RequiresFresh(
|
||||||
handlers.HandleChangePassword(logger, conn),
|
handlers.HandleChangePassword(logger, oldconn),
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/config"
|
||||||
|
"projectreshoot/db"
|
||||||
"projectreshoot/middleware"
|
"projectreshoot/middleware"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
@@ -14,7 +15,8 @@ import (
|
|||||||
func NewServer(
|
func NewServer(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
conn *sql.DB,
|
oldconn *sql.DB,
|
||||||
|
conn *db.SafeConn,
|
||||||
staticFS *http.FileSystem,
|
staticFS *http.FileSystem,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
@@ -22,6 +24,7 @@ func NewServer(
|
|||||||
mux,
|
mux,
|
||||||
logger,
|
logger,
|
||||||
config,
|
config,
|
||||||
|
oldconn,
|
||||||
conn,
|
conn,
|
||||||
staticFS,
|
staticFS,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
package tests
|
package tests
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"projectreshoot/db"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
@@ -32,11 +34,16 @@ func findSQLFile(filename string) (string, error) {
|
|||||||
|
|
||||||
// SetupTestDB initializes a test SQLite database with mock data
|
// SetupTestDB initializes a test SQLite database with mock data
|
||||||
// Make sure to call DeleteTestDB when finished to cleanup
|
// Make sure to call DeleteTestDB when finished to cleanup
|
||||||
func SetupTestDB() (*sql.DB, error) {
|
func SetupTestDB(ctx context.Context) (*db.SafeConn, error) {
|
||||||
conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
|
dbfile, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "sql.Open")
|
return nil, errors.Wrap(err, "sql.Open")
|
||||||
}
|
}
|
||||||
|
conn := db.MakeSafe(dbfile)
|
||||||
|
tx, err := conn.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "conn.Begin")
|
||||||
|
}
|
||||||
// Setup the test database
|
// Setup the test database
|
||||||
schemaPath, err := findSQLFile("schema.sql")
|
schemaPath, err := findSQLFile("schema.sql")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -49,9 +56,10 @@ func SetupTestDB() (*sql.DB, error) {
|
|||||||
}
|
}
|
||||||
schemaSQL := string(sqlBytes)
|
schemaSQL := string(sqlBytes)
|
||||||
|
|
||||||
_, err = conn.Exec(schemaSQL)
|
_, err = tx.Exec(ctx, schemaSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "conn.Exec")
|
tx.Rollback()
|
||||||
|
return nil, errors.Wrap(err, "tx.Exec")
|
||||||
}
|
}
|
||||||
// Load the test data
|
// Load the test data
|
||||||
dataPath, err := findSQLFile("testdata.sql")
|
dataPath, err := findSQLFile("testdata.sql")
|
||||||
@@ -64,10 +72,12 @@ func SetupTestDB() (*sql.DB, error) {
|
|||||||
}
|
}
|
||||||
dataSQL := string(sqlBytes)
|
dataSQL := string(sqlBytes)
|
||||||
|
|
||||||
_, err = conn.Exec(dataSQL)
|
_, err = tx.Exec(ctx, dataSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "conn.Exec")
|
tx.Rollback()
|
||||||
|
return nil, errors.Wrap(err, "tx.Exec")
|
||||||
}
|
}
|
||||||
|
tx.Commit()
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user