From a4b4f4f4af803a84dd7f1903b2f9748910443f57 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Mon, 9 Feb 2026 22:14:38 +1100 Subject: [PATCH] refactored db code --- internal/db/auditlog.go | 8 ++--- internal/db/discordtokens.go | 20 ++--------- internal/db/insert.go | 6 ++-- internal/db/permission.go | 65 ++++-------------------------------- internal/db/role.go | 15 ++++----- internal/db/user.go | 5 ++- internal/db/userrole.go | 5 ++- 7 files changed, 24 insertions(+), 100 deletions(-) diff --git a/internal/db/auditlog.go b/internal/db/auditlog.go index 5717143..414d583 100644 --- a/internal/db/auditlog.go +++ b/internal/db/auditlog.go @@ -32,14 +32,10 @@ func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error { if log == nil { return errors.New("log cannot be nil") } - - _, err := tx.NewInsert(). - Model(log). - Exec(ctx) + err := Insert(tx, log).Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.NewInsert") + return errors.Wrap(err, "db.Insert") } - return nil } diff --git a/internal/db/discordtokens.go b/internal/db/discordtokens.go index c0cd596..9c6093f 100644 --- a/internal/db/discordtokens.go +++ b/internal/db/discordtokens.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "time" "git.haelnorr.com/h/oslstats/internal/discord" @@ -38,15 +37,14 @@ func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord TokenType: token.TokenType, } - _, err := tx.NewInsert(). - Model(discordToken). + err := Insert(tx, discordToken). On("CONFLICT (discord_id) DO UPDATE"). Set("access_token = EXCLUDED.access_token"). Set("refresh_token = EXCLUDED.refresh_token"). Set("expires_at = EXCLUDED.expires_at"). Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.NewInsert") + return errors.Wrap(err, "db.Insert") } return nil } @@ -73,19 +71,7 @@ func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToke // GetDiscordToken retrieves the users discord token from the database func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { - token := new(DiscordToken) - err := tx.NewSelect(). - Model(token). - Where("discord_id = ?", u.DiscordID). - Limit(1). - Scan(ctx) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") - } - return token, nil + return GetByField[DiscordToken](tx, "discord_id", u.DiscordID).GetFirst(ctx) } // Convert reverts the token back into a *discord.Token diff --git a/internal/db/insert.go b/internal/db/insert.go index c2211f1..bf00a57 100644 --- a/internal/db/insert.go +++ b/internal/db/insert.go @@ -47,9 +47,9 @@ func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] { } } -// OnConflict adds conflict handling for upserts -// Example: .OnConflict("(discord_id) DO UPDATE") -func (i *inserter[T]) OnConflict(query string) *inserter[T] { +// On adds .On handling for upserts +// Example: .On("(discord_id) DO UPDATE") +func (i *inserter[T]) On(query string) *inserter[T] { i.q = i.q.On(query) return i } diff --git a/internal/db/permission.go b/internal/db/permission.go index 15c1697..2512f72 100644 --- a/internal/db/permission.go +++ b/internal/db/permission.go @@ -34,20 +34,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis if name == "" { return nil, errors.New("name cannot be empty") } - - perm := new(Permission) - err := tx.NewSelect(). - Model(perm). - Where("name = ?", name). - Limit(1). - Scan(ctx) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") - } - return perm, nil + return GetByField[Permission](tx, "name", name).GetFirst(ctx) } // GetPermissionByID queries the database for a permission matching the given ID @@ -56,20 +43,7 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err if id <= 0 { return nil, errors.New("id must be positive") } - - perm := new(Permission) - err := tx.NewSelect(). - Model(perm). - Where("id = ?", id). - Limit(1). - Scan(ctx) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, errors.Wrap(err, "tx.NewSelect") - } - return perm, nil + return GetByID[Permission](tx, id).GetFirst(ctx) } // GetPermissionsByResource queries for all permissions for a given resource @@ -77,34 +51,8 @@ func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) ( if resource == "" { return nil, errors.New("resource cannot be empty") } - - perms := []*Permission{} - err := tx.NewSelect(). - Model(&perms). - Where("resource = ?", resource). - Order("action ASC"). - Scan(ctx) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return nil, errors.Wrap(err, "tx.NewSelect") - } - return perms, nil -} - -// GetPermissionsByIDs queries for permissions matching the given IDs -func GetPermissionsByIDs(ctx context.Context, tx bun.Tx, ids []int) ([]*Permission, error) { - if len(ids) == 0 { - return []*Permission{}, nil - } - - var perms []*Permission - err := tx.NewSelect(). - Model(&perms). - Where("id IN (?)", bun.In(ids)). - Scan(ctx) - if err != nil && errors.Is(err, sql.ErrNoRows) { - return nil, errors.Wrap(err, "tx.NewSelect") - } - return perms, nil + perms, err := GetByField[[]*Permission](tx, "resource", resource).GetAll(ctx) + return *perms, err } // ListAllPermissions returns all permissions @@ -138,12 +86,11 @@ func CreatePermission(ctx context.Context, tx bun.Tx, perm *Permission) error { return errors.New("action cannot be empty") } - _, err := tx.NewInsert(). - Model(perm). + err := Insert(tx, perm). Returning("id"). Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.NewInsert") + return errors.Wrap(err, "db.Insert") } return nil diff --git a/internal/db/role.go b/internal/db/role.go index 22473b2..c2c9c26 100644 --- a/internal/db/role.go +++ b/internal/db/role.go @@ -77,12 +77,11 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error { } role.CreatedAt = time.Now().Unix() - _, err := tx.NewInsert(). - Model(role). + err := Insert(tx, role). Returning("id"). Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.NewInsert") + return errors.Wrap(err, "db.Insert") } return nil @@ -97,12 +96,11 @@ func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error { return errors.New("role id must be positive") } - _, err := tx.NewUpdate(). - Model(role). + err := Update(tx, role). WherePK(). Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.NewUpdate") + return errors.Wrap(err, "db.Update") } return nil @@ -128,12 +126,11 @@ func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID in RoleID: roleID, PermissionID: permissionID, } - _, err := tx.NewInsert(). - Model(rolePerm). + err := Insert(tx, rolePerm). On("CONFLICT (role_id, permission_id) DO NOTHING"). Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.NewInsert") + return errors.Wrap(err, "db.Insert") } return nil diff --git a/internal/db/user.go b/internal/db/user.go index 9a7f5b3..64b7a18 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -40,12 +40,11 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di DiscordID: discorduser.ID, } - _, err := tx.NewInsert(). - Model(user). + err := Insert(tx, user). Returning("id"). Exec(ctx) if err != nil { - return nil, errors.Wrap(err, "tx.NewInsert") + return nil, errors.Wrap(err, "db.Insert") } return user, nil diff --git a/internal/db/userrole.go b/internal/db/userrole.go index 984e1b2..1ad6e31 100644 --- a/internal/db/userrole.go +++ b/internal/db/userrole.go @@ -29,12 +29,11 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { UserID: userID, RoleID: roleID, } - _, err := tx.NewInsert(). - Model(userRole). + err := Insert(tx, userRole). On("CONFLICT (user_id, role_id) DO NOTHING"). Exec(ctx) if err != nil { - return errors.Wrap(err, "tx.NewInsert") + return errors.Wrap(err, "db.Insert") } return nil