everybody loves a refactor

This commit is contained in:
2026-02-15 12:27:36 +11:00
parent 61890ae20b
commit ef8c022e60
44 changed files with 278 additions and 234 deletions

View File

@@ -3,6 +3,7 @@ package db
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -77,6 +78,7 @@ func (a *AuditLogFilter) UserIDs(ids []int) *AuditLogFilter {
} }
func (a *AuditLogFilter) Actions(actions []string) *AuditLogFilter { func (a *AuditLogFilter) Actions(actions []string) *AuditLogFilter {
fmt.Println(actions)
if len(actions) > 0 { if len(actions) > 0 {
a.In("al.action", actions) a.In("al.action", actions)
} }

View File

@@ -2,7 +2,6 @@ package db
import ( import (
"context" "context"
"database/sql"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -46,13 +45,18 @@ func (d *deleter[T]) WithAudit(meta *AuditMeta, info *AuditInfo) *deleter[T] {
} }
func (d *deleter[T]) Delete(ctx context.Context) error { func (d *deleter[T]) Delete(ctx context.Context) error {
_, err := d.q.Exec(ctx) result, err := d.q.Exec(ctx)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return errors.Wrap(err, "bun.DeleteQuery.Exec") return errors.Wrap(err, "bun.DeleteQuery.Exec")
} }
rows, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "result.RowsAffected")
}
if rows == 0 {
resource := extractResourceType(extractTableName[T]())
return BadRequestNotFound(resource, "id", d.resourceID)
}
// Handle audit logging if enabled // Handle audit logging if enabled
if d.audit != nil { if d.audit != nil {
@@ -88,9 +92,6 @@ func DeleteWithProtection[T systemType](ctx context.Context, tx bun.Tx, id int,
if err != nil { if err != nil {
return errors.Wrap(err, "GetByID") return errors.Wrap(err, "GetByID")
} }
if item == nil {
return errors.New("record not found")
}
if (*item).isSystem() { if (*item).isSystem() {
return errors.New("record is system protected") return errors.New("record is system protected")
} }

View File

@@ -51,11 +51,11 @@ func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord
func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) {
token, err := u.GetDiscordToken(ctx, tx) token, err := u.GetDiscordToken(ctx, tx)
if err != nil { if err != nil {
if IsBadRequest(err) {
return nil, nil // Token doesn't exist - not an error
}
return nil, errors.Wrap(err, "user.GetDiscordToken") return nil, errors.Wrap(err, "user.GetDiscordToken")
} }
if token == nil {
return nil, nil
}
_, err = tx.NewDelete(). _, err = tx.NewDelete().
Model((*DiscordToken)(nil)). Model((*DiscordToken)(nil)).
Where("discord_id = ?", u.DiscordID). Where("discord_id = ?", u.DiscordID).

31
internal/db/errors.go Normal file
View File

@@ -0,0 +1,31 @@
package db
import (
"fmt"
"strings"
)
func IsBadRequest(err error) bool {
return strings.HasPrefix(err.Error(), "bad request:")
}
func BadRequest(err string) error {
return fmt.Errorf("bad request: %s", err)
}
func BadRequestNotFound(resource, field string, value any) error {
errStr := fmt.Sprintf("%s with %s=%v not found", resource, field, value)
return BadRequest(errStr)
}
func BadRequestNotAssociated(parent, child string, parentID, childID any) error {
errStr := fmt.Sprintf("%s (ID: %v) not associated with %s (ID: %v)",
child, childID, parent, parentID)
return BadRequest(errStr)
}
func BadRequestAssociated(parent, child string, parentID, childID any) error {
errStr := fmt.Sprintf("%s (ID: %v) already associated with %s (ID: %v)",
child, childID, parent, parentID)
return BadRequest(errStr)
}

View File

@@ -24,7 +24,8 @@ func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, nil resource := extractResourceType(extractTableName[T]())
return nil, BadRequestNotFound(resource, g.field, g.value)
} }
return nil, errors.Wrap(err, "bun.SelectQuery.Scan") return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
} }

View File

@@ -3,6 +3,7 @@ package db
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -104,6 +105,7 @@ func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] {
l.q = l.q.Where("? ? ?", bun.Ident(filter.Field), bun.Safe(filter.Comparator), filter.Value) l.q = l.q.Where("? ? ?", bun.Ident(filter.Field), bun.Safe(filter.Comparator), filter.Value)
} }
} }
fmt.Println(l.q.String())
return l return l
} }

View File

@@ -28,7 +28,7 @@ func (p Permission) isSystem() bool {
} }
// GetPermissionByName queries the database for a permission matching the given name // GetPermissionByName queries the database for a permission matching the given name
// Returns nil, nil if no permission is found // Returns a BadRequestNotFound error if no permission is found
func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) { func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permission) (*Permission, error) {
if name == "" { if name == "" {
return nil, errors.New("name cannot be empty") return nil, errors.New("name cannot be empty")
@@ -37,7 +37,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis
} }
// GetPermissionByID queries the database for a permission matching the given ID // GetPermissionByID queries the database for a permission matching the given ID
// Returns nil, nil if no permission is found // Returns a BadRequestNotFound error if no permission is found
func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, error) { func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, error) {
if id <= 0 { if id <= 0 {
return nil, errors.New("id must be positive") return nil, errors.New("id must be positive")

View File

@@ -30,7 +30,7 @@ func (r Role) isSystem() bool {
} }
// GetRoleByName queries the database for a role matching the given name // GetRoleByName queries the database for a role matching the given name
// Returns nil, nil if no role is found // Returns a BadRequestNotFound error if no role is found
func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) { func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, error) {
if name == "" { if name == "" {
return nil, errors.New("name cannot be empty") return nil, errors.New("name cannot be empty")
@@ -39,7 +39,7 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro
} }
// GetRoleByID queries the database for a role matching the given ID // GetRoleByID queries the database for a role matching the given ID
// Returns nil, nil if no role is found // Returns a BadRequestNotFound error if no role is found
func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) { func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
return GetByID[Role](tx, id).Relation("Permissions").Get(ctx) return GetByID[Role](tx, id).Relation("Permissions").Get(ctx)
} }
@@ -110,9 +110,6 @@ func DeleteRole(ctx context.Context, tx bun.Tx, id int, audit *AuditMeta) error
if err != nil { if err != nil {
return errors.Wrap(err, "GetRoleByID") return errors.Wrap(err, "GetRoleByID")
} }
if role == nil {
return errors.New("role not found")
}
if role.IsSystem { if role.IsSystem {
return errors.New("cannot delete system roles") return errors.New("cannot delete system roles")
} }

View File

@@ -35,8 +35,8 @@ func GetSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShor
if err != nil { if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetLeague") return nil, nil, nil, errors.Wrap(err, "GetLeague")
} }
if season == nil || league == nil || !season.HasLeague(league.ID) { if !season.HasLeague(league.ID) {
return nil, nil, nil, nil return nil, nil, nil, BadRequestNotAssociated("season", "league", seasonShortName, leagueShortName)
} }
// Get all teams participating in this season+league // Get all teams participating in this season+league
@@ -59,18 +59,12 @@ func NewSeasonLeague(ctx context.Context, tx bun.Tx, seasonShortName, leagueShor
if err != nil { if err != nil {
return errors.Wrap(err, "GetSeason") return errors.Wrap(err, "GetSeason")
} }
if season == nil {
return errors.New("season not found")
}
league, err := GetLeague(ctx, tx, leagueShortName) league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil { if err != nil {
return errors.Wrap(err, "GetLeague") return errors.Wrap(err, "GetLeague")
} }
if league == nil {
return errors.New("league not found")
}
if season.HasLeague(league.ID) { if season.HasLeague(league.ID) {
return errors.New("league already added to season") return BadRequestAssociated("season", "league", seasonShortName, leagueShortName)
} }
seasonLeague := &SeasonLeague{ seasonLeague := &SeasonLeague{
SeasonID: season.ID, SeasonID: season.ID,
@@ -94,9 +88,6 @@ func (s *Season) RemoveLeague(ctx context.Context, tx bun.Tx, leagueShortName st
if err != nil { if err != nil {
return errors.Wrap(err, "GetLeague") return errors.Wrap(err, "GetLeague")
} }
if league == nil {
return errors.New("league not found")
}
if !s.HasLeague(league.ID) { if !s.HasLeague(league.ID) {
return errors.New("league not in season") return errors.New("league not in season")
} }

View File

@@ -32,6 +32,7 @@ func (db *DB) RegisterModels() []any {
(*Role)(nil), (*Role)(nil),
(*Permission)(nil), (*Permission)(nil),
(*AuditLog)(nil), (*AuditLog)(nil),
(*Fixture)(nil),
} }
db.RegisterModel(models...) db.RegisterModel(models...)
return models return models

View File

@@ -23,28 +23,19 @@ func NewTeamParticipation(ctx context.Context, tx bun.Tx,
if err != nil { if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetSeason") return nil, nil, nil, errors.Wrap(err, "GetSeason")
} }
if season == nil {
return nil, nil, nil, errors.New("season not found")
}
league, err := GetLeague(ctx, tx, leagueShortName) league, err := GetLeague(ctx, tx, leagueShortName)
if err != nil { if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetLeague") return nil, nil, nil, errors.Wrap(err, "GetLeague")
} }
if league == nil {
return nil, nil, nil, errors.New("league not found")
}
if !season.HasLeague(league.ID) { if !season.HasLeague(league.ID) {
return nil, nil, nil, errors.New("league is not assigned to the season") return nil, nil, nil, BadRequestNotAssociated("season", "league", seasonShortName, leagueShortName)
} }
team, err := GetTeam(ctx, tx, teamID) team, err := GetTeam(ctx, tx, teamID)
if err != nil { if err != nil {
return nil, nil, nil, errors.Wrap(err, "GetTeam") return nil, nil, nil, errors.Wrap(err, "GetTeam")
} }
if team == nil {
return nil, nil, nil, errors.New("team not found")
}
if team.InSeason(season.ID) { if team.InSeason(season.ID) {
return nil, nil, nil, errors.New("team already in season") return nil, nil, nil, BadRequestAssociated("season", "team", seasonShortName, teamID)
} }
participation := &TeamParticipation{ participation := &TeamParticipation{
SeasonID: season.ID, SeasonID: season.ID,

View File

@@ -85,10 +85,18 @@ func (u *updater[T]) Exec(ctx context.Context) error {
} }
// Execute update // Execute update
_, err := u.q.Exec(ctx) result, err := u.q.Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "bun.UpdateQuery.Exec") return errors.Wrap(err, "bun.UpdateQuery.Exec")
} }
rows, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "result.RowsAffected")
}
if rows == 0 {
resource := extractResourceType(extractTableName[T]())
return BadRequestNotFound(resource, "id", extractPrimaryKey(u.model))
}
// Handle audit logging if enabled // Handle audit logging if enabled
if u.audit != nil { if u.audit != nil {

View File

@@ -53,13 +53,13 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
} }
// GetUserByID queries the database for a user matching the given ID // GetUserByID queries the database for a user matching the given ID
// Returns nil, nil if no user is found // Returns a BadRequestNotFound error if no user is found
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
return GetByID[User](tx, id).Get(ctx) return GetByID[User](tx, id).Get(ctx)
} }
// GetUserByUsername queries the database for a user matching the given username // GetUserByUsername queries the database for a user matching the given username
// Returns nil, nil if no user is found // Returns a BadRequestNotFound error if no user is found
func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) { func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, error) {
if username == "" { if username == "" {
return nil, errors.New("username not provided") return nil, errors.New("username not provided")
@@ -68,7 +68,7 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User,
} }
// GetUserByDiscordID queries the database for a user matching the given discord id // GetUserByDiscordID queries the database for a user matching the given discord id
// Returns nil, nil if no user is found // Returns a BadRequestNotFound error if no user is found
func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) { func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) {
if discordID == "" { if discordID == "" {
return nil, errors.New("discord_id not provided") return nil, errors.New("discord_id not provided")

View File

@@ -94,9 +94,6 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b
if err != nil { if err != nil {
return false, errors.Wrap(err, "GetByID") return false, errors.Wrap(err, "GetByID")
} }
if user == nil {
return false, nil
}
for _, role := range user.Roles { for _, role := range user.Roles {
if role.Name == roleName { if role.Name == roleName {
return true, nil return true, nil

View File

@@ -185,12 +185,12 @@ func AdminAuditLogDetail(s *hws.Server, conn *db.DB) http.Handler {
var err error var err error
log, err = db.GetAuditLogByID(ctx, tx, id) log, err = db.GetAuditLogByID(ctx, tx, id)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetAuditLogByID") return false, errors.Wrap(err, "db.GetAuditLogByID")
} }
if log == nil {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return true, nil return true, nil
}); !ok { }); !ok {
return return

View File

@@ -31,12 +31,12 @@ func AdminPreviewRoleStart(s *hws.Server, conn *db.DB, ssl bool) http.Handler {
var err error var err error
role, err = db.GetRoleByID(ctx, tx, roleID) role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, "Role not found")
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID") return false, errors.Wrap(err, "db.GetRoleByID")
} }
if role == nil {
throw.NotFound(s, w, r, "Role not found")
return false, nil
}
// Cannot preview admin role // Cannot preview admin role
if role.Name == roles.Admin { if role.Name == roles.Admin {
throw.BadRequest(s, w, r, "Cannot preview admin role", nil) throw.BadRequest(s, w, r, "Cannot preview admin role", nil)

View File

@@ -9,6 +9,7 @@ import (
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/roles" "git.haelnorr.com/h/oslstats/internal/roles"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
adminview "git.haelnorr.com/h/oslstats/internal/view/adminview" adminview "git.haelnorr.com/h/oslstats/internal/view/adminview"
@@ -108,7 +109,7 @@ func AdminRoleManage(s *hws.Server, conn *db.DB) http.Handler {
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr) roleID, err := strconv.Atoi(roleIDStr)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) respond.BadRequest(w, err)
return return
} }
@@ -117,11 +118,12 @@ func AdminRoleManage(s *hws.Server, conn *db.DB) http.Handler {
var err error var err error
role, err = db.GetRoleByID(ctx, tx, roleID) role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID") return false, errors.Wrap(err, "db.GetRoleByID")
} }
if role == nil {
return false, errors.New("role not found")
}
return true, nil return true, nil
}); !ok { }); !ok {
return return
@@ -146,11 +148,12 @@ func AdminRoleDeleteConfirm(s *hws.Server, conn *db.DB) http.Handler {
var err error var err error
role, err = db.GetRoleByID(ctx, tx, roleID) role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID") return false, errors.Wrap(err, "db.GetRoleByID")
} }
if role == nil {
return false, errors.New("role not found")
}
return true, nil return true, nil
}); !ok { }); !ok {
return return
@@ -166,7 +169,7 @@ func AdminRoleDelete(s *hws.Server, conn *db.DB) http.Handler {
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr) roleID, err := strconv.Atoi(roleIDStr)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) respond.BadRequest(w, err)
return return
} }
@@ -180,11 +183,12 @@ func AdminRoleDelete(s *hws.Server, conn *db.DB) http.Handler {
// First check if role exists and get its details // First check if role exists and get its details
role, err := db.GetRoleByID(ctx, tx, roleID) role, err := db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID") return false, errors.Wrap(err, "db.GetRoleByID")
} }
if role == nil {
return false, errors.New("role not found")
}
// Check if it's a system role // Check if it's a system role
if role.IsSystem { if role.IsSystem {
@@ -194,6 +198,10 @@ func AdminRoleDelete(s *hws.Server, conn *db.DB) http.Handler {
// Delete the role with audit logging // Delete the role with audit logging
err = db.DeleteRole(ctx, tx, roleID, db.NewAudit(r, nil)) err = db.DeleteRole(ctx, tx, roleID, db.NewAudit(r, nil))
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.DeleteRole") return false, errors.Wrap(err, "db.DeleteRole")
} }
@@ -218,7 +226,7 @@ func AdminRolePermissionsModal(s *hws.Server, conn *db.DB) http.Handler {
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr) roleID, err := strconv.Atoi(roleIDStr)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) respond.BadRequest(w, err)
return return
} }
@@ -232,11 +240,12 @@ func AdminRolePermissionsModal(s *hws.Server, conn *db.DB) http.Handler {
var err error var err error
role, err = db.GetRoleByID(ctx, tx, roleID) role, err = db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID") return false, errors.Wrap(err, "db.GetRoleByID")
} }
if role == nil {
return false, errors.New("role not found")
}
// Load all permissions // Load all permissions
allPermissions, err = db.ListAllPermissions(ctx, tx) allPermissions, err = db.ListAllPermissions(ctx, tx)
@@ -281,7 +290,7 @@ func AdminRolePermissionsUpdate(s *hws.Server, conn *db.DB) http.Handler {
roleIDStr := r.PathValue("id") roleIDStr := r.PathValue("id")
roleID, err := strconv.Atoi(roleIDStr) roleID, err := strconv.Atoi(roleIDStr)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) respond.BadRequest(w, err)
return return
} }
@@ -305,12 +314,12 @@ func AdminRolePermissionsUpdate(s *hws.Server, conn *db.DB) http.Handler {
if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
role, err := db.GetRoleByID(ctx, tx, roleID) role, err := db.GetRoleByID(ctx, tx, roleID)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetRoleByID") return false, errors.Wrap(err, "db.GetRoleByID")
} }
if role == nil {
w.WriteHeader(http.StatusBadRequest)
return false, nil
}
err = role.UpdatePermissions(ctx, tx, permissionIDs, db.NewAudit(r, nil)) err = role.UpdatePermissions(ctx, tx, permissionIDs, db.NewAudit(r, nil))
if err != nil { if err != nil {
return false, errors.Wrap(err, "role.UpdatePermissions") return false, errors.Wrap(err, "role.UpdatePermissions")

View File

@@ -42,9 +42,6 @@ func ensureUserHasAdminRole(ctx context.Context, tx bun.Tx, user *db.User) error
if err != nil { if err != nil {
return errors.Wrap(err, "db.GetRoleByName") return errors.Wrap(err, "db.GetRoleByName")
} }
if adminRole == nil {
return errors.New("admin role not found in database")
}
// Grant admin role // Grant admin role
err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil) err = db.AssignRole(ctx, tx, user.ID, adminRole.ID, nil)

View File

@@ -158,7 +158,7 @@ func login(
} }
user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID) user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID)
if err != nil { if err != nil && !db.IsBadRequest(err) {
return nil, errors.Wrap(err, "db.GetUserByDiscordID") return nil, errors.Wrap(err, "db.GetUserByDiscordID")
} }
var redirect string var redirect string

View File

@@ -2,10 +2,12 @@ package handlers
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -22,12 +24,12 @@ func IsUnique(
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, err := validation.ParseForm(r) getter, err := validation.ParseForm(r)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) respond.BadRequest(w, err)
return return
} }
value := getter.String(field).TrimSpace().Required().Value value := getter.String(field).TrimSpace().Required().Value
if !getter.Validate() { if !getter.Validate() {
w.WriteHeader(http.StatusBadRequest) respond.BadRequest(w, err)
return return
} }
unique := false unique := false
@@ -41,9 +43,10 @@ func IsUnique(
return return
} }
if unique { if unique {
w.WriteHeader(http.StatusOK) respond.OK(w)
} else { } else {
w.WriteHeader(http.StatusConflict) err := fmt.Errorf("'%s' is not unique for field '%s'", value, field)
respond.Conflict(w, err)
} }
}) })
} }

View File

@@ -11,6 +11,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
leaguesview "git.haelnorr.com/h/oslstats/internal/view/leaguesview" leaguesview "git.haelnorr.com/h/oslstats/internal/view/leaguesview"
) )
@@ -78,8 +79,7 @@ func NewLeagueSubmit(
notify.Warn(s, w, r, "Duplicate Short Name", "This short name is already taken.", nil) notify.Warn(s, w, r, "Duplicate Short Name", "This short name is already taken.", nil)
return return
} }
w.Header().Set("HX-Redirect", fmt.Sprintf("/leagues/%s", league.ShortName)) respond.HXRedirect(w, "/leagues/%s", league.ShortName)
w.WriteHeader(http.StatusOK)
notify.SuccessWithDelay(s, w, r, "League Created", fmt.Sprintf("Successfully created league: %s", name), nil) notify.SuccessWithDelay(s, w, r, "League Created", fmt.Sprintf("Successfully created league: %s", name), nil)
}) })
} }

View File

@@ -12,6 +12,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/pkg/oauth" "git.haelnorr.com/h/oslstats/pkg/oauth"
@@ -34,10 +35,10 @@ func Login(
if r.Method == "POST" { if r.Method == "POST" {
if err != nil { if err != nil {
notify.ServiceUnavailable(s, w, r, "Login currently unavailable", err) notify.ServiceUnavailable(s, w, r, "Login currently unavailable", err)
w.WriteHeader(http.StatusOK) respond.OK(w)
return return
} }
w.Header().Set("HX-Redirect", "/login") respond.HXRedirect(w, "/login")
return return
} }

View File

@@ -8,6 +8,7 @@ import (
"git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/throw"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -22,11 +23,6 @@ func Logout(
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
user := db.CurrentUser(r.Context()) user := db.CurrentUser(r.Context())
if user == nil {
// JIC - should be impossible to get here if route is protected by LoginReq
w.Header().Set("HX-Redirect", "/")
return
}
if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
token, err := user.DeleteDiscordTokens(ctx, tx) token, err := user.DeleteDiscordTokens(ctx, tx)
if err != nil { if err != nil {
@@ -48,7 +44,7 @@ func Logout(
}); !ok { }); !ok {
return return
} }
w.Header().Set("HX-Redirect", "/") respond.HXRedirect(w, "/")
}, },
) )
} }

View File

@@ -12,6 +12,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/throw"
authview "git.haelnorr.com/h/oslstats/internal/view/authview" authview "git.haelnorr.com/h/oslstats/internal/view/authview"
@@ -82,7 +83,7 @@ func Register(
return return
} }
if !unique { if !unique {
w.WriteHeader(http.StatusConflict) respond.Conflict(w, errors.New("username is taken"))
} else { } else {
err = auth.Login(w, r, user, true) err = auth.Login(w, r, user, true)
if err != nil { if err != nil {
@@ -90,7 +91,7 @@ func Register(
return return
} }
pageFrom := cookies.CheckPageFrom(w, r) pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom) respond.HXRedirect(w, "%s", pageFrom)
} }
}, },
) )

View File

@@ -25,11 +25,12 @@ func SeasonPage(
var err error var err error
season, err = db.GetSeason(ctx, tx, seasonStr) season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeason") return false, errors.Wrap(err, "db.GetSeason")
} }
if season == nil {
return true, nil
}
leaguesWithTeams, err = season.MapTeamsToLeagues(ctx, tx) leaguesWithTeams, err = season.MapTeamsToLeagues(ctx, tx)
if err != nil { if err != nil {
@@ -40,10 +41,6 @@ func SeasonPage(
}); !ok { }); !ok {
return return
} }
if season == nil {
throw.NotFound(s, w, r, r.URL.Path)
return
}
renderSafely(seasonsview.DetailPage(season, leaguesWithTeams), s, r, w) renderSafely(seasonsview.DetailPage(season, leaguesWithTeams), s, r, w)
}) })
} }

View File

@@ -8,6 +8,7 @@ import (
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
"git.haelnorr.com/h/oslstats/internal/view/seasonsview" "git.haelnorr.com/h/oslstats/internal/view/seasonsview"
@@ -28,6 +29,10 @@ func SeasonEditPage(
var err error var err error
season, err = db.GetSeason(ctx, tx, seasonStr) season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeason") return false, errors.Wrap(err, "db.GetSeason")
} }
allLeagues, err = db.GetLeagues(ctx, tx) allLeagues, err = db.GetLeagues(ctx, tx)
@@ -38,10 +43,6 @@ func SeasonEditPage(
}); !ok { }); !ok {
return return
} }
if season == nil {
throw.NotFound(s, w, r, r.URL.Path)
return
}
renderSafely(seasonsview.EditPage(season, allLeagues), s, r, w) renderSafely(seasonsview.EditPage(season, allLeagues), s, r, w)
}) })
} }
@@ -79,11 +80,12 @@ func SeasonEditSubmit(
var err error var err error
season, err = db.GetSeason(ctx, tx, seasonStr) season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeason") return false, errors.Wrap(err, "db.GetSeason")
} }
if season == nil {
return false, errors.New("season does not exist")
}
err = season.Update(ctx, tx, version, start, end, finalsStart, finalsEnd, db.NewAudit(r, nil)) err = season.Update(ctx, tx, version, start, end, finalsStart, finalsEnd, db.NewAudit(r, nil))
if err != nil { if err != nil {
return false, errors.Wrap(err, "season.Update") return false, errors.Wrap(err, "season.Update")
@@ -93,13 +95,7 @@ func SeasonEditSubmit(
return return
} }
if season == nil { respond.HXRedirect(w, "/seasons/%s", season.ShortName)
throw.NotFound(s, w, r, r.URL.Path)
return
}
w.Header().Set("HX-Redirect", fmt.Sprintf("/seasons/%s", season.ShortName))
w.WriteHeader(http.StatusOK)
notify.SuccessWithDelay(s, w, r, "Season Updated", fmt.Sprintf("Successfully updated season: %s", season.Name), nil) notify.SuccessWithDelay(s, w, r, "Season Updated", fmt.Sprintf("Successfully updated season: %s", season.Name), nil)
}) })
} }

View File

@@ -38,6 +38,10 @@ func SeasonLeagueAddTeam(
var err error var err error
team, season, league, err = db.NewTeamParticipation(ctx, tx, seasonStr, leagueStr, teamID, db.NewAudit(r, nil)) team, season, league, err = db.NewTeamParticipation(ctx, tx, seasonStr, leagueStr, teamID, db.NewAudit(r, nil))
if err != nil { if err != nil {
if db.IsBadRequest(err) {
w.WriteHeader(http.StatusBadRequest)
return false, nil
}
return false, errors.Wrap(err, "db.NewTeamParticipation") return false, errors.Wrap(err, "db.NewTeamParticipation")
} }
return true, nil return true, nil

View File

@@ -22,12 +22,15 @@ func SeasonLeaguePage(
leagueStr := r.PathValue("league_short_name") leagueStr := r.PathValue("league_short_name")
var season *db.Season var season *db.Season
var league *db.League
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
season, league, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr) season, _, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeasonLeague") return false, errors.Wrap(err, "db.GetSeasonLeague")
} }
return true, nil return true, nil
@@ -35,11 +38,6 @@ func SeasonLeaguePage(
return return
} }
if season == nil || league == nil {
throw.NotFound(s, w, r, r.URL.Path)
return
}
defaultTab := season.GetDefaultTab() defaultTab := season.GetDefaultTab()
redirectURL := fmt.Sprintf( redirectURL := fmt.Sprintf(
"/seasons/%s/leagues/%s/%s", "/seasons/%s/leagues/%s/%s",

View File

@@ -7,7 +7,7 @@ import (
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/throw"
seasonsview "git.haelnorr.com/h/oslstats/internal/view/seasonsview" "git.haelnorr.com/h/oslstats/internal/view/seasonsview"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@@ -28,6 +28,10 @@ func SeasonLeagueFinalsPage(
var err error var err error
season, league, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr) season, league, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeasonLeague") return false, errors.Wrap(err, "db.GetSeasonLeague")
} }
return true, nil return true, nil
@@ -35,11 +39,6 @@ func SeasonLeagueFinalsPage(
return return
} }
if season == nil || league == nil {
throw.NotFound(s, w, r, r.URL.Path)
return
}
if r.Method == "GET" { if r.Method == "GET" {
renderSafely(seasonsview.SeasonLeagueFinalsPage(season, league), s, r, w) renderSafely(seasonsview.SeasonLeagueFinalsPage(season, league), s, r, w)
} else { } else {

View File

@@ -28,6 +28,10 @@ func SeasonLeagueFixturesPage(
var err error var err error
season, league, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr) season, league, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeasonLeague") return false, errors.Wrap(err, "db.GetSeasonLeague")
} }
return true, nil return true, nil
@@ -35,11 +39,6 @@ func SeasonLeagueFixturesPage(
return return
} }
if season == nil || league == nil {
throw.NotFound(s, w, r, r.URL.Path)
return
}
if r.Method == "GET" { if r.Method == "GET" {
renderSafely(seasonsview.SeasonLeagueFixturesPage(season, league), s, r, w) renderSafely(seasonsview.SeasonLeagueFixturesPage(season, league), s, r, w)
} else { } else {

View File

@@ -28,6 +28,10 @@ func SeasonLeagueStatsPage(
var err error var err error
season, league, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr) season, league, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeasonLeague") return false, errors.Wrap(err, "db.GetSeasonLeague")
} }
return true, nil return true, nil
@@ -35,11 +39,6 @@ func SeasonLeagueStatsPage(
return return
} }
if season == nil || league == nil {
throw.NotFound(s, w, r, r.URL.Path)
return
}
if r.Method == "GET" { if r.Method == "GET" {
renderSafely(seasonsview.SeasonLeagueStatsPage(season, league), s, r, w) renderSafely(seasonsview.SeasonLeagueStatsPage(season, league), s, r, w)
} else { } else {

View File

@@ -28,18 +28,16 @@ func SeasonLeagueTablePage(
var err error var err error
season, league, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr) season, league, _, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeasonLeague") return false, errors.Wrap(err, "db.GetSeasonLeague")
} }
return true, nil return true, nil
}); !ok { }); !ok {
return return
} }
if season == nil || league == nil {
throw.NotFound(s, w, r, r.URL.Path)
return
}
if r.Method == "GET" { if r.Method == "GET" {
renderSafely(seasonsview.SeasonLeagueTablePage(season, league), s, r, w) renderSafely(seasonsview.SeasonLeagueTablePage(season, league), s, r, w)
} else { } else {

View File

@@ -30,6 +30,10 @@ func SeasonLeagueTeamsPage(
var err error var err error
season, league, teams, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr) season, league, teams, err = db.GetSeasonLeague(ctx, tx, seasonStr, leagueStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
throw.NotFound(s, w, r, r.URL.Path)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeasonLeague") return false, errors.Wrap(err, "db.GetSeasonLeague")
} }
@@ -46,11 +50,6 @@ func SeasonLeagueTeamsPage(
return return
} }
if season == nil || league == nil {
throw.NotFound(s, w, r, r.URL.Path)
return
}
if r.Method == "GET" { if r.Method == "GET" {
renderSafely(seasonsview.SeasonLeagueTeamsPage(season, league, teams, available), s, r, w) renderSafely(seasonsview.SeasonLeagueTeamsPage(season, league, teams, available), s, r, w)
} else { } else {

View File

@@ -10,6 +10,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/view/seasonsview" "git.haelnorr.com/h/oslstats/internal/view/seasonsview"
) )
@@ -26,6 +27,10 @@ func SeasonAddLeague(
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
err := db.NewSeasonLeague(ctx, tx, seasonStr, leagueStr, db.NewAudit(r, nil)) err := db.NewSeasonLeague(ctx, tx, seasonStr, leagueStr, db.NewAudit(r, nil))
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.BadRequest(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.NewSeasonLeague") return false, errors.Wrap(err, "db.NewSeasonLeague")
} }
@@ -64,13 +69,17 @@ func SeasonRemoveLeague(
var err error var err error
season, err = db.GetSeason(ctx, tx, seasonStr) season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.NotFound(w, err)
return false, nil
}
return false, errors.Wrap(err, "db.GetSeason") return false, errors.Wrap(err, "db.GetSeason")
} }
if season == nil {
return false, errors.New("season not found")
}
err = season.RemoveLeague(ctx, tx, leagueStr, db.NewAudit(r, nil)) err = season.RemoveLeague(ctx, tx, leagueStr, db.NewAudit(r, nil))
if err != nil { if err != nil {
if db.IsBadRequest(err) {
respond.BadRequest(w, err)
}
return false, errors.Wrap(err, "season.RemoveLeague") return false, errors.Wrap(err, "season.RemoveLeague")
} }

View File

@@ -8,6 +8,7 @@ import (
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
seasonsview "git.haelnorr.com/h/oslstats/internal/view/seasonsview" seasonsview "git.haelnorr.com/h/oslstats/internal/view/seasonsview"
"git.haelnorr.com/h/timefmt" "git.haelnorr.com/h/timefmt"
@@ -83,8 +84,7 @@ func NewSeasonSubmit(
notify.Warn(s, w, r, "Duplicate Short Name", "This short name is already taken.", nil) notify.Warn(s, w, r, "Duplicate Short Name", "This short name is already taken.", nil)
return return
} }
w.Header().Set("HX-Redirect", fmt.Sprintf("/seasons/%s", season.ShortName)) respond.HXRedirect(w, "/seasons/%s", season.ShortName)
w.WriteHeader(http.StatusOK)
notify.SuccessWithDelay(s, w, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil) notify.SuccessWithDelay(s, w, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil)
}) })
} }

View File

@@ -2,10 +2,9 @@ package handlers
import ( import (
"net/http" "net/http"
"path/filepath"
"strings"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/throw"
) )
@@ -23,41 +22,8 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
// Explicitly set Content-Type for CSS files respond.ContentType(w, r)
if strings.HasSuffix(r.URL.Path, ".css") {
w.Header().Set("Content-Type", "text/css; charset=utf-8")
} else if strings.HasSuffix(r.URL.Path, ".js") {
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
} else if strings.HasSuffix(r.URL.Path, ".ico") {
w.Header().Set("Content-Type", "image/x-icon")
} else {
// Let Go detect the content type for other files
ext := filepath.Ext(r.URL.Path)
if contentType := mimeTypes[ext]; contentType != "" {
w.Header().Set("Content-Type", contentType)
}
}
fs.ServeHTTP(w, r) fs.ServeHTTP(w, r)
}, },
) )
} }
// Common MIME types for static files
var mimeTypes = map[string]string{
".html": "text/html; charset=utf-8",
".css": "text/css; charset=utf-8",
".js": "application/javascript; charset=utf-8",
".json": "application/json; charset=utf-8",
".xml": "application/xml; charset=utf-8",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".svg": "image/svg+xml",
".ico": "image/x-icon",
".webp": "image/webp",
".woff": "font/woff",
".woff2": "font/woff2",
".ttf": "font/ttf",
".eot": "application/vnd.ms-fontobject",
}

View File

@@ -6,6 +6,7 @@ import (
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -20,7 +21,7 @@ func IsTeamShortNamesUnique(
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, err := validation.ParseForm(r) getter, err := validation.ParseForm(r)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) respond.BadRequest(w, err)
return return
} }
@@ -28,12 +29,12 @@ func IsTeamShortNamesUnique(
altShortName := getter.String("alt_short_name").TrimSpace().ToUpper().MaxLength(3).Value altShortName := getter.String("alt_short_name").TrimSpace().ToUpper().MaxLength(3).Value
if shortName == "" || altShortName == "" { if shortName == "" || altShortName == "" {
w.WriteHeader(http.StatusOK) respond.OK(w)
return return
} }
if shortName == altShortName { if shortName == altShortName {
w.WriteHeader(http.StatusConflict) respond.Conflict(w, errors.New("short names cannot be the same"))
return return
} }
@@ -49,9 +50,9 @@ func IsTeamShortNamesUnique(
} }
if isUnique { if isUnique {
w.WriteHeader(http.StatusOK) respond.OK(w)
} else { } else {
w.WriteHeader(http.StatusConflict) respond.Conflict(w, errors.New("short name combination is taken"))
} }
}) })
} }

View File

@@ -32,31 +32,10 @@ func TeamsPage(
}); !ok { }); !ok {
return return
} }
renderSafely(teamsview.ListPage(teams), s, r, w) if r.Method == "GET" {
}) renderSafely(teamsview.ListPage(teams), s, r, w)
} } else {
renderSafely(teamsview.TeamsList(teams), s, r, w)
// TeamsList renders just the teams list, for use with POST requests and HTMX }
func TeamsList(
s *hws.Server,
conn *db.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pageOpts, ok := db.GetPageOpts(s, w, r)
if !ok {
return
}
var teams *db.List[db.Team]
if ok := conn.WithReadTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
teams, err = db.ListTeams(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.ListTeams")
}
return true, nil
}); !ok {
return
}
renderSafely(teamsview.TeamsList(teams), s, r, w)
}) })
} }

View File

@@ -11,6 +11,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/validation"
teamsview "git.haelnorr.com/h/oslstats/internal/view/teamsview" teamsview "git.haelnorr.com/h/oslstats/internal/view/teamsview"
) )
@@ -88,8 +89,7 @@ func NewTeamSubmit(
notify.Warn(s, w, r, "Duplicate Short Names", "This combination of short names is already taken.", nil) notify.Warn(s, w, r, "Duplicate Short Names", "This combination of short names is already taken.", nil)
return return
} }
w.Header().Set("HX-Redirect", "/teams") respond.HXRedirect(w, "teams")
w.WriteHeader(http.StatusOK)
notify.SuccessWithDelay(s, w, r, "Team Created", fmt.Sprintf("Successfully created team: %s", name), nil) notify.SuccessWithDelay(s, w, r, "Team Created", fmt.Sprintf("Successfully created team: %s", name), nil)
}) })
} }

36
internal/respond/error.go Normal file
View File

@@ -0,0 +1,36 @@
// Package respond provides utilities for raw HTTP responses that don't fit into the throw or notify categories
package respond
import (
"encoding/json"
"fmt"
"net/http"
)
func OK(w http.ResponseWriter) {
w.WriteHeader(http.StatusOK)
}
func BadRequest(w http.ResponseWriter, err error) {
respondError(w, http.StatusBadRequest, err)
}
func NotFound(w http.ResponseWriter, err error) {
respondError(w, http.StatusNotFound, err)
}
func Conflict(w http.ResponseWriter, err error) {
respondError(w, http.StatusConflict, err)
}
func respondError(w http.ResponseWriter, statusCode int, err error) {
details := map[string]any{
"error": statusCode,
"details": fmt.Sprintf("%s", err),
}
resp, err := json.Marshal(details)
w.WriteHeader(statusCode)
if err == nil {
_, _ = w.Write(resp)
}
}

View File

@@ -0,0 +1,39 @@
package respond
import (
"fmt"
"net/http"
"path/filepath"
)
func HXRedirect(w http.ResponseWriter, format string, a ...any) {
w.Header().Set("HX-Redirect", fmt.Sprintf(format, a...))
w.WriteHeader(http.StatusOK)
}
func ContentType(w http.ResponseWriter, r *http.Request) {
ext := filepath.Ext(r.URL.Path)
if contentType := mimeTypes[ext]; contentType != "" {
w.Header().Set("Content-Type", contentType)
}
}
// Common MIME types for static files
var mimeTypes = map[string]string{
".html": "text/html; charset=utf-8",
".css": "text/css; charset=utf-8",
".js": "application/javascript; charset=utf-8",
".json": "application/json; charset=utf-8",
".xml": "application/xml; charset=utf-8",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".svg": "image/svg+xml",
".ico": "image/x-icon",
".webp": "image/webp",
".woff": "font/woff",
".woff2": "font/woff2",
".ttf": "font/ttf",
".eot": "application/vnd.ms-fontobject",
}

View File

@@ -163,14 +163,9 @@ func addRoutes(
teamRoutes := []hws.Route{ teamRoutes := []hws.Route{
{ {
Path: "/teams", Path: "/teams",
Method: hws.MethodGET, Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: handlers.TeamsPage(s, conn), Handler: handlers.TeamsPage(s, conn),
}, },
{
Path: "/teams",
Method: hws.MethodPOST,
Handler: handlers.TeamsList(s, conn),
},
{ {
Path: "/teams/new", Path: "/teams/new",
Method: hws.MethodGET, Method: hws.MethodGET,

View File

@@ -2,6 +2,7 @@ package validation
import ( import (
"net/http" "net/http"
"strings"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/notify"
@@ -25,7 +26,7 @@ func (f *FormGetter) Get(key string) string {
} }
func (f *FormGetter) GetList(key string) []string { func (f *FormGetter) GetList(key string) []string {
return f.r.Form[key] return strings.Split(f.Get(key), ",")
} }
func (f *FormGetter) getChecks() []*ValidationRule { func (f *FormGetter) getChecks() []*ValidationRule {

View File

@@ -114,7 +114,7 @@ _migrate-status:
{{bin}}/{{entrypoint}} --migrate-status --envfile $ENVFILE {{bin}}/{{entrypoint}} --migrate-status --envfile $ENVFILE
[private] [private]
_migrate-new name: && _migrate-status _migrate-new name: && _build _migrate-status
{{bin}}/{{entrypoint}} --migrate-create {{name}} {{bin}}/{{entrypoint}} --migrate-create {{name}}
# Hard reset the database # Hard reset the database