everybody loves a refactor
This commit is contained in:
@@ -3,6 +3,7 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
@@ -77,6 +78,7 @@ func (a *AuditLogFilter) UserIDs(ids []int) *AuditLogFilter {
|
||||
}
|
||||
|
||||
func (a *AuditLogFilter) Actions(actions []string) *AuditLogFilter {
|
||||
fmt.Println(actions)
|
||||
if len(actions) > 0 {
|
||||
a.In("al.action", actions)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
@@ -46,13 +45,18 @@ func (d *deleter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *deleter[T] {
|
||||
}
|
||||
|
||||
func (d *deleter[T]) Delete(ctx context.Context) error {
|
||||
_, err := d.q.Exec(ctx)
|
||||
result, err := d.q.Exec(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err, "bun.DeleteQuery.Exec")
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "result.RowsAffected")
|
||||
}
|
||||
if rows == 0 {
|
||||
resource := extractResourceType(extractTableName[T]())
|
||||
return BadRequestNotFound(resource, "id", d.resourceID)
|
||||
}
|
||||
|
||||
// Handle audit logging if enabled
|
||||
if d.audit != nil {
|
||||
@@ -88,9 +92,6 @@ func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int,
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "GetByID")
|
||||
}
|
||||
if item == nil {
|
||||
return errors.New("record not found")
|
||||
}
|
||||
if (*item).isSystem() {
|
||||
return errors.New("record is system protected")
|
||||
}
|
||||
|
||||
@@ -51,11 +51,11 @@ func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord
|
||||
func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
|
||||
token, err := u.GetDiscordToken(ctx, tx)
|
||||
if err != nil {
|
||||
if IsBadRequest(err) {
|
||||
return nil, nil // Token doesn't exist - not an error
|
||||
}
|
||||
return nil, errors.Wrap(err, "user.GetDiscordToken")
|
||||
}
|
||||
if token == nil {
|
||||
return nil, nil
|
||||
}
|
||||
_, err = tx.NewDelete().
|
||||
Model((*DiscordToken)(nil)).
|
||||
Where("discord_id = ?", u.DiscordID).
|
||||
|
||||
31
internal/db/errors.go
Normal file
31
internal/db/errors.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func IsBadRequest(err error) bool {
|
||||
return strings.HasPrefix(err.Error(), "bad request:")
|
||||
}
|
||||
|
||||
func BadRequest(err string) error {
|
||||
return fmt.Errorf("bad request: %s", err)
|
||||
}
|
||||
|
||||
func BadRequestNotFound(resource, field string, value any) error {
|
||||
errStr := fmt.Sprintf("%s with %s=%v not found", resource, field, value)
|
||||
return BadRequest(errStr)
|
||||
}
|
||||
|
||||
func BadRequestNotAssociated(parent, child string, parentID, childID any) error {
|
||||
errStr := fmt.Sprintf("%s (ID: %v) not associated with %s (ID: %v)",
|
||||
child, childID, parent, parentID)
|
||||
return BadRequest(errStr)
|
||||
}
|
||||
|
||||
func BadRequestAssociated(parent, child string, parentID, childID any) error {
|
||||
errStr := fmt.Sprintf("%s (ID: %v) already associated with %s (ID: %v)",
|
||||
child, childID, parent, parentID)
|
||||
return BadRequest(errStr)
|
||||
}
|
||||
@@ -24,7 +24,8 @@ func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
resource := extractResourceType(extractTableName[T]())
|
||||
return nil, BadRequestNotFound(resource, g.field, g.value)
|
||||
}
|
||||
return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
@@ -104,6 +105,7 @@ func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] {
|
||||
l.q = l.q.Where("? ? ?", bun.Ident(filter.Field), bun.Safe(filter.Comparator), filter.Value)
|
||||
}
|
||||
}
|
||||
fmt.Println(l.q.String())
|
||||
return l
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ func (p Permission) isSystem() bool {
|
||||
}
|
||||
|
||||
// GetPermissionByName queries the database for a permission matching the given name
|
||||
// Returns nil, nil if no permission is found
|
||||
// Returns a BadRequestNotFound error if no permission is found
|
||||
func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("name cannot be empty")
|
||||
@@ -37,7 +37,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis
|
||||
}
|
||||
|
||||
// GetPermissionByID queries the database for a permission matching the given ID
|
||||
// Returns nil, nil if no permission is found
|
||||
// Returns a BadRequestNotFound error if no permission is found
|
||||
func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, error) {
|
||||
if id <= 0 {
|
||||
return nil, errors.New("id must be positive")
|
||||
|
||||
@@ -30,7 +30,7 @@ func (r Role) isSystem() bool {
|
||||
}
|
||||
|
||||
// GetRoleByName queries the database for a role matching the given name
|
||||
// Returns nil, nil if no role is found
|
||||
// Returns a BadRequestNotFound error if no role is found
|
||||
func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("name cannot be empty")
|
||||
@@ -39,7 +39,7 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro
|
||||
}
|
||||
|
||||
// GetRoleByID queries the database for a role matching the given ID
|
||||
// Returns nil, nil if no role is found
|
||||
// Returns a BadRequestNotFound error if no role is found
|
||||
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
|
||||
return GetByID[Role](tx, id).Relation("Permissions").Get(ctx)
|
||||
}
|
||||
@@ -110,9 +110,6 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "GetRoleByID")
|
||||
}
|
||||
if role == nil {
|
||||
return errors.New("role not found")
|
||||
}
|
||||
if role.IsSystem {
|
||||
return errors.New("cannot delete system roles")
|
||||
}
|
||||
|
||||
@@ -35,8 +35,8 @@ func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShor
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "GetLeague")
|
||||
}
|
||||
if season == nil || league == nil || !season.HasLeague(league.ID) {
|
||||
return nil, nil, nil, nil
|
||||
if !season.HasLeague(league.ID) {
|
||||
return nil, nil, nil, BadRequestNotAssociated("season", "league", seasonShortName, leagueShortName)
|
||||
}
|
||||
|
||||
// Get all teams participating in this season+league
|
||||
@@ -59,18 +59,12 @@ func NewSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShor
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "GetSeason")
|
||||
}
|
||||
if season == nil {
|
||||
return errors.New("season not found")
|
||||
}
|
||||
league, err := GetLeague(ctx, tx, leagueShortName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "GetLeague")
|
||||
}
|
||||
if league == nil {
|
||||
return errors.New("league not found")
|
||||
}
|
||||
if season.HasLeague(league.ID) {
|
||||
return errors.New("league already added to season")
|
||||
return BadRequestAssociated("season", "league", seasonShortName, leagueShortName)
|
||||
}
|
||||
seasonLeague := &SeasonLeague{
|
||||
SeasonID: season.ID,
|
||||
@@ -94,9 +88,6 @@ func (s *Season) RemoveLeague(ctx context.Context, tx bun.Tx, leagueShortName st
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "GetLeague")
|
||||
}
|
||||
if league == nil {
|
||||
return errors.New("league not found")
|
||||
}
|
||||
if !s.HasLeague(league.ID) {
|
||||
return errors.New("league not in season")
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ func (db *DB) RegisterModels() []any {
|
||||
(*Role)(nil),
|
||||
(*Permission)(nil),
|
||||
(*AuditLog)(nil),
|
||||
(*Fixture)(nil),
|
||||
}
|
||||
db.RegisterModel(models...)
|
||||
return models
|
||||
|
||||
@@ -23,28 +23,19 @@ func NewTeamParticipation(ctx context.Context, tx bun.Tx,
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "GetSeason")
|
||||
}
|
||||
if season == nil {
|
||||
return nil, nil, nil, errors.New("season not found")
|
||||
}
|
||||
league, err := GetLeague(ctx, tx, leagueShortName)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "GetLeague")
|
||||
}
|
||||
if league == nil {
|
||||
return nil, nil, nil, errors.New("league not found")
|
||||
}
|
||||
if !season.HasLeague(league.ID) {
|
||||
return nil, nil, nil, errors.New("league is not assigned to the season")
|
||||
return nil, nil, nil, BadRequestNotAssociated("season", "league", seasonShortName, leagueShortName)
|
||||
}
|
||||
team, err := GetTeam(ctx, tx, teamID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.Wrap(err, "GetTeam")
|
||||
}
|
||||
if team == nil {
|
||||
return nil, nil, nil, errors.New("team not found")
|
||||
}
|
||||
if team.InSeason(season.ID) {
|
||||
return nil, nil, nil, errors.New("team already in season")
|
||||
return nil, nil, nil, BadRequestAssociated("season", "team", seasonShortName, teamID)
|
||||
}
|
||||
participation := &TeamParticipation{
|
||||
SeasonID: season.ID,
|
||||
|
||||
@@ -85,10 +85,18 @@ func (u *updater[T]) Exec(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Execute update
|
||||
_, err := u.q.Exec(ctx)
|
||||
result, err := u.q.Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bun.UpdateQuery.Exec")
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "result.RowsAffected")
|
||||
}
|
||||
if rows == 0 {
|
||||
resource := extractResourceType(extractTableName[T]())
|
||||
return BadRequestNotFound(resource, "id", extractPrimaryKey(u.model))
|
||||
}
|
||||
|
||||
// Handle audit logging if enabled
|
||||
if u.audit != nil {
|
||||
|
||||
@@ -53,13 +53,13 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
|
||||
}
|
||||
|
||||
// GetUserByID queries the database for a user matching the given ID
|
||||
// Returns nil, nil if no user is found
|
||||
// Returns a BadRequestNotFound error if no user is found
|
||||
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
|
||||
return GetByID[User](tx, id).Get(ctx)
|
||||
}
|
||||
|
||||
// GetUserByUsername queries the database for a user matching the given username
|
||||
// Returns nil, nil if no user is found
|
||||
// Returns a BadRequestNotFound error if no user is found
|
||||
func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) {
|
||||
if username == "" {
|
||||
return nil, errors.New("username not provided")
|
||||
@@ -68,7 +68,7 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User,
|
||||
}
|
||||
|
||||
// GetUserByDiscordID queries the database for a user matching the given discord id
|
||||
// Returns nil, nil if no user is found
|
||||
// Returns a BadRequestNotFound error if no user is found
|
||||
func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) {
|
||||
if discordID == "" {
|
||||
return nil, errors.New("discord_id not provided")
|
||||
|
||||
@@ -94,9 +94,6 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "GetByID")
|
||||
}
|
||||
if user == nil {
|
||||
return false, nil
|
||||
}
|
||||
for _, role := range user.Roles {
|
||||
if role.Name == roleName {
|
||||
return true, nil
|
||||
|
||||
Reference in New Issue
Block a user