Update authentication, reauth, logout to use new transactions

This commit is contained in:
2025-02-17 18:58:34 +11:00
parent 417daf0028
commit 2c61cec55c
17 changed files with 306 additions and 121 deletions

View File

@@ -5,7 +5,6 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"sync" "sync"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -18,17 +17,19 @@ type SafeConn struct {
mux sync.RWMutex mux sync.RWMutex
} }
func MakeSafe(db *sql.DB) *SafeConn {
return &SafeConn{db: db}
}
// Extends sql.Tx for use with SafeConn // Extends sql.Tx for use with SafeConn
type SafeTX struct { type SafeTX struct {
tx *sql.Tx tx *sql.Tx
sc *SafeConn sc *SafeConn
} }
// Starts a new transaction, waiting up to 10 seconds if the database is locked // Starts a new transaction based on the current context. Will cancel if
// the context is closed/cancelled/done
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) { func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
lockAcquired := make(chan struct{}) lockAcquired := make(chan struct{})
go func() { go func() {
conn.mux.RLock() conn.mux.RLock()
@@ -119,7 +120,18 @@ func (conn *SafeConn) Close() error {
return conn.db.Close() return conn.db.Close()
} }
// Returns a database connection handle for the Turso DB // Returns a database connection handle for the DB
func OldConnectToDatabase(dbName string) (*sql.DB, error) {
file := fmt.Sprintf("file:%s.db", dbName)
db, err := sql.Open("sqlite3", file)
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
return db, nil
}
// Returns a database connection handle for the DB
func ConnectToDatabase(dbName string) (*SafeConn, error) { func ConnectToDatabase(dbName string) (*SafeConn, error) {
file := fmt.Sprintf("file:%s.db", dbName) file := fmt.Sprintf("file:%s.db", dbName)
db, err := sql.Open("sqlite3", file) db, err := sql.Open("sqlite3", file)
@@ -127,7 +139,7 @@ func ConnectToDatabase(dbName string) (*SafeConn, error) {
return nil, errors.Wrap(err, "sql.Open") return nil, errors.Wrap(err, "sql.Open")
} }
conn := &SafeConn{db: db} conn := MakeSafe(db)
return conn, nil return conn, nil
} }

View File

@@ -1,6 +1,7 @@
package db package db
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
@@ -26,7 +27,7 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error
} }
// Fetches data from the users table using "WHERE column = 'value'" // Fetches data from the users table using "WHERE column = 'value'"
func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) { func oldfetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
query := fmt.Sprintf( query := fmt.Sprintf(
`SELECT `SELECT
id, id,
@@ -45,6 +46,31 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e
return rows, nil return rows, nil
} }
// Fetches data from the users table using "WHERE column = 'value'"
func fetchUserData(
ctx context.Context,
tx *SafeTX,
column string,
value interface{},
) (*sql.Rows, error) {
query := fmt.Sprintf(
`SELECT
id,
username,
password_hash,
created_at,
bio
FROM users
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
column,
)
rows, err := tx.Query(ctx, query, value)
if err != nil {
return nil, errors.Wrap(err, "tx.Query")
}
return rows, nil
}
// Scan the next row into the provided user pointer. Calls rows.Next() and // Scan the next row into the provided user pointer. Calls rows.Next() and
// assumes only row in the result. Providing a rows object with more than 1 // assumes only row in the result. Providing a rows object with more than 1
// row may result in undefined behaviour. // row may result in undefined behaviour.
@@ -67,7 +93,7 @@ func scanUserRow(user *User, rows *sql.Rows) error {
// Queries the database for a user matching the given username. // Queries the database for a user matching the given username.
// Query is case insensitive // Query is case insensitive
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) { func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
rows, err := fetchUserData(conn, "username", username) rows, err := oldfetchUserData(conn, "username", username)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetchUserData") return nil, errors.Wrap(err, "fetchUserData")
} }
@@ -81,8 +107,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
} }
// Queries the database for a user matching the given ID. // Queries the database for a user matching the given ID.
func GetUserFromID(conn *sql.DB, id int) (*User, error) { func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
rows, err := fetchUserData(conn, "id", id) rows, err := fetchUserData(ctx, tx, "id", id)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetchUserData") return nil, errors.Wrap(err, "fetchUserData")
} }

View File

@@ -1,41 +1,79 @@
package handlers package handlers
import ( import (
"database/sql" "context"
"net/http" "net/http"
"strings"
"projectreshoot/config" "projectreshoot/config"
"projectreshoot/cookies" "projectreshoot/cookies"
"projectreshoot/db"
"projectreshoot/jwt" "projectreshoot/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
func revokeAccess(
config *config.Config,
ctx context.Context,
tx *db.SafeTX,
atStr string,
) error {
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
if err != nil {
if strings.Contains(err.Error(), "Token is expired") ||
strings.Contains(err.Error(), "Token has been revoked") {
return nil // Token is expired, dont need to revoke it
}
return errors.Wrap(err, "jwt.ParseAccessToken")
}
err = jwt.RevokeToken(ctx, tx, aT)
if err != nil {
return errors.Wrap(err, "jwt.RevokeToken")
}
return nil
}
func revokeRefresh(
config *config.Config,
ctx context.Context,
tx *db.SafeTX,
rtStr string,
) error {
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
if err != nil {
if strings.Contains(err.Error(), "Token is expired") ||
strings.Contains(err.Error(), "Token has been revoked") {
return nil // Token is expired, dont need to revoke it
}
return errors.Wrap(err, "jwt.ParseRefreshToken")
}
err = jwt.RevokeToken(ctx, tx, rT)
if err != nil {
return errors.Wrap(err, "jwt.RevokeToken")
}
return nil
}
// Retrieve and revoke the user's tokens // Retrieve and revoke the user's tokens
func revokeTokens( func revokeTokens(
config *config.Config, config *config.Config,
conn *sql.DB, ctx context.Context,
tx *db.SafeTX,
r *http.Request, r *http.Request,
) error { ) error {
// get the tokens from the cookies // get the tokens from the cookies
atStr, rtStr := cookies.GetTokenStrings(r) atStr, rtStr := cookies.GetTokenStrings(r)
aT, err := jwt.ParseAccessToken(config, conn, atStr)
if err != nil {
return errors.Wrap(err, "jwt.ParseAccessToken")
}
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
if err != nil {
return errors.Wrap(err, "jwt.ParseRefreshToken")
}
// revoke the refresh token first as the access token expires quicker // revoke the refresh token first as the access token expires quicker
// only matters if there is an error revoking the tokens // only matters if there is an error revoking the tokens
err = jwt.RevokeToken(conn, rT) err := revokeRefresh(config, ctx, tx, rtStr)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.RevokeToken") return errors.Wrap(err, "revokeRefresh")
} }
err = jwt.RevokeToken(conn, aT) err = revokeAccess(config, ctx, tx, atStr)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.RevokeToken") return errors.Wrap(err, "revokeAccess")
} }
return nil return nil
} }
@@ -44,19 +82,24 @@ func revokeTokens(
func HandleLogout( func HandleLogout(
config *config.Config, config *config.Config,
logger *zerolog.Logger, logger *zerolog.Logger,
conn *sql.DB, conn *db.SafeConn,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
err := revokeTokens(config, conn, r) WithTransaction(w, r, logger, conn,
if err != nil { func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
logger.Error().Err(err).Msg("Error occured on user logout") err := revokeTokens(config, ctx, tx, r)
w.WriteHeader(http.StatusInternalServerError) if err != nil {
return tx.Rollback()
} logger.Error().Err(err).Msg("Error occured on user logout")
cookies.DeleteCookie(w, "access", "/") w.WriteHeader(http.StatusInternalServerError)
cookies.DeleteCookie(w, "refresh", "/") return
w.Header().Set("HX-Redirect", "/login") }
tx.Commit()
cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/")
w.Header().Set("HX-Redirect", "/login")
})
}, },
) )
} }

View File

@@ -1,12 +1,13 @@
package handlers package handlers
import ( import (
"database/sql" "context"
"net/http" "net/http"
"projectreshoot/config" "projectreshoot/config"
"projectreshoot/contexts" "projectreshoot/contexts"
"projectreshoot/cookies" "projectreshoot/cookies"
"projectreshoot/db"
"projectreshoot/jwt" "projectreshoot/jwt"
"projectreshoot/view/component/form" "projectreshoot/view/component/form"
@@ -17,16 +18,17 @@ import (
// Get the tokens from the request // Get the tokens from the request
func getTokens( func getTokens(
config *config.Config, config *config.Config,
conn *sql.DB, ctx context.Context,
tx *db.SafeTX,
r *http.Request, r *http.Request,
) (*jwt.AccessToken, *jwt.RefreshToken, error) { ) (*jwt.AccessToken, *jwt.RefreshToken, error) {
// get the existing tokens from the cookies // get the existing tokens from the cookies
atStr, rtStr := cookies.GetTokenStrings(r) atStr, rtStr := cookies.GetTokenStrings(r)
aT, err := jwt.ParseAccessToken(config, conn, atStr) aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken") return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken")
} }
rT, err := jwt.ParseRefreshToken(config, conn, rtStr) rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken") return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken")
} }
@@ -35,15 +37,16 @@ func getTokens(
// Revoke the given token pair // Revoke the given token pair
func revokeTokenPair( func revokeTokenPair(
conn *sql.DB, ctx context.Context,
tx *db.SafeTX,
aT *jwt.AccessToken, aT *jwt.AccessToken,
rT *jwt.RefreshToken, rT *jwt.RefreshToken,
) error { ) error {
err := jwt.RevokeToken(conn, aT) err := jwt.RevokeToken(ctx, tx, aT)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.RevokeToken") return errors.Wrap(err, "jwt.RevokeToken")
} }
err = jwt.RevokeToken(conn, rT) err = jwt.RevokeToken(ctx, tx, rT)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.RevokeToken") return errors.Wrap(err, "jwt.RevokeToken")
} }
@@ -53,11 +56,12 @@ func revokeTokenPair(
// Issue new tokens for the user, invalidating the old ones // Issue new tokens for the user, invalidating the old ones
func refreshTokens( func refreshTokens(
config *config.Config, config *config.Config,
conn *sql.DB, ctx context.Context,
tx *db.SafeTX,
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
) error { ) error {
aT, rT, err := getTokens(config, conn, r) aT, rT, err := getTokens(config, ctx, tx, r)
if err != nil { if err != nil {
return errors.Wrap(err, "getTokens") return errors.Wrap(err, "getTokens")
} }
@@ -71,7 +75,7 @@ func refreshTokens(
if err != nil { if err != nil {
return errors.Wrap(err, "cookies.SetTokenCookies") return errors.Wrap(err, "cookies.SetTokenCookies")
} }
err = revokeTokenPair(conn, aT, rT) err = revokeTokenPair(ctx, tx, aT, rT)
if err != nil { if err != nil {
return errors.Wrap(err, "revokeTokenPair") return errors.Wrap(err, "revokeTokenPair")
} }
@@ -97,23 +101,29 @@ func validatePassword(
func HandleReauthenticate( func HandleReauthenticate(
logger *zerolog.Logger, logger *zerolog.Logger,
config *config.Config, config *config.Config,
conn *sql.DB, conn *db.SafeConn,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
err := validatePassword(r) WithTransaction(w, r, logger, conn,
if err != nil { func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(445) err := validatePassword(r)
form.ConfirmPassword("Incorrect password").Render(r.Context(), w) if err != nil {
return tx.Rollback()
} w.WriteHeader(445)
err = refreshTokens(config, conn, w, r) form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
if err != nil { return
logger.Error().Err(err).Msg("Failed to refresh user tokens") }
w.WriteHeader(http.StatusInternalServerError) err = refreshTokens(config, ctx, tx, w, r)
return if err != nil {
} tx.Rollback()
w.WriteHeader(http.StatusOK) logger.Error().Err(err).Msg("Failed to refresh user tokens")
w.WriteHeader(http.StatusInternalServerError)
return
}
tx.Commit()
w.WriteHeader(http.StatusOK)
})
}, },
) )
} }

View File

@@ -0,0 +1,47 @@
package handlers
import (
"context"
"net/http"
"time"
"projectreshoot/db"
"projectreshoot/view/page"
"github.com/rs/zerolog"
)
// A helper function to create a transaction with a cancellable context.
func WithTransaction(
w http.ResponseWriter,
r *http.Request,
logger *zerolog.Logger,
conn *db.SafeConn,
handler func(
ctx context.Context,
tx *db.SafeTX,
w http.ResponseWriter,
r *http.Request,
),
) {
// Create a cancellable context from the request context
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
tx.Rollback()
logger.Warn().Err(err).Msg("Request failed to start a transaction")
w.WriteHeader(http.StatusServiceUnavailable)
page.Error(
"503",
http.StatusText(503),
"This service is currently unavailable. It could be down for maintenance").
Render(r.Context(), w)
return
}
// Pass the context and transaction to the handler
handler(ctx, tx, w, r)
}

View File

@@ -1,11 +1,12 @@
package jwt package jwt
import ( import (
"database/sql" "context"
"fmt" "fmt"
"time" "time"
"projectreshoot/config" "projectreshoot/config"
"projectreshoot/db"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/google/uuid" "github.com/google/uuid"
@@ -17,7 +18,8 @@ import (
// has the correct scope. // has the correct scope.
func ParseAccessToken( func ParseAccessToken(
config *config.Config, config *config.Config,
conn *sql.DB, ctx context.Context,
tx *db.SafeTX,
tokenString string, tokenString string,
) (*AccessToken, error) { ) (*AccessToken, error) {
if tokenString == "" { if tokenString == "" {
@@ -74,7 +76,7 @@ func ParseAccessToken(
Scope: scope, Scope: scope,
} }
valid, err := CheckTokenNotRevoked(conn, token) valid, err := CheckTokenNotRevoked(ctx, tx, token)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "CheckTokenNotRevoked") return nil, errors.Wrap(err, "CheckTokenNotRevoked")
} }
@@ -89,7 +91,8 @@ func ParseAccessToken(
// has the correct scope. // has the correct scope.
func ParseRefreshToken( func ParseRefreshToken(
config *config.Config, config *config.Config,
conn *sql.DB, ctx context.Context,
tx *db.SafeTX,
tokenString string, tokenString string,
) (*RefreshToken, error) { ) (*RefreshToken, error) {
if tokenString == "" { if tokenString == "" {
@@ -141,7 +144,7 @@ func ParseRefreshToken(
Scope: scope, Scope: scope,
} }
valid, err := CheckTokenNotRevoked(conn, token) valid, err := CheckTokenNotRevoked(ctx, tx, token)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "CheckTokenNotRevoked") return nil, errors.Wrap(err, "CheckTokenNotRevoked")
} }

View File

@@ -1,32 +1,34 @@
package jwt package jwt
import ( import (
"database/sql" "context"
"projectreshoot/db"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// Revoke a token by adding it to the database // Revoke a token by adding it to the database
func RevokeToken(conn *sql.DB, t Token) error { func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error {
jti := t.GetJTI() jti := t.GetJTI()
exp := t.GetEXP() exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)` query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err := conn.Exec(query, jti, exp) _, err := tx.Exec(ctx, query, jti, exp)
if err != nil { if err != nil {
return errors.Wrap(err, "conn.Exec") return errors.Wrap(err, "tx.Exec")
} }
return nil return nil
} }
// Check if a token has been revoked. Returns true if not revoked. // Check if a token has been revoked. Returns true if not revoked.
func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) { func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) {
jti := t.GetJTI() jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := conn.Query(query, jti) rows, err := tx.Query(ctx, query, jti)
defer rows.Close()
if err != nil { if err != nil {
return false, errors.Wrap(err, "conn.Exec") return false, errors.Wrap(err, "tx.Query")
} }
// NOTE: rows.Close()
defer rows.Close()
revoked := rows.Next() revoked := rows.Next()
return !revoked, nil return !revoked, nil
} }

View File

@@ -1,7 +1,7 @@
package jwt package jwt
import ( import (
"database/sql" "context"
"projectreshoot/db" "projectreshoot/db"
"github.com/google/uuid" "github.com/google/uuid"
@@ -12,7 +12,7 @@ type Token interface {
GetJTI() uuid.UUID GetJTI() uuid.UUID
GetEXP() int64 GetEXP() int64
GetScope() string GetScope() string
GetUser(conn *sql.DB) (*db.User, error) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error)
} }
// Access token // Access token
@@ -38,15 +38,15 @@ type RefreshToken struct {
Scope string // Should be "refresh" Scope string // Should be "refresh"
} }
func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) { func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
user, err := db.GetUserFromID(conn, a.SUB) user, err := db.GetUserFromID(ctx, tx, a.SUB)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID") return nil, errors.Wrap(err, "db.GetUserFromID")
} }
return user, nil return user, nil
} }
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) { func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
user, err := db.GetUserFromID(conn, r.SUB) user, err := db.GetUserFromID(ctx, tx, r.SUB)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID") return nil, errors.Wrap(err, "db.GetUserFromID")
} }

View File

@@ -77,6 +77,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return errors.Wrap(err, "logging.GetLogger") return errors.Wrap(err, "logging.GetLogger")
} }
oldconn, err := db.OldConnectToDatabase(config.DBName)
if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase")
}
defer oldconn.Close()
conn, err := db.ConnectToDatabase(config.DBName) conn, err := db.ConnectToDatabase(config.DBName)
if err != nil { if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase") return errors.Wrap(err, "db.ConnectToDatabase")
@@ -88,7 +93,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return errors.Wrap(err, "getStaticFiles") return errors.Wrap(err, "getStaticFiles")
} }
srv := server.NewServer(config, logger, conn, &staticFS) srv := server.NewServer(config, logger, oldconn, conn, &staticFS)
httpServer := &http.Server{ httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, config.Port), Addr: net.JoinHostPort(config.Host, config.Port),
Handler: srv, Handler: srv,
@@ -99,7 +104,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
// Runs function for testing in dev if --test flag true // Runs function for testing in dev if --test flag true
if args["test"] == "true" { if args["test"] == "true" {
test(config, logger, conn, httpServer) test(config, logger, oldconn, httpServer)
return nil return nil
} }

View File

@@ -1,7 +1,7 @@
package middleware package middleware
import ( import (
"database/sql" "context"
"net/http" "net/http"
"time" "time"
@@ -9,6 +9,7 @@ import (
"projectreshoot/contexts" "projectreshoot/contexts"
"projectreshoot/cookies" "projectreshoot/cookies"
"projectreshoot/db" "projectreshoot/db"
"projectreshoot/handlers"
"projectreshoot/jwt" "projectreshoot/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -18,14 +19,15 @@ import (
// Attempt to use a valid refresh token to generate a new token pair // Attempt to use a valid refresh token to generate a new token pair
func refreshAuthTokens( func refreshAuthTokens(
config *config.Config, config *config.Config,
conn *sql.DB, ctx context.Context,
tx *db.SafeTX,
w http.ResponseWriter, w http.ResponseWriter,
req *http.Request, req *http.Request,
ref *jwt.RefreshToken, ref *jwt.RefreshToken,
) (*db.User, error) { ) (*db.User, error) {
user, err := ref.GetUser(conn) user, err := ref.GetUser(ctx, tx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "rT.GetUser") return nil, errors.Wrap(err, "ref.GetUser")
} }
rememberMe := map[string]bool{ rememberMe := map[string]bool{
@@ -39,7 +41,7 @@ func refreshAuthTokens(
return nil, errors.Wrap(err, "cookies.SetTokenCookies") return nil, errors.Wrap(err, "cookies.SetTokenCookies")
} }
// New tokens sent, revoke the used refresh token // New tokens sent, revoke the used refresh token
err = jwt.RevokeToken(conn, ref) err = jwt.RevokeToken(ctx, tx, ref)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "jwt.RevokeToken") return nil, errors.Wrap(err, "jwt.RevokeToken")
} }
@@ -50,22 +52,23 @@ func refreshAuthTokens(
// Check the cookies for token strings and attempt to authenticate them // Check the cookies for token strings and attempt to authenticate them
func getAuthenticatedUser( func getAuthenticatedUser(
config *config.Config, config *config.Config,
conn *sql.DB, ctx context.Context,
tx *db.SafeTX,
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
) (*contexts.AuthenticatedUser, error) { ) (*contexts.AuthenticatedUser, error) {
// Get token strings from cookies // Get token strings from cookies
atStr, rtStr := cookies.GetTokenStrings(r) atStr, rtStr := cookies.GetTokenStrings(r)
// Attempt to parse the access token // Attempt to parse the access token
aT, err := jwt.ParseAccessToken(config, conn, atStr) aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
if err != nil { if err != nil {
// Access token invalid, attempt to parse refresh token // Access token invalid, attempt to parse refresh token
rT, err := jwt.ParseRefreshToken(config, conn, rtStr) rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "jwt.ParseRefreshToken") return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
} }
// Refresh token valid, attempt to get a new token pair // Refresh token valid, attempt to get a new token pair
user, err := refreshAuthTokens(config, conn, w, r, rT) user, err := refreshAuthTokens(config, ctx, tx, w, r, rT)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "refreshAuthTokens") return nil, errors.Wrap(err, "refreshAuthTokens")
} }
@@ -77,9 +80,9 @@ func getAuthenticatedUser(
return &authUser, nil return &authUser, nil
} }
// Access token valid // Access token valid
user, err := aT.GetUser(conn) user, err := aT.GetUser(ctx, tx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "rT.GetUser") return nil, errors.Wrap(err, "aT.GetUser")
} }
authUser := contexts.AuthenticatedUser{ authUser := contexts.AuthenticatedUser{
User: user, User: user,
@@ -93,22 +96,36 @@ func getAuthenticatedUser(
func Authentication( func Authentication(
logger *zerolog.Logger, logger *zerolog.Logger,
config *config.Config, config *config.Config,
conn *sql.DB, conn *db.SafeConn,
next http.Handler, next http.Handler,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, err := getAuthenticatedUser(config, conn, w, r) if r.URL.Path == "/static/css/output.css" ||
if err != nil { r.URL.Path == "/static/favicon.ico" {
// User auth failed, delete the cookies to avoid repeat requests next.ServeHTTP(w, r)
cookies.DeleteCookie(w, "access", "/") return
cookies.DeleteCookie(w, "refresh", "/")
logger.Debug().
Str("remote_addr", r.RemoteAddr).
Err(err).
Msg("Failed to authenticate user")
} }
ctx := contexts.SetUser(r.Context(), user) handlers.WithTransaction(w, r, logger, conn,
newReq := r.WithContext(ctx) func(ctx context.Context, tx *db.SafeTX, w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, newReq) tx, err := conn.Begin(ctx)
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
if err != nil {
tx.Rollback()
// User auth failed, delete the cookies to avoid repeat requests
cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/")
logger.Debug().
Str("remote_addr", r.RemoteAddr).
Err(err).
Msg("Failed to authenticate user")
next.ServeHTTP(w, r)
return
}
tx.Commit()
uctx := contexts.SetUser(r.Context(), user)
newReq := r.WithContext(uctx)
next.ServeHTTP(w, newReq)
},
)
}) })
} }

View File

@@ -19,7 +19,7 @@ func TestAuthenticationMiddleware(t *testing.T) {
cfg, err := tests.TestConfig() cfg, err := tests.TestConfig()
require.NoError(t, err) require.NoError(t, err)
logger := tests.NilLogger() logger := tests.NilLogger()
conn, err := tests.SetupTestDB() conn, err := tests.SetupTestDB(t.Context())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, conn) require.NotNil(t, conn)
defer tests.DeleteTestDB() defer tests.DeleteTestDB()

View File

@@ -23,9 +23,14 @@ func (w *wrappedWriter) WriteHeader(statusCode int) {
// Middleware to add logs to console with details of the request // Middleware to add logs to console with details of the request
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler { func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/static/css/output.css" ||
r.URL.Path == "/static/favicon.ico" {
next.ServeHTTP(w, r)
return
}
start, err := contexts.GetStartTime(r.Context()) start, err := contexts.GetStartTime(r.Context())
if err != nil { if err != nil {
// Handle failure here. internal server error maybe // TODO: Handle failure here. internal server error maybe
return return
} }
wrapped := &wrappedWriter{ wrapped := &wrappedWriter{

View File

@@ -16,7 +16,7 @@ func TestPageLoginRequired(t *testing.T) {
cfg, err := tests.TestConfig() cfg, err := tests.TestConfig()
require.NoError(t, err) require.NoError(t, err)
logger := tests.NilLogger() logger := tests.NilLogger()
conn, err := tests.SetupTestDB() conn, err := tests.SetupTestDB(t.Context())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, conn) require.NotNil(t, conn)
defer tests.DeleteTestDB() defer tests.DeleteTestDB()

View File

@@ -16,7 +16,7 @@ func TestActionReauthRequired(t *testing.T) {
cfg, err := tests.TestConfig() cfg, err := tests.TestConfig()
require.NoError(t, err) require.NoError(t, err)
logger := tests.NilLogger() logger := tests.NilLogger()
conn, err := tests.SetupTestDB() conn, err := tests.SetupTestDB(t.Context())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, conn) require.NotNil(t, conn)
defer tests.DeleteTestDB() defer tests.DeleteTestDB()

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"projectreshoot/config" "projectreshoot/config"
"projectreshoot/db"
"projectreshoot/handlers" "projectreshoot/handlers"
"projectreshoot/middleware" "projectreshoot/middleware"
"projectreshoot/view/page" "projectreshoot/view/page"
@@ -17,7 +18,8 @@ func addRoutes(
mux *http.ServeMux, mux *http.ServeMux,
logger *zerolog.Logger, logger *zerolog.Logger,
config *config.Config, config *config.Config,
conn *sql.DB, oldconn *sql.DB,
conn *db.SafeConn,
staticFS *http.FileSystem, staticFS *http.FileSystem,
) { ) {
// Health check // Health check
@@ -42,7 +44,7 @@ func addRoutes(
handlers.HandleLoginRequest( handlers.HandleLoginRequest(
config, config,
logger, logger,
conn, oldconn,
))) )))
// Register page and handlers // Register page and handlers
@@ -55,7 +57,7 @@ func addRoutes(
handlers.HandleRegisterRequest( handlers.HandleRegisterRequest(
config, config,
logger, logger,
conn, oldconn,
))) )))
// Logout // Logout
@@ -85,17 +87,17 @@ func addRoutes(
mux.Handle("POST /change-username", mux.Handle("POST /change-username",
middleware.RequiresLogin( middleware.RequiresLogin(
middleware.RequiresFresh( middleware.RequiresFresh(
handlers.HandleChangeUsername(logger, conn), handlers.HandleChangeUsername(logger, oldconn),
), ),
)) ))
mux.Handle("POST /change-bio", mux.Handle("POST /change-bio",
middleware.RequiresLogin( middleware.RequiresLogin(
handlers.HandleChangeBio(logger, conn), handlers.HandleChangeBio(logger, oldconn),
)) ))
mux.Handle("POST /change-password", mux.Handle("POST /change-password",
middleware.RequiresLogin( middleware.RequiresLogin(
middleware.RequiresFresh( middleware.RequiresFresh(
handlers.HandleChangePassword(logger, conn), handlers.HandleChangePassword(logger, oldconn),
), ),
)) ))
} }

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"projectreshoot/config" "projectreshoot/config"
"projectreshoot/db"
"projectreshoot/middleware" "projectreshoot/middleware"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@@ -14,7 +15,8 @@ import (
func NewServer( func NewServer(
config *config.Config, config *config.Config,
logger *zerolog.Logger, logger *zerolog.Logger,
conn *sql.DB, oldconn *sql.DB,
conn *db.SafeConn,
staticFS *http.FileSystem, staticFS *http.FileSystem,
) http.Handler { ) http.Handler {
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -22,6 +24,7 @@ func NewServer(
mux, mux,
logger, logger,
config, config,
oldconn,
conn, conn,
staticFS, staticFS,
) )

View File

@@ -1,10 +1,12 @@
package tests package tests
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"projectreshoot/db"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -32,11 +34,16 @@ func findSQLFile(filename string) (string, error) {
// SetupTestDB initializes a test SQLite database with mock data // SetupTestDB initializes a test SQLite database with mock data
// Make sure to call DeleteTestDB when finished to cleanup // Make sure to call DeleteTestDB when finished to cleanup
func SetupTestDB() (*sql.DB, error) { func SetupTestDB(ctx context.Context) (*db.SafeConn, error) {
conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db") dbfile, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sql.Open") return nil, errors.Wrap(err, "sql.Open")
} }
conn := db.MakeSafe(dbfile)
tx, err := conn.Begin(ctx)
if err != nil {
return nil, errors.Wrap(err, "conn.Begin")
}
// Setup the test database // Setup the test database
schemaPath, err := findSQLFile("schema.sql") schemaPath, err := findSQLFile("schema.sql")
if err != nil { if err != nil {
@@ -49,9 +56,10 @@ func SetupTestDB() (*sql.DB, error) {
} }
schemaSQL := string(sqlBytes) schemaSQL := string(sqlBytes)
_, err = conn.Exec(schemaSQL) _, err = tx.Exec(ctx, schemaSQL)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "conn.Exec") tx.Rollback()
return nil, errors.Wrap(err, "tx.Exec")
} }
// Load the test data // Load the test data
dataPath, err := findSQLFile("testdata.sql") dataPath, err := findSQLFile("testdata.sql")
@@ -64,10 +72,12 @@ func SetupTestDB() (*sql.DB, error) {
} }
dataSQL := string(sqlBytes) dataSQL := string(sqlBytes)
_, err = conn.Exec(dataSQL) _, err = tx.Exec(ctx, dataSQL)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "conn.Exec") tx.Rollback()
return nil, errors.Wrap(err, "tx.Exec")
} }
tx.Commit()
return conn, nil return conn, nil
} }