110 lines
3.0 KiB
Go
110 lines
3.0 KiB
Go
package hwsauth
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.haelnorr.com/h/golib/hws"
|
|
"github.com/gobwas/glob"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// Authenticate returns the main authentication middleware.
|
|
// This middleware validates JWT tokens, refreshes expired tokens, and adds
|
|
// the authenticated user to the request context.
|
|
//
|
|
// Example:
|
|
//
|
|
// server.AddMiddleware(auth.Authenticate(nil))
|
|
//
|
|
// If extraCheck is provided, it will run just before the user is added to the context,
|
|
// and the return will determine if the user will be added, or the request passed on
|
|
// without the user.
|
|
func (auth *Authenticator[T, TX]) Authenticate(
|
|
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
|
|
) hws.Middleware {
|
|
return auth.server.NewMiddleware(auth.authenticate(extraCheck))
|
|
}
|
|
|
|
func (auth *Authenticator[T, TX]) authenticate(
|
|
extraCheck func(ctx context.Context, model T, tx TX, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError),
|
|
) hws.MiddlewareFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) (*http.Request, *hws.HWSError) {
|
|
if globTest(r.URL.Path, auth.ignoredPaths) {
|
|
return r, nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
// Start the transaction
|
|
tx, err := auth.beginTx(ctx)
|
|
if err != nil {
|
|
return nil, &hws.HWSError{
|
|
Message: "Unable to start transaction",
|
|
StatusCode: http.StatusServiceUnavailable,
|
|
Error: errors.Wrap(err, "auth.beginTx"),
|
|
}
|
|
}
|
|
defer func() {
|
|
_ = tx.Rollback()
|
|
}()
|
|
// Type assert to TX - safe because user's beginTx should return their TX type
|
|
txTyped, ok := tx.(TX)
|
|
if !ok {
|
|
return nil, &hws.HWSError{
|
|
Message: "Transaction type mismatch",
|
|
StatusCode: http.StatusInternalServerError,
|
|
Error: errors.Wrap(err, "TX type not ok"),
|
|
}
|
|
}
|
|
model, err := auth.getAuthenticatedUser(txTyped, w, r)
|
|
if err != nil {
|
|
rberr := tx.Rollback()
|
|
if rberr != nil {
|
|
return nil, &hws.HWSError{
|
|
Message: "Failed rolling back after error",
|
|
StatusCode: http.StatusInternalServerError,
|
|
Error: errors.Wrap(err, "tx.Rollback"),
|
|
}
|
|
}
|
|
auth.logger.Debug().
|
|
Str("remote_addr", r.RemoteAddr).
|
|
Err(err).
|
|
Msg("Failed to authenticate user")
|
|
return r, nil
|
|
}
|
|
var check bool
|
|
if extraCheck != nil {
|
|
var err *hws.HWSError
|
|
check, err = extraCheck(ctx, model.model, txTyped, w, r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return nil, &hws.HWSError{
|
|
Message: "Failed to commit transaction",
|
|
StatusCode: http.StatusInternalServerError,
|
|
Error: errors.Wrap(err, "tx.Commit"),
|
|
}
|
|
}
|
|
authContext := setAuthenticatedModel(r.Context(), model)
|
|
newReq := r.WithContext(authContext)
|
|
if extraCheck == nil || check {
|
|
return newReq, nil
|
|
}
|
|
return r, nil
|
|
}
|
|
}
|
|
|
|
func globTest(testPath string, globs []glob.Glob) bool {
|
|
for _, g := range globs {
|
|
if g.Match(testPath) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|