more refactors
This commit is contained in:
@@ -104,8 +104,8 @@ func (l *Logger) log(
|
||||
}
|
||||
|
||||
// GetRecentLogs retrieves recent audit logs with pagination
|
||||
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
||||
var logs *db.AuditLogs
|
||||
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) {
|
||||
var logs *db.List[db.AuditLog]
|
||||
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||
var err error
|
||||
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil)
|
||||
@@ -120,8 +120,8 @@ func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.
|
||||
}
|
||||
|
||||
// GetLogsByUser retrieves audit logs for a specific user
|
||||
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
||||
var logs *db.AuditLogs
|
||||
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.List[db.AuditLog], error) {
|
||||
var logs *db.List[db.AuditLog]
|
||||
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||
var err error
|
||||
logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
|
||||
|
||||
@@ -2,7 +2,6 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -28,12 +27,6 @@ type AuditLog struct {
|
||||
User *User `bun:"rel:belongs-to,join:user_id=id"`
|
||||
}
|
||||
|
||||
type AuditLogs struct {
|
||||
AuditLogs []*AuditLog
|
||||
Total int
|
||||
PageOpts PageOpts
|
||||
}
|
||||
|
||||
// CreateAuditLog creates a new audit log entry
|
||||
func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
|
||||
if log == nil {
|
||||
@@ -50,80 +43,64 @@ func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type AuditLogFilters struct {
|
||||
UserID *int
|
||||
Action *string
|
||||
ResourceType *string
|
||||
Result *string
|
||||
type AuditLogFilter struct {
|
||||
*ListFilter
|
||||
}
|
||||
|
||||
func NewAuditLogFilter() *AuditLogFilter {
|
||||
return &AuditLogFilter{NewListFilter()}
|
||||
}
|
||||
|
||||
func (a *AuditLogFilter) UserID(id int) *AuditLogFilter {
|
||||
a.Add("al.user_id", id)
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *AuditLogFilter) Action(action string) *AuditLogFilter {
|
||||
a.Add("al.action", action)
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *AuditLogFilter) ResourceType(resourceType string) *AuditLogFilter {
|
||||
a.Add("al.resource_type", resourceType)
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *AuditLogFilter) Result(result string) *AuditLogFilter {
|
||||
a.Add("al.result", result)
|
||||
return a
|
||||
}
|
||||
|
||||
// GetAuditLogs retrieves audit logs with optional filters and pagination
|
||||
func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) {
|
||||
query := tx.NewSelect().
|
||||
Model((*AuditLog)(nil)).
|
||||
Relation("User")
|
||||
|
||||
// Apply filters if provided
|
||||
if filters != nil {
|
||||
if filters.UserID != nil {
|
||||
query = query.Where("al.user_id = ?", *filters.UserID)
|
||||
}
|
||||
if filters.Action != nil {
|
||||
query = query.Where("al.action = ?", *filters.Action)
|
||||
}
|
||||
if filters.ResourceType != nil {
|
||||
query = query.Where("al.resource_type = ?", *filters.ResourceType)
|
||||
}
|
||||
if filters.Result != nil {
|
||||
query = query.Where("al.result = ?", *filters.Result)
|
||||
}
|
||||
func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilter) (*List[AuditLog], error) {
|
||||
defaultPageOpts := &PageOpts{
|
||||
Page: 1,
|
||||
PerPage: 50,
|
||||
Order: bun.OrderDesc,
|
||||
OrderBy: "created_at",
|
||||
}
|
||||
|
||||
// Get total count
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query.Count")
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderDesc, "created_at")
|
||||
logs := new([]*AuditLog)
|
||||
err = query.Scan(ctx, &logs)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "query.Scan")
|
||||
}
|
||||
|
||||
list := &AuditLogs{
|
||||
AuditLogs: *logs,
|
||||
Total: total,
|
||||
PageOpts: *pageOpts,
|
||||
}
|
||||
|
||||
return list, nil
|
||||
return GetList[AuditLog](tx, pageOpts, defaultPageOpts).
|
||||
Relation("User").
|
||||
Filter(filters.filters...).
|
||||
GetAll(ctx)
|
||||
}
|
||||
|
||||
// GetAuditLogsByUser retrieves audit logs for a specific user
|
||||
func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*AuditLogs, error) {
|
||||
func GetAuditLogsByUser(ctx context.Context, tx bun.Tx, userID int, pageOpts *PageOpts) (*List[AuditLog], error) {
|
||||
if userID <= 0 {
|
||||
return nil, errors.New("userID must be positive")
|
||||
}
|
||||
|
||||
filters := &AuditLogFilters{
|
||||
UserID: &userID,
|
||||
}
|
||||
filters := NewAuditLogFilter().UserID(userID)
|
||||
|
||||
return GetAuditLogs(ctx, tx, pageOpts, filters)
|
||||
}
|
||||
|
||||
// GetAuditLogsByAction retrieves audit logs for a specific action
|
||||
func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*AuditLogs, error) {
|
||||
func GetAuditLogsByAction(ctx context.Context, tx bun.Tx, action string, pageOpts *PageOpts) (*List[AuditLog], error) {
|
||||
if action == "" {
|
||||
return nil, errors.New("action cannot be empty")
|
||||
}
|
||||
|
||||
filters := &AuditLogFilters{
|
||||
Action: &action,
|
||||
}
|
||||
filters := NewAuditLogFilter().Action(action)
|
||||
|
||||
return GetAuditLogs(ctx, tx, pageOpts, filters)
|
||||
}
|
||||
|
||||
58
internal/db/delete.go
Normal file
58
internal/db/delete.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type deleter[T any] struct {
|
||||
q *bun.DeleteQuery
|
||||
}
|
||||
|
||||
type systemType interface {
|
||||
isSystem() bool
|
||||
}
|
||||
|
||||
func DeleteItem[T any](tx bun.Tx) *deleter[T] {
|
||||
return &deleter[T]{
|
||||
tx.NewDelete().
|
||||
Model((*T)(nil)),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *deleter[T]) Where(query string, args ...any) *deleter[T] {
|
||||
d.q = d.q.Where(query, args...)
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *deleter[T]) Delete(ctx context.Context) error {
|
||||
_, err := d.q.Exec(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.Wrap(err, "bun.DeleteQuery.Exec")
|
||||
}
|
||||
|
||||
func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] {
|
||||
return DeleteItem[T](tx).Where("id = ?", id)
|
||||
}
|
||||
|
||||
func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) error {
|
||||
deleter := DeleteByID[T](tx, id)
|
||||
item, err := GetByID[T](tx, id).GetFirst(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "GetByID")
|
||||
}
|
||||
if item == nil {
|
||||
return errors.New("record not found")
|
||||
}
|
||||
if (*item).isSystem() {
|
||||
return errors.New("record is system protected")
|
||||
}
|
||||
return deleter.Delete(ctx)
|
||||
}
|
||||
83
internal/db/getbyfield.go
Normal file
83
internal/db/getbyfield.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type ListFilter struct {
|
||||
filters []Filter
|
||||
}
|
||||
|
||||
func NewListFilter() *ListFilter {
|
||||
return &ListFilter{[]Filter{}}
|
||||
}
|
||||
|
||||
func (f *ListFilter) Add(field string, value any) {
|
||||
f.filters = append(f.filters, Filter{field, value})
|
||||
}
|
||||
|
||||
type fieldgetter[T any] struct {
|
||||
q *bun.SelectQuery
|
||||
field string
|
||||
value any
|
||||
}
|
||||
|
||||
func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
|
||||
if g.field == "id" && (g.value).(int) < 1 {
|
||||
return nil, errors.New("invalid id")
|
||||
}
|
||||
model := new(T)
|
||||
err := g.q.
|
||||
Where("? = ?", bun.Ident(g.field), g.value).
|
||||
Scan(ctx, model)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (g *fieldgetter[T]) GetFirst(ctx context.Context) (*T, error) {
|
||||
g.q = g.q.Limit(1)
|
||||
return g.get(ctx)
|
||||
}
|
||||
|
||||
func (g *fieldgetter[T]) GetAll(ctx context.Context) (*T, error) {
|
||||
return g.get(ctx)
|
||||
}
|
||||
|
||||
func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] {
|
||||
g.q = g.q.Relation(name, apply...)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *fieldgetter[T]) Join(join string, args ...any) *fieldgetter[T] {
|
||||
g.q = g.q.Join(join, args...)
|
||||
return g
|
||||
}
|
||||
|
||||
// GetByField retrieves a single record by field name
|
||||
func GetByField[T any](
|
||||
tx bun.Tx,
|
||||
field string,
|
||||
value any,
|
||||
) *fieldgetter[T] {
|
||||
return &fieldgetter[T]{
|
||||
tx.NewSelect().Model((*T)(nil)),
|
||||
field,
|
||||
value,
|
||||
}
|
||||
}
|
||||
|
||||
func GetByID[T any](
|
||||
tx bun.Tx,
|
||||
id int,
|
||||
) *fieldgetter[T] {
|
||||
return GetByField[T](tx, "id", id)
|
||||
}
|
||||
71
internal/db/getlist.go
Normal file
71
internal/db/getlist.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type listgetter[T any] struct {
|
||||
q *bun.SelectQuery
|
||||
items *[]*T
|
||||
pageOpts *PageOpts
|
||||
defaults *PageOpts
|
||||
}
|
||||
|
||||
type List[T any] struct {
|
||||
Items []*T
|
||||
Total int
|
||||
PageOpts PageOpts
|
||||
}
|
||||
|
||||
type Filter struct {
|
||||
Field string
|
||||
Value any
|
||||
}
|
||||
|
||||
func GetList[T any](tx bun.Tx, pageOpts, defaults *PageOpts) *listgetter[T] {
|
||||
l := &listgetter[T]{
|
||||
items: new([]*T),
|
||||
pageOpts: pageOpts,
|
||||
defaults: defaults,
|
||||
}
|
||||
l.q = tx.NewSelect().
|
||||
Model(l.items)
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *listgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *listgetter[T] {
|
||||
l.q = l.q.Relation(name, apply...)
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] {
|
||||
for _, filter := range filters {
|
||||
l.q = l.q.Where("? = ?", bun.Ident(filter.Field), filter.Value)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *listgetter[T]) GetAll(ctx context.Context) (*List[T], error) {
|
||||
if l.defaults == nil {
|
||||
return nil, errors.New("default pageopts is nil")
|
||||
}
|
||||
total, err := l.q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query.Count")
|
||||
}
|
||||
l.q, l.pageOpts = setPageOpts(l.q, l.pageOpts, l.defaults, total)
|
||||
err = l.q.Scan(ctx)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "query.Scan")
|
||||
}
|
||||
list := &List[T]{
|
||||
Items: *l.items,
|
||||
Total: total,
|
||||
PageOpts: *l.pageOpts,
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
@@ -19,21 +19,25 @@ type OrderOpts struct {
|
||||
Label string
|
||||
}
|
||||
|
||||
func setPageOpts(q *bun.SelectQuery, p *PageOpts, page, perpage int, order bun.Order, orderby string) (*bun.SelectQuery, *PageOpts) {
|
||||
func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.SelectQuery, *PageOpts) {
|
||||
if p == nil {
|
||||
p = new(PageOpts)
|
||||
}
|
||||
if p.Page == 0 {
|
||||
p.Page = page
|
||||
if p.Page <= 0 {
|
||||
p.Page = d.Page
|
||||
}
|
||||
if p.PerPage == 0 {
|
||||
p.PerPage = perpage
|
||||
p.PerPage = d.PerPage
|
||||
}
|
||||
maxpage := p.TotalPages(totalitems)
|
||||
if p.Page > maxpage {
|
||||
p.Page = maxpage
|
||||
}
|
||||
if p.Order == "" {
|
||||
p.Order = order
|
||||
p.Order = d.Order
|
||||
}
|
||||
if p.OrderBy == "" {
|
||||
p.OrderBy = orderby
|
||||
p.OrderBy = d.OrderBy
|
||||
}
|
||||
p.OrderBy = sanitiseOrderBy(p.OrderBy)
|
||||
q = q.OrderBy(p.OrderBy, p.Order).
|
||||
|
||||
@@ -24,6 +24,10 @@ type Permission struct {
|
||||
Roles []Role `bun:"m2m:role_permissions,join:Permission=Role"`
|
||||
}
|
||||
|
||||
func (p Permission) isSystem() bool {
|
||||
return p.IsSystem
|
||||
}
|
||||
|
||||
// GetPermissionByName queries the database for a permission matching the given name
|
||||
// Returns nil, nil if no permission is found
|
||||
func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) {
|
||||
@@ -150,26 +154,5 @@ func DeletePermission(ctx context.Context, tx bun.Tx, id int) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id must be positive")
|
||||
}
|
||||
|
||||
// Check if permission is system permission
|
||||
perm, err := GetPermissionByID(ctx, tx, id)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "GetPermissionByID")
|
||||
}
|
||||
if perm == nil {
|
||||
return errors.New("permission not found")
|
||||
}
|
||||
if perm.IsSystem {
|
||||
return errors.New("cannot delete system permission")
|
||||
}
|
||||
|
||||
_, err = tx.NewDelete().
|
||||
Model((*Permission)(nil)).
|
||||
Where("id = ?", id).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.NewDelete")
|
||||
}
|
||||
|
||||
return nil
|
||||
return DeleteWithProtection[Permission](ctx, tx, id)
|
||||
}
|
||||
|
||||
@@ -33,70 +33,28 @@ type RolePermission struct {
|
||||
Permission *Permission `bun:"rel:belongs-to,join:permission_id=id"`
|
||||
}
|
||||
|
||||
func (r Role) isSystem() bool {
|
||||
return r.IsSystem
|
||||
}
|
||||
|
||||
// GetRoleByName queries the database for a role matching the given name
|
||||
// Returns nil, nil if no role is found
|
||||
func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("name cannot be empty")
|
||||
}
|
||||
|
||||
role := new(Role)
|
||||
err := tx.NewSelect().
|
||||
Model(role).
|
||||
Where("name = ?", name).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return role, nil
|
||||
return GetByField[Role](tx, "name", name).GetFirst(ctx)
|
||||
}
|
||||
|
||||
// GetRoleByID queries the database for a role matching the given ID
|
||||
// Returns nil, nil if no role is found
|
||||
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
|
||||
if id <= 0 {
|
||||
return nil, errors.New("id must be positive")
|
||||
}
|
||||
|
||||
role := new(Role)
|
||||
err := tx.NewSelect().
|
||||
Model(role).
|
||||
Where("id = ?", id).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return role, nil
|
||||
return GetByID[Role](tx, id).GetFirst(ctx)
|
||||
}
|
||||
|
||||
// GetRoleWithPermissions loads a role and all its permissions
|
||||
func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
|
||||
if id <= 0 {
|
||||
return nil, errors.New("id must be positive")
|
||||
}
|
||||
|
||||
role := new(Role)
|
||||
err := tx.NewSelect().
|
||||
Model(role).
|
||||
Where("id = ?", id).
|
||||
Relation("Permissions").
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return role, nil
|
||||
return GetByID[Role](tx, id).Relation("Permissions").GetFirst(ctx)
|
||||
}
|
||||
|
||||
// ListAllRoles returns all roles
|
||||
@@ -155,28 +113,7 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id must be positive")
|
||||
}
|
||||
|
||||
// Check if role is system role
|
||||
role, err := GetRoleByID(ctx, tx, id)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "GetRoleByID")
|
||||
}
|
||||
if role == nil {
|
||||
return errors.New("role not found")
|
||||
}
|
||||
if role.IsSystem {
|
||||
return errors.New("cannot delete system role")
|
||||
}
|
||||
|
||||
_, err = tx.NewDelete().
|
||||
Model((*Role)(nil)).
|
||||
Where("id = ?", id).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.NewDelete")
|
||||
}
|
||||
|
||||
return nil
|
||||
return DeleteWithProtection[Role](ctx, tx, id)
|
||||
}
|
||||
|
||||
// AddPermissionToRole grants a permission to a role
|
||||
|
||||
@@ -2,7 +2,6 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -22,12 +21,6 @@ type Season struct {
|
||||
FinalsEndDate bun.NullTime `bun:"finals_end_date"`
|
||||
}
|
||||
|
||||
type SeasonList struct {
|
||||
Seasons []*Season
|
||||
Total int
|
||||
PageOpts PageOpts
|
||||
}
|
||||
|
||||
func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start time.Time) (*Season, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("name cannot be empty")
|
||||
@@ -50,40 +43,19 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim
|
||||
return season, nil
|
||||
}
|
||||
|
||||
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) {
|
||||
seasons := new([]*Season)
|
||||
query := tx.NewSelect().
|
||||
Model(seasons)
|
||||
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query.Count")
|
||||
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Season], error) {
|
||||
defaults := &PageOpts{
|
||||
1,
|
||||
10,
|
||||
bun.OrderDesc,
|
||||
"start_date",
|
||||
}
|
||||
query, pageOpts = setPageOpts(query, pageOpts, 1, 10, bun.OrderDesc, "start_date")
|
||||
err = query.Scan(ctx)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "query.Scan")
|
||||
}
|
||||
sl := &SeasonList{
|
||||
Seasons: *seasons,
|
||||
Total: total,
|
||||
PageOpts: *pageOpts,
|
||||
}
|
||||
return sl, nil
|
||||
return GetList[Season](tx, pageOpts, defaults).GetAll(ctx)
|
||||
}
|
||||
|
||||
func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error) {
|
||||
season := new(Season)
|
||||
err := tx.NewSelect().
|
||||
Model(season).
|
||||
Where("short_name = ?", strings.ToUpper(shortname)).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
if shortname == "" {
|
||||
return nil, errors.New("short_name not provided")
|
||||
}
|
||||
return season, nil
|
||||
return GetByField[Season](tx, "short_name", shortname).GetFirst(ctx)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
@@ -26,12 +25,6 @@ type User struct {
|
||||
Roles []*Role `bun:"m2m:user_roles,join:User=Role"`
|
||||
}
|
||||
|
||||
type Users struct {
|
||||
Users []*User
|
||||
Total int
|
||||
PageOpts PageOpts
|
||||
}
|
||||
|
||||
func (u *User) GetID() int {
|
||||
return u.ID
|
||||
}
|
||||
@@ -61,55 +54,25 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
|
||||
// GetUserByID queries the database for a user matching the given ID
|
||||
// Returns nil, nil if no user is found
|
||||
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
|
||||
user := new(User)
|
||||
err := tx.NewSelect().
|
||||
Model(user).
|
||||
Where("id = ?", id).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return user, nil
|
||||
return GetByID[User](tx, id).GetFirst(ctx)
|
||||
}
|
||||
|
||||
// GetUserByUsername queries the database for a user matching the given username
|
||||
// Returns nil, nil if no user is found
|
||||
func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) {
|
||||
user := new(User)
|
||||
err := tx.NewSelect().
|
||||
Model(user).
|
||||
Where("username = ?", username).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
if username == "" {
|
||||
return nil, errors.New("username not provided")
|
||||
}
|
||||
return user, nil
|
||||
return GetByField[User](tx, "username", username).GetFirst(ctx)
|
||||
}
|
||||
|
||||
// GetUserByDiscordID queries the database for a user matching the given discord id
|
||||
// Returns nil, nil if no user is found
|
||||
func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) {
|
||||
user := new(User)
|
||||
err := tx.NewSelect().
|
||||
Model(user).
|
||||
Where("discord_id = ?", discordID).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
if discordID == "" {
|
||||
return nil, errors.New("discord_id not provided")
|
||||
}
|
||||
return user, nil
|
||||
return GetByField[User](tx, "discord_id", discordID).GetFirst(ctx)
|
||||
}
|
||||
|
||||
// GetRoles loads all the roles for this user
|
||||
@@ -117,12 +80,8 @@ func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
|
||||
if u == nil {
|
||||
return nil, errors.New("user cannot be nil")
|
||||
}
|
||||
|
||||
err := tx.NewSelect().
|
||||
Model(u).
|
||||
Relation("Roles").
|
||||
Where("id = ?", u.ID).
|
||||
Scan(ctx)
|
||||
u, err := GetByField[User](tx, "id", u.ID).
|
||||
Relation("Roles").GetFirst(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
@@ -134,18 +93,11 @@ func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, er
|
||||
if u == nil {
|
||||
return nil, errors.New("user cannot be nil")
|
||||
}
|
||||
|
||||
var permissions []*Permission
|
||||
err := tx.NewSelect().
|
||||
Model(&permissions).
|
||||
permissions, err := GetByField[[]*Permission](tx, "ur.user_id", u.ID).
|
||||
Join("JOIN role_permissions AS rp on rp.permission_id = p.id").
|
||||
Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id").
|
||||
Where("ur.user_id = ?", u.ID).
|
||||
Scan(ctx)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return permissions, nil
|
||||
GetAll(ctx)
|
||||
return *permissions, err
|
||||
}
|
||||
|
||||
// HasPermission checks if user has a specific permission (including wildcard check)
|
||||
@@ -186,23 +138,7 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
return u.HasRole(ctx, tx, "admin")
|
||||
}
|
||||
|
||||
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) {
|
||||
users := new([]*User)
|
||||
query := tx.NewSelect().
|
||||
Model(users)
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query.Count")
|
||||
}
|
||||
query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderAsc, "id")
|
||||
err = query.Scan(ctx)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, errors.Wrap(err, "query.Scan")
|
||||
}
|
||||
list := &Users{
|
||||
Users: *users,
|
||||
Total: total,
|
||||
PageOpts: *pageOpts,
|
||||
}
|
||||
return list, nil
|
||||
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) {
|
||||
defaults := &PageOpts{1, 50, bun.OrderAsc, "id"}
|
||||
return GetList[User](tx, pageOpts, defaults).GetAll(ctx)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var users *db.Users
|
||||
var users *db.List[db.User]
|
||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
var err error
|
||||
users, err = db.GetUsers(ctx, tx, nil)
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
// AdminUsersList shows all users
|
||||
func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var users *db.Users
|
||||
var users *db.List[db.User]
|
||||
pageOpts := pageOptsFromForm(s, w, r)
|
||||
if pageOpts == nil {
|
||||
return
|
||||
|
||||
@@ -20,7 +20,7 @@ func SeasonsPage(
|
||||
if pageOpts == nil {
|
||||
return
|
||||
}
|
||||
var seasons *db.SeasonList
|
||||
var seasons *db.List[db.Season]
|
||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
var err error
|
||||
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
|
||||
@@ -44,7 +44,7 @@ func SeasonsList(
|
||||
if pageOpts == nil {
|
||||
return
|
||||
}
|
||||
var seasons *db.SeasonList
|
||||
var seasons *db.List[db.Season]
|
||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
var err error
|
||||
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
|
||||
|
||||
@@ -45,17 +45,17 @@ func (c *Checker) UserHasPermission(ctx context.Context, user *db.User, permissi
|
||||
}
|
||||
|
||||
// Fallback to database
|
||||
tx, err := c.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "conn.BeginTx")
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
has, err := user.HasPermission(ctx, tx, permission)
|
||||
if err != nil {
|
||||
var has bool
|
||||
if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||
var err error
|
||||
has, err = user.HasPermission(ctx, tx, permission)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "user.HasPermission")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return has, nil
|
||||
}
|
||||
|
||||
@@ -73,13 +73,18 @@ func (c *Checker) UserHasRole(ctx context.Context, user *db.User, role roles.Rol
|
||||
}
|
||||
|
||||
// Fallback to database
|
||||
tx, err := c.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "conn.BeginTx")
|
||||
var has bool
|
||||
if err := db.WithTxFailSilently(ctx, c.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||
var err error
|
||||
has, err = user.HasRole(ctx, tx, role)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "user.HasPermission")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
return user.HasRole(ctx, tx, role)
|
||||
return has, nil
|
||||
}
|
||||
|
||||
// UserHasAnyPermission checks if user has ANY of the given permissions
|
||||
|
||||
@@ -2,5 +2,5 @@ package admin
|
||||
|
||||
import "git.haelnorr.com/h/oslstats/internal/db"
|
||||
|
||||
templ UserList(users *db.Users) {
|
||||
templ UserList(users *db.List[db.User]) {
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import "git.haelnorr.com/h/oslstats/internal/view/layout"
|
||||
import "git.haelnorr.com/h/oslstats/internal/view/component/admin"
|
||||
import "git.haelnorr.com/h/oslstats/internal/db"
|
||||
|
||||
templ AdminDashboard(users *db.Users) {
|
||||
templ AdminDashboard(users *db.List[db.User]) {
|
||||
@layout.AdminDashboard()
|
||||
@admin.UserList(users)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import "fmt"
|
||||
import "time"
|
||||
import "github.com/uptrace/bun"
|
||||
|
||||
templ SeasonsPage(seasons *db.SeasonList) {
|
||||
templ SeasonsPage(seasons *db.List[db.Season]) {
|
||||
@layout.Global("Seasons") {
|
||||
<div class="max-w-screen-2xl mx-auto px-2">
|
||||
@SeasonsList(seasons)
|
||||
@@ -17,7 +17,7 @@ templ SeasonsPage(seasons *db.SeasonList) {
|
||||
}
|
||||
}
|
||||
|
||||
templ SeasonsList(seasons *db.SeasonList) {
|
||||
templ SeasonsList(seasons *db.List[db.Season]) {
|
||||
{{
|
||||
sortOpts := []db.OrderOpts{
|
||||
{
|
||||
@@ -69,14 +69,14 @@ templ SeasonsList(seasons *db.SeasonList) {
|
||||
@sort.Dropdown(seasons.PageOpts, sortOpts)
|
||||
</div>
|
||||
<!-- Results section -->
|
||||
if len(seasons.Seasons) == 0 {
|
||||
if len(seasons.Items) == 0 {
|
||||
<div class="bg-mantle border border-surface1 rounded-lg p-8 text-center">
|
||||
<p class="text-subtext0 text-lg">No seasons found</p>
|
||||
</div>
|
||||
} else {
|
||||
<!-- Card grid -->
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
for _, s := range seasons.Seasons {
|
||||
for _, s := range seasons.Items {
|
||||
<a
|
||||
class="bg-mantle border border-surface1 rounded-lg p-6 hover:bg-surface0 transition-colors"
|
||||
href={ fmt.Sprintf("/seasons/%s", s.ShortName) }
|
||||
|
||||
Reference in New Issue
Block a user