diff --git a/internal/db/auditlog.go b/internal/db/auditlog.go index 414d583..564798f 100644 --- a/internal/db/auditlog.go +++ b/internal/db/auditlog.go @@ -75,10 +75,10 @@ func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *A Order: bun.OrderDesc, OrderBy: "created_at", } - return GetList[AuditLog](tx, pageOpts, defaultPageOpts). + return GetList[AuditLog](tx). Relation("User"). Filter(filters.filters...). - GetAll(ctx) + GetPaged(ctx, pageOpts, defaultPageOpts) } // GetAuditLogsByUser retrieves audit logs for a specific user diff --git a/internal/db/delete.go b/internal/db/delete.go index d8f8458..ac2607f 100644 --- a/internal/db/delete.go +++ b/internal/db/delete.go @@ -84,7 +84,7 @@ func DeleteByID[T any](tx bun.Tx, id int) *deleter[T] { func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int) error { deleter := DeleteByID[T](tx, id) - item, err := GetByID[T](tx, id).GetFirst(ctx) + item, err := GetByID[T](tx, id).Get(ctx) if err != nil { return errors.Wrap(err, "GetByID") } diff --git a/internal/db/discordtokens.go b/internal/db/discordtokens.go index c69e58a..568fdd8 100644 --- a/internal/db/discordtokens.go +++ b/internal/db/discordtokens.go @@ -68,7 +68,7 @@ func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToke // GetDiscordToken retrieves the users discord token from the database func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { - return GetByField[DiscordToken](tx, "discord_id", u.DiscordID).GetFirst(ctx) + return GetByField[DiscordToken](tx, "discord_id", u.DiscordID).Get(ctx) } // Convert reverts the token back into a *discord.Token diff --git a/internal/db/getbyfield.go b/internal/db/getbyfield.go index 74cdba7..4cb8f48 100644 --- a/internal/db/getbyfield.go +++ b/internal/db/getbyfield.go @@ -31,15 +31,11 @@ func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) { return g.model, nil } -func (g *fieldgetter[T]) GetFirst(ctx context.Context) (*T, error) { +func (g *fieldgetter[T]) Get(ctx context.Context) (*T, error) { g.q = g.q.Limit(1) return g.get(ctx) } -func (g *fieldgetter[T]) GetAll(ctx context.Context) (*T, error) { - return g.get(ctx) -} - func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] { g.q = g.q.Relation(name, apply...) return g diff --git a/internal/db/getlist.go b/internal/db/getlist.go index 8e4719c..cd414f0 100644 --- a/internal/db/getlist.go +++ b/internal/db/getlist.go @@ -9,10 +9,8 @@ import ( ) type listgetter[T any] struct { - q *bun.SelectQuery - items *[]*T - pageOpts *PageOpts - defaults *PageOpts + q *bun.SelectQuery + items *[]*T } type List[T any] struct { @@ -38,17 +36,25 @@ func (f *ListFilter) Add(field string, value any) { f.filters = append(f.filters, Filter{field, value}) } -func GetList[T any](tx bun.Tx, pageOpts, defaults *PageOpts) *listgetter[T] { +func GetList[T any](tx bun.Tx) *listgetter[T] { l := &listgetter[T]{ - items: new([]*T), - pageOpts: pageOpts, - defaults: defaults, + items: new([]*T), } l.q = tx.NewSelect(). Model(l.items) return l } +func (l *listgetter[T]) Join(join string, args ...any) *listgetter[T] { + l.q = l.q.Join(join, args...) + return l +} + +func (l *listgetter[T]) Where(query string, args ...any) *listgetter[T] { + l.q = l.q.Where(query, args...) + return l +} + func (l *listgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *listgetter[T] { l.q = l.q.Relation(name, apply...) return l @@ -61,15 +67,15 @@ func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] { return l } -func (l *listgetter[T]) GetAll(ctx context.Context) (*List[T], error) { - if l.defaults == nil { +func (l *listgetter[T]) GetPaged(ctx context.Context, pageOpts, defaults *PageOpts) (*List[T], error) { + if defaults == nil { return nil, errors.New("default pageopts is nil") } total, err := l.q.Count(ctx) if err != nil { return nil, errors.Wrap(err, "query.Count") } - l.q, l.pageOpts = setPageOpts(l.q, l.pageOpts, l.defaults, total) + l.q, pageOpts = setPageOpts(l.q, pageOpts, defaults, total) err = l.q.Scan(ctx) if err != nil && errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "query.Scan") @@ -77,7 +83,15 @@ func (l *listgetter[T]) GetAll(ctx context.Context) (*List[T], error) { list := &List[T]{ Items: *l.items, Total: total, - PageOpts: *l.pageOpts, + PageOpts: *pageOpts, } return list, nil } + +func (l *listgetter[T]) GetAll(ctx context.Context) ([]*T, error) { + err := l.q.Scan(ctx) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return nil, errors.Wrap(err, "query.Scan") + } + return *l.items, nil +} diff --git a/internal/db/permission.go b/internal/db/permission.go index 2512f72..10194ce 100644 --- a/internal/db/permission.go +++ b/internal/db/permission.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "git.haelnorr.com/h/oslstats/internal/permissions" "github.com/pkg/errors" @@ -34,7 +33,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis if name == "" { return nil, errors.New("name cannot be empty") } - return GetByField[Permission](tx, "name", name).GetFirst(ctx) + return GetByField[Permission](tx, "name", name).Get(ctx) } // GetPermissionByID queries the database for a permission matching the given ID @@ -43,7 +42,7 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err if id <= 0 { return nil, errors.New("id must be positive") } - return GetByID[Permission](tx, id).GetFirst(ctx) + return GetByID[Permission](tx, id).Get(ctx) } // GetPermissionsByResource queries for all permissions for a given resource @@ -51,21 +50,13 @@ func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) ( if resource == "" { return nil, errors.New("resource cannot be empty") } - perms, err := GetByField[[]*Permission](tx, "resource", resource).GetAll(ctx) - return *perms, err + return GetList[Permission](tx). + Where("resource = ?", resource).GetAll(ctx) } // ListAllPermissions returns all permissions func ListAllPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) { - var perms []*Permission - err := tx.NewSelect(). - Model(&perms). - Order("resource ASC", "action ASC"). - Scan(ctx) - if err != nil && errors.Is(err, sql.ErrNoRows) { - return nil, errors.Wrap(err, "tx.NewSelect") - } - return perms, nil + return GetList[Permission](tx).GetAll(ctx) } // CreatePermission creates a new permission diff --git a/internal/db/role.go b/internal/db/role.go index 335b651..12f88c6 100644 --- a/internal/db/role.go +++ b/internal/db/role.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "time" "git.haelnorr.com/h/oslstats/internal/roles" @@ -43,31 +42,23 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro if name == "" { return nil, errors.New("name cannot be empty") } - return GetByField[Role](tx, "name", name).GetFirst(ctx) + return GetByField[Role](tx, "name", name).Get(ctx) } // GetRoleByID queries the database for a role matching the given ID // Returns nil, nil if no role is found func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) { - return GetByID[Role](tx, id).GetFirst(ctx) + return GetByID[Role](tx, id).Get(ctx) } // GetRoleWithPermissions loads a role and all its permissions func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, error) { - return GetByID[Role](tx, id).Relation("Permissions").GetFirst(ctx) + return GetByID[Role](tx, id).Relation("Permissions").Get(ctx) } // ListAllRoles returns all roles func ListAllRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) { - var roles []*Role - err := tx.NewSelect(). - Model(&roles). - Order("name ASC"). - Scan(ctx) - if err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "tx.NewSelect") - } - return roles, nil + return GetList[Role](tx).GetAll(ctx) } // CreateRole creates a new role diff --git a/internal/db/season.go b/internal/db/season.go index bff84ce..ae3d6e9 100644 --- a/internal/db/season.go +++ b/internal/db/season.go @@ -38,14 +38,14 @@ func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[Seas bun.OrderDesc, "start_date", } - return GetList[Season](tx, pageOpts, defaults).GetAll(ctx) + return GetList[Season](tx).GetPaged(ctx, pageOpts, defaults) } func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error) { if shortname == "" { return nil, errors.New("short_name not provided") } - return GetByField[Season](tx, "short_name", shortname).GetFirst(ctx) + return GetByField[Season](tx, "short_name", shortname).Get(ctx) } // Update updates the season struct. It does not insert to the database diff --git a/internal/db/user.go b/internal/db/user.go index 64b7a18..6bb435b 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -53,7 +53,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di // GetUserByID queries the database for a user matching the given ID // Returns nil, nil if no user is found func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { - return GetByID[User](tx, id).GetFirst(ctx) + return GetByID[User](tx, id).Get(ctx) } // GetUserByUsername queries the database for a user matching the given username @@ -62,7 +62,7 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, if username == "" { return nil, errors.New("username not provided") } - return GetByField[User](tx, "username", username).GetFirst(ctx) + return GetByField[User](tx, "username", username).Get(ctx) } // GetUserByDiscordID queries the database for a user matching the given discord id @@ -71,7 +71,7 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User if discordID == "" { return nil, errors.New("discord_id not provided") } - return GetByField[User](tx, "discord_id", discordID).GetFirst(ctx) + return GetByField[User](tx, "discord_id", discordID).Get(ctx) } // GetRoles loads all the roles for this user @@ -80,9 +80,9 @@ func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) { return nil, errors.New("user cannot be nil") } u, err := GetByField[User](tx, "id", u.ID). - Relation("Roles").GetFirst(ctx) + Relation("Roles").Get(ctx) if err != nil { - return nil, errors.Wrap(err, "tx.NewSelect") + return nil, errors.Wrap(err, "GetByField") } return u.Roles, nil } @@ -92,11 +92,11 @@ func (u *User) GetPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, er if u == nil { return nil, errors.New("user cannot be nil") } - permissions, err := GetByField[[]*Permission](tx, "ur.user_id", u.ID). + return GetList[Permission](tx). Join("JOIN role_permissions AS rp on rp.permission_id = p.id"). Join("JOIN user_roles AS ur ON ur.role_id = rp.role_id"). + Where("ur.user_id = ?", u.ID). GetAll(ctx) - return *permissions, err } // HasPermission checks if user has a specific permission (including wildcard check) @@ -139,5 +139,5 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) { func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) { defaults := &PageOpts{1, 50, bun.OrderAsc, "id"} - return GetList[User](tx, pageOpts, defaults).GetAll(ctx) + return GetList[User](tx).GetPaged(ctx, pageOpts, defaults) } diff --git a/internal/db/userrole.go b/internal/db/userrole.go index e30332e..ccd70bb 100644 --- a/internal/db/userrole.go +++ b/internal/db/userrole.go @@ -51,7 +51,7 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { Where("role_id = ?", roleID). Delete(ctx) if err != nil { - return errors.Wrap(err, "tx.NewDelete") + return errors.Wrap(err, "DeleteItem") } return nil @@ -66,7 +66,7 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b return false, errors.New("roleName cannot be empty") } user, err := GetByID[User](tx, userID). - Relation("Roles").GetFirst(ctx) + Relation("Roles").Get(ctx) if err != nil { return false, errors.Wrap(err, "GetByID") }