and another one

This commit is contained in:
2026-02-10 18:07:44 +11:00
parent 299c775aba
commit ac5e38d82b
10 changed files with 52 additions and 60 deletions

View File

@@ -75,10 +75,10 @@ func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *A
Order: bun.OrderDesc, Order: bun.OrderDesc,
OrderBy: "created_at", OrderBy: "created_at",
} }
return GetList[AuditLog](tx, pageOpts, defaultPageOpts). return GetList[AuditLog](tx).
Relation("User"). Relation("User").
Filter(filters.filters...). Filter(filters.filters...).
GetAll(ctx) GetPaged(ctx, pageOpts, defaultPageOpts)
} }
// GetAuditLogsByUser retrieves audit logs for a specific user // GetAuditLogsByUser retrieves audit logs for a specific user

View File

@@ -84,7 +84,7 @@ func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] {
func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) error { func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) error {
deleter := DeleteByID[T](tx, id) deleter := DeleteByID[T](tx, id)
item, err := GetByID[T](tx, id).GetFirst(ctx) item, err := GetByID[T](tx, id).Get(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "GetByID") return errors.Wrap(err, "GetByID")
} }

View File

@@ -68,7 +68,7 @@ func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToke
// GetDiscordToken retrieves the users discord token from the database // GetDiscordToken retrieves the users discord token from the database
func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
return GetByField[DiscordToken](tx, "discord_id", u.DiscordID).GetFirst(ctx) return GetByField[DiscordToken](tx, "discord_id", u.DiscordID).Get(ctx)
} }
// Convert reverts the token back into a *discord.Token // Convert reverts the token back into a *discord.Token

View File

@@ -31,15 +31,11 @@ func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
return g.model, nil return g.model, nil
} }
func (g *fieldgetter[T]) GetFirst(ctx context.Context) (*T, error) { func (g *fieldgetter[T]) Get(ctx context.Context) (*T, error) {
g.q = g.q.Limit(1) g.q = g.q.Limit(1)
return g.get(ctx) return g.get(ctx)
} }
func (g *fieldgetter[T]) GetAll(ctx context.Context) (*T, error) {
return g.get(ctx)
}
func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] { func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] {
g.q = g.q.Relation(name, apply...) g.q = g.q.Relation(name, apply...)
return g return g

View File

@@ -9,10 +9,8 @@ import (
) )
type listgetter[T any] struct { type listgetter[T any] struct {
q *bun.SelectQuery q *bun.SelectQuery
items *[]*T items *[]*T
pageOpts *PageOpts
defaults *PageOpts
} }
type List[T any] struct { type List[T any] struct {
@@ -38,17 +36,25 @@ func (f *ListFilter) Add(field string, value any) {
f.filters = append(f.filters, Filter{field, value}) f.filters = append(f.filters, Filter{field, value})
} }
func GetList[T any](tx bun.Tx, pageOpts, defaults *PageOpts) *listgetter[T] { func GetList[T any](tx bun.Tx) *listgetter[T] {
l := &listgetter[T]{ l := &listgetter[T]{
items: new([]*T), items: new([]*T),
pageOpts: pageOpts,
defaults: defaults,
} }
l.q = tx.NewSelect(). l.q = tx.NewSelect().
Model(l.items) Model(l.items)
return l return l
} }
func (l *listgetter[T]) Join(join string, args ...any) *listgetter[T] {
l.q = l.q.Join(join, args...)
return l
}
func (l *listgetter[T]) Where(query string, args ...any) *listgetter[T] {
l.q = l.q.Where(query, args...)
return l
}
func (l *listgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *listgetter[T] { func (l *listgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *listgetter[T] {
l.q = l.q.Relation(name, apply...) l.q = l.q.Relation(name, apply...)
return l return l
@@ -61,15 +67,15 @@ func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] {
return l return l
} }
func (l *listgetter[T]) GetAll(ctx context.Context) (*List[T], error) { func (l *listgetter[T]) GetPaged(ctx context.Context, pageOpts, defaults *PageOpts) (*List[T], error) {
if l.defaults == nil { if defaults == nil {
return nil, errors.New("default pageopts is nil") return nil, errors.New("default pageopts is nil")
} }
total, err := l.q.Count(ctx) total, err := l.q.Count(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "query.Count") return nil, errors.Wrap(err, "query.Count")
} }
l.q, l.pageOpts = setPageOpts(l.q, l.pageOpts, l.defaults, total) l.q, pageOpts = setPageOpts(l.q, pageOpts, defaults, total)
err = l.q.Scan(ctx) err = l.q.Scan(ctx)
if err != nil && errors.Is(err, sql.ErrNoRows) { if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, errors.Wrap(err, "query.Scan") return nil, errors.Wrap(err, "query.Scan")
@@ -77,7 +83,15 @@ func (l *listgetter[T]) GetAll(ctx context.Context) (*List[T], error) {
list := &List[T]{ list := &List[T]{
Items: *l.items, Items: *l.items,
Total: total, Total: total,
PageOpts: *l.pageOpts, PageOpts: *pageOpts,
} }
return list, nil return list, nil
} }
func (l *listgetter[T]) GetAll(ctx context.Context) ([]*T, error) {
err := l.q.Scan(ctx)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, errors.Wrap(err, "query.Scan")
}
return *l.items, nil
}

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"context" "context"
"database/sql"
"git.haelnorr.com/h/oslstats/internal/permissions" "git.haelnorr.com/h/oslstats/internal/permissions"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -34,7 +33,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis
if name == "" { if name == "" {
return nil, errors.New("name cannot be empty") return nil, errors.New("name cannot be empty")
} }
return GetByField[Permission](tx, "name", name).GetFirst(ctx) return GetByField[Permission](tx, "name", name).Get(ctx)
} }
// GetPermissionByID queries the database for a permission matching the given ID // GetPermissionByID queries the database for a permission matching the given ID
@@ -43,7 +42,7 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err
if id <= 0 { if id <= 0 {
return nil, errors.New("id must be positive") return nil, errors.New("id must be positive")
} }
return GetByID[Permission](tx, id).GetFirst(ctx) return GetByID[Permission](tx, id).Get(ctx)
} }
// GetPermissionsByResource queries for all permissions for a given resource // GetPermissionsByResource queries for all permissions for a given resource
@@ -51,21 +50,13 @@ func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) (
if resource == "" { if resource == "" {
return nil, errors.New("resource cannot be empty") return nil, errors.New("resource cannot be empty")
} }
perms, err := GetByField[[]*Permission](tx, "resource", resource).GetAll(ctx) return GetList[Permission](tx).
return *perms, err Where("resource = ?", resource).GetAll(ctx)
} }
// ListAllPermissions returns all permissions // ListAllPermissions returns all permissions
func ListAllPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) { func ListAllPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
var perms []*Permission return GetList[Permission](tx).GetAll(ctx)
err := tx.NewSelect().
Model(&perms).
Order("resource ASC", "action ASC").
Scan(ctx)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perms, nil
} }
// CreatePermission creates a new permission // CreatePermission creates a new permission

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"context" "context"
"database/sql"
"time" "time"
"git.haelnorr.com/h/oslstats/internal/roles" "git.haelnorr.com/h/oslstats/internal/roles"
@@ -43,31 +42,23 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro
if name == "" { if name == "" {
return nil, errors.New("name cannot be empty") return nil, errors.New("name cannot be empty")
} }
return GetByField[Role](tx, "name", name).GetFirst(ctx) return GetByField[Role](tx, "name", name).Get(ctx)
} }
// GetRoleByID queries the database for a role matching the given ID // GetRoleByID queries the database for a role matching the given ID
// Returns nil, nil if no role is found // Returns nil, nil if no role is found
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) { func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
return GetByID[Role](tx, id).GetFirst(ctx) return GetByID[Role](tx, id).Get(ctx)
} }
// GetRoleWithPermissions loads a role and all its permissions // GetRoleWithPermissions loads a role and all its permissions
func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) { func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
return GetByID[Role](tx, id).Relation("Permissions").GetFirst(ctx) return GetByID[Role](tx, id).Relation("Permissions").Get(ctx)
} }
// ListAllRoles returns all roles // ListAllRoles returns all roles
func ListAllRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) { func ListAllRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
var roles []*Role return GetList[Role](tx).GetAll(ctx)
err := tx.NewSelect().
Model(&roles).
Order("name ASC").
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return roles, nil
} }
// CreateRole creates a new role // CreateRole creates a new role

View File

@@ -38,14 +38,14 @@ func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Seas
bun.OrderDesc, bun.OrderDesc,
"start_date", "start_date",
} }
return GetList[Season](tx, pageOpts, defaults).GetAll(ctx) return GetList[Season](tx).GetPaged(ctx, pageOpts, defaults)
} }
func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error) { func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error) {
if shortname == "" { if shortname == "" {
return nil, errors.New("short_name not provided") return nil, errors.New("short_name not provided")
} }
return GetByField[Season](tx, "short_name", shortname).GetFirst(ctx) return GetByField[Season](tx, "short_name", shortname).Get(ctx)
} }
// Update updates the season struct. It does not insert to the database // Update updates the season struct. It does not insert to the database

View File

@@ -53,7 +53,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
// GetUserByID queries the database for a user matching the given ID // GetUserByID queries the database for a user matching the given ID
// Returns nil, nil if no user is found // Returns nil, nil if no user is found
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
return GetByID[User](tx, id).GetFirst(ctx) return GetByID[User](tx, id).Get(ctx)
} }
// GetUserByUsername queries the database for a user matching the given username // GetUserByUsername queries the database for a user matching the given username
@@ -62,7 +62,7 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User,
if username == "" { if username == "" {
return nil, errors.New("username not provided") return nil, errors.New("username not provided")
} }
return GetByField[User](tx, "username", username).GetFirst(ctx) return GetByField[User](tx, "username", username).Get(ctx)
} }
// GetUserByDiscordID queries the database for a user matching the given discord id // GetUserByDiscordID queries the database for a user matching the given discord id
@@ -71,7 +71,7 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User
if discordID == "" { if discordID == "" {
return nil, errors.New("discord_id not provided") return nil, errors.New("discord_id not provided")
} }
return GetByField[User](tx, "discord_id", discordID).GetFirst(ctx) return GetByField[User](tx, "discord_id", discordID).Get(ctx)
} }
// GetRoles loads all the roles for this user // GetRoles loads all the roles for this user
@@ -80,9 +80,9 @@ func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
return nil, errors.New("user cannot be nil") return nil, errors.New("user cannot be nil")
} }
u, err := GetByField[User](tx, "id", u.ID). u, err := GetByField[User](tx, "id", u.ID).
Relation("Roles").GetFirst(ctx) Relation("Roles").Get(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "GetByField")
} }
return u.Roles, nil return u.Roles, nil
} }
@@ -92,11 +92,11 @@ func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, er
if u == nil { if u == nil {
return nil, errors.New("user cannot be nil") return nil, errors.New("user cannot be nil")
} }
permissions, err := GetByField[[]*Permission](tx, "ur.user_id", u.ID). return GetList[Permission](tx).
Join("JOIN role_permissions AS rp on rp.permission_id = p.id"). Join("JOIN role_permissions AS rp on rp.permission_id = p.id").
Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id"). Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id").
Where("ur.user_id = ?", u.ID).
GetAll(ctx) GetAll(ctx)
return *permissions, err
} }
// HasPermission checks if user has a specific permission (including wildcard check) // HasPermission checks if user has a specific permission (including wildcard check)
@@ -139,5 +139,5 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) { func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) {
defaults := &PageOpts{1, 50, bun.OrderAsc, "id"} defaults := &PageOpts{1, 50, bun.OrderAsc, "id"}
return GetList[User](tx, pageOpts, defaults).GetAll(ctx) return GetList[User](tx).GetPaged(ctx, pageOpts, defaults)
} }

View File

@@ -51,7 +51,7 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
Where("role_id = ?", roleID). Where("role_id = ?", roleID).
Delete(ctx) Delete(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewDelete") return errors.Wrap(err, "DeleteItem")
} }
return nil return nil
@@ -66,7 +66,7 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b
return false, errors.New("roleName cannot be empty") return false, errors.New("roleName cannot be empty")
} }
user, err := GetByID[User](tx, userID). user, err := GetByID[User](tx, userID).
Relation("Roles").GetFirst(ctx) Relation("Roles").Get(ctx)
if err != nil { if err != nil {
return false, errors.Wrap(err, "GetByID") return false, errors.Wrap(err, "GetByID")
} }