Files
oslstats/internal/db/user.go

209 lines
5.2 KiB
Go

package db
import (
"context"
"database/sql"
"time"
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/bwmarrin/discordgo"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
var CurrentUser hwsauth.ContextLoader[*User]
type User struct {
bun.BaseModel `bun:"table:users,alias:u"`
ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key)
Username string `bun:"username,unique"` // Username (unique)
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
DiscordID string `bun:"discord_id,unique"`
Roles []*Role `bun:"m2m:user_roles,join:User=Role"`
}
type Users struct {
Users []*User
Total int
PageOpts PageOpts
}
func (u *User) GetID() int {
return u.ID
}
// CreateUser creates a new user with the given username and password
func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) {
if discorduser == nil {
return nil, errors.New("user cannot be nil")
}
user := &User{
Username: username,
CreatedAt: time.Now().Unix(),
DiscordID: discorduser.ID,
}
_, err := tx.NewInsert().
Model(user).
Returning("id").
Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewInsert")
}
return user, nil
}
// GetUserByID queries the database for a user matching the given ID
// Returns nil, nil if no user is found
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
user := new(User)
err := tx.NewSelect().
Model(user).
Where("id = ?", id).
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return user, nil
}
// GetUserByUsername queries the database for a user matching the given username
// Returns nil, nil if no user is found
func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) {
user := new(User)
err := tx.NewSelect().
Model(user).
Where("username = ?", username).
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return user, nil
}
// GetUserByDiscordID queries the database for a user matching the given discord id
// Returns nil, nil if no user is found
func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) {
user := new(User)
err := tx.NewSelect().
Model(user).
Where("discord_id = ?", discordID).
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return user, nil
}
// GetRoles loads all the roles for this user
func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
if u == nil {
return nil, errors.New("user cannot be nil")
}
err := tx.NewSelect().
Model(u).
Relation("Roles").
Where("id = ?", u.ID).
Scan(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return u.Roles, nil
}
// GetPermissions loads and returns all permissions for this user
func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
if u == nil {
return nil, errors.New("user cannot be nil")
}
var permissions []*Permission
err := tx.NewSelect().
Model(&permissions).
Join("JOIN role_permissions AS rp on rp.permission_id = p.id").
Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id").
Where("ur.user_id = ?", u.ID).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return permissions, nil
}
// HasPermission checks if user has a specific permission (including wildcard check)
func (u *User) HasPermission(ctx context.Context, tx bun.Tx, permissionName permissions.Permission) (bool, error) {
if u == nil {
return false, errors.New("user cannot be nil")
}
if permissionName == "" {
return false, errors.New("permissionName cannot be empty")
}
perms, err := u.GetPermissions(ctx, tx)
if err != nil {
return false, err
}
for _, p := range perms {
if p.Name == permissionName || p.Name == permissions.Wildcard {
return true, nil
}
}
return false, nil
}
// HasRole checks if user has a specific role
func (u *User) HasRole(ctx context.Context, tx bun.Tx, roleName roles.Role) (bool, error) {
if u == nil {
return false, errors.New("user cannot be nil")
}
return HasRole(ctx, tx, u.ID, roleName)
}
// IsAdmin is a convenience method to check if user has admin role
func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
if u == nil {
return false, errors.New("user cannot be nil")
}
return u.HasRole(ctx, tx, "admin")
}
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) {
users := new([]*User)
query := tx.NewSelect().
Model(users)
total, err := query.Count(ctx)
if err != nil {
return nil, errors.Wrap(err, "query.Count")
}
query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderAsc, "id")
err = query.Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "query.Scan")
}
list := &Users{
Users: *users,
Total: total,
PageOpts: *pageOpts,
}
return list, nil
}