diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index 334b520..b2d53ac 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -101,17 +101,17 @@ func addRoutes( { Path: "/htmx/isusernameunique", Method: hws.MethodPOST, - Handler: handlers.IsUsernameUnique(s, conn, cfg, store), + Handler: handlers.IsUnique(s, conn, (*db.User)(nil), "username"), }, { Path: "/htmx/isseasonnameunique", Method: hws.MethodPOST, - Handler: handlers.IsSeasonNameUnique(s, conn), + Handler: handlers.IsUnique(s, conn, (*db.Season)(nil), "name"), }, { Path: "/htmx/isseasonshortnameunique", Method: hws.MethodPOST, - Handler: handlers.IsSeasonShortNameUnique(s, conn), + Handler: handlers.IsUnique(s, conn, (*db.Season)(nil), "short_name"), }, } diff --git a/go.mod b/go.mod index 57684ab..9aa5155 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module git.haelnorr.com/h/oslstats -go 1.25.5 +go 1.25.6 require ( git.haelnorr.com/h/golib/env v0.9.1 @@ -27,6 +27,7 @@ require ( require ( git.haelnorr.com/h/golib/cookies v0.9.0 git.haelnorr.com/h/golib/jwt v0.10.1 // indirect + git.haelnorr.com/h/timefmt v0.1.0 github.com/bwmarrin/discordgo v0.29.0 github.com/go-logr/logr v1.4.3 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect diff --git a/go.sum b/go.sum index 35b9693..95c72f1 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7Kv git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10= git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc= +git.haelnorr.com/h/timefmt v0.1.0 h1:ULDkWEtFIV+FkkoV0q9n62Spj+HDdtFL9QeAdGIEp+o= +git.haelnorr.com/h/timefmt v0.1.0/go.mod h1:12gXXYLP4w9Fa9ZkbZWdvKV6RyZEzwAm9mN+WB3oXpw= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg= diff --git a/internal/auditlog/logger.go b/internal/auditlog/logger.go index 9a06e15..044a194 100644 --- a/internal/auditlog/logger.go +++ b/internal/auditlog/logger.go @@ -105,35 +105,33 @@ func (l *Logger) log( // GetRecentLogs retrieves recent audit logs with pagination func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) { - tx, err := l.conn.BeginTx(ctx, nil) - if err != nil { - return nil, errors.Wrap(err, "conn.BeginTx") + var logs *db.AuditLogs + if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { + var err error + logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil) + if err != nil { + return errors.Wrap(err, "db.GetAuditLogs") + } + return nil + }); err != nil { + return nil, errors.Wrap(err, "db.WithTxFailSilently") } - defer func() { _ = tx.Rollback() }() - - logs, err := db.GetAuditLogs(ctx, tx, pageOpts, nil) - if err != nil { - return nil, err - } - - _ = tx.Commit() // read only transaction return logs, nil } // GetLogsByUser retrieves audit logs for a specific user func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) { - tx, err := l.conn.BeginTx(ctx, nil) - if err != nil { - return nil, errors.Wrap(err, "conn.BeginTx") + var logs *db.AuditLogs + if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { + var err error + logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts) + if err != nil { + return errors.Wrap(err, "db.GetAuditLogsByUser") + } + return nil + }); err != nil { + return nil, errors.Wrap(err, "db.WithTxFailSilently") } - defer func() { _ = tx.Rollback() }() - - logs, err := db.GetAuditLogsByUser(ctx, tx, userID, pageOpts) - if err != nil { - return nil, err - } - - _ = tx.Commit() // read only transaction return logs, nil } @@ -145,20 +143,16 @@ func (l *Logger) CleanupOldLogs(ctx context.Context, daysToKeep int) (int, error cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix() - tx, err := l.conn.BeginTx(ctx, nil) - if err != nil { - return 0, errors.Wrap(err, "conn.BeginTx") - } - defer func() { _ = tx.Rollback() }() - - count, err := db.CleanupOldAuditLogs(ctx, tx, cutoffTime) - if err != nil { - return 0, err - } - - err = tx.Commit() - if err != nil { - return 0, errors.Wrap(err, "tx.Commit") + var count int + if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error { + var err error + count, err = db.CleanupOldAuditLogs(ctx, tx, cutoffTime) + if err != nil { + return errors.Wrap(err, "db.CleanupOldAuditLogs") + } + return nil + }); err != nil { + return 0, errors.Wrap(err, "db.WithTxFailSilently") } return count, nil } diff --git a/internal/db/isunique.go b/internal/db/isunique.go new file mode 100644 index 0000000..1f9ba01 --- /dev/null +++ b/internal/db/isunique.go @@ -0,0 +1,19 @@ +package db + +import ( + "context" + + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +func IsUnique(ctx context.Context, tx bun.Tx, model any, field, value string) (bool, error) { + count, err := tx.NewSelect(). + Model(model). + Where("? = ?", bun.Ident(field), value). + Count(ctx) + if err != nil { + return false, errors.Wrap(err, "tx.NewSelect") + } + return count == 0, nil +} diff --git a/internal/db/season.go b/internal/db/season.go index f08c569..4a78571 100644 --- a/internal/db/season.go +++ b/internal/db/season.go @@ -87,25 +87,3 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error } return season, nil } - -func IsSeasonNameUnique(ctx context.Context, tx bun.Tx, name string) (bool, error) { - count, err := tx.NewSelect(). - Model((*Season)(nil)). - Where("name = ?", name). - Count(ctx) - if err != nil { - return false, errors.Wrap(err, "tx.NewSelect") - } - return count == 0, nil -} - -func IsSeasonShortNameUnique(ctx context.Context, tx bun.Tx, shortname string) (bool, error) { - count, err := tx.NewSelect(). - Model((*Season)(nil)). - Where("short_name = ?", shortname). - Count(ctx) - if err != nil { - return false, errors.Wrap(err, "tx.NewSelect") - } - return count == 0, nil -} diff --git a/internal/db/txhelpers.go b/internal/db/txhelpers.go new file mode 100644 index 0000000..3d9a99f --- /dev/null +++ b/internal/db/txhelpers.go @@ -0,0 +1,118 @@ +package db + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/notify" + "git.haelnorr.com/h/oslstats/internal/throw" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +// TxFunc is a function that runs within a database transaction +type ( + TxFunc func(ctx context.Context, tx bun.Tx) (bool, error) + TxFuncSilent func(ctx context.Context, tx bun.Tx) error +) + +var timeout = 15 * time.Second + +// WithReadTx executes a read-only transaction with automatic rollback +// Returns true if successful, false if error was thrown to client +func WithReadTx( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + conn *bun.DB, + fn TxFunc, +) bool { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + ok, err := withTx(ctx, conn, fn, false) + if err != nil { + throw.InternalServiceError(s, w, r, "Database error", err) + } + return ok +} + +// WithTxFailSilently executes a transaction with automatic rollback +// Returns true if successful, false if error occured. +// Does not throw any errors to the client. +func WithTxFailSilently( + ctx context.Context, + conn *bun.DB, + fn TxFuncSilent, +) error { + fnc := func(ctx context.Context, tx bun.Tx) (bool, error) { + err := fn(ctx, tx) + return err != nil, err + } + _, err := withTx(ctx, conn, fnc, true) + return err +} + +// WithWriteTx executes a write transaction with automatic rollback on error +// Commits only if fn returns nil. Returns true if successful. +func WithWriteTx( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + conn *bun.DB, + fn TxFunc, +) bool { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + ok, err := withTx(ctx, conn, fn, true) + if err != nil { + throw.InternalServiceError(s, w, r, "Database error", err) + } + return ok +} + +// WithNotifyTx executes a transaction with notification-based error handling +// Uses notifyInternalServiceError instead of throwInternalServiceError +func WithNotifyTx( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + conn *bun.DB, + fn TxFunc, +) bool { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + ok, err := withTx(ctx, conn, fn, true) + if err != nil { + notify.InternalServiceError(s, w, r, "Database error", err) + } + return ok +} + +// withTx executes a transaction with automatic rollback on error +func withTx( + ctx context.Context, + conn *bun.DB, + fn TxFunc, + write bool, +) (bool, error) { + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return false, errors.Wrap(err, "conn.BeginTx") + } + defer func() { _ = tx.Rollback() }() + ok, err := fn(ctx, tx) + if err != nil || !ok { + return false, err + } + if write { + err = tx.Commit() + if err != nil { + return false, errors.Wrap(err, "tx.Commit") + } + } else { + _ = tx.Commit() + } + return true, nil +} diff --git a/internal/db/user.go b/internal/db/user.go index c8a1457..16a53c8 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -112,19 +112,6 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User return user, nil } -// IsUsernameUnique checks if the given username is unique (not already taken) -// Returns true if the username is available, false if it's taken -func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) { - count, err := tx.NewSelect(). - Model((*User)(nil)). - Where("username = ?", username). - Count(ctx) - if err != nil { - return false, errors.Wrap(err, "tx.NewSelect") - } - return count == 0, nil -} - // GetRoles loads all the roles for this user func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) { if u == nil { diff --git a/internal/handlers/admin_dashboard.go b/internal/handlers/admin_dashboard.go index 2f16041..00f3883 100644 --- a/internal/handlers/admin_dashboard.go +++ b/internal/handlers/admin_dashboard.go @@ -3,7 +3,6 @@ package handlers import ( "context" "net/http" - "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/db" @@ -14,23 +13,17 @@ import ( func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) - defer cancel() - - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) + var users *db.Users + if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + var err error + users, err = db.GetUsers(ctx, tx, nil) + if err != nil { + return false, errors.Wrap(err, "db.GetUsers") + } + return true, nil + }); !ok { return } - defer func() { _ = tx.Rollback() }() - - users, err := db.GetUsers(ctx, tx, nil) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetUsers")) - return - } - _ = tx.Commit() - renderSafely(page.AdminDashboard(users), s, r, w) }) } diff --git a/internal/handlers/admin_users.go b/internal/handlers/admin_users.go index c15eb2f..9075302 100644 --- a/internal/handlers/admin_users.go +++ b/internal/handlers/admin_users.go @@ -3,7 +3,6 @@ package handlers import ( "context" "net/http" - "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/db" @@ -15,30 +14,21 @@ import ( // AdminUsersList shows all users func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) - defer cancel() - - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(s, w, r, "DB Transaction failed", errors.Wrap(err, "conn.BeginTx")) + var users *db.Users + pageOpts := pageOptsFromForm(s, w, r) + if pageOpts == nil { return } - defer func() { _ = tx.Rollback() }() - - // Get all users - pageOpts, err := pageOptsFromForm(r) - if err != nil { - throwBadRequest(s, w, r, "invalid form data", err) + if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + var err error + users, err = db.GetUsers(ctx, tx, pageOpts) + if err != nil { + return false, errors.Wrap(err, "db.GetUsers") + } + return true, nil + }); !ok { return } - users, err := db.GetUsers(ctx, tx, pageOpts) - if err != nil { - throwInternalServiceError(s, w, r, "Failed to load users", errors.Wrap(err, "db.GetUsers")) - return - } - - _ = tx.Commit() - renderSafely(admin.UserList(users), s, r, w) }) } diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go index 0e9b3d9..4607557 100644 --- a/internal/handlers/callback.go +++ b/internal/handlers/callback.go @@ -3,7 +3,6 @@ package handlers import ( "context" "net/http" - "time" "git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/hws" @@ -15,11 +14,12 @@ import ( "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/pkg/oauth" ) func Callback( - server *hws.Server, + s *hws.Server, auth *hwsauth.Authenticator[*db.User, bun.Tx], conn *bun.DB, cfg *config.Config, @@ -31,26 +31,9 @@ func Callback( attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5) if exceeded { - err := errors.Errorf( - "callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", - attempts, - track.IP, - track.UserAgent, - track.Path, - track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), - ) - + err := track.Error(attempts) store.ClearRedirectTrack(r, "/callback") - - throwError( - server, - w, - r, - http.StatusBadRequest, - "OAuth callback failed: Too many redirect attempts. Please try logging in again.", - err, - "warn", - ) + throw.BadRequest(s, w, r, "Too many redirects. Please try logging in again.", err) return } @@ -64,12 +47,12 @@ func Callback( if err != nil { if vsErr, ok := err.(*verifyStateError); ok { if vsErr.IsCookieError() { - throwUnauthorized(server, w, r, "OAuth session not found or expired", err) + throw.Unauthorized(s, w, r, "OAuth session not found or expired", err) } else { - throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) + throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err) } } else { - throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) + throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err) } return } @@ -77,20 +60,17 @@ func Callback( switch data { case "login": - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) - defer cancel() - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(server, w, r, "DB Transaction failed to start", err) + var redirect func() + if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI) + if err != nil { + throw.InternalServiceError(s, w, r, "OAuth login failed", err) + return false, nil + } + return true, nil + }); !ok { return } - defer tx.Rollback() - redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI) - if err != nil { - throwInternalServiceError(server, w, r, "OAuth login failed", err) - return - } - tx.Commit() redirect() return } diff --git a/internal/handlers/errorpage.go b/internal/handlers/errorpage.go index 53cb172..917e930 100644 --- a/internal/handlers/errorpage.go +++ b/internal/handlers/errorpage.go @@ -4,6 +4,7 @@ import ( "net/http" "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/view/page" ) @@ -21,7 +22,7 @@ func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) { // Get technical details if applicable var details string if showDetails && hwsError.Error != nil { - details = formatErrorDetails(hwsError.Error) + details = notify.FormatErrorDetails(hwsError.Error) } // Render appropriate template diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go index 26765e6..7c602a2 100644 --- a/internal/handlers/errors.go +++ b/internal/handlers/errors.go @@ -2,152 +2,15 @@ package handlers import ( "encoding/json" - "fmt" "net/http" "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/notify" + "git.haelnorr.com/h/oslstats/internal/throw" "github.com/a-h/templ" "github.com/pkg/errors" ) -// throwError is a generic helper that all throw* functions use internally -func throwError( - s *hws.Server, - w http.ResponseWriter, - r *http.Request, - statusCode int, - msg string, - err error, - level hws.ErrorLevel, -) { - s.ThrowError(w, r, hws.HWSError{ - StatusCode: statusCode, - Message: msg, - Error: err, - Level: level, - RenderErrorPage: true, // throw* family always renders error pages - }) -} - -// throwInternalServiceError handles 500 errors (server failures) -func throwInternalServiceError( - s *hws.Server, - w http.ResponseWriter, - r *http.Request, - msg string, - err error, -) { - throwError(s, w, r, http.StatusInternalServerError, msg, err, hws.ErrorERROR) -} - -// throwServiceUnavailable handles 503 errors -func throwServiceUnavailable( - s *hws.Server, - w http.ResponseWriter, - r *http.Request, - msg string, - err error, -) { - throwError(s, w, r, http.StatusServiceUnavailable, msg, err, hws.ErrorERROR) -} - -// throwBadRequest handles 400 errors (malformed requests) -func throwBadRequest( - s *hws.Server, - w http.ResponseWriter, - r *http.Request, - msg string, - err error, -) { - throwError(s, w, r, http.StatusBadRequest, msg, err, hws.ErrorDEBUG) -} - -// throwForbidden handles 403 errors (normal permission denials) -func throwForbidden( - s *hws.Server, - w http.ResponseWriter, - r *http.Request, - msg string, - err error, -) { - throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorDEBUG) -} - -// throwForbiddenSecurity handles 403 errors for security events (uses WARN level) -func throwForbiddenSecurity( - s *hws.Server, - w http.ResponseWriter, - r *http.Request, - msg string, - err error, -) { - throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorWARN) -} - -// throwUnauthorized handles 401 errors (not authenticated) -func throwUnauthorized( - s *hws.Server, - w http.ResponseWriter, - r *http.Request, - msg string, - err error, -) { - throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorDEBUG) -} - -// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level) -func throwUnauthorizedSecurity( - s *hws.Server, - w http.ResponseWriter, - r *http.Request, - msg string, - err error, -) { - throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorWARN) -} - -// throwNotFound handles 404 errors -func throwNotFound( - s *hws.Server, - w http.ResponseWriter, - r *http.Request, - path string, -) { - msg := fmt.Sprintf("The requested resource was not found: %s", path) - err := errors.New("Resource not found") - throwError(s, w, r, http.StatusNotFound, msg, err, hws.ErrorDEBUG) -} - -// ErrorDetails contains structured error information for WebSocket error modals -type ErrorDetails struct { - Code int `json:"code"` - Stacktrace string `json:"stacktrace"` -} - -// formatErrorDetails extracts and formats error details from wrapped errors -func formatErrorDetails(err error) string { - if err == nil { - return "" - } - // Use %+v format to get stack trace from github.com/pkg/errors - return fmt.Sprintf("%+v", err) -} - -// SerializeErrorDetails creates a JSON string with code and stacktrace -// This is exported so it can be used when creating error notifications -func SerializeErrorDetails(code int, err error) string { - details := ErrorDetails{ - Code: code, - Stacktrace: formatErrorDetails(err), - } - jsonData, jsonErr := json.Marshal(details) - if jsonErr != nil { - // Fallback if JSON encoding fails - return fmt.Sprintf(`{"code":%d,"stacktrace":"Failed to serialize error"}`, code) - } - return string(jsonData) -} - // parseErrorDetails extracts code and stacktrace from JSON Details field // Returns (code, stacktrace). If parsing fails, returns (500, original details string) func parseErrorDetails(details string) (int, string) { @@ -155,7 +18,7 @@ func parseErrorDetails(details string) (int, string) { return 500, "" } - var errDetails ErrorDetails + var errDetails notify.ErrorDetails err := json.Unmarshal([]byte(details), &errDetails) if err != nil { // Not JSON or malformed - treat as plain stacktrace with default code @@ -168,6 +31,6 @@ func parseErrorDetails(details string) (int, string) { func renderSafely(page templ.Component, s *hws.Server, r *http.Request, w http.ResponseWriter) { err := page.Render(r.Context(), w) if err != nil { - throwInternalServiceError(s, w, r, "Failed to render page", errors.Wrap(err, "page.")) + throw.InternalServiceError(s, w, r, "Failed to render page", errors.Wrap(err, "page.")) } } diff --git a/internal/handlers/index.go b/internal/handlers/index.go index ec91c5f..19e668d 100644 --- a/internal/handlers/index.go +++ b/internal/handlers/index.go @@ -3,6 +3,7 @@ package handlers import ( "net/http" + "git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/view/page" "git.haelnorr.com/h/golib/hws" @@ -14,7 +15,7 @@ func Index(s *hws.Server) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - throwNotFound(s, w, r, r.URL.Path) + throw.NotFound(s, w, r, r.URL.Path) } renderSafely(page.Index(), s, r, w) }, diff --git a/internal/handlers/isunique.go b/internal/handlers/isunique.go new file mode 100644 index 0000000..c75daa6 --- /dev/null +++ b/internal/handlers/isunique.go @@ -0,0 +1,49 @@ +package handlers + +import ( + "context" + "net/http" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/validation" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +// IsUnique creates a handler that checks field uniqueness +// Returns 200 OK if unique, 409 Conflict if not unique +func IsUnique( + s *hws.Server, + conn *bun.DB, + model any, + field string, +) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + getter, err := validation.ParseForm(r) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + value := getter.String(field).TrimSpace().Required().Value + if !getter.Validate() { + w.WriteHeader(http.StatusBadRequest) + return + } + unique := false + if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + unique, err = db.IsUnique(ctx, tx, model, field, value) + if err != nil { + return false, errors.Wrap(err, "db.IsUnique") + } + return true, nil + }); !ok { + return + } + if unique { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusConflict) + } + }) +} diff --git a/internal/handlers/isusernameunique.go b/internal/handlers/isusernameunique.go deleted file mode 100644 index f441784..0000000 --- a/internal/handlers/isusernameunique.go +++ /dev/null @@ -1,45 +0,0 @@ -package handlers - -import ( - "context" - "net/http" - "time" - - "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/oslstats/internal/config" - "git.haelnorr.com/h/oslstats/internal/db" - "git.haelnorr.com/h/oslstats/internal/store" - "github.com/uptrace/bun" -) - -func IsUsernameUnique( - server *hws.Server, - conn *bun.DB, - cfg *config.Config, - store *store.Store, -) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - username := r.FormValue("username") - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) - defer cancel() - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(server, w, r, "Database transaction failed", err) - return - } - defer tx.Rollback() - unique, err := db.IsUsernameUnique(ctx, tx, username) - if err != nil { - throwInternalServiceError(server, w, r, "Database query failed", err) - return - } - tx.Commit() - if !unique { - w.WriteHeader(http.StatusConflict) - } else { - w.WriteHeader(http.StatusOK) - } - }, - ) -} diff --git a/internal/handlers/login.go b/internal/handlers/login.go index 1a635f6..7944df0 100644 --- a/internal/handlers/login.go +++ b/internal/handlers/login.go @@ -11,7 +11,9 @@ import ( "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/pkg/oauth" ) @@ -31,7 +33,7 @@ func Login( if r.Method == "POST" { if err != nil { - notifyServiceUnavailable(s, r, "Login currently unavailable", err) + notify.ServiceUnavailable(s, w, r, "Login currently unavailable", err) w.WriteHeader(http.StatusOK) return } @@ -40,46 +42,29 @@ func Login( } if err != nil { - throwServiceUnavailable(s, w, r, "Login currently unavailable", err) + throw.ServiceUnavailable(s, w, r, "Login currently unavailable", err) return } cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost) attempts, exceeded, track := st.TrackRedirect(r, "/login", 5) if exceeded { - err := errors.Errorf( - "login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", - attempts, - track.IP, - track.UserAgent, - track.Path, - track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), - ) - + err = track.Error(attempts) st.ClearRedirectTrack(r, "/login") - - throwError( - s, - w, - r, - http.StatusBadRequest, - "Login failed: Too many redirect attempts. Please clear your browser cookies and try again.", - err, - "warn", - ) + throw.BadRequest(s, w, r, "Too many redirects. Please clear your browser cookies and try again", err) return } state, uak, err := oauth.GenerateState(cfg.OAuth, "login") if err != nil { - throwInternalServiceError(s, w, r, "Failed to generate state token", err) + throw.InternalServiceError(s, w, r, "Failed to generate state token", err) return } oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL) link, err := discordAPI.GetOAuthLink(state) if err != nil { - throwInternalServiceError(s, w, r, "An error occurred trying to generate the login link", err) + throw.InternalServiceError(s, w, r, "An error occurred trying to generate the login link", err) return } st.ClearRedirectTrack(r, "/login") diff --git a/internal/handlers/logout.go b/internal/handlers/logout.go index 8db5a56..93ae282 100644 --- a/internal/handlers/logout.go +++ b/internal/handlers/logout.go @@ -3,12 +3,12 @@ package handlers import ( "context" "net/http" - "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/throw" "github.com/pkg/errors" "github.com/uptrace/bun" ) @@ -21,42 +21,31 @@ func Logout( ) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) - defer cancel() - - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) - return - } - defer func() { _ = tx.Rollback() }() - user := db.CurrentUser(r.Context()) if user == nil { // JIC - should be impossible to get here if route is protected by LoginReq w.Header().Set("HX-Redirect", "/") return } - token, err := user.DeleteDiscordTokens(ctx, tx) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens")) - return - } - if token != nil { - err = discordAPI.RevokeToken(token.Convert()) + if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + token, err := user.DeleteDiscordTokens(ctx, tx) if err != nil { - throwInternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken")) - return + return false, errors.Wrap(err, "user.DeleteDiscordTokens") } - } - err = auth.Logout(tx, w, r) - if err != nil { - throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout")) - return - } - err = tx.Commit() - if err != nil { - throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "tx.Commit")) + if token != nil { + err = discordAPI.RevokeToken(token.Convert()) + if err != nil { + throw.InternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken")) + return false, nil + } + } + err = auth.Logout(tx, w, r) + if err != nil { + throw.InternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout")) + return false, nil + } + return true, nil + }); !ok { return } w.Header().Set("HX-Redirect", "/") diff --git a/internal/handlers/newseason.go b/internal/handlers/newseason.go index 0ed80a1..615564e 100644 --- a/internal/handlers/newseason.go +++ b/internal/handlers/newseason.go @@ -4,12 +4,13 @@ import ( "context" "fmt" "net/http" - "strings" - "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/notify" + "git.haelnorr.com/h/oslstats/internal/validation" "git.haelnorr.com/h/oslstats/internal/view/page" + "git.haelnorr.com/h/timefmt" "github.com/pkg/errors" "github.com/uptrace/bun" ) @@ -31,248 +32,62 @@ func NewSeasonSubmit( conn *bun.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Parse form data - err := r.ParseForm() - if err != nil { - err = notifyWarn(s, r, "Invalid Form", "Please check your input and try again.", nil) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } + getter, ok := validation.ParseFormOrNotify(s, w, r) + if !ok { + return + } + name := getter.String("name"). + TrimSpace().Required(). + MaxLength(20).MinLength(5).Value + shortName := getter.String("short_name"). + TrimSpace().ToUpper().Required(). + MaxLength(6).MinLength(2).Value + format := timefmt.NewBuilder(). + DayNumeric2().Slash(). + MonthNumeric2().Slash(). + Year4().Build() + startDate := getter.Time("start_date", format).Required().Value + if !getter.ValidateAndNotify(s, w, r) { return } - // Get form values - name := strings.TrimSpace(r.FormValue("name")) - shortName := strings.TrimSpace(strings.ToUpper(r.FormValue("short_name"))) - startDateStr := r.FormValue("start_date") - - // Validate required fields - if name == "" || shortName == "" || startDateStr == "" { - err = notifyWarn(s, r, "Missing Fields", "All fields are required.", nil) + nameUnique := false + shortNameUnique := false + var season *db.Season + if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + var err error + nameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "name", name) if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) + return false, errors.Wrap(err, "db.IsSeasonNameUnique") } - return - } - - // Validate field lengths - if len(name) > 20 { - err = notifyWarn(s, r, "Invalid Name", "Season name must be 20 characters or less.", nil) + shortNameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "short_name", shortName) if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) + return false, errors.Wrap(err, "db.IsSeasonShortNameUnique") } - return - } - - if len(shortName) > 6 { - err = notifyWarn(s, r, "Invalid Short Name", "Short name must be 6 characters or less.", nil) + if !nameUnique || !shortNameUnique { + return true, nil + } + season, err = db.NewSeason(ctx, tx, name, shortName, startDate) if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - - // Validate short name is alphanumeric only - if !isAlphanumeric(shortName) { - err = notifyWarn(s, r, "Invalid Short Name", "Short name must contain only letters and numbers.", nil) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - - // Parse start date (DD/MM/YYYY format) - startDate, err := time.Parse("02/01/2006", startDateStr) - if err != nil { - err = notifyWarn(s, r, "Invalid Date", "Please provide a valid start date in DD/MM/YYYY format.", nil) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - - // Begin database transaction - ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) - defer cancel() - - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx")) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - defer tx.Rollback() - - // Double-check uniqueness (race condition protection) - nameUnique, err := db.IsSeasonNameUnique(ctx, tx, name) - if err != nil { - err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonNameUnique")) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) + return false, errors.Wrap(err, "db.NewSeason") } + return true, nil + }); !ok { return } if !nameUnique { - err = notifyWarn(s, r, "Duplicate Name", "This season name is already taken.", nil) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - - shortNameUnique, err := db.IsSeasonShortNameUnique(ctx, tx, shortName) - if err != nil { - err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonShortNameUnique")) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } + notify.Warn(s, w, r, "Duplicate Name", "This season name is already taken.", nil) return } if !shortNameUnique { - err = notifyWarn(s, r, "Duplicate Short Name", "This short name is already taken.", nil) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } + notify.Warn(s, w, r, "Duplicate Short Name", "This short name is already taken.", nil) return } - // Create the season - season, err := db.NewSeason(ctx, tx, name, shortName, startDate) - if err != nil { - err = notifyInternalServiceError(s, r, "Failed to create season", errors.Wrap(err, "db.NewSeason")) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - - // Commit transaction - err = tx.Commit() - if err != nil { - err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "tx.Commit")) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - - // Send success notification - err = notifySuccess(s, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil) - if err != nil { - // Log but don't fail the request - s.LogError(hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "Failed to send success notification", - Error: err, - }) - } - - // Redirect to the season detail page + notify.Success(s, w, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil) w.Header().Set("HX-Redirect", fmt.Sprintf("/seasons/%s", season.ShortName)) w.WriteHeader(http.StatusOK) }) } - -// Helper function to validate alphanumeric strings -func isAlphanumeric(s string) bool { - for _, r := range s { - if ((r < 'A') || (r > 'Z')) && ((r < '0') || (r > '9')) { - return false - } - } - return true -} - -func IsSeasonNameUnique( - s *hws.Server, - conn *bun.DB, -) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) - defer cancel() - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx")) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - defer tx.Rollback() - - // Trim whitespace for consistency - name := strings.TrimSpace(r.FormValue("name")) - - unique, err := db.IsSeasonNameUnique(ctx, tx, name) - if err != nil { - err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonNameUnique")) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - - tx.Commit() - - if !unique { - w.WriteHeader(http.StatusConflict) - return - } - - w.WriteHeader(http.StatusOK) - }) -} - -func IsSeasonShortNameUnique( - s *hws.Server, - conn *bun.DB, -) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) - defer cancel() - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx")) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - defer tx.Rollback() - - // Get short name and convert to uppercase for consistency - shortname := strings.ToUpper(strings.TrimSpace(r.FormValue("short_name"))) - - unique, err := db.IsSeasonShortNameUnique(ctx, tx, shortname) - if err != nil { - err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonShortNameUnique")) - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) - } - return - } - - tx.Commit() - - if !unique { - w.WriteHeader(http.StatusConflict) - return - } - - w.WriteHeader(http.StatusOK) - }) -} diff --git a/internal/handlers/notifications.go b/internal/handlers/notifications.go deleted file mode 100644 index ee0af94..0000000 --- a/internal/handlers/notifications.go +++ /dev/null @@ -1,55 +0,0 @@ -package handlers - -import ( - "net/http" - - "git.haelnorr.com/h/golib/hws" - "git.haelnorr.com/h/golib/notify" - "github.com/pkg/errors" -) - -func notifyClient( - s *hws.Server, - r *http.Request, - level notify.Level, - title, message, details string, - action any, -) error { - subCookie, err := r.Cookie("ws_sub_id") - if err != nil { - return errors.Wrap(err, "r.Cookie") - } - subID := notify.Target(subCookie.Value) - nt := notify.Notification{ - Target: subID, - Title: title, - Message: message, - Details: details, - Action: action, - Level: level, - } - s.NotifySub(nt) - return nil -} - -func notifyInternalServiceError(s *hws.Server, r *http.Request, msg string, err error) error { - return notifyClient(s, r, notify.LevelError, "Internal Service Error", msg, - SerializeErrorDetails(http.StatusInternalServerError, err), nil) -} - -func notifyServiceUnavailable(s *hws.Server, r *http.Request, msg string, err error) error { - return notifyClient(s, r, notify.LevelError, "Service Unavailable", msg, - SerializeErrorDetails(http.StatusServiceUnavailable, err), nil) -} - -func notifyWarn(s *hws.Server, r *http.Request, title, msg string, action any) error { - return notifyClient(s, r, notify.LevelWarn, title, msg, "", action) -} - -func notifyInfo(s *hws.Server, r *http.Request, title, msg string, action any) error { - return notifyClient(s, r, notify.LevelInfo, title, msg, "", action) -} - -func notifySuccess(s *hws.Server, r *http.Request, title, msg string, action any) error { - return notifyClient(s, r, notify.LevelSuccess, title, msg, "", action) -} diff --git a/internal/handlers/notifswebsocket.go b/internal/handlers/notifswebsocket.go index 9e5caea..072c901 100644 --- a/internal/handlers/notifswebsocket.go +++ b/internal/handlers/notifswebsocket.go @@ -10,6 +10,7 @@ import ( "git.haelnorr.com/h/golib/notify" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/view/component/popup" "github.com/coder/websocket" @@ -22,7 +23,7 @@ func NotificationWS( ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Upgrade") != "websocket" { - throwNotFound(s, w, r, r.URL.Path) + throw.NotFound(s, w, r, r.URL.Path) return } nc, err := setupClient(s, w, r) diff --git a/internal/handlers/notifytest.go b/internal/handlers/notifytest.go index 78d33ed..0aefe8a 100644 --- a/internal/handlers/notifytest.go +++ b/internal/handlers/notifytest.go @@ -6,6 +6,7 @@ import ( "github.com/pkg/errors" "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/notify" "git.haelnorr.com/h/oslstats/internal/view/page" ) @@ -16,12 +17,6 @@ func NotifyTester(s *hws.Server) http.Handler { func(w http.ResponseWriter, r *http.Request) { testErr := errors.New("This is a stack trace. No really i swear. Just pretend ok? Thanks") if r.Method == "GET" { - // page, _ := ErrorPage(hws.HWSError{ - // StatusCode: http.StatusTeapot, - // Message: "This error has been rendered as a test", - // Error: testErr, - // }) - // page.Render(r.Context(), w) renderSafely(page.Test(), s, r, w) } else { _ = r.ParseForm() @@ -30,19 +25,15 @@ func NotifyTester(s *hws.Server) http.Handler { level := r.Form.Get("type") message := r.Form.Get("message") - var err error switch level { case "success": - err = notifySuccess(s, r, title, message, nil) + notify.Success(s, w, r, title, message, nil) case "info": - err = notifyInfo(s, r, title, message, nil) + notify.Info(s, w, r, title, message, nil) case "warn": - err = notifyWarn(s, r, title, message, nil) + notify.Warn(s, w, r, title, message, nil) case "error": - err = notifyInternalServiceError(s, r, message, testErr) - } - if err != nil { - throwInternalServiceError(s, w, r, "Error notifying client", err) + notify.InternalServiceError(s, w, r, message, testErr) } } }, diff --git a/internal/handlers/page_opt_helpers.go b/internal/handlers/page_opt_helpers.go deleted file mode 100644 index 19ce03f..0000000 --- a/internal/handlers/page_opt_helpers.go +++ /dev/null @@ -1,68 +0,0 @@ -package handlers - -import ( - "net/http" - "strconv" - - "git.haelnorr.com/h/oslstats/internal/db" - "github.com/pkg/errors" - "github.com/uptrace/bun" -) - -func pageOptsFromForm(r *http.Request) (*db.PageOpts, error) { - var pageNum, perPage int - var order bun.Order - var orderBy string - var err error - - if pageStr := r.FormValue("page"); pageStr != "" { - pageNum, err = strconv.Atoi(pageStr) - if err != nil { - return nil, errors.Wrap(err, "invalid page number") - } - } - if perPageStr := r.FormValue("per_page"); perPageStr != "" { - perPage, err = strconv.Atoi(perPageStr) - if err != nil { - return nil, errors.Wrap(err, "invalid per_page number") - } - } - order = bun.Order(r.FormValue("order")) - orderBy = r.FormValue("order_by") - - pageOpts := &db.PageOpts{ - Page: pageNum, - PerPage: perPage, - Order: order, - OrderBy: orderBy, - } - return pageOpts, nil -} - -func pageOptsFromQuery(r *http.Request) (*db.PageOpts, error) { - var pageNum, perPage int - var order bun.Order - var orderBy string - var err error - if pageStr := r.URL.Query().Get("page"); pageStr != "" { - pageNum, err = strconv.Atoi(pageStr) - if err != nil { - return nil, errors.Wrap(err, "invalid page number") - } - } - if perPageStr := r.URL.Query().Get("per_page"); perPageStr != "" { - perPage, err = strconv.Atoi(perPageStr) - if err != nil { - return nil, errors.Wrap(err, "invalid per_page number") - } - } - order = bun.Order(r.URL.Query().Get("order")) - orderBy = r.URL.Query().Get("order_by") - pageOpts := &db.PageOpts{ - Page: pageNum, - PerPage: perPage, - Order: order, - OrderBy: orderBy, - } - return pageOpts, nil -} diff --git a/internal/handlers/pageopt_helpers.go b/internal/handlers/pageopt_helpers.go new file mode 100644 index 0000000..1e814a8 --- /dev/null +++ b/internal/handlers/pageopt_helpers.go @@ -0,0 +1,45 @@ +package handlers + +import ( + "net/http" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/validation" + "github.com/uptrace/bun" +) + +// pageOptsFromForm calls r.ParseForm and gets the pageOpts from the formdata. +// It renders a Bad Request error page on fail +// PageOpts will be nil on fail +func pageOptsFromForm(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts { + getter, ok := validation.ParseFormOrError(s, w, r) + if !ok { + return nil + } + return getPageOpts(s, w, r, getter) +} + +// pageOptsFromQuery gets the pageOpts from the request query and renders a Bad Request error page on fail +// PageOpts will be nil on fail +func pageOptsFromQuery(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts { + return getPageOpts(s, w, r, validation.NewQueryGetter(r)) +} + +func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *db.PageOpts { + page := g.Int("page").Min(1).Value + perPage := g.Int("per_page").Min(1).Max(100).Value + order := g.String("order").TrimSpace().ToUpper().AllowedValues([]string{"ASC", "DESC"}).Value + orderBy := g.String("order_by").TrimSpace().ToLower().Value + valid := g.ValidateAndError(s, w, r) + if !valid { + return nil + } + pageOpts := &db.PageOpts{ + Page: page, + PerPage: perPage, + Order: bun.Order(order), + OrderBy: orderBy, + } + return pageOpts +} diff --git a/internal/handlers/permtest.go b/internal/handlers/permtest.go index 7f2d145..d020e91 100644 --- a/internal/handlers/permtest.go +++ b/internal/handlers/permtest.go @@ -7,6 +7,7 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/roles" + "git.haelnorr.com/h/oslstats/internal/throw" "github.com/uptrace/bun" ) @@ -17,7 +18,7 @@ func PermTester(s *hws.Server, conn *bun.DB) http.Handler { isAdmin, err := user.HasRole(r.Context(), tx, roles.Admin) tx.Rollback() if err != nil { - throwInternalServiceError(s, w, r, "Error", err) + throw.InternalServiceError(s, w, r, "Error", err) } _, _ = w.Write([]byte(strconv.FormatBool(isAdmin))) }) diff --git a/internal/handlers/register.go b/internal/handlers/register.go index c0fa59a..80f973f 100644 --- a/internal/handlers/register.go +++ b/internal/handlers/register.go @@ -3,7 +3,6 @@ package handlers import ( "context" "net/http" - "time" "git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/hws" @@ -14,6 +13,7 @@ import ( "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/store" + "git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/view/page" ) @@ -29,27 +29,9 @@ func Register( attempts, exceeded, track := store.TrackRedirect(r, "/register", 3) if exceeded { - err := errors.Errorf( - "registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t", - attempts, - track.IP, - track.UserAgent, - track.Path, - track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), - cfg.HWSAuth.SSL, - ) - + err := track.Error(attempts) store.ClearRedirectTrack(r, "/register") - - throwError( - s, - w, - r, - http.StatusBadRequest, - "Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.", - err, - "warn", - ) + throw.BadRequest(s, w, r, "Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again", err) return } @@ -65,68 +47,45 @@ func Register( } store.ClearRedirectTrack(r, "/register") - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) - defer cancel() - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(s, w, r, "Database transaction failed", err) - return - } - defer tx.Rollback() - method := r.Method - if method == "GET" { - tx.Commit() + + if r.Method == "GET" { renderSafely(page.Register(details.DiscordUser.Username), s, r, w) return } - if method == "POST" { - username := r.FormValue("username") - user, err := registerUser(ctx, tx, username, details) + username := r.FormValue("username") + unique := false + var user *db.User + if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username) if err != nil { - err = notifyInternalServiceError(s, r, "Registration failed", err) - if err != nil { - throwInternalServiceError(s, w, r, "Registration failed", err) - } + return false, errors.Wrap(err, "db.IsUsernameUnique") + } + if !unique { + return true, nil + } + user, err = db.CreateUser(ctx, tx, username, details.DiscordUser) + if err != nil { + return false, errors.Wrap(err, "db.CreateUser") + } + err = user.UpdateDiscordToken(ctx, tx, details.Token) + if err != nil { + return false, errors.Wrap(err, "db.UpdateDiscordToken") + } + return true, nil + }); !ok { + return + } + if !unique { + w.WriteHeader(http.StatusConflict) + } else { + err = auth.Login(w, r, user, true) + if err != nil { + throw.InternalServiceError(s, w, r, "Login failed", err) return } - tx.Commit() - if user == nil { - w.WriteHeader(http.StatusConflict) - } else { - err = auth.Login(w, r, user, true) - if err != nil { - throwInternalServiceError(s, w, r, "Login failed", err) - return - } - pageFrom := cookies.CheckPageFrom(w, r) - w.Header().Set("HX-Redirect", pageFrom) - } - return + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) } }, ) } - -func registerUser( - ctx context.Context, - tx bun.Tx, - username string, - details *store.RegistrationSession, -) (*db.User, error) { - unique, err := db.IsUsernameUnique(ctx, tx, username) - if err != nil { - return nil, errors.Wrap(err, "db.IsUsernameUnique") - } - if !unique { - return nil, nil - } - user, err := db.CreateUser(ctx, tx, username, details.DiscordUser) - if err != nil { - return nil, errors.Wrap(err, "db.CreateUser") - } - err = user.UpdateDiscordToken(ctx, tx, details.Token) - if err != nil { - return nil, errors.Wrap(err, "db.UpdateDiscordToken") - } - return user, nil -} diff --git a/internal/handlers/season.go b/internal/handlers/season.go index bfd7c93..4a6b22b 100644 --- a/internal/handlers/season.go +++ b/internal/handlers/season.go @@ -3,10 +3,10 @@ package handlers import ( "context" "net/http" - "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/view/page" "github.com/pkg/errors" "github.com/uptrace/bun" @@ -17,23 +17,20 @@ func SeasonPage( conn *bun.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) - defer cancel() - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) - return - } - defer tx.Rollback() seasonStr := r.PathValue("season_short_name") - season, err := db.GetSeason(ctx, tx, seasonStr) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetSeason")) + var season *db.Season + if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + var err error + season, err = db.GetSeason(ctx, tx, seasonStr) + if err != nil { + return false, errors.Wrap(err, "db.GetSeason") + } + return true, nil + }); !ok { return } - tx.Commit() if season == nil { - throwNotFound(s, w, r, r.URL.Path) + throw.NotFound(s, w, r, r.URL.Path) return } renderSafely(page.SeasonPage(season), s, r, w) diff --git a/internal/handlers/seasons.go b/internal/handlers/seasons.go index f572103..e772c01 100644 --- a/internal/handlers/seasons.go +++ b/internal/handlers/seasons.go @@ -3,7 +3,6 @@ package handlers import ( "context" "net/http" - "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/db" @@ -17,25 +16,21 @@ func SeasonsPage( conn *bun.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) - defer cancel() - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) + pageOpts := pageOptsFromQuery(s, w, r) + if pageOpts == nil { return } - defer tx.Rollback() - pageOpts, err := pageOptsFromQuery(r) - if err != nil { - throwBadRequest(s, w, r, "invalid query", err) + var seasons *db.SeasonList + if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + var err error + seasons, err = db.ListSeasons(ctx, tx, pageOpts) + if err != nil { + return false, errors.Wrap(err, "db.ListSeasons") + } + return true, nil + }); !ok { return } - seasons, err := db.ListSeasons(ctx, tx, pageOpts) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.ListSeasons")) - return - } - tx.Commit() renderSafely(page.SeasonsPage(seasons), s, r, w) }) } @@ -45,37 +40,21 @@ func SeasonsList( conn *bun.DB, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) - defer cancel() - - // Parse form values - if err := r.ParseForm(); err != nil { - throwBadRequest(s, w, r, "Invalid form data", err) + pageOpts := pageOptsFromForm(s, w, r) + if pageOpts == nil { return } - - pageOpts, err := pageOptsFromForm(r) - if err != nil { - throwBadRequest(s, w, r, "invalid form data", err) + var seasons *db.SeasonList + if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) { + var err error + seasons, err = db.ListSeasons(ctx, tx, pageOpts) + if err != nil { + return false, errors.Wrap(err, "db.ListSeasons") + } + return true, nil + }); !ok { return } - - // Database query - tx, err := conn.BeginTx(ctx, nil) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) - return - } - defer tx.Rollback() - - seasons, err := db.ListSeasons(ctx, tx, pageOpts) - if err != nil { - throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.ListSeasons")) - return - } - tx.Commit() - - // Return only the list component (hx-push-url handles URL update client-side) renderSafely(page.SeasonsList(seasons), s, r, w) }) } diff --git a/internal/handlers/static.go b/internal/handlers/static.go index 3f4d2ff..8c35ece 100644 --- a/internal/handlers/static.go +++ b/internal/handlers/static.go @@ -6,6 +6,7 @@ import ( "strings" "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/throw" ) // StaticFS handles requests for static files, without allowing access to the @@ -16,7 +17,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler { if err != nil { // If we can't create the file server, return a handler that always errors return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err) + throw.InternalServiceError(server, w, r, "An error occurred trying to load the file system", err) }) } diff --git a/internal/notify/notify.go b/internal/notify/notify.go new file mode 100644 index 0000000..be3ce60 --- /dev/null +++ b/internal/notify/notify.go @@ -0,0 +1,63 @@ +// Package notify provides utility functions for sending notifications to clients +package notify + +import ( + "net/http" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/notify" + "git.haelnorr.com/h/oslstats/internal/throw" + "github.com/pkg/errors" +) + +func notifyClient( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + level notify.Level, + title, message, details string, + action any, +) { + subCookie, err := r.Cookie("ws_sub_id") + if err != nil { + throw.InternalServiceError(s, w, r, "Notification failed. Do you have cookies enabled?", errors.Wrap(err, "r.Cookie")) + return + } + subID := notify.Target(subCookie.Value) + nt := notify.Notification{ + Target: subID, + Title: title, + Message: message, + Details: details, + Action: action, + Level: level, + } + s.NotifySub(nt) +} + +// InternalServiceError notifies with error level +func InternalServiceError(s *hws.Server, w http.ResponseWriter, r *http.Request, msg string, err error) { + notifyClient(s, w, r, notify.LevelError, "Internal Service Error", msg, + SerializeErrorDetails(http.StatusInternalServerError, err), nil) +} + +// ServiceUnavailable notifies with error level +func ServiceUnavailable(s *hws.Server, w http.ResponseWriter, r *http.Request, msg string, err error) { + notifyClient(s, w, r, notify.LevelError, "Service Unavailable", msg, + SerializeErrorDetails(http.StatusServiceUnavailable, err), nil) +} + +// Warn notifies with warn level +func Warn(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) { + notifyClient(s, w, r, notify.LevelWarn, title, msg, "", action) +} + +// Info notifies with info level +func Info(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) { + notifyClient(s, w, r, notify.LevelInfo, title, msg, "", action) +} + +// Success notifies with success level +func Success(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) { + notifyClient(s, w, r, notify.LevelSuccess, title, msg, "", action) +} diff --git a/internal/notify/util.go b/internal/notify/util.go new file mode 100644 index 0000000..6ab1129 --- /dev/null +++ b/internal/notify/util.go @@ -0,0 +1,53 @@ +package notify + +import ( + "encoding/json" + "fmt" +) + +// ErrorDetails contains structured error information for WebSocket error modals +type ErrorDetails struct { + Code int `json:"code"` + Stacktrace string `json:"stacktrace"` +} + +// SerializeErrorDetails creates a JSON string with code and stacktrace +// This is exported so it can be used when creating error notifications +func SerializeErrorDetails(code int, err error) string { + details := ErrorDetails{ + Code: code, + Stacktrace: FormatErrorDetails(err), + } + jsonData, jsonErr := json.Marshal(details) + if jsonErr != nil { + // Fallback if JSON encoding fails + return fmt.Sprintf(`{"code":%d,"stacktrace":"Failed to serialize error"}`, code) + } + return string(jsonData) +} + +// ParseErrorDetails extracts code and stacktrace from JSON Details field +// Returns (code, stacktrace). If parsing fails, returns (500, original details string) +func ParseErrorDetails(details string) (int, string) { + if details == "" { + return 500, "" + } + + var errDetails ErrorDetails + err := json.Unmarshal([]byte(details), &errDetails) + if err != nil { + // Not JSON or malformed - treat as plain stacktrace with default code + return 500, details + } + + return errDetails.Code, errDetails.Stacktrace +} + +// FormatErrorDetails extracts and formats error details from wrapped errors +func FormatErrorDetails(err error) string { + if err == nil { + return "" + } + // Use %+v format to get stack trace from github.com/pkg/errors + return fmt.Sprintf("%+v", err) +} diff --git a/internal/rbac/cache_middleware.go b/internal/rbac/cache_middleware.go index 5b642d3..a834b82 100644 --- a/internal/rbac/cache_middleware.go +++ b/internal/rbac/cache_middleware.go @@ -3,7 +3,6 @@ package rbac import ( "context" "net/http" - "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/contexts" @@ -11,6 +10,7 @@ import ( "git.haelnorr.com/h/oslstats/internal/permissions" "git.haelnorr.com/h/oslstats/internal/roles" "github.com/pkg/errors" + "github.com/uptrace/bun" ) // LoadPermissionsMiddleware loads user permissions into context after authentication @@ -26,46 +26,28 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware { return } - // Start transaction for loading permissions - ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) - defer cancel() - tx, err := c.conn.BeginTx(ctx, nil) - if err != nil { - // Log but don't block - permission checks will fail gracefully + var roles_ []*db.Role + var perms []*db.Permission + if err := db.WithTxFailSilently(r.Context(), c.conn, func(ctx context.Context, tx bun.Tx) error { + var err error + roles_, err = user.GetRoles(ctx, tx) + if err != nil { + return errors.Wrap(err, "user.GetRoles") + } + perms, err = user.GetPermissions(ctx, tx) + if err != nil { + return errors.Wrap(err, "user.GetPermissions") + } + return nil + }); err != nil { c.s.LogError(hws.HWSError{ - Message: "Failed to start database transaction", - Error: errors.Wrap(err, "c.conn.BeginTx"), + Message: "Database error", + Error: err, Level: hws.ErrorERROR, }) next.ServeHTTP(w, r) return } - defer func() { _ = tx.Rollback() }() - - // Load user's roles_ and permissions - roles_, err := user.GetRoles(ctx, tx) - if err != nil { - c.s.LogError(hws.HWSError{ - Message: "Failed to get user roles", - Error: errors.Wrap(err, "user.GetRoles"), - Level: hws.ErrorERROR, - }) - next.ServeHTTP(w, r) - return - } - - perms, err := user.GetPermissions(ctx, tx) - if err != nil { - c.s.LogError(hws.HWSError{ - Message: "Failed to get user permissions", - Error: errors.Wrap(err, "user.GetPermissions"), - Level: hws.ErrorERROR, - }) - next.ServeHTTP(w, r) - return - } - - _ = tx.Commit() // read only transaction // Build permission cache cache := &contexts.PermissionCache{ @@ -88,7 +70,7 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware { } // Add cache to context (type-safe) - ctx = context.WithValue(ctx, contexts.PermissionCacheKey, cache) + ctx := context.WithValue(r.Context(), contexts.PermissionCacheKey, cache) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/store/redirects.go b/internal/store/redirects.go index e2e7852..b296dd4 100644 --- a/internal/store/redirects.go +++ b/internal/store/redirects.go @@ -5,6 +5,8 @@ import ( "net/http" "strings" "time" + + "github.com/pkg/errors" ) // getClientIP extracts the client IP address, checking X-Forwarded-For first @@ -93,3 +95,14 @@ func (s *Store) ClearRedirectTrack(r *http.Request, path string) { key := redirectKey(ip, userAgent, path) s.redirectTracks.Delete(key) } + +func (t *RedirectTrack) Error(attempts int) error { + return errors.Errorf( + "callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s", + attempts, + t.IP, + t.UserAgent, + t.Path, + t.FirstSeen.Format("2006-01-02T15:04:05Z07:00"), + ) +} diff --git a/internal/throw/throw.go b/internal/throw/throw.go new file mode 100644 index 0000000..ad67eaf --- /dev/null +++ b/internal/throw/throw.go @@ -0,0 +1,118 @@ +// Package throw provides utility functions for throwing HTTP errors that render an error page +package throw + +import ( + "fmt" + "net/http" + + "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" +) + +// throwError is a generic helper that all throw* functions use internally +func throwError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + statusCode int, + msg string, + err error, + level hws.ErrorLevel, +) { + s.ThrowError(w, r, hws.HWSError{ + StatusCode: statusCode, + Message: msg, + Error: err, + Level: level, + RenderErrorPage: true, // throw* family always renders error pages + }) +} + +// InternalServiceError handles 500 errors (server failures) +func InternalServiceError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusInternalServerError, msg, err, hws.ErrorERROR) +} + +// ServiceUnavailable handles 503 errors +func ServiceUnavailable( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusServiceUnavailable, msg, err, hws.ErrorERROR) +} + +// BadRequest handles 400 errors (malformed requests) +func BadRequest( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusBadRequest, msg, err, hws.ErrorDEBUG) +} + +// Forbidden handles 403 errors (normal permission denials) +func Forbidden( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorDEBUG) +} + +// ForbiddenSecurity handles 403 errors for security events (uses WARN level) +func ForbiddenSecurity( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorWARN) +} + +// Unauthorized handles 401 errors (not authenticated) +func Unauthorized( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorDEBUG) +} + +// UnauthorizedSecurity handles 401 errors for security events (uses WARN level) +func UnauthorizedSecurity( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorWARN) +} + +// NotFound handles 404 errors +func NotFound( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + path string, +) { + msg := fmt.Sprintf("The requested resource was not found: %s", path) + err := errors.New("Resource not found") + throwError(s, w, r, http.StatusNotFound, msg, err, hws.ErrorDEBUG) +} diff --git a/internal/validation/forms.go b/internal/validation/forms.go new file mode 100644 index 0000000..53de7db --- /dev/null +++ b/internal/validation/forms.go @@ -0,0 +1,101 @@ +package validation + +import ( + "net/http" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/notify" + "git.haelnorr.com/h/oslstats/internal/throw" + "git.haelnorr.com/h/timefmt" + "github.com/pkg/errors" +) + +// FormGetter wraps http.Request to get form values +type FormGetter struct { + r *http.Request + checks []*ValidationRule +} + +func NewFormGetter(r *http.Request) *FormGetter { + return &FormGetter{r: r, checks: []*ValidationRule{}} +} + +func (f *FormGetter) Get(key string) string { + return f.r.FormValue(key) +} + +func (f *FormGetter) getChecks() []*ValidationRule { + return f.checks +} + +func (f *FormGetter) AddCheck(check *ValidationRule) { + f.checks = append(f.checks, check) +} + +func (f *FormGetter) ValidateChecks() []*ValidationRule { + return validate(f) +} + +func (f *FormGetter) String(key string) *StringField { + return newStringField(key, f) +} + +func (f *FormGetter) Int(key string) *IntField { + return newIntField(key, f) +} + +func (f *FormGetter) Time(key string, format *timefmt.Format) *TimeField { + return newTimeField(key, format, f) +} + +func ParseForm(r *http.Request) (*FormGetter, error) { + err := r.ParseForm() + if err != nil { + return nil, errors.Wrap(err, "r.ParseForm") + } + return NewFormGetter(r), nil +} + +// ParseFormOrNotify attempts to parse the form data and notifies the user on fail +func ParseFormOrNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) { + getter, err := ParseForm(r) + if err != nil { + notify.Warn(s, w, r, "Invalid Form", "Please check your input and try again.", nil) + return nil, false + } + return getter, true +} + +// ParseFormOrError attempts to parse the form data and renders an error page on fail +func ParseFormOrError(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) { + getter, err := ParseForm(r) + if err != nil { + throw.BadRequest(s, w, r, "Invalid form data", err) + return nil, false + } + return getter, true +} + +func (f *FormGetter) Validate() bool { + return len(validate(f)) == 0 +} + +// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check +// Returns true if all checks passed +func (f *FormGetter) ValidateAndNotify( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, +) bool { + return validateAndNotify(s, w, r, f) +} + +// ValidateAndError runs the provided validation checks and renders an error page with all the error messages +// Returns true if all checks passed +func (f *FormGetter) ValidateAndError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, +) bool { + return validateAndError(s, w, r, f) +} diff --git a/internal/validation/integerfield.go b/internal/validation/integerfield.go new file mode 100644 index 0000000..25cc69a --- /dev/null +++ b/internal/validation/integerfield.go @@ -0,0 +1,71 @@ +package validation + +import ( + "fmt" + "strconv" +) + +type IntField struct { + Field + Value int +} + +func newIntField(key string, g Getter) *IntField { + raw := g.Get(key) + var val int + if raw != "" { + var err error + val, err = strconv.Atoi(raw) + if err != nil { + g.AddCheck(newFailedCheck( + "Value is not a number", + fmt.Sprintf("%s must be an integer: %s provided", key, raw), + )) + } + } + return &IntField{ + Value: val, + Field: newField(key, g), + } +} + +// Required enforces a non-zero value +func (i *IntField) Required() *IntField { + if i.Value == 0 { + i.getter.AddCheck(newFailedCheck( + "Value cannot be 0", + fmt.Sprintf("%s is required", i.Key), + )) + } + return i +} + +// Optional will skip all validations if value is empty +func (i *IntField) Optional() *IntField { + if i.Value == 0 { + i.optional = true + } + return i +} + +// Max enforces a maxmium value +func (i *IntField) Max(max int) *IntField { + if i.Value > max && !i.optional { + i.getter.AddCheck(newFailedCheck( + "Value too large", + fmt.Sprintf("%s is too large, max %v", i.Key, max), + )) + } + return i +} + +// Min enforces a minimum value +func (i *IntField) Min(min int) *IntField { + if i.Value < min && !i.optional { + i.getter.AddCheck(newFailedCheck( + "Value too small", + fmt.Sprintf("%s is too small, min %v", i.Key, min), + )) + } + return i +} diff --git a/internal/validation/querys.go b/internal/validation/querys.go new file mode 100644 index 0000000..ed32fd9 --- /dev/null +++ b/internal/validation/querys.go @@ -0,0 +1,70 @@ +package validation + +import ( + "net/http" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/timefmt" +) + +// QueryGetter wraps http.Request to get query values +type QueryGetter struct { + r *http.Request + checks []*ValidationRule +} + +func NewQueryGetter(r *http.Request) *QueryGetter { + return &QueryGetter{r: r, checks: []*ValidationRule{}} +} + +func (q *QueryGetter) Get(key string) string { + return q.r.URL.Query().Get(key) +} + +func (q *QueryGetter) getChecks() []*ValidationRule { + return q.checks +} + +func (q *QueryGetter) AddCheck(check *ValidationRule) { + q.checks = append(q.checks, check) +} + +func (q *QueryGetter) ValidateChecks() []*ValidationRule { + return validate(q) +} + +func (q *QueryGetter) String(key string) *StringField { + return newStringField(key, q) +} + +func (q *QueryGetter) Int(key string) *IntField { + return newIntField(key, q) +} + +func (q *QueryGetter) Time(key string, format *timefmt.Format) *TimeField { + return newTimeField(key, format, q) +} + +func (q *QueryGetter) Validate() bool { + return len(validate(q)) == 0 +} + +// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check +// Returns true if all checks passed +func (q *QueryGetter) ValidateAndNotify( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, +) bool { + return validateAndNotify(s, w, r, q) +} + +// ValidateAndError runs the provided validation checks and renders an error page with all the error messages +// Returns true if all checks passed +func (q *QueryGetter) ValidateAndError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, +) bool { + return validateAndError(s, w, r, q) +} diff --git a/internal/validation/stringfield.go b/internal/validation/stringfield.go new file mode 100644 index 0000000..5b97e4e --- /dev/null +++ b/internal/validation/stringfield.go @@ -0,0 +1,106 @@ +package validation + +import ( + "fmt" + "slices" + "strings" + "unicode" +) + +type StringField struct { + Field + Value string +} + +func newStringField(key string, g Getter) *StringField { + return &StringField{ + Value: g.Get(key), + Field: newField(key, g), + } +} + +// Required enforces a non empty string +func (s *StringField) Required() *StringField { + if s.Value == "" { + s.getter.AddCheck(newFailedCheck( + "Field not provided", + fmt.Sprintf("%s is required", s.Key), + )) + } + return s +} + +// Optional will skip all validations if value is empty +func (s *StringField) Optional() *StringField { + if s.Value == "" { + s.optional = true + } + return s +} + +// MaxLength enforces a maximum string length +func (s *StringField) MaxLength(length int) *StringField { + if len(s.Value) > length && !s.optional { + s.getter.AddCheck(newFailedCheck( + "Input too long", + fmt.Sprintf("%s is too long, max %v chars", s.Key, length), + )) + } + return s +} + +// MinLength enforces a minimum string length +func (s *StringField) MinLength(length int) *StringField { + if len(s.Value) < length && !s.optional { + s.getter.AddCheck(newFailedCheck( + "Input too short", + fmt.Sprintf("%s is too short, min %v chars", s.Key, length), + )) + } + return s +} + +// AlphaNumeric enforces the string contains only letters and numbers +func (s *StringField) AlphaNumeric() *StringField { + if s.optional { + return s + } + for _, r := range s.Value { + if !unicode.IsLetter(r) && !unicode.IsDigit(r) { + s.getter.AddCheck(newFailedCheck( + "Invalid characters", + fmt.Sprintf("%s must contain only letters and numbers.", s.Key), + )) + return s + } + } + return s +} + +func (s *StringField) AllowedValues(allowed []string) *StringField { + if !slices.Contains(allowed, s.Value) && !s.optional { + s.getter.AddCheck(newFailedCheck( + "Value not allowed", + fmt.Sprintf("%s must be one of: %s", s.Key, strings.Join(allowed, ",")), + )) + } + return s +} + +// ToUpper transforms the string to uppercase +func (s *StringField) ToUpper() *StringField { + s.Value = strings.ToUpper(s.Value) + return s +} + +// ToLower transforms the string to lowercase +func (s *StringField) ToLower() *StringField { + s.Value = strings.ToLower(s.Value) + return s +} + +// TrimSpace removes leading and trailing whitespace +func (s *StringField) TrimSpace() *StringField { + s.Value = strings.TrimSpace(s.Value) + return s +} diff --git a/internal/validation/timefield.go b/internal/validation/timefield.go new file mode 100644 index 0000000..6cad564 --- /dev/null +++ b/internal/validation/timefield.go @@ -0,0 +1,50 @@ +package validation + +import ( + "fmt" + "time" + + "git.haelnorr.com/h/timefmt" +) + +type TimeField struct { + Field + Value time.Time +} + +func newTimeField(key string, format *timefmt.Format, g Getter) *TimeField { + raw := g.Get(key) + var startDate time.Time + if raw != "" { + var err error + startDate, err = format.Parse(raw) + if err != nil { + g.AddCheck(newFailedCheck( + "Invalid date/time format", + fmt.Sprintf("%s should be in format %s", key, format.LDML()), + )) + } + } + return &TimeField{ + Value: startDate, + Field: newField(key, g), + } +} + +func (t *TimeField) Required() *TimeField { + if t.Value.IsZero() { + t.getter.AddCheck(newFailedCheck( + "Date/Time not provided", + fmt.Sprintf("%s must be provided", t.Key), + )) + } + return t +} + +// Optional will skip all validations if value is empty +func (t *TimeField) Optional() *TimeField { + if t.Value.IsZero() { + t.optional = true + } + return t +} diff --git a/internal/validation/validation.go b/internal/validation/validation.go new file mode 100644 index 0000000..3206ecc --- /dev/null +++ b/internal/validation/validation.go @@ -0,0 +1,93 @@ +// Package validation provides utilities for parsing and validating request data +package validation + +import ( + "errors" + "fmt" + "net/http" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/notify" + "git.haelnorr.com/h/oslstats/internal/throw" + "git.haelnorr.com/h/timefmt" +) + +type ValidationRule struct { + Condition bool // Condition on which to fail validation + Title string // Title for warning message + Message string // Warning message +} + +// Getter abstracts getting values from either form or query +type Getter interface { + Get(key string) string + AddCheck(check *ValidationRule) + String(key string) *StringField + Int(key string) *IntField + Time(key string, format *timefmt.Format) *TimeField + ValidateChecks() []*ValidationRule + Validate() bool + ValidateAndNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) bool + ValidateAndError(s *hws.Server, w http.ResponseWriter, r *http.Request) bool + + getChecks() []*ValidationRule +} +type Field struct { + Key string + optional bool + getter Getter +} + +func newField(key string, g Getter) Field { + return Field{ + Key: key, + getter: g, + } +} + +func newFailedCheck(title, message string) *ValidationRule { + return &ValidationRule{ + Condition: true, + Title: title, + Message: message, + } +} + +func validate(g Getter) []*ValidationRule { + failed := []*ValidationRule{} + for _, check := range g.getChecks() { + if check != nil && check.Condition { + failed = append(failed, check) + } + } + return failed +} + +func validateAndNotify( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + g Getter, +) bool { + failedChecks := g.ValidateChecks() + for _, check := range failedChecks { + notify.Warn(s, w, r, check.Title, check.Message, nil) + } + return len(failedChecks) == 0 +} + +func validateAndError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + g Getter, +) bool { + failedChecks := g.ValidateChecks() + var err error + for _, check := range failedChecks { + err_ := fmt.Errorf("%s: %s", check.Title, check.Message) + err = errors.Join(err, err_) + } + throw.BadRequest(s, w, r, "Invalid form data", err) + return len(failedChecks) == 0 +}