refactored for maintainability
This commit is contained in:
@@ -101,17 +101,17 @@ func addRoutes(
|
||||
{
|
||||
Path: "/htmx/isusernameunique",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: handlers.IsUsernameUnique(s, conn, cfg, store),
|
||||
Handler: handlers.IsUnique(s, conn, (*db.User)(nil), "username"),
|
||||
},
|
||||
{
|
||||
Path: "/htmx/isseasonnameunique",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: handlers.IsSeasonNameUnique(s, conn),
|
||||
Handler: handlers.IsUnique(s, conn, (*db.Season)(nil), "name"),
|
||||
},
|
||||
{
|
||||
Path: "/htmx/isseasonshortnameunique",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: handlers.IsSeasonShortNameUnique(s, conn),
|
||||
Handler: handlers.IsUnique(s, conn, (*db.Season)(nil), "short_name"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
3
go.mod
3
go.mod
@@ -1,6 +1,6 @@
|
||||
module git.haelnorr.com/h/oslstats
|
||||
|
||||
go 1.25.5
|
||||
go 1.25.6
|
||||
|
||||
require (
|
||||
git.haelnorr.com/h/golib/env v0.9.1
|
||||
@@ -27,6 +27,7 @@ require (
|
||||
require (
|
||||
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1 // indirect
|
||||
git.haelnorr.com/h/timefmt v0.1.0
|
||||
github.com/bwmarrin/discordgo v0.29.0
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -16,6 +16,8 @@ git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7Kv
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
|
||||
git.haelnorr.com/h/timefmt v0.1.0 h1:ULDkWEtFIV+FkkoV0q9n62Spj+HDdtFL9QeAdGIEp+o=
|
||||
git.haelnorr.com/h/timefmt v0.1.0/go.mod h1:12gXXYLP4w9Fa9ZkbZWdvKV6RyZEzwAm9mN+WB3oXpw=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=
|
||||
|
||||
@@ -105,35 +105,33 @@ func (l *Logger) log(
|
||||
|
||||
// GetRecentLogs retrieves recent audit logs with pagination
|
||||
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
||||
tx, err := l.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.BeginTx")
|
||||
var logs *db.AuditLogs
|
||||
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||
var err error
|
||||
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.GetAuditLogs")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "db.WithTxFailSilently")
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
logs, err := db.GetAuditLogs(ctx, tx, pageOpts, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = tx.Commit() // read only transaction
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// GetLogsByUser retrieves audit logs for a specific user
|
||||
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
||||
tx, err := l.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.BeginTx")
|
||||
var logs *db.AuditLogs
|
||||
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||
var err error
|
||||
logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.GetAuditLogsByUser")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "db.WithTxFailSilently")
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
logs, err := db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = tx.Commit() // read only transaction
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
@@ -145,20 +143,16 @@ func (l *Logger) CleanupOldLogs(ctx context.Context, daysToKeep int) (int, error
|
||||
|
||||
cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix()
|
||||
|
||||
tx, err := l.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "conn.BeginTx")
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
count, err := db.CleanupOldAuditLogs(ctx, tx, cutoffTime)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "tx.Commit")
|
||||
var count int
|
||||
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||
var err error
|
||||
count, err = db.CleanupOldAuditLogs(ctx, tx, cutoffTime)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.CleanupOldAuditLogs")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return 0, errors.Wrap(err, "db.WithTxFailSilently")
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
19
internal/db/isunique.go
Normal file
19
internal/db/isunique.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func IsUnique(ctx context.Context, tx bun.Tx, model any, field, value string) (bool, error) {
|
||||
count, err := tx.NewSelect().
|
||||
Model(model).
|
||||
Where("? = ?", bun.Ident(field), value).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return count == 0, nil
|
||||
}
|
||||
@@ -87,25 +87,3 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error
|
||||
}
|
||||
return season, nil
|
||||
}
|
||||
|
||||
func IsSeasonNameUnique(ctx context.Context, tx bun.Tx, name string) (bool, error) {
|
||||
count, err := tx.NewSelect().
|
||||
Model((*Season)(nil)).
|
||||
Where("name = ?", name).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return count == 0, nil
|
||||
}
|
||||
|
||||
func IsSeasonShortNameUnique(ctx context.Context, tx bun.Tx, shortname string) (bool, error) {
|
||||
count, err := tx.NewSelect().
|
||||
Model((*Season)(nil)).
|
||||
Where("short_name = ?", shortname).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return count == 0, nil
|
||||
}
|
||||
|
||||
118
internal/db/txhelpers.go
Normal file
118
internal/db/txhelpers.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// TxFunc is a function that runs within a database transaction
|
||||
type (
|
||||
TxFunc func(ctx context.Context, tx bun.Tx) (bool, error)
|
||||
TxFuncSilent func(ctx context.Context, tx bun.Tx) error
|
||||
)
|
||||
|
||||
var timeout = 15 * time.Second
|
||||
|
||||
// WithReadTx executes a read-only transaction with automatic rollback
|
||||
// Returns true if successful, false if error was thrown to client
|
||||
func WithReadTx(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
conn *bun.DB,
|
||||
fn TxFunc,
|
||||
) bool {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
ok, err := withTx(ctx, conn, fn, false)
|
||||
if err != nil {
|
||||
throw.InternalServiceError(s, w, r, "Database error", err)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// WithTxFailSilently executes a transaction with automatic rollback
|
||||
// Returns true if successful, false if error occured.
|
||||
// Does not throw any errors to the client.
|
||||
func WithTxFailSilently(
|
||||
ctx context.Context,
|
||||
conn *bun.DB,
|
||||
fn TxFuncSilent,
|
||||
) error {
|
||||
fnc := func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
err := fn(ctx, tx)
|
||||
return err != nil, err
|
||||
}
|
||||
_, err := withTx(ctx, conn, fnc, true)
|
||||
return err
|
||||
}
|
||||
|
||||
// WithWriteTx executes a write transaction with automatic rollback on error
|
||||
// Commits only if fn returns nil. Returns true if successful.
|
||||
func WithWriteTx(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
conn *bun.DB,
|
||||
fn TxFunc,
|
||||
) bool {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
ok, err := withTx(ctx, conn, fn, true)
|
||||
if err != nil {
|
||||
throw.InternalServiceError(s, w, r, "Database error", err)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// WithNotifyTx executes a transaction with notification-based error handling
|
||||
// Uses notifyInternalServiceError instead of throwInternalServiceError
|
||||
func WithNotifyTx(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
conn *bun.DB,
|
||||
fn TxFunc,
|
||||
) bool {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
ok, err := withTx(ctx, conn, fn, true)
|
||||
if err != nil {
|
||||
notify.InternalServiceError(s, w, r, "Database error", err)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// withTx executes a transaction with automatic rollback on error
|
||||
func withTx(
|
||||
ctx context.Context,
|
||||
conn *bun.DB,
|
||||
fn TxFunc,
|
||||
write bool,
|
||||
) (bool, error) {
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "conn.BeginTx")
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
ok, err := fn(ctx, tx)
|
||||
if err != nil || !ok {
|
||||
return false, err
|
||||
}
|
||||
if write {
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.Commit")
|
||||
}
|
||||
} else {
|
||||
_ = tx.Commit()
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@@ -112,19 +112,6 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// IsUsernameUnique checks if the given username is unique (not already taken)
|
||||
// Returns true if the username is available, false if it's taken
|
||||
func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) {
|
||||
count, err := tx.NewSelect().
|
||||
Model((*User)(nil)).
|
||||
Where("username = ?", username).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.NewSelect")
|
||||
}
|
||||
return count == 0, nil
|
||||
}
|
||||
|
||||
// GetRoles loads all the roles for this user
|
||||
func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
|
||||
if u == nil {
|
||||
|
||||
@@ -3,7 +3,6 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
@@ -14,23 +13,17 @@ import (
|
||||
|
||||
func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
var users *db.Users
|
||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
var err error
|
||||
users, err = db.GetUsers(ctx, tx, nil)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "db.GetUsers")
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
users, err := db.GetUsers(ctx, tx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetUsers"))
|
||||
return
|
||||
}
|
||||
_ = tx.Commit()
|
||||
|
||||
renderSafely(page.AdminDashboard(users), s, r, w)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
@@ -15,30 +14,21 @@ import (
|
||||
// AdminUsersList shows all users
|
||||
func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "DB Transaction failed", errors.Wrap(err, "conn.BeginTx"))
|
||||
var users *db.Users
|
||||
pageOpts := pageOptsFromForm(s, w, r)
|
||||
if pageOpts == nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Get all users
|
||||
pageOpts, err := pageOptsFromForm(r)
|
||||
if err != nil {
|
||||
throwBadRequest(s, w, r, "invalid form data", err)
|
||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
var err error
|
||||
users, err = db.GetUsers(ctx, tx, pageOpts)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "db.GetUsers")
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
users, err := db.GetUsers(ctx, tx, pageOpts)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Failed to load users", errors.Wrap(err, "db.GetUsers"))
|
||||
return
|
||||
}
|
||||
|
||||
_ = tx.Commit()
|
||||
|
||||
renderSafely(admin.UserList(users), s, r, w)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
@@ -15,11 +14,12 @@ import (
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||
)
|
||||
|
||||
func Callback(
|
||||
server *hws.Server,
|
||||
s *hws.Server,
|
||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
@@ -31,26 +31,9 @@ func Callback(
|
||||
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
|
||||
|
||||
if exceeded {
|
||||
err := errors.Errorf(
|
||||
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||
attempts,
|
||||
track.IP,
|
||||
track.UserAgent,
|
||||
track.Path,
|
||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||
)
|
||||
|
||||
err := track.Error(attempts)
|
||||
store.ClearRedirectTrack(r, "/callback")
|
||||
|
||||
throwError(
|
||||
server,
|
||||
w,
|
||||
r,
|
||||
http.StatusBadRequest,
|
||||
"OAuth callback failed: Too many redirect attempts. Please try logging in again.",
|
||||
err,
|
||||
"warn",
|
||||
)
|
||||
throw.BadRequest(s, w, r, "Too many redirects. Please try logging in again.", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -64,12 +47,12 @@ func Callback(
|
||||
if err != nil {
|
||||
if vsErr, ok := err.(*verifyStateError); ok {
|
||||
if vsErr.IsCookieError() {
|
||||
throwUnauthorized(server, w, r, "OAuth session not found or expired", err)
|
||||
throw.Unauthorized(s, w, r, "OAuth session not found or expired", err)
|
||||
} else {
|
||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
||||
throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
|
||||
}
|
||||
} else {
|
||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
||||
throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -77,20 +60,17 @@ func Callback(
|
||||
|
||||
switch data {
|
||||
case "login":
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "DB Transaction failed to start", err)
|
||||
var redirect func()
|
||||
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
|
||||
if err != nil {
|
||||
throw.InternalServiceError(s, w, r, "OAuth login failed", err)
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "OAuth login failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
redirect()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,7 @@ func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
|
||||
// Get technical details if applicable
|
||||
var details string
|
||||
if showDetails && hwsError.Error != nil {
|
||||
details = formatErrorDetails(hwsError.Error)
|
||||
details = notify.FormatErrorDetails(hwsError.Error)
|
||||
}
|
||||
|
||||
// Render appropriate template
|
||||
|
||||
@@ -2,152 +2,15 @@ package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"github.com/a-h/templ"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// throwError is a generic helper that all throw* functions use internally
|
||||
func throwError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
statusCode int,
|
||||
msg string,
|
||||
err error,
|
||||
level hws.ErrorLevel,
|
||||
) {
|
||||
s.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: statusCode,
|
||||
Message: msg,
|
||||
Error: err,
|
||||
Level: level,
|
||||
RenderErrorPage: true, // throw* family always renders error pages
|
||||
})
|
||||
}
|
||||
|
||||
// throwInternalServiceError handles 500 errors (server failures)
|
||||
func throwInternalServiceError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusInternalServerError, msg, err, hws.ErrorERROR)
|
||||
}
|
||||
|
||||
// throwServiceUnavailable handles 503 errors
|
||||
func throwServiceUnavailable(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusServiceUnavailable, msg, err, hws.ErrorERROR)
|
||||
}
|
||||
|
||||
// throwBadRequest handles 400 errors (malformed requests)
|
||||
func throwBadRequest(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusBadRequest, msg, err, hws.ErrorDEBUG)
|
||||
}
|
||||
|
||||
// throwForbidden handles 403 errors (normal permission denials)
|
||||
func throwForbidden(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorDEBUG)
|
||||
}
|
||||
|
||||
// throwForbiddenSecurity handles 403 errors for security events (uses WARN level)
|
||||
func throwForbiddenSecurity(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorWARN)
|
||||
}
|
||||
|
||||
// throwUnauthorized handles 401 errors (not authenticated)
|
||||
func throwUnauthorized(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorDEBUG)
|
||||
}
|
||||
|
||||
// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level)
|
||||
func throwUnauthorizedSecurity(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorWARN)
|
||||
}
|
||||
|
||||
// throwNotFound handles 404 errors
|
||||
func throwNotFound(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
path string,
|
||||
) {
|
||||
msg := fmt.Sprintf("The requested resource was not found: %s", path)
|
||||
err := errors.New("Resource not found")
|
||||
throwError(s, w, r, http.StatusNotFound, msg, err, hws.ErrorDEBUG)
|
||||
}
|
||||
|
||||
// ErrorDetails contains structured error information for WebSocket error modals
|
||||
type ErrorDetails struct {
|
||||
Code int `json:"code"`
|
||||
Stacktrace string `json:"stacktrace"`
|
||||
}
|
||||
|
||||
// formatErrorDetails extracts and formats error details from wrapped errors
|
||||
func formatErrorDetails(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
// Use %+v format to get stack trace from github.com/pkg/errors
|
||||
return fmt.Sprintf("%+v", err)
|
||||
}
|
||||
|
||||
// SerializeErrorDetails creates a JSON string with code and stacktrace
|
||||
// This is exported so it can be used when creating error notifications
|
||||
func SerializeErrorDetails(code int, err error) string {
|
||||
details := ErrorDetails{
|
||||
Code: code,
|
||||
Stacktrace: formatErrorDetails(err),
|
||||
}
|
||||
jsonData, jsonErr := json.Marshal(details)
|
||||
if jsonErr != nil {
|
||||
// Fallback if JSON encoding fails
|
||||
return fmt.Sprintf(`{"code":%d,"stacktrace":"Failed to serialize error"}`, code)
|
||||
}
|
||||
return string(jsonData)
|
||||
}
|
||||
|
||||
// parseErrorDetails extracts code and stacktrace from JSON Details field
|
||||
// Returns (code, stacktrace). If parsing fails, returns (500, original details string)
|
||||
func parseErrorDetails(details string) (int, string) {
|
||||
@@ -155,7 +18,7 @@ func parseErrorDetails(details string) (int, string) {
|
||||
return 500, ""
|
||||
}
|
||||
|
||||
var errDetails ErrorDetails
|
||||
var errDetails notify.ErrorDetails
|
||||
err := json.Unmarshal([]byte(details), &errDetails)
|
||||
if err != nil {
|
||||
// Not JSON or malformed - treat as plain stacktrace with default code
|
||||
@@ -168,6 +31,6 @@ func parseErrorDetails(details string) (int, string) {
|
||||
func renderSafely(page templ.Component, s *hws.Server, r *http.Request, w http.ResponseWriter) {
|
||||
err := page.Render(r.Context(), w)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Failed to render page", errors.Wrap(err, "page."))
|
||||
throw.InternalServiceError(s, w, r, "Failed to render page", errors.Wrap(err, "page."))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package handlers
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
@@ -14,7 +15,7 @@ func Index(s *hws.Server) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
throwNotFound(s, w, r, r.URL.Path)
|
||||
throw.NotFound(s, w, r, r.URL.Path)
|
||||
}
|
||||
renderSafely(page.Index(), s, r, w)
|
||||
},
|
||||
|
||||
49
internal/handlers/isunique.go
Normal file
49
internal/handlers/isunique.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// IsUnique creates a handler that checks field uniqueness
|
||||
// Returns 200 OK if unique, 409 Conflict if not unique
|
||||
func IsUnique(
|
||||
s *hws.Server,
|
||||
conn *bun.DB,
|
||||
model any,
|
||||
field string,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
getter, err := validation.ParseForm(r)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
value := getter.String(field).TrimSpace().Required().Value
|
||||
if !getter.Validate() {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
unique := false
|
||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
unique, err = db.IsUnique(ctx, tx, model, field, value)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "db.IsUnique")
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
if unique {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func IsUsernameUnique(
|
||||
server *hws.Server,
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
store *store.Store,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
username := r.FormValue("username")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database transaction failed", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
unique, err := db.IsUsernameUnique(ctx, tx, username)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database query failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
if !unique {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||
)
|
||||
|
||||
@@ -31,7 +33,7 @@ func Login(
|
||||
|
||||
if r.Method == "POST" {
|
||||
if err != nil {
|
||||
notifyServiceUnavailable(s, r, "Login currently unavailable", err)
|
||||
notify.ServiceUnavailable(s, w, r, "Login currently unavailable", err)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
@@ -40,46 +42,29 @@ func Login(
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
throwServiceUnavailable(s, w, r, "Login currently unavailable", err)
|
||||
throw.ServiceUnavailable(s, w, r, "Login currently unavailable", err)
|
||||
return
|
||||
}
|
||||
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
|
||||
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
|
||||
|
||||
if exceeded {
|
||||
err := errors.Errorf(
|
||||
"login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||
attempts,
|
||||
track.IP,
|
||||
track.UserAgent,
|
||||
track.Path,
|
||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||
)
|
||||
|
||||
err = track.Error(attempts)
|
||||
st.ClearRedirectTrack(r, "/login")
|
||||
|
||||
throwError(
|
||||
s,
|
||||
w,
|
||||
r,
|
||||
http.StatusBadRequest,
|
||||
"Login failed: Too many redirect attempts. Please clear your browser cookies and try again.",
|
||||
err,
|
||||
"warn",
|
||||
)
|
||||
throw.BadRequest(s, w, r, "Too many redirects. Please clear your browser cookies and try again", err)
|
||||
return
|
||||
}
|
||||
|
||||
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Failed to generate state token", err)
|
||||
throw.InternalServiceError(s, w, r, "Failed to generate state token", err)
|
||||
return
|
||||
}
|
||||
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
|
||||
|
||||
link, err := discordAPI.GetOAuthLink(state)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "An error occurred trying to generate the login link", err)
|
||||
throw.InternalServiceError(s, w, r, "An error occurred trying to generate the login link", err)
|
||||
return
|
||||
}
|
||||
st.ClearRedirectTrack(r, "/login")
|
||||
|
||||
@@ -3,12 +3,12 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
@@ -21,42 +21,31 @@ func Logout(
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
return
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
user := db.CurrentUser(r.Context())
|
||||
if user == nil {
|
||||
// JIC - should be impossible to get here if route is protected by LoginReq
|
||||
w.Header().Set("HX-Redirect", "/")
|
||||
return
|
||||
}
|
||||
token, err := user.DeleteDiscordTokens(ctx, tx)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
|
||||
return
|
||||
}
|
||||
if token != nil {
|
||||
err = discordAPI.RevokeToken(token.Convert())
|
||||
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
token, err := user.DeleteDiscordTokens(ctx, tx)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
|
||||
return
|
||||
return false, errors.Wrap(err, "user.DeleteDiscordTokens")
|
||||
}
|
||||
}
|
||||
err = auth.Logout(tx, w, r)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout"))
|
||||
return
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "tx.Commit"))
|
||||
if token != nil {
|
||||
err = discordAPI.RevokeToken(token.Convert())
|
||||
if err != nil {
|
||||
throw.InternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
err = auth.Logout(tx, w, r)
|
||||
if err != nil {
|
||||
throw.InternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout"))
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
w.Header().Set("HX-Redirect", "/")
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
"git.haelnorr.com/h/timefmt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
@@ -31,248 +32,62 @@ func NewSeasonSubmit(
|
||||
conn *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse form data
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
err = notifyWarn(s, r, "Invalid Form", "Please check your input and try again.", nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
name := getter.String("name").
|
||||
TrimSpace().Required().
|
||||
MaxLength(20).MinLength(5).Value
|
||||
shortName := getter.String("short_name").
|
||||
TrimSpace().ToUpper().Required().
|
||||
MaxLength(6).MinLength(2).Value
|
||||
format := timefmt.NewBuilder().
|
||||
DayNumeric2().Slash().
|
||||
MonthNumeric2().Slash().
|
||||
Year4().Build()
|
||||
startDate := getter.Time("start_date", format).Required().Value
|
||||
if !getter.ValidateAndNotify(s, w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get form values
|
||||
name := strings.TrimSpace(r.FormValue("name"))
|
||||
shortName := strings.TrimSpace(strings.ToUpper(r.FormValue("short_name")))
|
||||
startDateStr := r.FormValue("start_date")
|
||||
|
||||
// Validate required fields
|
||||
if name == "" || shortName == "" || startDateStr == "" {
|
||||
err = notifyWarn(s, r, "Missing Fields", "All fields are required.", nil)
|
||||
nameUnique := false
|
||||
shortNameUnique := false
|
||||
var season *db.Season
|
||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
var err error
|
||||
nameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "name", name)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
return false, errors.Wrap(err, "db.IsSeasonNameUnique")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Validate field lengths
|
||||
if len(name) > 20 {
|
||||
err = notifyWarn(s, r, "Invalid Name", "Season name must be 20 characters or less.", nil)
|
||||
shortNameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "short_name", shortName)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
return false, errors.Wrap(err, "db.IsSeasonShortNameUnique")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(shortName) > 6 {
|
||||
err = notifyWarn(s, r, "Invalid Short Name", "Short name must be 6 characters or less.", nil)
|
||||
if !nameUnique || !shortNameUnique {
|
||||
return true, nil
|
||||
}
|
||||
season, err = db.NewSeason(ctx, tx, name, shortName, startDate)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Validate short name is alphanumeric only
|
||||
if !isAlphanumeric(shortName) {
|
||||
err = notifyWarn(s, r, "Invalid Short Name", "Short name must contain only letters and numbers.", nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Parse start date (DD/MM/YYYY format)
|
||||
startDate, err := time.Parse("02/01/2006", startDateStr)
|
||||
if err != nil {
|
||||
err = notifyWarn(s, r, "Invalid Date", "Please provide a valid start date in DD/MM/YYYY format.", nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Begin database transaction
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Double-check uniqueness (race condition protection)
|
||||
nameUnique, err := db.IsSeasonNameUnique(ctx, tx, name)
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonNameUnique"))
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
return false, errors.Wrap(err, "db.NewSeason")
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if !nameUnique {
|
||||
err = notifyWarn(s, r, "Duplicate Name", "This season name is already taken.", nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
shortNameUnique, err := db.IsSeasonShortNameUnique(ctx, tx, shortName)
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonShortNameUnique"))
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
notify.Warn(s, w, r, "Duplicate Name", "This season name is already taken.", nil)
|
||||
return
|
||||
}
|
||||
|
||||
if !shortNameUnique {
|
||||
err = notifyWarn(s, r, "Duplicate Short Name", "This short name is already taken.", nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
notify.Warn(s, w, r, "Duplicate Short Name", "This short name is already taken.", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Create the season
|
||||
season, err := db.NewSeason(ctx, tx, name, shortName, startDate)
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Failed to create season", errors.Wrap(err, "db.NewSeason"))
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "tx.Commit"))
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Send success notification
|
||||
err = notifySuccess(s, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil)
|
||||
if err != nil {
|
||||
// Log but don't fail the request
|
||||
s.LogError(hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "Failed to send success notification",
|
||||
Error: err,
|
||||
})
|
||||
}
|
||||
|
||||
// Redirect to the season detail page
|
||||
notify.Success(s, w, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil)
|
||||
w.Header().Set("HX-Redirect", fmt.Sprintf("/seasons/%s", season.ShortName))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to validate alphanumeric strings
|
||||
func isAlphanumeric(s string) bool {
|
||||
for _, r := range s {
|
||||
if ((r < 'A') || (r > 'Z')) && ((r < '0') || (r > '9')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func IsSeasonNameUnique(
|
||||
s *hws.Server,
|
||||
conn *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Trim whitespace for consistency
|
||||
name := strings.TrimSpace(r.FormValue("name"))
|
||||
|
||||
unique, err := db.IsSeasonNameUnique(ctx, tx, name)
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonNameUnique"))
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
if !unique {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
func IsSeasonShortNameUnique(
|
||||
s *hws.Server,
|
||||
conn *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Get short name and convert to uppercase for consistency
|
||||
shortname := strings.ToUpper(strings.TrimSpace(r.FormValue("short_name")))
|
||||
|
||||
unique, err := db.IsSeasonShortNameUnique(ctx, tx, shortname)
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonShortNameUnique"))
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
if !unique {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/notify"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func notifyClient(
|
||||
s *hws.Server,
|
||||
r *http.Request,
|
||||
level notify.Level,
|
||||
title, message, details string,
|
||||
action any,
|
||||
) error {
|
||||
subCookie, err := r.Cookie("ws_sub_id")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "r.Cookie")
|
||||
}
|
||||
subID := notify.Target(subCookie.Value)
|
||||
nt := notify.Notification{
|
||||
Target: subID,
|
||||
Title: title,
|
||||
Message: message,
|
||||
Details: details,
|
||||
Action: action,
|
||||
Level: level,
|
||||
}
|
||||
s.NotifySub(nt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func notifyInternalServiceError(s *hws.Server, r *http.Request, msg string, err error) error {
|
||||
return notifyClient(s, r, notify.LevelError, "Internal Service Error", msg,
|
||||
SerializeErrorDetails(http.StatusInternalServerError, err), nil)
|
||||
}
|
||||
|
||||
func notifyServiceUnavailable(s *hws.Server, r *http.Request, msg string, err error) error {
|
||||
return notifyClient(s, r, notify.LevelError, "Service Unavailable", msg,
|
||||
SerializeErrorDetails(http.StatusServiceUnavailable, err), nil)
|
||||
}
|
||||
|
||||
func notifyWarn(s *hws.Server, r *http.Request, title, msg string, action any) error {
|
||||
return notifyClient(s, r, notify.LevelWarn, title, msg, "", action)
|
||||
}
|
||||
|
||||
func notifyInfo(s *hws.Server, r *http.Request, title, msg string, action any) error {
|
||||
return notifyClient(s, r, notify.LevelInfo, title, msg, "", action)
|
||||
}
|
||||
|
||||
func notifySuccess(s *hws.Server, r *http.Request, title, msg string, action any) error {
|
||||
return notifyClient(s, r, notify.LevelSuccess, title, msg, "", action)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"git.haelnorr.com/h/golib/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/component/popup"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
@@ -22,7 +23,7 @@ func NotificationWS(
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") != "websocket" {
|
||||
throwNotFound(s, w, r, r.URL.Path)
|
||||
throw.NotFound(s, w, r, r.URL.Path)
|
||||
return
|
||||
}
|
||||
nc, err := setupClient(s, w, r)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
)
|
||||
|
||||
@@ -16,12 +17,6 @@ func NotifyTester(s *hws.Server) http.Handler {
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
testErr := errors.New("This is a stack trace. No really i swear. Just pretend ok? Thanks")
|
||||
if r.Method == "GET" {
|
||||
// page, _ := ErrorPage(hws.HWSError{
|
||||
// StatusCode: http.StatusTeapot,
|
||||
// Message: "This error has been rendered as a test",
|
||||
// Error: testErr,
|
||||
// })
|
||||
// page.Render(r.Context(), w)
|
||||
renderSafely(page.Test(), s, r, w)
|
||||
} else {
|
||||
_ = r.ParseForm()
|
||||
@@ -30,19 +25,15 @@ func NotifyTester(s *hws.Server) http.Handler {
|
||||
level := r.Form.Get("type")
|
||||
message := r.Form.Get("message")
|
||||
|
||||
var err error
|
||||
switch level {
|
||||
case "success":
|
||||
err = notifySuccess(s, r, title, message, nil)
|
||||
notify.Success(s, w, r, title, message, nil)
|
||||
case "info":
|
||||
err = notifyInfo(s, r, title, message, nil)
|
||||
notify.Info(s, w, r, title, message, nil)
|
||||
case "warn":
|
||||
err = notifyWarn(s, r, title, message, nil)
|
||||
notify.Warn(s, w, r, title, message, nil)
|
||||
case "error":
|
||||
err = notifyInternalServiceError(s, r, message, testErr)
|
||||
}
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
||||
notify.InternalServiceError(s, w, r, message, testErr)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func pageOptsFromForm(r *http.Request) (*db.PageOpts, error) {
|
||||
var pageNum, perPage int
|
||||
var order bun.Order
|
||||
var orderBy string
|
||||
var err error
|
||||
|
||||
if pageStr := r.FormValue("page"); pageStr != "" {
|
||||
pageNum, err = strconv.Atoi(pageStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid page number")
|
||||
}
|
||||
}
|
||||
if perPageStr := r.FormValue("per_page"); perPageStr != "" {
|
||||
perPage, err = strconv.Atoi(perPageStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid per_page number")
|
||||
}
|
||||
}
|
||||
order = bun.Order(r.FormValue("order"))
|
||||
orderBy = r.FormValue("order_by")
|
||||
|
||||
pageOpts := &db.PageOpts{
|
||||
Page: pageNum,
|
||||
PerPage: perPage,
|
||||
Order: order,
|
||||
OrderBy: orderBy,
|
||||
}
|
||||
return pageOpts, nil
|
||||
}
|
||||
|
||||
func pageOptsFromQuery(r *http.Request) (*db.PageOpts, error) {
|
||||
var pageNum, perPage int
|
||||
var order bun.Order
|
||||
var orderBy string
|
||||
var err error
|
||||
if pageStr := r.URL.Query().Get("page"); pageStr != "" {
|
||||
pageNum, err = strconv.Atoi(pageStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid page number")
|
||||
}
|
||||
}
|
||||
if perPageStr := r.URL.Query().Get("per_page"); perPageStr != "" {
|
||||
perPage, err = strconv.Atoi(perPageStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid per_page number")
|
||||
}
|
||||
}
|
||||
order = bun.Order(r.URL.Query().Get("order"))
|
||||
orderBy = r.URL.Query().Get("order_by")
|
||||
pageOpts := &db.PageOpts{
|
||||
Page: pageNum,
|
||||
PerPage: perPage,
|
||||
Order: order,
|
||||
OrderBy: orderBy,
|
||||
}
|
||||
return pageOpts, nil
|
||||
}
|
||||
45
internal/handlers/pageopt_helpers.go
Normal file
45
internal/handlers/pageopt_helpers.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// pageOptsFromForm calls r.ParseForm and gets the pageOpts from the formdata.
|
||||
// It renders a Bad Request error page on fail
|
||||
// PageOpts will be nil on fail
|
||||
func pageOptsFromForm(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
|
||||
getter, ok := validation.ParseFormOrError(s, w, r)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return getPageOpts(s, w, r, getter)
|
||||
}
|
||||
|
||||
// pageOptsFromQuery gets the pageOpts from the request query and renders a Bad Request error page on fail
|
||||
// PageOpts will be nil on fail
|
||||
func pageOptsFromQuery(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
|
||||
return getPageOpts(s, w, r, validation.NewQueryGetter(r))
|
||||
}
|
||||
|
||||
func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *db.PageOpts {
|
||||
page := g.Int("page").Min(1).Value
|
||||
perPage := g.Int("per_page").Min(1).Max(100).Value
|
||||
order := g.String("order").TrimSpace().ToUpper().AllowedValues([]string{"ASC", "DESC"}).Value
|
||||
orderBy := g.String("order_by").TrimSpace().ToLower().Value
|
||||
valid := g.ValidateAndError(s, w, r)
|
||||
if !valid {
|
||||
return nil
|
||||
}
|
||||
pageOpts := &db.PageOpts{
|
||||
Page: page,
|
||||
PerPage: perPage,
|
||||
Order: bun.Order(order),
|
||||
OrderBy: orderBy,
|
||||
}
|
||||
return pageOpts
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
@@ -17,7 +18,7 @@ func PermTester(s *hws.Server, conn *bun.DB) http.Handler {
|
||||
isAdmin, err := user.HasRole(r.Context(), tx, roles.Admin)
|
||||
tx.Rollback()
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Error", err)
|
||||
throw.InternalServiceError(s, w, r, "Error", err)
|
||||
}
|
||||
_, _ = w.Write([]byte(strconv.FormatBool(isAdmin)))
|
||||
})
|
||||
|
||||
@@ -3,7 +3,6 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
)
|
||||
|
||||
@@ -29,27 +29,9 @@ func Register(
|
||||
attempts, exceeded, track := store.TrackRedirect(r, "/register", 3)
|
||||
|
||||
if exceeded {
|
||||
err := errors.Errorf(
|
||||
"registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t",
|
||||
attempts,
|
||||
track.IP,
|
||||
track.UserAgent,
|
||||
track.Path,
|
||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||
cfg.HWSAuth.SSL,
|
||||
)
|
||||
|
||||
err := track.Error(attempts)
|
||||
store.ClearRedirectTrack(r, "/register")
|
||||
|
||||
throwError(
|
||||
s,
|
||||
w,
|
||||
r,
|
||||
http.StatusBadRequest,
|
||||
"Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.",
|
||||
err,
|
||||
"warn",
|
||||
)
|
||||
throw.BadRequest(s, w, r, "Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -65,68 +47,45 @@ func Register(
|
||||
}
|
||||
|
||||
store.ClearRedirectTrack(r, "/register")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database transaction failed", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
method := r.Method
|
||||
if method == "GET" {
|
||||
tx.Commit()
|
||||
|
||||
if r.Method == "GET" {
|
||||
renderSafely(page.Register(details.DiscordUser.Username), s, r, w)
|
||||
return
|
||||
}
|
||||
if method == "POST" {
|
||||
username := r.FormValue("username")
|
||||
user, err := registerUser(ctx, tx, username, details)
|
||||
username := r.FormValue("username")
|
||||
unique := false
|
||||
var user *db.User
|
||||
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username)
|
||||
if err != nil {
|
||||
err = notifyInternalServiceError(s, r, "Registration failed", err)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Registration failed", err)
|
||||
}
|
||||
return false, errors.Wrap(err, "db.IsUsernameUnique")
|
||||
}
|
||||
if !unique {
|
||||
return true, nil
|
||||
}
|
||||
user, err = db.CreateUser(ctx, tx, username, details.DiscordUser)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "db.CreateUser")
|
||||
}
|
||||
err = user.UpdateDiscordToken(ctx, tx, details.Token)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "db.UpdateDiscordToken")
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
if !unique {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
} else {
|
||||
err = auth.Login(w, r, user, true)
|
||||
if err != nil {
|
||||
throw.InternalServiceError(s, w, r, "Login failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
if user == nil {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
} else {
|
||||
err = auth.Login(w, r, user, true)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Login failed", err)
|
||||
return
|
||||
}
|
||||
pageFrom := cookies.CheckPageFrom(w, r)
|
||||
w.Header().Set("HX-Redirect", pageFrom)
|
||||
}
|
||||
return
|
||||
pageFrom := cookies.CheckPageFrom(w, r)
|
||||
w.Header().Set("HX-Redirect", pageFrom)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func registerUser(
|
||||
ctx context.Context,
|
||||
tx bun.Tx,
|
||||
username string,
|
||||
details *store.RegistrationSession,
|
||||
) (*db.User, error) {
|
||||
unique, err := db.IsUsernameUnique(ctx, tx, username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.IsUsernameUnique")
|
||||
}
|
||||
if !unique {
|
||||
return nil, nil
|
||||
}
|
||||
user, err := db.CreateUser(ctx, tx, username, details.DiscordUser)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.CreateUser")
|
||||
}
|
||||
err = user.UpdateDiscordToken(ctx, tx, details.Token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.UpdateDiscordToken")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
@@ -17,23 +17,20 @@ func SeasonPage(
|
||||
conn *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
seasonStr := r.PathValue("season_short_name")
|
||||
season, err := db.GetSeason(ctx, tx, seasonStr)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetSeason"))
|
||||
var season *db.Season
|
||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
var err error
|
||||
season, err = db.GetSeason(ctx, tx, seasonStr)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "db.GetSeason")
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
if season == nil {
|
||||
throwNotFound(s, w, r, r.URL.Path)
|
||||
throw.NotFound(s, w, r, r.URL.Path)
|
||||
return
|
||||
}
|
||||
renderSafely(page.SeasonPage(season), s, r, w)
|
||||
|
||||
@@ -3,7 +3,6 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
@@ -17,25 +16,21 @@ func SeasonsPage(
|
||||
conn *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
pageOpts := pageOptsFromQuery(s, w, r)
|
||||
if pageOpts == nil {
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
pageOpts, err := pageOptsFromQuery(r)
|
||||
if err != nil {
|
||||
throwBadRequest(s, w, r, "invalid query", err)
|
||||
var seasons *db.SeasonList
|
||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
var err error
|
||||
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "db.ListSeasons")
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
seasons, err := db.ListSeasons(ctx, tx, pageOpts)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.ListSeasons"))
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
renderSafely(page.SeasonsPage(seasons), s, r, w)
|
||||
})
|
||||
}
|
||||
@@ -45,37 +40,21 @@ func SeasonsList(
|
||||
conn *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Parse form values
|
||||
if err := r.ParseForm(); err != nil {
|
||||
throwBadRequest(s, w, r, "Invalid form data", err)
|
||||
pageOpts := pageOptsFromForm(s, w, r)
|
||||
if pageOpts == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pageOpts, err := pageOptsFromForm(r)
|
||||
if err != nil {
|
||||
throwBadRequest(s, w, r, "invalid form data", err)
|
||||
var seasons *db.SeasonList
|
||||
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||
var err error
|
||||
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "db.ListSeasons")
|
||||
}
|
||||
return true, nil
|
||||
}); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Database query
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
seasons, err := db.ListSeasons(ctx, tx, pageOpts)
|
||||
if err != nil {
|
||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.ListSeasons"))
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
// Return only the list component (hx-push-url handles URL update client-side)
|
||||
renderSafely(page.SeasonsList(seasons), s, r, w)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
)
|
||||
|
||||
// StaticFS handles requests for static files, without allowing access to the
|
||||
@@ -16,7 +17,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler {
|
||||
if err != nil {
|
||||
// If we can't create the file server, return a handler that always errors
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err)
|
||||
throw.InternalServiceError(server, w, r, "An error occurred trying to load the file system", err)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
63
internal/notify/notify.go
Normal file
63
internal/notify/notify.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Package notify provides utility functions for sending notifications to clients
|
||||
package notify
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func notifyClient(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
level notify.Level,
|
||||
title, message, details string,
|
||||
action any,
|
||||
) {
|
||||
subCookie, err := r.Cookie("ws_sub_id")
|
||||
if err != nil {
|
||||
throw.InternalServiceError(s, w, r, "Notification failed. Do you have cookies enabled?", errors.Wrap(err, "r.Cookie"))
|
||||
return
|
||||
}
|
||||
subID := notify.Target(subCookie.Value)
|
||||
nt := notify.Notification{
|
||||
Target: subID,
|
||||
Title: title,
|
||||
Message: message,
|
||||
Details: details,
|
||||
Action: action,
|
||||
Level: level,
|
||||
}
|
||||
s.NotifySub(nt)
|
||||
}
|
||||
|
||||
// InternalServiceError notifies with error level
|
||||
func InternalServiceError(s *hws.Server, w http.ResponseWriter, r *http.Request, msg string, err error) {
|
||||
notifyClient(s, w, r, notify.LevelError, "Internal Service Error", msg,
|
||||
SerializeErrorDetails(http.StatusInternalServerError, err), nil)
|
||||
}
|
||||
|
||||
// ServiceUnavailable notifies with error level
|
||||
func ServiceUnavailable(s *hws.Server, w http.ResponseWriter, r *http.Request, msg string, err error) {
|
||||
notifyClient(s, w, r, notify.LevelError, "Service Unavailable", msg,
|
||||
SerializeErrorDetails(http.StatusServiceUnavailable, err), nil)
|
||||
}
|
||||
|
||||
// Warn notifies with warn level
|
||||
func Warn(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) {
|
||||
notifyClient(s, w, r, notify.LevelWarn, title, msg, "", action)
|
||||
}
|
||||
|
||||
// Info notifies with info level
|
||||
func Info(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) {
|
||||
notifyClient(s, w, r, notify.LevelInfo, title, msg, "", action)
|
||||
}
|
||||
|
||||
// Success notifies with success level
|
||||
func Success(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) {
|
||||
notifyClient(s, w, r, notify.LevelSuccess, title, msg, "", action)
|
||||
}
|
||||
53
internal/notify/util.go
Normal file
53
internal/notify/util.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ErrorDetails contains structured error information for WebSocket error modals
|
||||
type ErrorDetails struct {
|
||||
Code int `json:"code"`
|
||||
Stacktrace string `json:"stacktrace"`
|
||||
}
|
||||
|
||||
// SerializeErrorDetails creates a JSON string with code and stacktrace
|
||||
// This is exported so it can be used when creating error notifications
|
||||
func SerializeErrorDetails(code int, err error) string {
|
||||
details := ErrorDetails{
|
||||
Code: code,
|
||||
Stacktrace: FormatErrorDetails(err),
|
||||
}
|
||||
jsonData, jsonErr := json.Marshal(details)
|
||||
if jsonErr != nil {
|
||||
// Fallback if JSON encoding fails
|
||||
return fmt.Sprintf(`{"code":%d,"stacktrace":"Failed to serialize error"}`, code)
|
||||
}
|
||||
return string(jsonData)
|
||||
}
|
||||
|
||||
// ParseErrorDetails extracts code and stacktrace from JSON Details field
|
||||
// Returns (code, stacktrace). If parsing fails, returns (500, original details string)
|
||||
func ParseErrorDetails(details string) (int, string) {
|
||||
if details == "" {
|
||||
return 500, ""
|
||||
}
|
||||
|
||||
var errDetails ErrorDetails
|
||||
err := json.Unmarshal([]byte(details), &errDetails)
|
||||
if err != nil {
|
||||
// Not JSON or malformed - treat as plain stacktrace with default code
|
||||
return 500, details
|
||||
}
|
||||
|
||||
return errDetails.Code, errDetails.Stacktrace
|
||||
}
|
||||
|
||||
// FormatErrorDetails extracts and formats error details from wrapped errors
|
||||
func FormatErrorDetails(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
// Use %+v format to get stack trace from github.com/pkg/errors
|
||||
return fmt.Sprintf("%+v", err)
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package rbac
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/contexts"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"git.haelnorr.com/h/oslstats/internal/permissions"
|
||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// LoadPermissionsMiddleware loads user permissions into context after authentication
|
||||
@@ -26,46 +26,28 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
|
||||
return
|
||||
}
|
||||
|
||||
// Start transaction for loading permissions
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
tx, err := c.conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
// Log but don't block - permission checks will fail gracefully
|
||||
var roles_ []*db.Role
|
||||
var perms []*db.Permission
|
||||
if err := db.WithTxFailSilently(r.Context(), c.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||
var err error
|
||||
roles_, err = user.GetRoles(ctx, tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "user.GetRoles")
|
||||
}
|
||||
perms, err = user.GetPermissions(ctx, tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "user.GetPermissions")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
c.s.LogError(hws.HWSError{
|
||||
Message: "Failed to start database transaction",
|
||||
Error: errors.Wrap(err, "c.conn.BeginTx"),
|
||||
Message: "Database error",
|
||||
Error: err,
|
||||
Level: hws.ErrorERROR,
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Load user's roles_ and permissions
|
||||
roles_, err := user.GetRoles(ctx, tx)
|
||||
if err != nil {
|
||||
c.s.LogError(hws.HWSError{
|
||||
Message: "Failed to get user roles",
|
||||
Error: errors.Wrap(err, "user.GetRoles"),
|
||||
Level: hws.ErrorERROR,
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
perms, err := user.GetPermissions(ctx, tx)
|
||||
if err != nil {
|
||||
c.s.LogError(hws.HWSError{
|
||||
Message: "Failed to get user permissions",
|
||||
Error: errors.Wrap(err, "user.GetPermissions"),
|
||||
Level: hws.ErrorERROR,
|
||||
})
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
_ = tx.Commit() // read only transaction
|
||||
|
||||
// Build permission cache
|
||||
cache := &contexts.PermissionCache{
|
||||
@@ -88,7 +70,7 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
|
||||
}
|
||||
|
||||
// Add cache to context (type-safe)
|
||||
ctx = context.WithValue(ctx, contexts.PermissionCacheKey, cache)
|
||||
ctx := context.WithValue(r.Context(), contexts.PermissionCacheKey, cache)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// getClientIP extracts the client IP address, checking X-Forwarded-For first
|
||||
@@ -93,3 +95,14 @@ func (s *Store) ClearRedirectTrack(r *http.Request, path string) {
|
||||
key := redirectKey(ip, userAgent, path)
|
||||
s.redirectTracks.Delete(key)
|
||||
}
|
||||
|
||||
func (t *RedirectTrack) Error(attempts int) error {
|
||||
return errors.Errorf(
|
||||
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||
attempts,
|
||||
t.IP,
|
||||
t.UserAgent,
|
||||
t.Path,
|
||||
t.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||
)
|
||||
}
|
||||
|
||||
118
internal/throw/throw.go
Normal file
118
internal/throw/throw.go
Normal file
@@ -0,0 +1,118 @@
|
||||
// Package throw provides utility functions for throwing HTTP errors that render an error page
|
||||
package throw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// throwError is a generic helper that all throw* functions use internally
|
||||
func throwError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
statusCode int,
|
||||
msg string,
|
||||
err error,
|
||||
level hws.ErrorLevel,
|
||||
) {
|
||||
s.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: statusCode,
|
||||
Message: msg,
|
||||
Error: err,
|
||||
Level: level,
|
||||
RenderErrorPage: true, // throw* family always renders error pages
|
||||
})
|
||||
}
|
||||
|
||||
// InternalServiceError handles 500 errors (server failures)
|
||||
func InternalServiceError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusInternalServerError, msg, err, hws.ErrorERROR)
|
||||
}
|
||||
|
||||
// ServiceUnavailable handles 503 errors
|
||||
func ServiceUnavailable(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusServiceUnavailable, msg, err, hws.ErrorERROR)
|
||||
}
|
||||
|
||||
// BadRequest handles 400 errors (malformed requests)
|
||||
func BadRequest(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusBadRequest, msg, err, hws.ErrorDEBUG)
|
||||
}
|
||||
|
||||
// Forbidden handles 403 errors (normal permission denials)
|
||||
func Forbidden(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorDEBUG)
|
||||
}
|
||||
|
||||
// ForbiddenSecurity handles 403 errors for security events (uses WARN level)
|
||||
func ForbiddenSecurity(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorWARN)
|
||||
}
|
||||
|
||||
// Unauthorized handles 401 errors (not authenticated)
|
||||
func Unauthorized(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorDEBUG)
|
||||
}
|
||||
|
||||
// UnauthorizedSecurity handles 401 errors for security events (uses WARN level)
|
||||
func UnauthorizedSecurity(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorWARN)
|
||||
}
|
||||
|
||||
// NotFound handles 404 errors
|
||||
func NotFound(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
path string,
|
||||
) {
|
||||
msg := fmt.Sprintf("The requested resource was not found: %s", path)
|
||||
err := errors.New("Resource not found")
|
||||
throwError(s, w, r, http.StatusNotFound, msg, err, hws.ErrorDEBUG)
|
||||
}
|
||||
101
internal/validation/forms.go
Normal file
101
internal/validation/forms.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/timefmt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// FormGetter wraps http.Request to get form values
|
||||
type FormGetter struct {
|
||||
r *http.Request
|
||||
checks []*ValidationRule
|
||||
}
|
||||
|
||||
func NewFormGetter(r *http.Request) *FormGetter {
|
||||
return &FormGetter{r: r, checks: []*ValidationRule{}}
|
||||
}
|
||||
|
||||
func (f *FormGetter) Get(key string) string {
|
||||
return f.r.FormValue(key)
|
||||
}
|
||||
|
||||
func (f *FormGetter) getChecks() []*ValidationRule {
|
||||
return f.checks
|
||||
}
|
||||
|
||||
func (f *FormGetter) AddCheck(check *ValidationRule) {
|
||||
f.checks = append(f.checks, check)
|
||||
}
|
||||
|
||||
func (f *FormGetter) ValidateChecks() []*ValidationRule {
|
||||
return validate(f)
|
||||
}
|
||||
|
||||
func (f *FormGetter) String(key string) *StringField {
|
||||
return newStringField(key, f)
|
||||
}
|
||||
|
||||
func (f *FormGetter) Int(key string) *IntField {
|
||||
return newIntField(key, f)
|
||||
}
|
||||
|
||||
func (f *FormGetter) Time(key string, format *timefmt.Format) *TimeField {
|
||||
return newTimeField(key, format, f)
|
||||
}
|
||||
|
||||
func ParseForm(r *http.Request) (*FormGetter, error) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "r.ParseForm")
|
||||
}
|
||||
return NewFormGetter(r), nil
|
||||
}
|
||||
|
||||
// ParseFormOrNotify attempts to parse the form data and notifies the user on fail
|
||||
func ParseFormOrNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) {
|
||||
getter, err := ParseForm(r)
|
||||
if err != nil {
|
||||
notify.Warn(s, w, r, "Invalid Form", "Please check your input and try again.", nil)
|
||||
return nil, false
|
||||
}
|
||||
return getter, true
|
||||
}
|
||||
|
||||
// ParseFormOrError attempts to parse the form data and renders an error page on fail
|
||||
func ParseFormOrError(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) {
|
||||
getter, err := ParseForm(r)
|
||||
if err != nil {
|
||||
throw.BadRequest(s, w, r, "Invalid form data", err)
|
||||
return nil, false
|
||||
}
|
||||
return getter, true
|
||||
}
|
||||
|
||||
func (f *FormGetter) Validate() bool {
|
||||
return len(validate(f)) == 0
|
||||
}
|
||||
|
||||
// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check
|
||||
// Returns true if all checks passed
|
||||
func (f *FormGetter) ValidateAndNotify(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) bool {
|
||||
return validateAndNotify(s, w, r, f)
|
||||
}
|
||||
|
||||
// ValidateAndError runs the provided validation checks and renders an error page with all the error messages
|
||||
// Returns true if all checks passed
|
||||
func (f *FormGetter) ValidateAndError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) bool {
|
||||
return validateAndError(s, w, r, f)
|
||||
}
|
||||
71
internal/validation/integerfield.go
Normal file
71
internal/validation/integerfield.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type IntField struct {
|
||||
Field
|
||||
Value int
|
||||
}
|
||||
|
||||
func newIntField(key string, g Getter) *IntField {
|
||||
raw := g.Get(key)
|
||||
var val int
|
||||
if raw != "" {
|
||||
var err error
|
||||
val, err = strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
g.AddCheck(newFailedCheck(
|
||||
"Value is not a number",
|
||||
fmt.Sprintf("%s must be an integer: %s provided", key, raw),
|
||||
))
|
||||
}
|
||||
}
|
||||
return &IntField{
|
||||
Value: val,
|
||||
Field: newField(key, g),
|
||||
}
|
||||
}
|
||||
|
||||
// Required enforces a non-zero value
|
||||
func (i *IntField) Required() *IntField {
|
||||
if i.Value == 0 {
|
||||
i.getter.AddCheck(newFailedCheck(
|
||||
"Value cannot be 0",
|
||||
fmt.Sprintf("%s is required", i.Key),
|
||||
))
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// Optional will skip all validations if value is empty
|
||||
func (i *IntField) Optional() *IntField {
|
||||
if i.Value == 0 {
|
||||
i.optional = true
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// Max enforces a maxmium value
|
||||
func (i *IntField) Max(max int) *IntField {
|
||||
if i.Value > max && !i.optional {
|
||||
i.getter.AddCheck(newFailedCheck(
|
||||
"Value too large",
|
||||
fmt.Sprintf("%s is too large, max %v", i.Key, max),
|
||||
))
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// Min enforces a minimum value
|
||||
func (i *IntField) Min(min int) *IntField {
|
||||
if i.Value < min && !i.optional {
|
||||
i.getter.AddCheck(newFailedCheck(
|
||||
"Value too small",
|
||||
fmt.Sprintf("%s is too small, min %v", i.Key, min),
|
||||
))
|
||||
}
|
||||
return i
|
||||
}
|
||||
70
internal/validation/querys.go
Normal file
70
internal/validation/querys.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/timefmt"
|
||||
)
|
||||
|
||||
// QueryGetter wraps http.Request to get query values
|
||||
type QueryGetter struct {
|
||||
r *http.Request
|
||||
checks []*ValidationRule
|
||||
}
|
||||
|
||||
func NewQueryGetter(r *http.Request) *QueryGetter {
|
||||
return &QueryGetter{r: r, checks: []*ValidationRule{}}
|
||||
}
|
||||
|
||||
func (q *QueryGetter) Get(key string) string {
|
||||
return q.r.URL.Query().Get(key)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) getChecks() []*ValidationRule {
|
||||
return q.checks
|
||||
}
|
||||
|
||||
func (q *QueryGetter) AddCheck(check *ValidationRule) {
|
||||
q.checks = append(q.checks, check)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) ValidateChecks() []*ValidationRule {
|
||||
return validate(q)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) String(key string) *StringField {
|
||||
return newStringField(key, q)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) Int(key string) *IntField {
|
||||
return newIntField(key, q)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) Time(key string, format *timefmt.Format) *TimeField {
|
||||
return newTimeField(key, format, q)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) Validate() bool {
|
||||
return len(validate(q)) == 0
|
||||
}
|
||||
|
||||
// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check
|
||||
// Returns true if all checks passed
|
||||
func (q *QueryGetter) ValidateAndNotify(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) bool {
|
||||
return validateAndNotify(s, w, r, q)
|
||||
}
|
||||
|
||||
// ValidateAndError runs the provided validation checks and renders an error page with all the error messages
|
||||
// Returns true if all checks passed
|
||||
func (q *QueryGetter) ValidateAndError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) bool {
|
||||
return validateAndError(s, w, r, q)
|
||||
}
|
||||
106
internal/validation/stringfield.go
Normal file
106
internal/validation/stringfield.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type StringField struct {
|
||||
Field
|
||||
Value string
|
||||
}
|
||||
|
||||
func newStringField(key string, g Getter) *StringField {
|
||||
return &StringField{
|
||||
Value: g.Get(key),
|
||||
Field: newField(key, g),
|
||||
}
|
||||
}
|
||||
|
||||
// Required enforces a non empty string
|
||||
func (s *StringField) Required() *StringField {
|
||||
if s.Value == "" {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Field not provided",
|
||||
fmt.Sprintf("%s is required", s.Key),
|
||||
))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Optional will skip all validations if value is empty
|
||||
func (s *StringField) Optional() *StringField {
|
||||
if s.Value == "" {
|
||||
s.optional = true
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// MaxLength enforces a maximum string length
|
||||
func (s *StringField) MaxLength(length int) *StringField {
|
||||
if len(s.Value) > length && !s.optional {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Input too long",
|
||||
fmt.Sprintf("%s is too long, max %v chars", s.Key, length),
|
||||
))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// MinLength enforces a minimum string length
|
||||
func (s *StringField) MinLength(length int) *StringField {
|
||||
if len(s.Value) < length && !s.optional {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Input too short",
|
||||
fmt.Sprintf("%s is too short, min %v chars", s.Key, length),
|
||||
))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// AlphaNumeric enforces the string contains only letters and numbers
|
||||
func (s *StringField) AlphaNumeric() *StringField {
|
||||
if s.optional {
|
||||
return s
|
||||
}
|
||||
for _, r := range s.Value {
|
||||
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Invalid characters",
|
||||
fmt.Sprintf("%s must contain only letters and numbers.", s.Key),
|
||||
))
|
||||
return s
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *StringField) AllowedValues(allowed []string) *StringField {
|
||||
if !slices.Contains(allowed, s.Value) && !s.optional {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Value not allowed",
|
||||
fmt.Sprintf("%s must be one of: %s", s.Key, strings.Join(allowed, ",")),
|
||||
))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ToUpper transforms the string to uppercase
|
||||
func (s *StringField) ToUpper() *StringField {
|
||||
s.Value = strings.ToUpper(s.Value)
|
||||
return s
|
||||
}
|
||||
|
||||
// ToLower transforms the string to lowercase
|
||||
func (s *StringField) ToLower() *StringField {
|
||||
s.Value = strings.ToLower(s.Value)
|
||||
return s
|
||||
}
|
||||
|
||||
// TrimSpace removes leading and trailing whitespace
|
||||
func (s *StringField) TrimSpace() *StringField {
|
||||
s.Value = strings.TrimSpace(s.Value)
|
||||
return s
|
||||
}
|
||||
50
internal/validation/timefield.go
Normal file
50
internal/validation/timefield.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/timefmt"
|
||||
)
|
||||
|
||||
type TimeField struct {
|
||||
Field
|
||||
Value time.Time
|
||||
}
|
||||
|
||||
func newTimeField(key string, format *timefmt.Format, g Getter) *TimeField {
|
||||
raw := g.Get(key)
|
||||
var startDate time.Time
|
||||
if raw != "" {
|
||||
var err error
|
||||
startDate, err = format.Parse(raw)
|
||||
if err != nil {
|
||||
g.AddCheck(newFailedCheck(
|
||||
"Invalid date/time format",
|
||||
fmt.Sprintf("%s should be in format %s", key, format.LDML()),
|
||||
))
|
||||
}
|
||||
}
|
||||
return &TimeField{
|
||||
Value: startDate,
|
||||
Field: newField(key, g),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TimeField) Required() *TimeField {
|
||||
if t.Value.IsZero() {
|
||||
t.getter.AddCheck(newFailedCheck(
|
||||
"Date/Time not provided",
|
||||
fmt.Sprintf("%s must be provided", t.Key),
|
||||
))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// Optional will skip all validations if value is empty
|
||||
func (t *TimeField) Optional() *TimeField {
|
||||
if t.Value.IsZero() {
|
||||
t.optional = true
|
||||
}
|
||||
return t
|
||||
}
|
||||
93
internal/validation/validation.go
Normal file
93
internal/validation/validation.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Package validation provides utilities for parsing and validating request data
|
||||
package validation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/timefmt"
|
||||
)
|
||||
|
||||
type ValidationRule struct {
|
||||
Condition bool // Condition on which to fail validation
|
||||
Title string // Title for warning message
|
||||
Message string // Warning message
|
||||
}
|
||||
|
||||
// Getter abstracts getting values from either form or query
|
||||
type Getter interface {
|
||||
Get(key string) string
|
||||
AddCheck(check *ValidationRule)
|
||||
String(key string) *StringField
|
||||
Int(key string) *IntField
|
||||
Time(key string, format *timefmt.Format) *TimeField
|
||||
ValidateChecks() []*ValidationRule
|
||||
Validate() bool
|
||||
ValidateAndNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) bool
|
||||
ValidateAndError(s *hws.Server, w http.ResponseWriter, r *http.Request) bool
|
||||
|
||||
getChecks() []*ValidationRule
|
||||
}
|
||||
type Field struct {
|
||||
Key string
|
||||
optional bool
|
||||
getter Getter
|
||||
}
|
||||
|
||||
func newField(key string, g Getter) Field {
|
||||
return Field{
|
||||
Key: key,
|
||||
getter: g,
|
||||
}
|
||||
}
|
||||
|
||||
func newFailedCheck(title, message string) *ValidationRule {
|
||||
return &ValidationRule{
|
||||
Condition: true,
|
||||
Title: title,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
func validate(g Getter) []*ValidationRule {
|
||||
failed := []*ValidationRule{}
|
||||
for _, check := range g.getChecks() {
|
||||
if check != nil && check.Condition {
|
||||
failed = append(failed, check)
|
||||
}
|
||||
}
|
||||
return failed
|
||||
}
|
||||
|
||||
func validateAndNotify(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
g Getter,
|
||||
) bool {
|
||||
failedChecks := g.ValidateChecks()
|
||||
for _, check := range failedChecks {
|
||||
notify.Warn(s, w, r, check.Title, check.Message, nil)
|
||||
}
|
||||
return len(failedChecks) == 0
|
||||
}
|
||||
|
||||
func validateAndError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
g Getter,
|
||||
) bool {
|
||||
failedChecks := g.ValidateChecks()
|
||||
var err error
|
||||
for _, check := range failedChecks {
|
||||
err_ := fmt.Errorf("%s: %s", check.Title, check.Message)
|
||||
err = errors.Join(err, err_)
|
||||
}
|
||||
throw.BadRequest(s, w, r, "Invalid form data", err)
|
||||
return len(failedChecks) == 0
|
||||
}
|
||||
Reference in New Issue
Block a user