Update authentication, reauth, logout to use new transactions
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@@ -18,17 +17,19 @@ type SafeConn struct {
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func MakeSafe(db *sql.DB) *SafeConn {
|
||||
return &SafeConn{db: db}
|
||||
}
|
||||
|
||||
// Extends sql.Tx for use with SafeConn
|
||||
type SafeTX struct {
|
||||
tx *sql.Tx
|
||||
sc *SafeConn
|
||||
}
|
||||
|
||||
// Starts a new transaction, waiting up to 10 seconds if the database is locked
|
||||
// Starts a new transaction based on the current context. Will cancel if
|
||||
// the context is closed/cancelled/done
|
||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
lockAcquired := make(chan struct{})
|
||||
go func() {
|
||||
conn.mux.RLock()
|
||||
@@ -119,7 +120,18 @@ func (conn *SafeConn) Close() error {
|
||||
return conn.db.Close()
|
||||
}
|
||||
|
||||
// Returns a database connection handle for the Turso DB
|
||||
// Returns a database connection handle for the DB
|
||||
func OldConnectToDatabase(dbName string) (*sql.DB, error) {
|
||||
file := fmt.Sprintf("file:%s.db", dbName)
|
||||
db, err := sql.Open("sqlite3", file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Returns a database connection handle for the DB
|
||||
func ConnectToDatabase(dbName string) (*SafeConn, error) {
|
||||
file := fmt.Sprintf("file:%s.db", dbName)
|
||||
db, err := sql.Open("sqlite3", file)
|
||||
@@ -127,7 +139,7 @@ func ConnectToDatabase(dbName string) (*SafeConn, error) {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
|
||||
conn := &SafeConn{db: db}
|
||||
conn := MakeSafe(db)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
@@ -26,7 +27,7 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error
|
||||
}
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
|
||||
func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT
|
||||
id,
|
||||
@@ -45,6 +46,31 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func fetchUserData(
|
||||
ctx context.Context,
|
||||
tx *SafeTX,
|
||||
column string,
|
||||
value interface{},
|
||||
) (*sql.Rows, error) {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT
|
||||
id,
|
||||
username,
|
||||
password_hash,
|
||||
created_at,
|
||||
bio
|
||||
FROM users
|
||||
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
||||
column,
|
||||
)
|
||||
rows, err := tx.Query(ctx, query, value)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Scan the next row into the provided user pointer. Calls rows.Next() and
|
||||
// assumes only row in the result. Providing a rows object with more than 1
|
||||
// row may result in undefined behaviour.
|
||||
@@ -67,7 +93,7 @@ func scanUserRow(user *User, rows *sql.Rows) error {
|
||||
// Queries the database for a user matching the given username.
|
||||
// Query is case insensitive
|
||||
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
||||
rows, err := fetchUserData(conn, "username", username)
|
||||
rows, err := oldfetchUserData(conn, "username", username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetchUserData")
|
||||
}
|
||||
@@ -81,8 +107,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
||||
}
|
||||
|
||||
// Queries the database for a user matching the given ID.
|
||||
func GetUserFromID(conn *sql.DB, id int) (*User, error) {
|
||||
rows, err := fetchUserData(conn, "id", id)
|
||||
func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
||||
rows, err := fetchUserData(ctx, tx, "id", id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetchUserData")
|
||||
}
|
||||
|
||||
@@ -1,41 +1,79 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/cookies"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/jwt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func revokeAccess(
|
||||
config *config.Config,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
atStr string,
|
||||
) error {
|
||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "Token is expired") ||
|
||||
strings.Contains(err.Error(), "Token has been revoked") {
|
||||
return nil // Token is expired, dont need to revoke it
|
||||
}
|
||||
return errors.Wrap(err, "jwt.ParseAccessToken")
|
||||
}
|
||||
err = jwt.RevokeToken(ctx, tx, aT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func revokeRefresh(
|
||||
config *config.Config,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
rtStr string,
|
||||
) error {
|
||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "Token is expired") ||
|
||||
strings.Contains(err.Error(), "Token has been revoked") {
|
||||
return nil // Token is expired, dont need to revoke it
|
||||
}
|
||||
return errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||
}
|
||||
err = jwt.RevokeToken(ctx, tx, rT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retrieve and revoke the user's tokens
|
||||
func revokeTokens(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) error {
|
||||
// get the tokens from the cookies
|
||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.ParseAccessToken")
|
||||
}
|
||||
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||
}
|
||||
// revoke the refresh token first as the access token expires quicker
|
||||
// only matters if there is an error revoking the tokens
|
||||
err = jwt.RevokeToken(conn, rT)
|
||||
err := revokeRefresh(config, ctx, tx, rtStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
return errors.Wrap(err, "revokeRefresh")
|
||||
}
|
||||
err = jwt.RevokeToken(conn, aT)
|
||||
err = revokeAccess(config, ctx, tx, atStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
return errors.Wrap(err, "revokeAccess")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -44,19 +82,24 @@ func revokeTokens(
|
||||
func HandleLogout(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
err := revokeTokens(config, conn, r)
|
||||
WithTransaction(w, r, logger, conn,
|
||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||
err := revokeTokens(config, ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error occured on user logout")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
cookies.DeleteCookie(w, "access", "/")
|
||||
cookies.DeleteCookie(w, "refresh", "/")
|
||||
w.Header().Set("HX-Redirect", "/login")
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/cookies"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/jwt"
|
||||
"projectreshoot/view/component/form"
|
||||
|
||||
@@ -17,16 +18,17 @@ import (
|
||||
// Get the tokens from the request
|
||||
func getTokens(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
||||
// get the existing tokens from the cookies
|
||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken")
|
||||
}
|
||||
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||
}
|
||||
@@ -35,15 +37,16 @@ func getTokens(
|
||||
|
||||
// Revoke the given token pair
|
||||
func revokeTokenPair(
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
aT *jwt.AccessToken,
|
||||
rT *jwt.RefreshToken,
|
||||
) error {
|
||||
err := jwt.RevokeToken(conn, aT)
|
||||
err := jwt.RevokeToken(ctx, tx, aT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
err = jwt.RevokeToken(conn, rT)
|
||||
err = jwt.RevokeToken(ctx, tx, rT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
@@ -53,11 +56,12 @@ func revokeTokenPair(
|
||||
// Issue new tokens for the user, invalidating the old ones
|
||||
func refreshTokens(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) error {
|
||||
aT, rT, err := getTokens(config, conn, r)
|
||||
aT, rT, err := getTokens(config, ctx, tx, r)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "getTokens")
|
||||
}
|
||||
@@ -71,7 +75,7 @@ func refreshTokens(
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cookies.SetTokenCookies")
|
||||
}
|
||||
err = revokeTokenPair(conn, aT, rT)
|
||||
err = revokeTokenPair(ctx, tx, aT, rT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "revokeTokenPair")
|
||||
}
|
||||
@@ -97,23 +101,29 @@ func validatePassword(
|
||||
func HandleReauthenticate(
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
WithTransaction(w, r, logger, conn,
|
||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||
err := validatePassword(r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
w.WriteHeader(445)
|
||||
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
err = refreshTokens(config, conn, w, r)
|
||||
err = refreshTokens(config, ctx, tx, w, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Failed to refresh user tokens")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
47
handlers/withtransaction.go
Normal file
47
handlers/withtransaction.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/view/page"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// A helper function to create a transaction with a cancellable context.
|
||||
func WithTransaction(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
logger *zerolog.Logger,
|
||||
conn *db.SafeConn,
|
||||
handler func(
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
),
|
||||
) {
|
||||
// Create a cancellable context from the request context
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Warn().Err(err).Msg("Request failed to start a transaction")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
page.Error(
|
||||
"503",
|
||||
http.StatusText(503),
|
||||
"This service is currently unavailable. It could be down for maintenance").
|
||||
Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
|
||||
// Pass the context and transaction to the handler
|
||||
handler(ctx, tx, w, r)
|
||||
}
|
||||
13
jwt/parse.go
13
jwt/parse.go
@@ -1,11 +1,12 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
@@ -17,7 +18,8 @@ import (
|
||||
// has the correct scope.
|
||||
func ParseAccessToken(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
tokenString string,
|
||||
) (*AccessToken, error) {
|
||||
if tokenString == "" {
|
||||
@@ -74,7 +76,7 @@ func ParseAccessToken(
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
valid, err := CheckTokenNotRevoked(conn, token)
|
||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||
}
|
||||
@@ -89,7 +91,8 @@ func ParseAccessToken(
|
||||
// has the correct scope.
|
||||
func ParseRefreshToken(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
tokenString string,
|
||||
) (*RefreshToken, error) {
|
||||
if tokenString == "" {
|
||||
@@ -141,7 +144,7 @@ func ParseRefreshToken(
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
valid, err := CheckTokenNotRevoked(conn, token)
|
||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||
}
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Revoke a token by adding it to the database
|
||||
func RevokeToken(conn *sql.DB, t Token) error {
|
||||
func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error {
|
||||
jti := t.GetJTI()
|
||||
exp := t.GetEXP()
|
||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||
_, err := conn.Exec(query, jti, exp)
|
||||
_, err := tx.Exec(ctx, query, jti, exp)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if a token has been revoked. Returns true if not revoked.
|
||||
func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
|
||||
func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) {
|
||||
jti := t.GetJTI()
|
||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||
rows, err := conn.Query(query, jti)
|
||||
defer rows.Close()
|
||||
rows, err := tx.Query(ctx, query, jti)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "conn.Exec")
|
||||
return false, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
// NOTE: rows.Close()
|
||||
defer rows.Close()
|
||||
revoked := rows.Next()
|
||||
return !revoked, nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -12,7 +12,7 @@ type Token interface {
|
||||
GetJTI() uuid.UUID
|
||||
GetEXP() int64
|
||||
GetScope() string
|
||||
GetUser(conn *sql.DB) (*db.User, error)
|
||||
GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error)
|
||||
}
|
||||
|
||||
// Access token
|
||||
@@ -38,15 +38,15 @@ type RefreshToken struct {
|
||||
Scope string // Should be "refresh"
|
||||
}
|
||||
|
||||
func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) {
|
||||
user, err := db.GetUserFromID(conn, a.SUB)
|
||||
func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
||||
user, err := db.GetUserFromID(ctx, tx, a.SUB)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) {
|
||||
user, err := db.GetUserFromID(conn, r.SUB)
|
||||
func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
||||
user, err := db.GetUserFromID(ctx, tx, r.SUB)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||
}
|
||||
|
||||
9
main.go
9
main.go
@@ -77,6 +77,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
return errors.Wrap(err, "logging.GetLogger")
|
||||
}
|
||||
|
||||
oldconn, err := db.OldConnectToDatabase(config.DBName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
||||
}
|
||||
defer oldconn.Close()
|
||||
conn, err := db.ConnectToDatabase(config.DBName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
||||
@@ -88,7 +93,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
return errors.Wrap(err, "getStaticFiles")
|
||||
}
|
||||
|
||||
srv := server.NewServer(config, logger, conn, &staticFS)
|
||||
srv := server.NewServer(config, logger, oldconn, conn, &staticFS)
|
||||
httpServer := &http.Server{
|
||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
||||
Handler: srv,
|
||||
@@ -99,7 +104,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
|
||||
// Runs function for testing in dev if --test flag true
|
||||
if args["test"] == "true" {
|
||||
test(config, logger, conn, httpServer)
|
||||
test(config, logger, oldconn, httpServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/cookies"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/handlers"
|
||||
"projectreshoot/jwt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -18,14 +19,15 @@ import (
|
||||
// Attempt to use a valid refresh token to generate a new token pair
|
||||
func refreshAuthTokens(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
w http.ResponseWriter,
|
||||
req *http.Request,
|
||||
ref *jwt.RefreshToken,
|
||||
) (*db.User, error) {
|
||||
user, err := ref.GetUser(conn)
|
||||
user, err := ref.GetUser(ctx, tx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "rT.GetUser")
|
||||
return nil, errors.Wrap(err, "ref.GetUser")
|
||||
}
|
||||
|
||||
rememberMe := map[string]bool{
|
||||
@@ -39,7 +41,7 @@ func refreshAuthTokens(
|
||||
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
|
||||
}
|
||||
// New tokens sent, revoke the used refresh token
|
||||
err = jwt.RevokeToken(conn, ref)
|
||||
err = jwt.RevokeToken(ctx, tx, ref)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
@@ -50,22 +52,23 @@ func refreshAuthTokens(
|
||||
// Check the cookies for token strings and attempt to authenticate them
|
||||
func getAuthenticatedUser(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) (*contexts.AuthenticatedUser, error) {
|
||||
// Get token strings from cookies
|
||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||
// Attempt to parse the access token
|
||||
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
||||
if err != nil {
|
||||
// Access token invalid, attempt to parse refresh token
|
||||
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||
}
|
||||
// Refresh token valid, attempt to get a new token pair
|
||||
user, err := refreshAuthTokens(config, conn, w, r, rT)
|
||||
user, err := refreshAuthTokens(config, ctx, tx, w, r, rT)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "refreshAuthTokens")
|
||||
}
|
||||
@@ -77,9 +80,9 @@ func getAuthenticatedUser(
|
||||
return &authUser, nil
|
||||
}
|
||||
// Access token valid
|
||||
user, err := aT.GetUser(conn)
|
||||
user, err := aT.GetUser(ctx, tx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "rT.GetUser")
|
||||
return nil, errors.Wrap(err, "aT.GetUser")
|
||||
}
|
||||
authUser := contexts.AuthenticatedUser{
|
||||
User: user,
|
||||
@@ -93,12 +96,21 @@ func getAuthenticatedUser(
|
||||
func Authentication(
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
next http.Handler,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := getAuthenticatedUser(config, conn, w, r)
|
||||
if r.URL.Path == "/static/css/output.css" ||
|
||||
r.URL.Path == "/static/favicon.ico" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
handlers.WithTransaction(w, r, logger, conn,
|
||||
func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
|
||||
tx, err := conn.Begin(ctx)
|
||||
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
// User auth failed, delete the cookies to avoid repeat requests
|
||||
cookies.DeleteCookie(w, "access", "/")
|
||||
cookies.DeleteCookie(w, "refresh", "/")
|
||||
@@ -106,9 +118,14 @@ func Authentication(
|
||||
Str("remote_addr", r.RemoteAddr).
|
||||
Err(err).
|
||||
Msg("Failed to authenticate user")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
ctx := contexts.SetUser(r.Context(), user)
|
||||
newReq := r.WithContext(ctx)
|
||||
tx.Commit()
|
||||
uctx := contexts.SetUser(r.Context(), user)
|
||||
newReq := r.WithContext(uctx)
|
||||
next.ServeHTTP(w, newReq)
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestAuthenticationMiddleware(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB()
|
||||
conn, err := tests.SetupTestDB(t.Context())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
|
||||
@@ -23,9 +23,14 @@ func (w *wrappedWriter) WriteHeader(statusCode int) {
|
||||
// Middleware to add logs to console with details of the request
|
||||
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/static/css/output.css" ||
|
||||
r.URL.Path == "/static/favicon.ico" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
start, err := contexts.GetStartTime(r.Context())
|
||||
if err != nil {
|
||||
// Handle failure here. internal server error maybe
|
||||
// TODO: Handle failure here. internal server error maybe
|
||||
return
|
||||
}
|
||||
wrapped := &wrappedWriter{
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestPageLoginRequired(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB()
|
||||
conn, err := tests.SetupTestDB(t.Context())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestActionReauthRequired(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB()
|
||||
conn, err := tests.SetupTestDB(t.Context())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/handlers"
|
||||
"projectreshoot/middleware"
|
||||
"projectreshoot/view/page"
|
||||
@@ -17,7 +18,8 @@ func addRoutes(
|
||||
mux *http.ServeMux,
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
oldconn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
staticFS *http.FileSystem,
|
||||
) {
|
||||
// Health check
|
||||
@@ -42,7 +44,7 @@ func addRoutes(
|
||||
handlers.HandleLoginRequest(
|
||||
config,
|
||||
logger,
|
||||
conn,
|
||||
oldconn,
|
||||
)))
|
||||
|
||||
// Register page and handlers
|
||||
@@ -55,7 +57,7 @@ func addRoutes(
|
||||
handlers.HandleRegisterRequest(
|
||||
config,
|
||||
logger,
|
||||
conn,
|
||||
oldconn,
|
||||
)))
|
||||
|
||||
// Logout
|
||||
@@ -85,17 +87,17 @@ func addRoutes(
|
||||
mux.Handle("POST /change-username",
|
||||
middleware.RequiresLogin(
|
||||
middleware.RequiresFresh(
|
||||
handlers.HandleChangeUsername(logger, conn),
|
||||
handlers.HandleChangeUsername(logger, oldconn),
|
||||
),
|
||||
))
|
||||
mux.Handle("POST /change-bio",
|
||||
middleware.RequiresLogin(
|
||||
handlers.HandleChangeBio(logger, conn),
|
||||
handlers.HandleChangeBio(logger, oldconn),
|
||||
))
|
||||
mux.Handle("POST /change-password",
|
||||
middleware.RequiresLogin(
|
||||
middleware.RequiresFresh(
|
||||
handlers.HandleChangePassword(logger, conn),
|
||||
handlers.HandleChangePassword(logger, oldconn),
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/middleware"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
@@ -14,7 +15,8 @@ import (
|
||||
func NewServer(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
oldconn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
staticFS *http.FileSystem,
|
||||
) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
@@ -22,6 +24,7 @@ func NewServer(
|
||||
mux,
|
||||
logger,
|
||||
config,
|
||||
oldconn,
|
||||
conn,
|
||||
staticFS,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@@ -32,11 +34,16 @@ func findSQLFile(filename string) (string, error) {
|
||||
|
||||
// SetupTestDB initializes a test SQLite database with mock data
|
||||
// Make sure to call DeleteTestDB when finished to cleanup
|
||||
func SetupTestDB() (*sql.DB, error) {
|
||||
conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
|
||||
func SetupTestDB(ctx context.Context) (*db.SafeConn, error) {
|
||||
dbfile, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
conn := db.MakeSafe(dbfile)
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Begin")
|
||||
}
|
||||
// Setup the test database
|
||||
schemaPath, err := findSQLFile("schema.sql")
|
||||
if err != nil {
|
||||
@@ -49,9 +56,10 @@ func SetupTestDB() (*sql.DB, error) {
|
||||
}
|
||||
schemaSQL := string(sqlBytes)
|
||||
|
||||
_, err = conn.Exec(schemaSQL)
|
||||
_, err = tx.Exec(ctx, schemaSQL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Exec")
|
||||
tx.Rollback()
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
// Load the test data
|
||||
dataPath, err := findSQLFile("testdata.sql")
|
||||
@@ -64,10 +72,12 @@ func SetupTestDB() (*sql.DB, error) {
|
||||
}
|
||||
dataSQL := string(sqlBytes)
|
||||
|
||||
_, err = conn.Exec(dataSQL)
|
||||
_, err = tx.Exec(ctx, dataSQL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Exec")
|
||||
tx.Rollback()
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
tx.Commit()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user