diff --git a/internal/db/discordtokens.go b/internal/db/discordtokens.go index 9c6093f..c69e58a 100644 --- a/internal/db/discordtokens.go +++ b/internal/db/discordtokens.go @@ -38,10 +38,7 @@ func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord } err := Insert(tx, discordToken). - On("CONFLICT (discord_id) DO UPDATE"). - Set("access_token = EXCLUDED.access_token"). - Set("refresh_token = EXCLUDED.refresh_token"). - Set("expires_at = EXCLUDED.expires_at"). + ConflictUpdate([]string{"discord_id"}, "access_token", "refresh_token", "expires_at"). Exec(ctx) if err != nil { return errors.Wrap(err, "db.Insert") diff --git a/internal/db/insert.go b/internal/db/insert.go index bf00a57..a8082dd 100644 --- a/internal/db/insert.go +++ b/internal/db/insert.go @@ -2,6 +2,7 @@ package db import ( "context" + "fmt" "net/http" "strings" @@ -47,17 +48,18 @@ func InsertMultiple[T any](tx bun.Tx, models []*T) *inserter[T] { } } -// On adds .On handling for upserts -// Example: .On("(discord_id) DO UPDATE") -func (i *inserter[T]) On(query string) *inserter[T] { - i.q = i.q.On(query) +func (i *inserter[T]) ConflictNothing(conflicts ...string) *inserter[T] { + fieldstr := strings.Join(conflicts, ", ") + i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO NOTHING", fieldstr)) return i } -// Set adds a SET clause for upserts (use with OnConflict) -// Example: .Set("access_token = EXCLUDED.access_token") -func (i *inserter[T]) Set(query string, args ...any) *inserter[T] { - i.q = i.q.Set(query, args...) +func (i *inserter[T]) ConflictUpdate(conflicts []string, columns ...string) *inserter[T] { + fieldstr := strings.Join(conflicts, ", ") + i.q = i.q.On(fmt.Sprintf("CONFLICT (%s) DO UPDATE", fieldstr)) + for _, column := range columns { + i.q = i.q.Set(fmt.Sprintf("%s = EXCLUDED.%s", column, column)) + } return i } diff --git a/internal/db/role.go b/internal/db/role.go index c2c9c26..335b651 100644 --- a/internal/db/role.go +++ b/internal/db/role.go @@ -127,7 +127,7 @@ func AddPermissionToRole(ctx context.Context, tx bun.Tx, roleID, permissionID in PermissionID: permissionID, } err := Insert(tx, rolePerm). - On("CONFLICT (role_id, permission_id) DO NOTHING"). + ConflictNothing("role_id", "permission_id"). Exec(ctx) if err != nil { return errors.Wrap(err, "db.Insert") @@ -145,13 +145,12 @@ func RemovePermissionFromRole(ctx context.Context, tx bun.Tx, roleID, permission return errors.New("permissionID must be positive") } - _, err := tx.NewDelete(). - Model((*RolePermission)(nil)). + err := DeleteItem[RolePermission](tx). Where("role_id = ?", roleID). Where("permission_id = ?", permissionID). - Exec(ctx) + Delete(ctx) if err != nil { - return errors.Wrap(err, "tx.NewDelete") + return errors.Wrap(err, "DeleteItem") } return nil diff --git a/internal/db/userrole.go b/internal/db/userrole.go index 1ad6e31..a4c1f3c 100644 --- a/internal/db/userrole.go +++ b/internal/db/userrole.go @@ -30,8 +30,7 @@ func AssignRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { RoleID: roleID, } err := Insert(tx, userRole). - On("CONFLICT (user_id, role_id) DO NOTHING"). - Exec(ctx) + ConflictNothing("user_id", "role_id").Exec(ctx) if err != nil { return errors.Wrap(err, "db.Insert") } @@ -48,11 +47,10 @@ func RevokeRole(ctx context.Context, tx bun.Tx, userID, roleID int) error { return errors.New("roleID must be positive") } - _, err := tx.NewDelete(). - Model((*UserRole)(nil)). + err := DeleteItem[UserRole](tx). Where("user_id = ?", userID). Where("role_id = ?", roleID). - Exec(ctx) + Delete(ctx) if err != nil { return errors.Wrap(err, "tx.NewDelete") }