Files
oslstats/internal/db/userrole.go
2026-02-09 22:14:38 +11:00

90 lines
1.9 KiB
Go

package db
import (
"context"
"database/sql"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type UserRole struct {
UserID int `bun:",pk"`
User *User `bun:"rel:belongs-to,join:user_id=id"`
RoleID int `bun:",pk"`
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
}
// AssignRole grants a role to a user
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
if userID <= 0 {
return errors.New("userID must be positive")
}
if roleID <= 0 {
return errors.New("roleID must be positive")
}
userRole := &UserRole{
UserID: userID,
RoleID: roleID,
}
err := Insert(tx, userRole).
On("CONFLICT (user_id, role_id) DO NOTHING").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.Insert")
}
return nil
}
// RevokeRole removes a role from a user
func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
if userID <= 0 {
return errors.New("userID must be positive")
}
if roleID <= 0 {
return errors.New("roleID must be positive")
}
_, err := tx.NewDelete().
Model((*UserRole)(nil)).
Where("user_id = ?", userID).
Where("role_id = ?", roleID).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewDelete")
}
return nil
}
// HasRole checks if a user has a specific role
func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (bool, error) {
if userID <= 0 {
return false, errors.New("userID must be positive")
}
if roleName == "" {
return false, errors.New("roleName cannot be empty")
}
user := new(User)
err := tx.NewSelect().
Model(user).
Relation("Roles").
Where("u.id = ?", userID).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, errors.Wrap(err, "tx.NewSelect")
}
for _, role := range user.Roles {
if role.Name == roleName {
return true, nil
}
}
return false, nil
}