refactored db code

This commit is contained in:
2026-02-09 22:14:38 +11:00
parent b89ee75ca7
commit a4b4f4f4af
7 changed files with 24 additions and 100 deletions

View File

@@ -32,14 +32,10 @@ func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
if log == nil {
return errors.New("log cannot be nil")
}
_, err := tx.NewInsert().
Model(log).
Exec(ctx)
err := Insert(tx, log).Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
return errors.Wrap(err, "db.Insert")
}
return nil
}

View File

@@ -2,7 +2,6 @@ package db
import (
"context"
"database/sql"
"time"
"git.haelnorr.com/h/oslstats/internal/discord"
@@ -38,15 +37,14 @@ func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord
TokenType: token.TokenType,
}
_, err := tx.NewInsert().
Model(discordToken).
err := Insert(tx, discordToken).
On("CONFLICT (discord_id) DO UPDATE").
Set("access_token = EXCLUDED.access_token").
Set("refresh_token = EXCLUDED.refresh_token").
Set("expires_at = EXCLUDED.expires_at").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
return errors.Wrap(err, "db.Insert")
}
return nil
}
@@ -73,19 +71,7 @@ func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToke
// GetDiscordToken retrieves the users discord token from the database
func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token := new(DiscordToken)
err := tx.NewSelect().
Model(token).
Where("discord_id = ?", u.DiscordID).
Limit(1).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return token, nil
return GetByField[DiscordToken](tx, "discord_id", u.DiscordID).GetFirst(ctx)
}
// Convert reverts the token back into a *discord.Token

View File

@@ -47,9 +47,9 @@ func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] {
}
}
// OnConflict adds conflict handling for upserts
// Example: .OnConflict("(discord_id) DO UPDATE")
func (i *inserter[T]) OnConflict(query string) *inserter[T] {
// On adds .On handling for upserts
// Example: .On("(discord_id) DO UPDATE")
func (i *inserter[T]) On(query string) *inserter[T] {
i.q = i.q.On(query)
return i
}

View File

@@ -34,20 +34,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis
if name == "" {
return nil, errors.New("name cannot be empty")
}
perm := new(Permission)
err := tx.NewSelect().
Model(perm).
Where("name = ?", name).
Limit(1).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perm, nil
return GetByField[Permission](tx, "name", name).GetFirst(ctx)
}
// GetPermissionByID queries the database for a permission matching the given ID
@@ -56,20 +43,7 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err
if id <= 0 {
return nil, errors.New("id must be positive")
}
perm := new(Permission)
err := tx.NewSelect().
Model(perm).
Where("id = ?", id).
Limit(1).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perm, nil
return GetByID[Permission](tx, id).GetFirst(ctx)
}
// GetPermissionsByResource queries for all permissions for a given resource
@@ -77,34 +51,8 @@ func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) (
if resource == "" {
return nil, errors.New("resource cannot be empty")
}
perms := []*Permission{}
err := tx.NewSelect().
Model(&perms).
Where("resource = ?", resource).
Order("action ASC").
Scan(ctx)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perms, nil
}
// GetPermissionsByIDs queries for permissions matching the given IDs
func GetPermissionsByIDs(ctx context.Context, tx bun.Tx, ids []int) ([]*Permission, error) {
if len(ids) == 0 {
return []*Permission{}, nil
}
var perms []*Permission
err := tx.NewSelect().
Model(&perms).
Where("id IN (?)", bun.In(ids)).
Scan(ctx)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perms, nil
perms, err := GetByField[[]*Permission](tx, "resource", resource).GetAll(ctx)
return *perms, err
}
// ListAllPermissions returns all permissions
@@ -138,12 +86,11 @@ func CreatePermission(ctx context.Context, tx bun.Tx, perm *Permission) error {
return errors.New("action cannot be empty")
}
_, err := tx.NewInsert().
Model(perm).
err := Insert(tx, perm).
Returning("id").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
return errors.Wrap(err, "db.Insert")
}
return nil

View File

@@ -77,12 +77,11 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
}
role.CreatedAt = time.Now().Unix()
_, err := tx.NewInsert().
Model(role).
err := Insert(tx, role).
Returning("id").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
return errors.Wrap(err, "db.Insert")
}
return nil
@@ -97,12 +96,11 @@ func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error {
return errors.New("role id must be positive")
}
_, err := tx.NewUpdate().
Model(role).
err := Update(tx, role).
WherePK().
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewUpdate")
return errors.Wrap(err, "db.Update")
}
return nil
@@ -128,12 +126,11 @@ func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID in
RoleID: roleID,
PermissionID: permissionID,
}
_, err := tx.NewInsert().
Model(rolePerm).
err := Insert(tx, rolePerm).
On("CONFLICT (role_id, permission_id) DO NOTHING").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
return errors.Wrap(err, "db.Insert")
}
return nil

View File

@@ -40,12 +40,11 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
DiscordID: discorduser.ID,
}
_, err := tx.NewInsert().
Model(user).
err := Insert(tx, user).
Returning("id").
Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "tx.NewInsert")
return nil, errors.Wrap(err, "db.Insert")
}
return user, nil

View File

@@ -29,12 +29,11 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
UserID: userID,
RoleID: roleID,
}
_, err := tx.NewInsert().
Model(userRole).
err := Insert(tx, userRole).
On("CONFLICT (user_id, role_id) DO NOTHING").
Exec(ctx)
if err != nil {
return errors.Wrap(err, "tx.NewInsert")
return errors.Wrap(err, "db.Insert")
}
return nil