100 lines
2.6 KiB
Go
100 lines
2.6 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"slices"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
type RolePermission struct {
|
|
RoleID int `bun:",pk"`
|
|
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
|
|
PermissionID int `bun:",pk"`
|
|
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
|
|
}
|
|
|
|
func (r *Role) UpdatePermissions(ctx context.Context, tx bun.Tx, newPermissionsIDs []int, audit *AuditMeta) error {
|
|
addPerms, removePerms, err := detectChangedPermissions(ctx, tx, r, newPermissionsIDs)
|
|
if err != nil {
|
|
return errors.Wrap(err, "detectChangedPermissions")
|
|
}
|
|
addedPerms := []string{}
|
|
removedPerms := []string{}
|
|
for _, perm := range addPerms {
|
|
rolePerm := &RolePermission{
|
|
RoleID: r.ID,
|
|
PermissionID: perm.ID,
|
|
}
|
|
err := Insert(tx, rolePerm).
|
|
ConflictNothing("role_id", "permission_id").
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "db.Insert")
|
|
}
|
|
addedPerms = append(addedPerms, perm.Name.String())
|
|
}
|
|
for _, perm := range removePerms {
|
|
err := DeleteItem[RolePermission](tx).
|
|
Where("role_id = ?", r.ID).
|
|
Where("permission_id = ?", perm.ID).
|
|
Delete(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "DeleteItem")
|
|
}
|
|
removedPerms = append(removedPerms, perm.Name.String())
|
|
}
|
|
// Log the permission changes
|
|
if len(addedPerms) > 0 || len(removedPerms) > 0 {
|
|
details := map[string]any{
|
|
"role_name": string(r.Name),
|
|
}
|
|
if len(addedPerms) > 0 {
|
|
details["added_permissions"] = addedPerms
|
|
}
|
|
if len(removedPerms) > 0 {
|
|
details["removed_permissions"] = removedPerms
|
|
}
|
|
info := &AuditInfo{
|
|
"roles.update_permissions",
|
|
"role",
|
|
r.ID,
|
|
details,
|
|
}
|
|
err = LogSuccess(ctx, tx, audit, info)
|
|
if err != nil {
|
|
return errors.Wrap(err, "LogSuccess")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func detectChangedPermissions(ctx context.Context, tx bun.Tx, role *Role, permissionIDs []int) ([]*Permission, []*Permission, error) {
|
|
allPermissions, err := ListAllPermissions(ctx, tx)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "ListAllPermissions")
|
|
}
|
|
// Build map of current permissions
|
|
currentPermIDs := make(map[int]bool)
|
|
for _, perm := range role.Permissions {
|
|
currentPermIDs[perm.ID] = true
|
|
}
|
|
|
|
var addedPerms []*Permission
|
|
var removedPerms []*Permission
|
|
|
|
// Determine what to add and remove
|
|
for _, perm := range allPermissions {
|
|
hasNow := currentPermIDs[perm.ID]
|
|
shouldHave := slices.Contains(permissionIDs, perm.ID)
|
|
|
|
if shouldHave && !hasNow {
|
|
addedPerms = append(addedPerms, perm)
|
|
} else if !shouldHave && hasNow {
|
|
removedPerms = append(removedPerms, perm)
|
|
}
|
|
}
|
|
return addedPerms, removedPerms, nil
|
|
}
|