more refactors

This commit is contained in:
2026-02-08 19:57:12 +11:00
parent e94a212a08
commit d038f7a42d
17 changed files with 334 additions and 308 deletions

View File

@@ -104,8 +104,8 @@ func (l *Logger) log(
} }
// GetRecentLogs retrieves recent audit logs with pagination // GetRecentLogs retrieves recent audit logs with pagination
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) { func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) {
var logs *db.AuditLogs var logs *db.List[db.AuditLog]
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
var err error var err error
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil) logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil)
@@ -120,8 +120,8 @@ func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.
} }
// GetLogsByUser retrieves audit logs for a specific user // GetLogsByUser retrieves audit logs for a specific user
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) { func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) {
var logs *db.AuditLogs var logs *db.List[db.AuditLog]
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
var err error var err error
logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts) logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -28,12 +27,6 @@ type AuditLog struct {
User *User `bun:"rel:belongs-to,join:user_id=id"` User *User `bun:"rel:belongs-to,join:user_id=id"`
} }
type AuditLogs struct {
AuditLogs []*AuditLog
Total int
PageOpts PageOpts
}
// CreateAuditLog creates a new audit log entry // CreateAuditLog creates a new audit log entry
func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error { func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
if log == nil { if log == nil {
@@ -50,80 +43,64 @@ func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
return nil return nil
} }
type AuditLogFilters struct { type AuditLogFilter struct {
UserID *int *ListFilter
Action *string }
ResourceType *string
Result *string func NewAuditLogFilter() *AuditLogFilter {
return &AuditLogFilter{NewListFilter()}
}
func (a *AuditLogFilter) UserID(id int) *AuditLogFilter {
a.Add("al.user_id", id)
return a
}
func (a *AuditLogFilter) Action(action string) *AuditLogFilter {
a.Add("al.action", action)
return a
}
func (a *AuditLogFilter) ResourceType(resourceType string) *AuditLogFilter {
a.Add("al.resource_type", resourceType)
return a
}
func (a *AuditLogFilter) Result(result string) *AuditLogFilter {
a.Add("al.result", result)
return a
} }
// GetAuditLogs retrieves audit logs with optional filters and pagination // GetAuditLogs retrieves audit logs with optional filters and pagination
func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) { func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilter) (*List[AuditLog], error) {
query := tx.NewSelect(). defaultPageOpts := &PageOpts{
Model((*AuditLog)(nil)). Page: 1,
Relation("User") PerPage: 50,
Order: bun.OrderDesc,
// Apply filters if provided OrderBy: "created_at",
if filters != nil {
if filters.UserID != nil {
query = query.Where("al.user_id = ?", *filters.UserID)
}
if filters.Action != nil {
query = query.Where("al.action = ?", *filters.Action)
}
if filters.ResourceType != nil {
query = query.Where("al.resource_type = ?", *filters.ResourceType)
}
if filters.Result != nil {
query = query.Where("al.result = ?", *filters.Result)
}
} }
return GetList[AuditLog](tx, pageOpts, defaultPageOpts).
// Get total count Relation("User").
total, err := query.Count(ctx) Filter(filters.filters...).
if err != nil { GetAll(ctx)
return nil, errors.Wrap(err, "query.Count")
}
// Get paginated results
query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderDesc, "created_at")
logs := new([]*AuditLog)
err = query.Scan(ctx, &logs)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "query.Scan")
}
list := &AuditLogs{
AuditLogs: *logs,
Total: total,
PageOpts: *pageOpts,
}
return list, nil
} }
// GetAuditLogsByUser retrieves audit logs for a specific user // GetAuditLogsByUser retrieves audit logs for a specific user
func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*AuditLogs, error) { func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*List[AuditLog], error) {
if userID <= 0 { if userID <= 0 {
return nil, errors.New("userID must be positive") return nil, errors.New("userID must be positive")
} }
filters := NewAuditLogFilter().UserID(userID)
filters := &AuditLogFilters{
UserID: &userID,
}
return GetAuditLogs(ctx, tx, pageOpts, filters) return GetAuditLogs(ctx, tx, pageOpts, filters)
} }
// GetAuditLogsByAction retrieves audit logs for a specific action // GetAuditLogsByAction retrieves audit logs for a specific action
func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*AuditLogs, error) { func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*List[AuditLog], error) {
if action == "" { if action == "" {
return nil, errors.New("action cannot be empty") return nil, errors.New("action cannot be empty")
} }
filters := NewAuditLogFilter().Action(action)
filters := &AuditLogFilters{
Action: &action,
}
return GetAuditLogs(ctx, tx, pageOpts, filters) return GetAuditLogs(ctx, tx, pageOpts, filters)
} }

58
internal/db/delete.go Normal file
View File

@@ -0,0 +1,58 @@
package db
import (
"context"
"database/sql"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type deleter[T any] struct {
q *bun.DeleteQuery
}
type systemType interface {
isSystem() bool
}
func DeleteItem[T any](tx bun.Tx) *deleter[T] {
return &deleter[T]{
tx.NewDelete().
Model((*T)(nil)),
}
}
func (d *deleter[T]) Where(query string, args ...any) *deleter[T] {
d.q = d.q.Where(query, args...)
return d
}
func (d *deleter[T]) Delete(ctx context.Context) error {
_, err := d.q.Exec(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil
}
}
return errors.Wrap(err, "bun.DeleteQuery.Exec")
}
func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] {
return DeleteItem[T](tx).Where("id = ?", id)
}
func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) error {
deleter := DeleteByID[T](tx, id)
item, err := GetByID[T](tx, id).GetFirst(ctx)
if err != nil {
return errors.Wrap(err, "GetByID")
}
if item == nil {
return errors.New("record not found")
}
if (*item).isSystem() {
return errors.New("record is system protected")
}
return deleter.Delete(ctx)
}

83
internal/db/getbyfield.go Normal file
View File

@@ -0,0 +1,83 @@
package db
import (
"context"
"database/sql"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type ListFilter struct {
filters []Filter
}
func NewListFilter() *ListFilter {
return &ListFilter{[]Filter{}}
}
func (f *ListFilter) Add(field string, value any) {
f.filters = append(f.filters, Filter{field, value})
}
type fieldgetter[T any] struct {
q *bun.SelectQuery
field string
value any
}
func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
if g.field == "id" && (g.value).(int) < 1 {
return nil, errors.New("invalid id")
}
model := new(T)
err := g.q.
Where("? = ?", bun.Ident(g.field), g.value).
Scan(ctx, model)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
}
return model, nil
}
func (g *fieldgetter[T]) GetFirst(ctx context.Context) (*T, error) {
g.q = g.q.Limit(1)
return g.get(ctx)
}
func (g *fieldgetter[T]) GetAll(ctx context.Context) (*T, error) {
return g.get(ctx)
}
func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] {
g.q = g.q.Relation(name, apply...)
return g
}
func (g *fieldgetter[T]) Join(join string, args ...any) *fieldgetter[T] {
g.q = g.q.Join(join, args...)
return g
}
// GetByField retrieves a single record by field name
func GetByField[T any](
tx bun.Tx,
field string,
value any,
) *fieldgetter[T] {
return &fieldgetter[T]{
tx.NewSelect().Model((*T)(nil)),
field,
value,
}
}
func GetByID[T any](
tx bun.Tx,
id int,
) *fieldgetter[T] {
return GetByField[T](tx, "id", id)
}

71
internal/db/getlist.go Normal file
View File

@@ -0,0 +1,71 @@
package db
import (
"context"
"database/sql"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type listgetter[T any] struct {
q *bun.SelectQuery
items *[]*T
pageOpts *PageOpts
defaults *PageOpts
}
type List[T any] struct {
Items []*T
Total int
PageOpts PageOpts
}
type Filter struct {
Field string
Value any
}
func GetList[T any](tx bun.Tx, pageOpts, defaults *PageOpts) *listgetter[T] {
l := &listgetter[T]{
items: new([]*T),
pageOpts: pageOpts,
defaults: defaults,
}
l.q = tx.NewSelect().
Model(l.items)
return l
}
func (l *listgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *listgetter[T] {
l.q = l.q.Relation(name, apply...)
return l
}
func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] {
for _, filter := range filters {
l.q = l.q.Where("? = ?", bun.Ident(filter.Field), filter.Value)
}
return l
}
func (l *listgetter[T]) GetAll(ctx context.Context) (*List[T], error) {
if l.defaults == nil {
return nil, errors.New("default pageopts is nil")
}
total, err := l.q.Count(ctx)
if err != nil {
return nil, errors.Wrap(err, "query.Count")
}
l.q, l.pageOpts = setPageOpts(l.q, l.pageOpts, l.defaults, total)
err = l.q.Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "query.Scan")
}
list := &List[T]{
Items: *l.items,
Total: total,
PageOpts: *l.pageOpts,
}
return list, nil
}

View File

@@ -19,21 +19,25 @@ type OrderOpts struct {
Label string Label string
} }
func setPageOpts(q *bun.SelectQuery, p *PageOpts, page, perpage int, order bun.Order, orderby string) (*bun.SelectQuery, *PageOpts) { func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.SelectQuery, *PageOpts) {
if p == nil { if p == nil {
p = new(PageOpts) p = new(PageOpts)
} }
if p.Page == 0 { if p.Page <= 0 {
p.Page = page p.Page = d.Page
} }
if p.PerPage == 0 { if p.PerPage == 0 {
p.PerPage = perpage p.PerPage = d.PerPage
}
maxpage := p.TotalPages(totalitems)
if p.Page > maxpage {
p.Page = maxpage
} }
if p.Order == "" { if p.Order == "" {
p.Order = order p.Order = d.Order
} }
if p.OrderBy == "" { if p.OrderBy == "" {
p.OrderBy = orderby p.OrderBy = d.OrderBy
} }
p.OrderBy = sanitiseOrderBy(p.OrderBy) p.OrderBy = sanitiseOrderBy(p.OrderBy)
q = q.OrderBy(p.OrderBy, p.Order). q = q.OrderBy(p.OrderBy, p.Order).

View File

@@ -24,6 +24,10 @@ type Permission struct {
Roles []Role `bun:"m2m:role_permissions,join:Permission=Role"` Roles []Role `bun:"m2m:role_permissions,join:Permission=Role"`
} }
func (p Permission) isSystem() bool {
return p.IsSystem
}
// GetPermissionByName queries the database for a permission matching the given name // GetPermissionByName queries the database for a permission matching the given name
// Returns nil, nil if no permission is found // Returns nil, nil if no permission is found
func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) { func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) {
@@ -150,26 +154,5 @@ func DeletePermission(ctx context.Context, tx bun.Tx, id int) error {
if id <= 0 { if id <= 0 {
return errors.New("id must be positive") return errors.New("id must be positive")
} }
return DeleteWithProtection[Permission](ctx, tx, id)
// Check if permission is system permission
perm, err := GetPermissionByID(ctx, tx, id)
if err != nil {
return errors.Wrap(err, "GetPermissionByID")
}
if perm == nil {
return errors.New("permission not found")
}
if perm.IsSystem {
return errors.New("cannot delete system permission")
}
_, err = tx.NewDelete().
Model((*Permission)(nil)).
Where("id = ?", id).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewDelete")
}
return nil
} }

View File

@@ -33,70 +33,28 @@ type RolePermission struct {
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"` Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
} }
func (r Role) isSystem() bool {
return r.IsSystem
}
// GetRoleByName queries the database for a role matching the given name // GetRoleByName queries the database for a role matching the given name
// Returns nil, nil if no role is found // Returns nil, nil if no role is found
func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) { func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) {
if name == "" { if name == "" {
return nil, errors.New("name cannot be empty") return nil, errors.New("name cannot be empty")
} }
return GetByField[Role](tx, "name", name).GetFirst(ctx)
role := new(Role)
err := tx.NewSelect().
Model(role).
Where("name = ?", name).
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return role, nil
} }
// GetRoleByID queries the database for a role matching the given ID // GetRoleByID queries the database for a role matching the given ID
// Returns nil, nil if no role is found // Returns nil, nil if no role is found
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) { func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
if id <= 0 { return GetByID[Role](tx, id).GetFirst(ctx)
return nil, errors.New("id must be positive")
}
role := new(Role)
err := tx.NewSelect().
Model(role).
Where("id = ?", id).
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return role, nil
} }
// GetRoleWithPermissions loads a role and all its permissions // GetRoleWithPermissions loads a role and all its permissions
func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) { func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
if id <= 0 { return GetByID[Role](tx, id).Relation("Permissions").GetFirst(ctx)
return nil, errors.New("id must be positive")
}
role := new(Role)
err := tx.NewSelect().
Model(role).
Where("id = ?", id).
Relation("Permissions").
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return role, nil
} }
// ListAllRoles returns all roles // ListAllRoles returns all roles
@@ -155,28 +113,7 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int) error {
if id <= 0 { if id <= 0 {
return errors.New("id must be positive") return errors.New("id must be positive")
} }
return DeleteWithProtection[Role](ctx, tx, id)
// Check if role is system role
role, err := GetRoleByID(ctx, tx, id)
if err != nil {
return errors.Wrap(err, "GetRoleByID")
}
if role == nil {
return errors.New("role not found")
}
if role.IsSystem {
return errors.New("cannot delete system role")
}
_, err = tx.NewDelete().
Model((*Role)(nil)).
Where("id = ?", id).
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewDelete")
}
return nil
} }
// AddPermissionToRole grants a permission to a role // AddPermissionToRole grants a permission to a role

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"context" "context"
"database/sql"
"strings" "strings"
"time" "time"
@@ -22,12 +21,6 @@ type Season struct {
FinalsEndDate bun.NullTime `bun:"finals_end_date"` FinalsEndDate bun.NullTime `bun:"finals_end_date"`
} }
type SeasonList struct {
Seasons []*Season
Total int
PageOpts PageOpts
}
func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start time.Time) (*Season, error) { func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start time.Time) (*Season, error) {
if name == "" { if name == "" {
return nil, errors.New("name cannot be empty") return nil, errors.New("name cannot be empty")
@@ -50,40 +43,19 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim
return season, nil return season, nil
} }
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) { func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) {
seasons := new([]*Season) defaults := &PageOpts{
query := tx.NewSelect(). 1,
Model(seasons) 10,
bun.OrderDesc,
total, err := query.Count(ctx) "start_date",
if err != nil {
return nil, errors.Wrap(err, "query.Count")
} }
query, pageOpts = setPageOpts(query, pageOpts, 1, 10, bun.OrderDesc, "start_date") return GetList[Season](tx, pageOpts, defaults).GetAll(ctx)
err = query.Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "query.Scan")
}
sl := &SeasonList{
Seasons: *seasons,
Total: total,
PageOpts: *pageOpts,
}
return sl, nil
} }
func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error) { func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error) {
season := new(Season) if shortname == "" {
err := tx.NewSelect(). return nil, errors.New("short_name not provided")
Model(season).
Where("short_name = ?", strings.ToUpper(shortname)).
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
} }
return season, nil return GetByField[Season](tx, "short_name", shortname).GetFirst(ctx)
} }

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"context" "context"
"database/sql"
"time" "time"
"git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/golib/hwsauth"
@@ -26,12 +25,6 @@ type User struct {
Roles []*Role `bun:"m2m:user_roles,join:User=Role"` Roles []*Role `bun:"m2m:user_roles,join:User=Role"`
} }
type Users struct {
Users []*User
Total int
PageOpts PageOpts
}
func (u *User) GetID() int { func (u *User) GetID() int {
return u.ID return u.ID
} }
@@ -61,55 +54,25 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
// GetUserByID queries the database for a user matching the given ID // GetUserByID queries the database for a user matching the given ID
// Returns nil, nil if no user is found // Returns nil, nil if no user is found
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
user := new(User) return GetByID[User](tx, id).GetFirst(ctx)
err := tx.NewSelect().
Model(user).
Where("id = ?", id).
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return user, nil
} }
// GetUserByUsername queries the database for a user matching the given username // GetUserByUsername queries the database for a user matching the given username
// Returns nil, nil if no user is found // Returns nil, nil if no user is found
func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) { func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) {
user := new(User) if username == "" {
err := tx.NewSelect(). return nil, errors.New("username not provided")
Model(user).
Where("username = ?", username).
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
} }
return user, nil return GetByField[User](tx, "username", username).GetFirst(ctx)
} }
// GetUserByDiscordID queries the database for a user matching the given discord id // GetUserByDiscordID queries the database for a user matching the given discord id
// Returns nil, nil if no user is found // Returns nil, nil if no user is found
func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) { func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) {
user := new(User) if discordID == "" {
err := tx.NewSelect(). return nil, errors.New("discord_id not provided")
Model(user).
Where("discord_id = ?", discordID).
Limit(1).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
} }
return user, nil return GetByField[User](tx, "discord_id", discordID).GetFirst(ctx)
} }
// GetRoles loads all the roles for this user // GetRoles loads all the roles for this user
@@ -117,12 +80,8 @@ func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
if u == nil { if u == nil {
return nil, errors.New("user cannot be nil") return nil, errors.New("user cannot be nil")
} }
u, err := GetByField[User](tx, "id", u.ID).
err := tx.NewSelect(). Relation("Roles").GetFirst(ctx)
Model(u).
Relation("Roles").
Where("id = ?", u.ID).
Scan(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
@@ -134,18 +93,11 @@ func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, er
if u == nil { if u == nil {
return nil, errors.New("user cannot be nil") return nil, errors.New("user cannot be nil")
} }
permissions, err := GetByField[[]*Permission](tx, "ur.user_id", u.ID).
var permissions []*Permission
err := tx.NewSelect().
Model(&permissions).
Join("JOIN role_permissions AS rp on rp.permission_id = p.id"). Join("JOIN role_permissions AS rp on rp.permission_id = p.id").
Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id"). Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id").
Where("ur.user_id = ?", u.ID). GetAll(ctx)
Scan(ctx) return *permissions, err
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return permissions, nil
} }
// HasPermission checks if user has a specific permission (including wildcard check) // HasPermission checks if user has a specific permission (including wildcard check)
@@ -186,23 +138,7 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
return u.HasRole(ctx, tx, "admin") return u.HasRole(ctx, tx, "admin")
} }
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) { func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) {
users := new([]*User) defaults := &PageOpts{1, 50, bun.OrderAsc, "id"}
query := tx.NewSelect(). return GetList[User](tx, pageOpts, defaults).GetAll(ctx)
Model(users)
total, err := query.Count(ctx)
if err != nil {
return nil, errors.Wrap(err, "query.Count")
}
query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderAsc, "id")
err = query.Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "query.Scan")
}
list := &Users{
Users: *users,
Total: total,
PageOpts: *pageOpts,
}
return list, nil
} }

View File

@@ -13,7 +13,7 @@ import (
func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler { func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var users *db.Users var users *db.List[db.User]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
users, err = db.GetUsers(ctx, tx, nil) users, err = db.GetUsers(ctx, tx, nil)

View File

@@ -14,7 +14,7 @@ import (
// AdminUsersList shows all users // AdminUsersList shows all users
func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler { func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var users *db.Users var users *db.List[db.User]
pageOpts := pageOptsFromForm(s, w, r) pageOpts := pageOptsFromForm(s, w, r)
if pageOpts == nil { if pageOpts == nil {
return return

View File

@@ -20,7 +20,7 @@ func SeasonsPage(
if pageOpts == nil { if pageOpts == nil {
return return
} }
var seasons *db.SeasonList var seasons *db.List[db.Season]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
seasons, err = db.ListSeasons(ctx, tx, pageOpts) seasons, err = db.ListSeasons(ctx, tx, pageOpts)
@@ -44,7 +44,7 @@ func SeasonsList(
if pageOpts == nil { if pageOpts == nil {
return return
} }
var seasons *db.SeasonList var seasons *db.List[db.Season]
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
seasons, err = db.ListSeasons(ctx, tx, pageOpts) seasons, err = db.ListSeasons(ctx, tx, pageOpts)

View File

@@ -45,17 +45,17 @@ func (c *Checker) UserHasPermission(ctx context.Context, user *db.User, permissi
} }
// Fallback to database // Fallback to database
tx, err := c.conn.BeginTx(ctx, nil) var has bool
if err != nil { if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error {
return false, errors.Wrap(err, "conn.BeginTx") var err error
} has, err = user.HasPermission(ctx, tx, permission)
defer func() { _ = tx.Rollback() }() if err != nil {
return errors.Wrap(err, "user.HasPermission")
has, err := user.HasPermission(ctx, tx, permission) }
if err != nil { return nil
}); err != nil {
return false, err return false, err
} }
return has, nil return has, nil
} }
@@ -73,13 +73,18 @@ func (c *Checker) UserHasRole(ctx context.Context, user *db.User, role roles.Rol
} }
// Fallback to database // Fallback to database
tx, err := c.conn.BeginTx(ctx, nil) var has bool
if err != nil { if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error {
return false, errors.Wrap(err, "conn.BeginTx") var err error
has, err = user.HasRole(ctx, tx, role)
if err != nil {
return errors.Wrap(err, "user.HasPermission")
}
return nil
}); err != nil {
return false, err
} }
defer func() { _ = tx.Rollback() }() return has, nil
return user.HasRole(ctx, tx, role)
} }
// UserHasAnyPermission checks if user has ANY of the given permissions // UserHasAnyPermission checks if user has ANY of the given permissions

View File

@@ -2,5 +2,5 @@ package admin
import "git.haelnorr.com/h/oslstats/internal/db" import "git.haelnorr.com/h/oslstats/internal/db"
templ UserList(users *db.Users) { templ UserList(users *db.List[db.User]) {
} }

View File

@@ -4,7 +4,7 @@ import "git.haelnorr.com/h/oslstats/internal/view/layout"
import "git.haelnorr.com/h/oslstats/internal/view/component/admin" import "git.haelnorr.com/h/oslstats/internal/view/component/admin"
import "git.haelnorr.com/h/oslstats/internal/db" import "git.haelnorr.com/h/oslstats/internal/db"
templ AdminDashboard(users *db.Users) { templ AdminDashboard(users *db.List[db.User]) {
@layout.AdminDashboard() @layout.AdminDashboard()
@admin.UserList(users) @admin.UserList(users)
} }

View File

@@ -9,7 +9,7 @@ import "fmt"
import "time" import "time"
import "github.com/uptrace/bun" import "github.com/uptrace/bun"
templ SeasonsPage(seasons *db.SeasonList) { templ SeasonsPage(seasons *db.List[db.Season]) {
@layout.Global("Seasons") { @layout.Global("Seasons") {
<div class="max-w-screen-2xl mx-auto px-2"> <div class="max-w-screen-2xl mx-auto px-2">
@SeasonsList(seasons) @SeasonsList(seasons)
@@ -17,7 +17,7 @@ templ SeasonsPage(seasons *db.SeasonList) {
} }
} }
templ SeasonsList(seasons *db.SeasonList) { templ SeasonsList(seasons *db.List[db.Season]) {
{{ {{
sortOpts := []db.OrderOpts{ sortOpts := []db.OrderOpts{
{ {
@@ -69,14 +69,14 @@ templ SeasonsList(seasons *db.SeasonList) {
@sort.Dropdown(seasons.PageOpts, sortOpts) @sort.Dropdown(seasons.PageOpts, sortOpts)
</div> </div>
<!-- Results section --> <!-- Results section -->
if len(seasons.Seasons) == 0 { if len(seasons.Items) == 0 {
<div class="bg-mantle border border-surface1 rounded-lg p-8 text-center"> <div class="bg-mantle border border-surface1 rounded-lg p-8 text-center">
<p class="text-subtext0 text-lg">No seasons found</p> <p class="text-subtext0 text-lg">No seasons found</p>
</div> </div>
} else { } else {
<!-- Card grid --> <!-- Card grid -->
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4"> <div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
for _, s := range seasons.Seasons { for _, s := range seasons.Items {
<a <a
class="bg-mantle border border-surface1 rounded-lg p-6 hover:bg-surface0 transition-colors" class="bg-mantle border border-surface1 rounded-lg p-6 hover:bg-surface0 transition-colors"
href={ fmt.Sprintf("/seasons/%s", s.ShortName) } href={ fmt.Sprintf("/seasons/%s", s.ShortName) }