refactored db code

This commit is contained in:
2026-02-09 22:14:38 +11:00
parent b89ee75ca7
commit a4b4f4f4af
7 changed files with 24 additions and 100 deletions

View File

@@ -32,14 +32,10 @@ func CreateAuditLog(ctx context.Context, tx bun.Tx, log *AuditLog) error {
if log == nil { if log == nil {
return errors.New("log cannot be nil") return errors.New("log cannot be nil")
} }
err := Insert(tx, log).Exec(ctx)
_, err := tx.NewInsert().
Model(log).
Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewInsert") return errors.Wrap(err, "db.Insert")
} }
return nil return nil
} }

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"context" "context"
"database/sql"
"time" "time"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
@@ -38,15 +37,14 @@ func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord
TokenType: token.TokenType, TokenType: token.TokenType,
} }
_, err := tx.NewInsert(). err := Insert(tx, discordToken).
Model(discordToken).
On("CONFLICT (discord_id) DO UPDATE"). On("CONFLICT (discord_id) DO UPDATE").
Set("access_token = EXCLUDED.access_token"). Set("access_token = EXCLUDED.access_token").
Set("refresh_token = EXCLUDED.refresh_token"). Set("refresh_token = EXCLUDED.refresh_token").
Set("expires_at = EXCLUDED.expires_at"). Set("expires_at = EXCLUDED.expires_at").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewInsert") return errors.Wrap(err, "db.Insert")
} }
return nil return nil
} }
@@ -73,19 +71,7 @@ func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToke
// GetDiscordToken retrieves the users discord token from the database // GetDiscordToken retrieves the users discord token from the database
func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token := new(DiscordToken) return GetByField[DiscordToken](tx, "discord_id", u.DiscordID).GetFirst(ctx)
err := tx.NewSelect().
Model(token).
Where("discord_id = ?", u.DiscordID).
Limit(1).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return token, nil
} }
// Convert reverts the token back into a *discord.Token // Convert reverts the token back into a *discord.Token

View File

@@ -47,9 +47,9 @@ func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] {
} }
} }
// OnConflict adds conflict handling for upserts // On adds .On handling for upserts
// Example: .OnConflict("(discord_id) DO UPDATE") // Example: .On("(discord_id) DO UPDATE")
func (i *inserter[T]) OnConflict(query string) *inserter[T] { func (i *inserter[T]) On(query string) *inserter[T] {
i.q = i.q.On(query) i.q = i.q.On(query)
return i return i
} }

View File

@@ -34,20 +34,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis
if name == "" { if name == "" {
return nil, errors.New("name cannot be empty") return nil, errors.New("name cannot be empty")
} }
return GetByField[Permission](tx, "name", name).GetFirst(ctx)
perm := new(Permission)
err := tx.NewSelect().
Model(perm).
Where("name = ?", name).
Limit(1).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perm, nil
} }
// GetPermissionByID queries the database for a permission matching the given ID // GetPermissionByID queries the database for a permission matching the given ID
@@ -56,20 +43,7 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err
if id <= 0 { if id <= 0 {
return nil, errors.New("id must be positive") return nil, errors.New("id must be positive")
} }
return GetByID[Permission](tx, id).GetFirst(ctx)
perm := new(Permission)
err := tx.NewSelect().
Model(perm).
Where("id = ?", id).
Limit(1).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perm, nil
} }
// GetPermissionsByResource queries for all permissions for a given resource // GetPermissionsByResource queries for all permissions for a given resource
@@ -77,34 +51,8 @@ func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) (
if resource == "" { if resource == "" {
return nil, errors.New("resource cannot be empty") return nil, errors.New("resource cannot be empty")
} }
perms, err := GetByField[[]*Permission](tx, "resource", resource).GetAll(ctx)
perms := []*Permission{} return *perms, err
err := tx.NewSelect().
Model(&perms).
Where("resource = ?", resource).
Order("action ASC").
Scan(ctx)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perms, nil
}
// GetPermissionsByIDs queries for permissions matching the given IDs
func GetPermissionsByIDs(ctx context.Context, tx bun.Tx, ids []int) ([]*Permission, error) {
if len(ids) == 0 {
return []*Permission{}, nil
}
var perms []*Permission
err := tx.NewSelect().
Model(&perms).
Where("id IN (?)", bun.In(ids)).
Scan(ctx)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, errors.Wrap(err, "tx.NewSelect")
}
return perms, nil
} }
// ListAllPermissions returns all permissions // ListAllPermissions returns all permissions
@@ -138,12 +86,11 @@ func CreatePermission(ctx context.Context, tx bun.Tx, perm *Permission) error {
return errors.New("action cannot be empty") return errors.New("action cannot be empty")
} }
_, err := tx.NewInsert(). err := Insert(tx, perm).
Model(perm).
Returning("id"). Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewInsert") return errors.Wrap(err, "db.Insert")
} }
return nil return nil

View File

@@ -77,12 +77,11 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
} }
role.CreatedAt = time.Now().Unix() role.CreatedAt = time.Now().Unix()
_, err := tx.NewInsert(). err := Insert(tx, role).
Model(role).
Returning("id"). Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewInsert") return errors.Wrap(err, "db.Insert")
} }
return nil return nil
@@ -97,12 +96,11 @@ func UpdateRole(ctx context.Context, tx bun.Tx, role *Role) error {
return errors.New("role id must be positive") return errors.New("role id must be positive")
} }
_, err := tx.NewUpdate(). err := Update(tx, role).
Model(role).
WherePK(). WherePK().
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewUpdate") return errors.Wrap(err, "db.Update")
} }
return nil return nil
@@ -128,12 +126,11 @@ func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID in
RoleID: roleID, RoleID: roleID,
PermissionID: permissionID, PermissionID: permissionID,
} }
_, err := tx.NewInsert(). err := Insert(tx, rolePerm).
Model(rolePerm).
On("CONFLICT (role_id, permission_id) DO NOTHING"). On("CONFLICT (role_id, permission_id) DO NOTHING").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewInsert") return errors.Wrap(err, "db.Insert")
} }
return nil return nil

View File

@@ -40,12 +40,11 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
DiscordID: discorduser.ID, DiscordID: discorduser.ID,
} }
_, err := tx.NewInsert(). err := Insert(tx, user).
Model(user).
Returning("id"). Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.NewInsert") return nil, errors.Wrap(err, "db.Insert")
} }
return user, nil return user, nil

View File

@@ -29,12 +29,11 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
UserID: userID, UserID: userID,
RoleID: roleID, RoleID: roleID,
} }
_, err := tx.NewInsert(). err := Insert(tx, userRole).
Model(userRole).
On("CONFLICT (user_id, role_id) DO NOTHING"). On("CONFLICT (user_id, role_id) DO NOTHING").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewInsert") return errors.Wrap(err, "db.Insert")
} }
return nil return nil