everybody loves a refactor

This commit is contained in:
2026-02-15 12:27:36 +11:00
parent 61890ae20b
commit ef8c022e60
44 changed files with 278 additions and 234 deletions

View File

@@ -3,6 +3,7 @@ package db
import (
"context"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"github.com/uptrace/bun"
@@ -77,6 +78,7 @@ func (a *AuditLogFilter) UserIDs(ids []int) *AuditLogFilter {
}
func (a *AuditLogFilter) Actions(actions []string) *AuditLogFilter {
fmt.Println(actions)
if len(actions) > 0 {
a.In("al.action", actions)
}

View File

@@ -2,7 +2,6 @@ package db
import (
"context"
"database/sql"
"github.com/pkg/errors"
"github.com/uptrace/bun"
@@ -46,13 +45,18 @@ func (d *deleter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *deleter[T] {
}
func (d *deleter[T]) Delete(ctx context.Context) error {
_, err := d.q.Exec(ctx)
result, err := d.q.Exec(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return errors.Wrap(err, "bun.DeleteQuery.Exec")
}
rows, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "result.RowsAffected")
}
if rows == 0 {
resource := extractResourceType(extractTableName[T]())
return BadRequestNotFound(resource, "id", d.resourceID)
}
// Handle audit logging if enabled
if d.audit != nil {
@@ -88,9 +92,6 @@ func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int,
if err != nil {
return errors.Wrap(err, "GetByID")
}
if item == nil {
return errors.New("record not found")
}
if (*item).isSystem() {
return errors.New("record is system protected")
}

View File

@@ -51,11 +51,11 @@ func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord
func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token, err := u.GetDiscordToken(ctx, tx)
if err != nil {
if IsBadRequest(err) {
return nil, nil // Token doesn't exist - not an error
}
return nil, errors.Wrap(err, "user.GetDiscordToken")
}
if token == nil {
return nil, nil
}
_, err = tx.NewDelete().
Model((*DiscordToken)(nil)).
Where("discord_id = ?", u.DiscordID).

31
internal/db/errors.go Normal file
View File

@@ -0,0 +1,31 @@
package db
import (
"fmt"
"strings"
)
func IsBadRequest(err error) bool {
return strings.HasPrefix(err.Error(), "bad request:")
}
func BadRequest(err string) error {
return fmt.Errorf("bad request: %s", err)
}
func BadRequestNotFound(resource, field string, value any) error {
errStr := fmt.Sprintf("%s with %s=%v not found", resource, field, value)
return BadRequest(errStr)
}
func BadRequestNotAssociated(parent, child string, parentID, childID any) error {
errStr := fmt.Sprintf("%s (ID: %v) not associated with %s (ID: %v)",
child, childID, parent, parentID)
return BadRequest(errStr)
}
func BadRequestAssociated(parent, child string, parentID, childID any) error {
errStr := fmt.Sprintf("%s (ID: %v) already associated with %s (ID: %v)",
child, childID, parent, parentID)
return BadRequest(errStr)
}

View File

@@ -24,7 +24,8 @@ func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
resource := extractResourceType(extractTableName[T]())
return nil, BadRequestNotFound(resource, g.field, g.value)
}
return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
}

View File

@@ -3,6 +3,7 @@ package db
import (
"context"
"database/sql"
"fmt"
"github.com/pkg/errors"
"github.com/uptrace/bun"
@@ -104,6 +105,7 @@ func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] {
l.q = l.q.Where("? ? ?", bun.Ident(filter.Field), bun.Safe(filter.Comparator), filter.Value)
}
}
fmt.Println(l.q.String())
return l
}

View File

@@ -28,7 +28,7 @@ func (p Permission) isSystem() bool {
}
// GetPermissionByName queries the database for a permission matching the given name
// Returns nil, nil if no permission is found
// Returns a BadRequestNotFound error if no permission is found
func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) {
if name == "" {
return nil, errors.New("name cannot be empty")
@@ -37,7 +37,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis
}
// GetPermissionByID queries the database for a permission matching the given ID
// Returns nil, nil if no permission is found
// Returns a BadRequestNotFound error if no permission is found
func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, error) {
if id <= 0 {
return nil, errors.New("id must be positive")

View File

@@ -30,7 +30,7 @@ func (r Role) isSystem() bool {
}
// GetRoleByName queries the database for a role matching the given name
// Returns nil, nil if no role is found
// Returns a BadRequestNotFound error if no role is found
func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) {
if name == "" {
return nil, errors.New("name cannot be empty")
@@ -39,7 +39,7 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro
}
// GetRoleByID queries the database for a role matching the given ID
// Returns nil, nil if no role is found
// Returns a BadRequestNotFound error if no role is found
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
return GetByID[Role](tx, id).Relation("Permissions").Get(ctx)
}
@@ -110,9 +110,6 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error
if err != nil {
return errors.Wrap(err, "GetRoleByID")
}
if role == nil {
return errors.New("role not found")
}
if role.IsSystem {
return errors.New("cannot delete system roles")
}

View File

@@ -35,8 +35,8 @@ func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShor
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetLeague")
}
if season == nil || league == nil || !season.HasLeague(league.ID) {
return nil, nil, nil, nil
if !season.HasLeague(league.ID) {
return nil, nil, nil, BadRequestNotAssociated("season", "league", seasonShortName, leagueShortName)
}
// Get all teams participating in this season+league
@@ -59,18 +59,12 @@ func NewSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShor
if err != nil {
return errors.Wrap(err, "GetSeason")
}
if season == nil {
return errors.New("season not found")
}
league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil {
return errors.Wrap(err, "GetLeague")
}
if league == nil {
return errors.New("league not found")
}
if season.HasLeague(league.ID) {
return errors.New("league already added to season")
return BadRequestAssociated("season", "league", seasonShortName, leagueShortName)
}
seasonLeague := &SeasonLeague{
SeasonID: season.ID,
@@ -94,9 +88,6 @@ func (s *Season) RemoveLeague(ctx context.Context, tx bun.Tx, leagueShortName st
if err != nil {
return errors.Wrap(err, "GetLeague")
}
if league == nil {
return errors.New("league not found")
}
if !s.HasLeague(league.ID) {
return errors.New("league not in season")
}

View File

@@ -32,6 +32,7 @@ func (db *DB) RegisterModels() []any {
(*Role)(nil),
(*Permission)(nil),
(*AuditLog)(nil),
(*Fixture)(nil),
}
db.RegisterModel(models...)
return models

View File

@@ -23,28 +23,19 @@ func NewTeamParticipation(ctx context.Context, tx bun.Tx,
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetSeason")
}
if season == nil {
return nil, nil, nil, errors.New("season not found")
}
league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetLeague")
}
if league == nil {
return nil, nil, nil, errors.New("league not found")
}
if !season.HasLeague(league.ID) {
return nil, nil, nil, errors.New("league is not assigned to the season")
return nil, nil, nil, BadRequestNotAssociated("season", "league", seasonShortName, leagueShortName)
}
team, err := GetTeam(ctx, tx, teamID)
if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetTeam")
}
if team == nil {
return nil, nil, nil, errors.New("team not found")
}
if team.InSeason(season.ID) {
return nil, nil, nil, errors.New("team already in season")
return nil, nil, nil, BadRequestAssociated("season", "team", seasonShortName, teamID)
}
participation := &TeamParticipation{
SeasonID: season.ID,

View File

@@ -85,10 +85,18 @@ func (u *updater[T]) Exec(ctx context.Context) error {
}
// Execute update
_, err := u.q.Exec(ctx)
result, err := u.q.Exec(ctx)
if err != nil {
return errors.Wrap(err, "bun.UpdateQuery.Exec")
}
rows, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "result.RowsAffected")
}
if rows == 0 {
resource := extractResourceType(extractTableName[T]())
return BadRequestNotFound(resource, "id", extractPrimaryKey(u.model))
}
// Handle audit logging if enabled
if u.audit != nil {

View File

@@ -53,13 +53,13 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
}
// GetUserByID queries the database for a user matching the given ID
// Returns nil, nil if no user is found
// Returns a BadRequestNotFound error if no user is found
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
return GetByID[User](tx, id).Get(ctx)
}
// GetUserByUsername queries the database for a user matching the given username
// Returns nil, nil if no user is found
// Returns a BadRequestNotFound error if no user is found
func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) {
if username == "" {
return nil, errors.New("username not provided")
@@ -68,7 +68,7 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User,
}
// GetUserByDiscordID queries the database for a user matching the given discord id
// Returns nil, nil if no user is found
// Returns a BadRequestNotFound error if no user is found
func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) {
if discordID == "" {
return nil, errors.New("discord_id not provided")

View File

@@ -94,9 +94,6 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b
if err != nil {
return false, errors.Wrap(err, "GetByID")
}
if user == nil {
return false, nil
}
for _, role := range user.Roles {
if role.Name == roleName {
return true, nil