more db refactors

This commit is contained in:
2026-02-09 22:28:52 +11:00
parent fada95a7e4
commit ead70bcfc4
4 changed files with 18 additions and 22 deletions

View File

@@ -38,10 +38,7 @@ func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord
} }
err := Insert(tx, discordToken). err := Insert(tx, discordToken).
On("CONFLICT (discord_id) DO UPDATE"). ConflictUpdate([]string{"discord_id"}, "access_token", "refresh_token", "expires_at").
Set("access_token = EXCLUDED.access_token").
Set("refresh_token = EXCLUDED.refresh_token").
Set("expires_at = EXCLUDED.expires_at").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "db.Insert") return errors.Wrap(err, "db.Insert")

View File

@@ -2,6 +2,7 @@ package db
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"strings" "strings"
@@ -47,17 +48,18 @@ func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] {
} }
} }
// On adds .On handling for upserts func (i *inserter[T]) ConflictNothing(conflicts ...string) *inserter[T] {
// Example: .On("(discord_id) DO UPDATE") fieldstr := strings.Join(conflicts, ", ")
func (i *inserter[T]) On(query string) *inserter[T] { i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO NOTHING", fieldstr))
i.q = i.q.On(query)
return i return i
} }
// Set adds a SET clause for upserts (use with OnConflict) func (i *inserter[T]) ConflictUpdate(conflicts []string, columns ...string) *inserter[T] {
// Example: .Set("access_token = EXCLUDED.access_token") fieldstr := strings.Join(conflicts, ", ")
func (i *inserter[T]) Set(query string, args ...any) *inserter[T] { i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO UPDATE", fieldstr))
i.q = i.q.Set(query, args...) for _, column := range columns {
i.q = i.q.Set(fmt.Sprintf("%s = EXCLUDED.%s", column, column))
}
return i return i
} }

View File

@@ -127,7 +127,7 @@ func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID in
PermissionID: permissionID, PermissionID: permissionID,
} }
err := Insert(tx, rolePerm). err := Insert(tx, rolePerm).
On("CONFLICT (role_id, permission_id) DO NOTHING"). ConflictNothing("role_id", "permission_id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "db.Insert") return errors.Wrap(err, "db.Insert")
@@ -145,13 +145,12 @@ func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permission
return errors.New("permissionID must be positive") return errors.New("permissionID must be positive")
} }
_, err := tx.NewDelete(). err := DeleteItem[RolePermission](tx).
Model((*RolePermission)(nil)).
Where("role_id = ?", roleID). Where("role_id = ?", roleID).
Where("permission_id = ?", permissionID). Where("permission_id = ?", permissionID).
Exec(ctx) Delete(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewDelete") return errors.Wrap(err, "DeleteItem")
} }
return nil return nil

View File

@@ -30,8 +30,7 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
RoleID: roleID, RoleID: roleID,
} }
err := Insert(tx, userRole). err := Insert(tx, userRole).
On("CONFLICT (user_id, role_id) DO NOTHING"). ConflictNothing("user_id", "role_id").Exec(ctx)
Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "db.Insert") return errors.Wrap(err, "db.Insert")
} }
@@ -48,11 +47,10 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error {
return errors.New("roleID must be positive") return errors.New("roleID must be positive")
} }
_, err := tx.NewDelete(). err := DeleteItem[UserRole](tx).
Model((*UserRole)(nil)).
Where("user_id = ?", userID). Where("user_id = ?", userID).
Where("role_id = ?", roleID). Where("role_id = ?", roleID).
Exec(ctx) Delete(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewDelete") return errors.Wrap(err, "tx.NewDelete")
} }