diff --git a/internal/auditlog/logger.go b/internal/auditlog/logger.go index 044a194..61d5e98 100644 --- a/internal/auditlog/logger.go +++ b/internal/auditlog/logger.go @@ -104,8 +104,8 @@ func (l *Logger) log( } // GetRecentLogs retrieves recent audit logs with pagination -func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) { - var logs *db.AuditLogs +func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) { + var logs *db.List[db.AuditLog] if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { var err error logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil) @@ -120,8 +120,8 @@ func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db. } // GetLogsByUser retrieves audit logs for a specific user -func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) { - var logs *db.AuditLogs +func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) { + var logs *db.List[db.AuditLog] if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { var err error logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts) diff --git a/internal/db/auditlog.go b/internal/db/auditlog.go index 16b9bc3..5717143 100644 --- a/internal/db/auditlog.go +++ b/internal/db/auditlog.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "encoding/json" "github.com/pkg/errors" @@ -28,12 +27,6 @@ type AuditLog struct { User *User `bun:"rel:belongs-to,join:user_id=id"` } -type AuditLogs struct { - AuditLogs []*AuditLog - Total int - PageOpts PageOpts -} - // CreateAuditLog creates a new audit log entry func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error { if log == nil { @@ -50,80 +43,64 @@ func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error { return nil } -type AuditLogFilters struct { - UserID *int - Action *string - ResourceType *string - Result *string +type AuditLogFilter struct { + *ListFilter +} + +func NewAuditLogFilter() *AuditLogFilter { + return &AuditLogFilter{NewListFilter()} +} + +func (a *AuditLogFilter) UserID(id int) *AuditLogFilter { + a.Add("al.user_id", id) + return a +} + +func (a *AuditLogFilter) Action(action string) *AuditLogFilter { + a.Add("al.action", action) + return a +} + +func (a *AuditLogFilter) ResourceType(resourceType string) *AuditLogFilter { + a.Add("al.resource_type", resourceType) + return a +} + +func (a *AuditLogFilter) Result(result string) *AuditLogFilter { + a.Add("al.result", result) + return a } // GetAuditLogs retrieves audit logs with optional filters and pagination -func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) { - query := tx.NewSelect(). - Model((*AuditLog)(nil)). - Relation("User") - - // Apply filters if provided - if filters != nil { - if filters.UserID != nil { - query = query.Where("al.user_id = ?", *filters.UserID) - } - if filters.Action != nil { - query = query.Where("al.action = ?", *filters.Action) - } - if filters.ResourceType != nil { - query = query.Where("al.resource_type = ?", *filters.ResourceType) - } - if filters.Result != nil { - query = query.Where("al.result = ?", *filters.Result) - } +func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilter) (*List[AuditLog], error) { + defaultPageOpts := &PageOpts{ + Page: 1, + PerPage: 50, + Order: bun.OrderDesc, + OrderBy: "created_at", } - - // Get total count - total, err := query.Count(ctx) - if err != nil { - return nil, errors.Wrap(err, "query.Count") - } - - // Get paginated results - query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderDesc, "created_at") - logs := new([]*AuditLog) - err = query.Scan(ctx, &logs) - if err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "query.Scan") - } - - list := &AuditLogs{ - AuditLogs: *logs, - Total: total, - PageOpts: *pageOpts, - } - - return list, nil + return GetList[AuditLog](tx, pageOpts, defaultPageOpts). + Relation("User"). + Filter(filters.filters...). + GetAll(ctx) } // GetAuditLogsByUser retrieves audit logs for a specific user -func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*AuditLogs, error) { +func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*List[AuditLog], error) { if userID <= 0 { return nil, errors.New("userID must be positive") } - - filters := &AuditLogFilters{ - UserID: &userID, - } + filters := NewAuditLogFilter().UserID(userID) return GetAuditLogs(ctx, tx, pageOpts, filters) } // GetAuditLogsByAction retrieves audit logs for a specific action -func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*AuditLogs, error) { +func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*List[AuditLog], error) { if action == "" { return nil, errors.New("action cannot be empty") } - - filters := &AuditLogFilters{ - Action: &action, - } + filters := NewAuditLogFilter().Action(action) return GetAuditLogs(ctx, tx, pageOpts, filters) } diff --git a/internal/db/delete.go b/internal/db/delete.go new file mode 100644 index 0000000..2b6189d --- /dev/null +++ b/internal/db/delete.go @@ -0,0 +1,58 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type deleter[T any] struct { + q *bun.DeleteQuery +} + +type systemType interface { + isSystem() bool +} + +func DeleteItem[T any](tx bun.Tx) *deleter[T] { + return &deleter[T]{ + tx.NewDelete(). + Model((*T)(nil)), + } +} + +func (d *deleter[T]) Where(query string, args ...any) *deleter[T] { + d.q = d.q.Where(query, args...) + return d +} + +func (d *deleter[T]) Delete(ctx context.Context) error { + _, err := d.q.Exec(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + } + return errors.Wrap(err, "bun.DeleteQuery.Exec") +} + +func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] { + return DeleteItem[T](tx).Where("id = ?", id) +} + +func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) error { + deleter := DeleteByID[T](tx, id) + item, err := GetByID[T](tx, id).GetFirst(ctx) + if err != nil { + return errors.Wrap(err, "GetByID") + } + if item == nil { + return errors.New("record not found") + } + if (*item).isSystem() { + return errors.New("record is system protected") + } + return deleter.Delete(ctx) +} diff --git a/internal/db/getbyfield.go b/internal/db/getbyfield.go new file mode 100644 index 0000000..39dd408 --- /dev/null +++ b/internal/db/getbyfield.go @@ -0,0 +1,83 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type ListFilter struct { + filters []Filter +} + +func NewListFilter() *ListFilter { + return &ListFilter{[]Filter{}} +} + +func (f *ListFilter) Add(field string, value any) { + f.filters = append(f.filters, Filter{field, value}) +} + +type fieldgetter[T any] struct { + q *bun.SelectQuery + field string + value any +} + +func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) { + if g.field == "id" && (g.value).(int) < 1 { + return nil, errors.New("invalid id") + } + model := new(T) + err := g.q. + Where("? = ?", bun.Ident(g.field), g.value). + Scan(ctx, model) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, errors.Wrap(err, "bun.SelectQuery.Scan") + } + return model, nil +} + +func (g *fieldgetter[T]) GetFirst(ctx context.Context) (*T, error) { + g.q = g.q.Limit(1) + return g.get(ctx) +} + +func (g *fieldgetter[T]) GetAll(ctx context.Context) (*T, error) { + return g.get(ctx) +} + +func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] { + g.q = g.q.Relation(name, apply...) + return g +} + +func (g *fieldgetter[T]) Join(join string, args ...any) *fieldgetter[T] { + g.q = g.q.Join(join, args...) + return g +} + +// GetByField retrieves a single record by field name +func GetByField[T any]( + tx bun.Tx, + field string, + value any, +) *fieldgetter[T] { + return &fieldgetter[T]{ + tx.NewSelect().Model((*T)(nil)), + field, + value, + } +} + +func GetByID[T any]( + tx bun.Tx, + id int, +) *fieldgetter[T] { + return GetByField[T](tx, "id", id) +} diff --git a/internal/db/getlist.go b/internal/db/getlist.go new file mode 100644 index 0000000..434db0b --- /dev/null +++ b/internal/db/getlist.go @@ -0,0 +1,71 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type listgetter[T any] struct { + q *bun.SelectQuery + items *[]*T + pageOpts *PageOpts + defaults *PageOpts +} + +type List[T any] struct { + Items []*T + Total int + PageOpts PageOpts +} + +type Filter struct { + Field string + Value any +} + +func GetList[T any](tx bun.Tx, pageOpts, defaults *PageOpts) *listgetter[T] { + l := &listgetter[T]{ + items: new([]*T), + pageOpts: pageOpts, + defaults: defaults, + } + l.q = tx.NewSelect(). + Model(l.items) + return l +} + +func (l *listgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *listgetter[T] { + l.q = l.q.Relation(name, apply...) + return l +} + +func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] { + for _, filter := range filters { + l.q = l.q.Where("? = ?", bun.Ident(filter.Field), filter.Value) + } + return l +} + +func (l *listgetter[T]) GetAll(ctx context.Context) (*List[T], error) { + if l.defaults == nil { + return nil, errors.New("default pageopts is nil") + } + total, err := l.q.Count(ctx) + if err != nil { + return nil, errors.Wrap(err, "query.Count") + } + l.q, l.pageOpts = setPageOpts(l.q, l.pageOpts, l.defaults, total) + err = l.q.Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "query.Scan") + } + list := &List[T]{ + Items: *l.items, + Total: total, + PageOpts: *l.pageOpts, + } + return list, nil +} diff --git a/internal/db/paginate.go b/internal/db/paginate.go index 706065e..fbd422a 100644 --- a/internal/db/paginate.go +++ b/internal/db/paginate.go @@ -19,21 +19,25 @@ type OrderOpts struct { Label string } -func setPageOpts(q *bun.SelectQuery, p *PageOpts, page, perpage int, order bun.Order, orderby string) (*bun.SelectQuery, *PageOpts) { +func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.SelectQuery, *PageOpts) { if p == nil { p = new(PageOpts) } - if p.Page == 0 { - p.Page = page + if p.Page <= 0 { + p.Page = d.Page } if p.PerPage == 0 { - p.PerPage = perpage + p.PerPage = d.PerPage + } + maxpage := p.TotalPages(totalitems) + if p.Page > maxpage { + p.Page = maxpage } if p.Order == "" { - p.Order = order + p.Order = d.Order } if p.OrderBy == "" { - p.OrderBy = orderby + p.OrderBy = d.OrderBy } p.OrderBy = sanitiseOrderBy(p.OrderBy) q = q.OrderBy(p.OrderBy, p.Order). diff --git a/internal/db/permission.go b/internal/db/permission.go index e36f05d..a95904e 100644 --- a/internal/db/permission.go +++ b/internal/db/permission.go @@ -24,6 +24,10 @@ type Permission struct { Roles []Role `bun:"m2m:role_permissions,join:Permission=Role"` } +func (p Permission) isSystem() bool { + return p.IsSystem +} + // GetPermissionByName queries the database for a permission matching the given name // Returns nil, nil if no permission is found func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) { @@ -150,26 +154,5 @@ func DeletePermission(ctx context.Context, tx bun.Tx, id int) error { if id <= 0 { return errors.New("id must be positive") } - - // Check if permission is system permission - perm, err := GetPermissionByID(ctx, tx, id) - if err != nil { - return errors.Wrap(err, "GetPermissionByID") - } - if perm == nil { - return errors.New("permission not found") - } - if perm.IsSystem { - return errors.New("cannot delete system permission") - } - - _, err = tx.NewDelete(). - Model((*Permission)(nil)). - Where("id = ?", id). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "tx.NewDelete") - } - - return nil + return DeleteWithProtection[Permission](ctx, tx, id) } diff --git a/internal/db/role.go b/internal/db/role.go index 181e573..22473b2 100644 --- a/internal/db/role.go +++ b/internal/db/role.go @@ -33,70 +33,28 @@ type RolePermission struct { Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"` } +func (r Role) isSystem() bool { + return r.IsSystem +} + // GetRoleByName queries the database for a role matching the given name // Returns nil, nil if no role is found func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) { if name == "" { return nil, errors.New("name cannot be empty") } - - role := new(Role) - err := tx.NewSelect(). - Model(role). - Where("name = ?", name). - Limit(1). - Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") - } - return role, nil + return GetByField[Role](tx, "name", name).GetFirst(ctx) } // GetRoleByID queries the database for a role matching the given ID // Returns nil, nil if no role is found func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) { - if id <= 0 { - return nil, errors.New("id must be positive") - } - - role := new(Role) - err := tx.NewSelect(). - Model(role). - Where("id = ?", id). - Limit(1). - Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") - } - return role, nil + return GetByID[Role](tx, id).GetFirst(ctx) } // GetRoleWithPermissions loads a role and all its permissions func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) { - if id <= 0 { - return nil, errors.New("id must be positive") - } - - role := new(Role) - err := tx.NewSelect(). - Model(role). - Where("id = ?", id). - Relation("Permissions"). - Limit(1). - Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") - } - return role, nil + return GetByID[Role](tx, id).Relation("Permissions").GetFirst(ctx) } // ListAllRoles returns all roles @@ -155,28 +113,7 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int) error { if id <= 0 { return errors.New("id must be positive") } - - // Check if role is system role - role, err := GetRoleByID(ctx, tx, id) - if err != nil { - return errors.Wrap(err, "GetRoleByID") - } - if role == nil { - return errors.New("role not found") - } - if role.IsSystem { - return errors.New("cannot delete system role") - } - - _, err = tx.NewDelete(). - Model((*Role)(nil)). - Where("id = ?", id). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "tx.NewDelete") - } - - return nil + return DeleteWithProtection[Role](ctx, tx, id) } // AddPermissionToRole grants a permission to a role diff --git a/internal/db/season.go b/internal/db/season.go index 4a78571..32b6dc8 100644 --- a/internal/db/season.go +++ b/internal/db/season.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "strings" "time" @@ -22,12 +21,6 @@ type Season struct { FinalsEndDate bun.NullTime `bun:"finals_end_date"` } -type SeasonList struct { - Seasons []*Season - Total int - PageOpts PageOpts -} - func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start time.Time) (*Season, error) { if name == "" { return nil, errors.New("name cannot be empty") @@ -50,40 +43,19 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim return season, nil } -func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) { - seasons := new([]*Season) - query := tx.NewSelect(). - Model(seasons) - - total, err := query.Count(ctx) - if err != nil { - return nil, errors.Wrap(err, "query.Count") +func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) { + defaults := &PageOpts{ + 1, + 10, + bun.OrderDesc, + "start_date", } - query, pageOpts = setPageOpts(query, pageOpts, 1, 10, bun.OrderDesc, "start_date") - err = query.Scan(ctx) - if err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "query.Scan") - } - sl := &SeasonList{ - Seasons: *seasons, - Total: total, - PageOpts: *pageOpts, - } - return sl, nil + return GetList[Season](tx, pageOpts, defaults).GetAll(ctx) } func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error) { - season := new(Season) - err := tx.NewSelect(). - Model(season). - Where("short_name = ?", strings.ToUpper(shortname)). - Limit(1). - Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") + if shortname == "" { + return nil, errors.New("short_name not provided") } - return season, nil + return GetByField[Season](tx, "short_name", shortname).GetFirst(ctx) } diff --git a/internal/db/user.go b/internal/db/user.go index 16a53c8..9a7f5b3 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "time" "git.haelnorr.com/h/golib/hwsauth" @@ -26,12 +25,6 @@ type User struct { Roles []*Role `bun:"m2m:user_roles,join:User=Role"` } -type Users struct { - Users []*User - Total int - PageOpts PageOpts -} - func (u *User) GetID() int { return u.ID } @@ -61,55 +54,25 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di // GetUserByID queries the database for a user matching the given ID // Returns nil, nil if no user is found func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { - user := new(User) - err := tx.NewSelect(). - Model(user). - Where("id = ?", id). - Limit(1). - Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") - } - return user, nil + return GetByID[User](tx, id).GetFirst(ctx) } // GetUserByUsername queries the database for a user matching the given username // Returns nil, nil if no user is found func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) { - user := new(User) - err := tx.NewSelect(). - Model(user). - Where("username = ?", username). - Limit(1). - Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") + if username == "" { + return nil, errors.New("username not provided") } - return user, nil + return GetByField[User](tx, "username", username).GetFirst(ctx) } // GetUserByDiscordID queries the database for a user matching the given discord id // Returns nil, nil if no user is found func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) { - user := new(User) - err := tx.NewSelect(). - Model(user). - Where("discord_id = ?", discordID). - Limit(1). - Scan(ctx) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") + if discordID == "" { + return nil, errors.New("discord_id not provided") } - return user, nil + return GetByField[User](tx, "discord_id", discordID).GetFirst(ctx) } // GetRoles loads all the roles for this user @@ -117,12 +80,8 @@ func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) { if u == nil { return nil, errors.New("user cannot be nil") } - - err := tx.NewSelect(). - Model(u). - Relation("Roles"). - Where("id = ?", u.ID). - Scan(ctx) + u, err := GetByField[User](tx, "id", u.ID). + Relation("Roles").GetFirst(ctx) if err != nil { return nil, errors.Wrap(err, "tx.NewSelect") } @@ -134,18 +93,11 @@ func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, er if u == nil { return nil, errors.New("user cannot be nil") } - - var permissions []*Permission - err := tx.NewSelect(). - Model(&permissions). + permissions, err := GetByField[[]*Permission](tx, "ur.user_id", u.ID). Join("JOIN role_permissions AS rp on rp.permission_id = p.id"). Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id"). - Where("ur.user_id = ?", u.ID). - Scan(ctx) - if err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "tx.NewSelect") - } - return permissions, nil + GetAll(ctx) + return *permissions, err } // HasPermission checks if user has a specific permission (including wildcard check) @@ -186,23 +138,7 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) { return u.HasRole(ctx, tx, "admin") } -func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) { - users := new([]*User) - query := tx.NewSelect(). - Model(users) - total, err := query.Count(ctx) - if err != nil { - return nil, errors.Wrap(err, "query.Count") - } - query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderAsc, "id") - err = query.Scan(ctx) - if err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "query.Scan") - } - list := &Users{ - Users: *users, - Total: total, - PageOpts: *pageOpts, - } - return list, nil +func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) { + defaults := &PageOpts{1, 50, bun.OrderAsc, "id"} + return GetList[User](tx, pageOpts, defaults).GetAll(ctx) } diff --git a/internal/handlers/admin_dashboard.go b/internal/handlers/admin_dashboard.go index 00f3883..cf49316 100644 --- a/internal/handlers/admin_dashboard.go +++ b/internal/handlers/admin_dashboard.go @@ -13,7 +13,7 @@ import ( func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var users *db.Users + var users *db.List[db.User] if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error users, err = db.GetUsers(ctx, tx, nil) diff --git a/internal/handlers/admin_users.go b/internal/handlers/admin_users.go index 9075302..6bab1e2 100644 --- a/internal/handlers/admin_users.go +++ b/internal/handlers/admin_users.go @@ -14,7 +14,7 @@ import ( // AdminUsersList shows all users func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var users *db.Users + var users *db.List[db.User] pageOpts := pageOptsFromForm(s, w, r) if pageOpts == nil { return diff --git a/internal/handlers/seasons.go b/internal/handlers/seasons.go index e772c01..7402f1e 100644 --- a/internal/handlers/seasons.go +++ b/internal/handlers/seasons.go @@ -20,7 +20,7 @@ func SeasonsPage( if pageOpts == nil { return } - var seasons *db.SeasonList + var seasons *db.List[db.Season] if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error seasons, err = db.ListSeasons(ctx, tx, pageOpts) @@ -44,7 +44,7 @@ func SeasonsList( if pageOpts == nil { return } - var seasons *db.SeasonList + var seasons *db.List[db.Season] if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { var err error seasons, err = db.ListSeasons(ctx, tx, pageOpts) diff --git a/internal/rbac/checker.go b/internal/rbac/checker.go index 886cf89..fccbbf1 100644 --- a/internal/rbac/checker.go +++ b/internal/rbac/checker.go @@ -45,17 +45,17 @@ func (c *Checker) UserHasPermission(ctx context.Context, user *db.User, permissi } // Fallback to database - tx, err := c.conn.BeginTx(ctx, nil) - if err != nil { - return false, errors.Wrap(err, "conn.BeginTx") - } - defer func() { _ = tx.Rollback() }() - - has, err := user.HasPermission(ctx, tx, permission) - if err != nil { + var has bool + if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error { + var err error + has, err = user.HasPermission(ctx, tx, permission) + if err != nil { + return errors.Wrap(err, "user.HasPermission") + } + return nil + }); err != nil { return false, err } - return has, nil } @@ -73,13 +73,18 @@ func (c *Checker) UserHasRole(ctx context.Context, user *db.User, role roles.Rol } // Fallback to database - tx, err := c.conn.BeginTx(ctx, nil) - if err != nil { - return false, errors.Wrap(err, "conn.BeginTx") + var has bool + if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error { + var err error + has, err = user.HasRole(ctx, tx, role) + if err != nil { + return errors.Wrap(err, "user.HasPermission") + } + return nil + }); err != nil { + return false, err } - defer func() { _ = tx.Rollback() }() - - return user.HasRole(ctx, tx, role) + return has, nil } // UserHasAnyPermission checks if user has ANY of the given permissions diff --git a/internal/view/component/admin/user_list.templ b/internal/view/component/admin/user_list.templ index bd3c63c..55d914e 100644 --- a/internal/view/component/admin/user_list.templ +++ b/internal/view/component/admin/user_list.templ @@ -2,5 +2,5 @@ package admin import "git.haelnorr.com/h/oslstats/internal/db" -templ UserList(users *db.Users) { +templ UserList(users *db.List[db.User]) { } diff --git a/internal/view/page/admin_dashboard.templ b/internal/view/page/admin_dashboard.templ index 68189ba..d72899a 100644 --- a/internal/view/page/admin_dashboard.templ +++ b/internal/view/page/admin_dashboard.templ @@ -4,7 +4,7 @@ import "git.haelnorr.com/h/oslstats/internal/view/layout" import "git.haelnorr.com/h/oslstats/internal/view/component/admin" import "git.haelnorr.com/h/oslstats/internal/db" -templ AdminDashboard(users *db.Users) { +templ AdminDashboard(users *db.List[db.User]) { @layout.AdminDashboard() @admin.UserList(users) } diff --git a/internal/view/page/seasons_list.templ b/internal/view/page/seasons_list.templ index f300edf..a859dc7 100644 --- a/internal/view/page/seasons_list.templ +++ b/internal/view/page/seasons_list.templ @@ -9,7 +9,7 @@ import "fmt" import "time" import "github.com/uptrace/bun" -templ SeasonsPage(seasons *db.SeasonList) { +templ SeasonsPage(seasons *db.List[db.Season]) { @layout.Global("Seasons") {
@SeasonsList(seasons) @@ -17,7 +17,7 @@ templ SeasonsPage(seasons *db.SeasonList) { } } -templ SeasonsList(seasons *db.SeasonList) { +templ SeasonsList(seasons *db.List[db.Season]) { {{ sortOpts := []db.OrderOpts{ { @@ -69,14 +69,14 @@ templ SeasonsList(seasons *db.SeasonList) { @sort.Dropdown(seasons.PageOpts, sortOpts)
- if len(seasons.Seasons) == 0 { + if len(seasons.Items) == 0 {

No seasons found

} else {
- for _, s := range seasons.Seasons { + for _, s := range seasons.Items {