91 lines
1.9 KiB
Go
91 lines
1.9 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
|
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
|
"github.com/pkg/errors"
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
type UserRole struct {
|
|
UserID int `bun:",pk"`
|
|
User *User `bun:"rel:belongs-to,join:user_id=id"`
|
|
RoleID int `bun:",pk"`
|
|
Role *Role `bun:"rel:belongs-to,join:role_id=id"`
|
|
}
|
|
|
|
// AssignRole grants a role to a user
|
|
func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
|
|
if userID <= 0 {
|
|
return errors.New("userID must be positive")
|
|
}
|
|
if roleID <= 0 {
|
|
return errors.New("roleID must be positive")
|
|
}
|
|
|
|
userRole := &UserRole{
|
|
UserID: userID,
|
|
RoleID: roleID,
|
|
}
|
|
_, err := tx.NewInsert().
|
|
Model(userRole).
|
|
On("CONFLICT (user_id, role_id) DO NOTHING").
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "tx.NewInsert")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RevokeRole removes a role from a user
|
|
func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
|
|
if userID <= 0 {
|
|
return errors.New("userID must be positive")
|
|
}
|
|
if roleID <= 0 {
|
|
return errors.New("roleID must be positive")
|
|
}
|
|
|
|
_, err := tx.NewDelete().
|
|
Model((*UserRole)(nil)).
|
|
Where("user_id = ?", userID).
|
|
Where("role_id = ?", roleID).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errors.Wrap(err, "tx.NewDelete")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// HasRole checks if a user has a specific role
|
|
func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (bool, error) {
|
|
if userID <= 0 {
|
|
return false, errors.New("userID must be positive")
|
|
}
|
|
if roleName == "" {
|
|
return false, errors.New("roleName cannot be empty")
|
|
}
|
|
user := new(User)
|
|
err := tx.NewSelect().
|
|
Model(user).
|
|
Relation("Roles").
|
|
Where("u.id = ?", userID).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
return false, errors.Wrap(err, "tx.NewSelect")
|
|
}
|
|
for _, role := range user.Roles {
|
|
if role.Name == roleName {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|