Files
oslstats/internal/rbac/cache_middleware.go
2026-03-05 18:32:55 +11:00

96 lines
2.6 KiB
Go

package rbac
import (
"context"
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/contexts"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// LoadPermissionsMiddleware loads user permissions into context after authentication
// MUST run AFTER auth.Authenticate() middleware and LoadPreviewRoleMiddleware
func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := db.CurrentUser(r.Context())
// Build permission cache
cache := &contexts.PermissionCache{
Permissions: make(map[permissions.Permission]bool),
Roles: make(map[roles.Role]bool),
}
defer func() {
ctx := context.WithValue(r.Context(), contexts.PermissionCacheKey, cache)
next.ServeHTTP(w, r.WithContext(ctx))
}()
if user == nil {
return
}
// Check if we're in preview mode
previewRole := contexts.GetPreviewRole(r.Context())
var roles_ []*db.Role
var perms []*db.Permission
if err := c.conn.WithTxFailSilently(r.Context(), func(ctx context.Context, tx bun.Tx) error {
var err error
if previewRole != nil {
// In preview mode: use the preview role instead of user's roles
role, err := db.GetRoleByID(ctx, tx, previewRole.ID)
if err != nil {
return errors.Wrap(err, "db.GetRoleWithPermissions")
}
if role != nil {
roles_ = []*db.Role{role}
// Convert []Permission to []*Permission
perms = make([]*db.Permission, len(role.Permissions))
for i := range role.Permissions {
perms[i] = &role.Permissions[i]
}
}
} else {
// Normal mode: use user's actual roles and permissions
roles_, err = user.GetRoles(ctx, tx)
if err != nil {
return errors.Wrap(err, "user.GetRoles")
}
perms, err = user.GetPermissions(ctx, tx)
if err != nil {
return errors.Wrap(err, "user.GetPermissions")
}
}
return nil
}); err != nil {
c.s.LogError(hws.HWSError{
Message: "Database error",
Error: err,
Level: hws.ErrorERROR,
})
return
}
// Check for wildcard permission
hasWildcard := false
for _, perm := range perms {
cache.Permissions[perm.Name] = true
if perm.Name == permissions.Wildcard {
hasWildcard = true
}
}
cache.HasWildcard = hasWildcard
for _, role := range roles_ {
cache.Roles[role.Name] = true
}
})
}
}