league #1

Merged
h merged 41 commits from league into master 2026-02-15 19:59:31 +11:00
40 changed files with 1211 additions and 920 deletions
Showing only changes of commit ac38025b77 - Show all commits

View File

@@ -101,17 +101,17 @@ func addRoutes(
{
Path: "/htmx/isusernameunique",
Method: hws.MethodPOST,
Handler: handlers.IsUsernameUnique(s, conn, cfg, store),
Handler: handlers.IsUnique(s, conn, (*db.User)(nil), "username"),
},
{
Path: "/htmx/isseasonnameunique",
Method: hws.MethodPOST,
Handler: handlers.IsSeasonNameUnique(s, conn),
Handler: handlers.IsUnique(s, conn, (*db.Season)(nil), "name"),
},
{
Path: "/htmx/isseasonshortnameunique",
Method: hws.MethodPOST,
Handler: handlers.IsSeasonShortNameUnique(s, conn),
Handler: handlers.IsUnique(s, conn, (*db.Season)(nil), "short_name"),
},
}

3
go.mod
View File

@@ -1,6 +1,6 @@
module git.haelnorr.com/h/oslstats
go 1.25.5
go 1.25.6
require (
git.haelnorr.com/h/golib/env v0.9.1
@@ -27,6 +27,7 @@ require (
require (
git.haelnorr.com/h/golib/cookies v0.9.0
git.haelnorr.com/h/golib/jwt v0.10.1 // indirect
git.haelnorr.com/h/timefmt v0.1.0
github.com/bwmarrin/discordgo v0.29.0
github.com/go-logr/logr v1.4.3 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect

2
go.sum
View File

@@ -16,6 +16,8 @@ git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7Kv
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
git.haelnorr.com/h/timefmt v0.1.0 h1:ULDkWEtFIV+FkkoV0q9n62Spj+HDdtFL9QeAdGIEp+o=
git.haelnorr.com/h/timefmt v0.1.0/go.mod h1:12gXXYLP4w9Fa9ZkbZWdvKV6RyZEzwAm9mN+WB3oXpw=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=

View File

@@ -105,35 +105,33 @@ func (l *Logger) log(
// GetRecentLogs retrieves recent audit logs with pagination
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
tx, err := l.conn.BeginTx(ctx, nil)
if err != nil {
return nil, errors.Wrap(err, "conn.BeginTx")
var logs *db.AuditLogs
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
var err error
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil)
if err != nil {
return errors.Wrap(err, "db.GetAuditLogs")
}
return nil
}); err != nil {
return nil, errors.Wrap(err, "db.WithTxFailSilently")
}
defer func() { _ = tx.Rollback() }()
logs, err := db.GetAuditLogs(ctx, tx, pageOpts, nil)
if err != nil {
return nil, err
}
_ = tx.Commit() // read only transaction
return logs, nil
}
// GetLogsByUser retrieves audit logs for a specific user
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
tx, err := l.conn.BeginTx(ctx, nil)
if err != nil {
return nil, errors.Wrap(err, "conn.BeginTx")
var logs *db.AuditLogs
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
var err error
logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
if err != nil {
return errors.Wrap(err, "db.GetAuditLogsByUser")
}
return nil
}); err != nil {
return nil, errors.Wrap(err, "db.WithTxFailSilently")
}
defer func() { _ = tx.Rollback() }()
logs, err := db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
if err != nil {
return nil, err
}
_ = tx.Commit() // read only transaction
return logs, nil
}
@@ -145,20 +143,16 @@ func (l *Logger) CleanupOldLogs(ctx context.Context, daysToKeep int) (int, error
cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix()
tx, err := l.conn.BeginTx(ctx, nil)
if err != nil {
return 0, errors.Wrap(err, "conn.BeginTx")
}
defer func() { _ = tx.Rollback() }()
count, err := db.CleanupOldAuditLogs(ctx, tx, cutoffTime)
if err != nil {
return 0, err
}
err = tx.Commit()
if err != nil {
return 0, errors.Wrap(err, "tx.Commit")
var count int
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
var err error
count, err = db.CleanupOldAuditLogs(ctx, tx, cutoffTime)
if err != nil {
return errors.Wrap(err, "db.CleanupOldAuditLogs")
}
return nil
}); err != nil {
return 0, errors.Wrap(err, "db.WithTxFailSilently")
}
return count, nil
}

19
internal/db/isunique.go Normal file
View File

@@ -0,0 +1,19 @@
package db
import (
"context"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func IsUnique(ctx context.Context, tx bun.Tx, model any, field, value string) (bool, error) {
count, err := tx.NewSelect().
Model(model).
Where("? = ?", bun.Ident(field), value).
Count(ctx)
if err != nil {
return false, errors.Wrap(err, "tx.NewSelect")
}
return count == 0, nil
}

View File

@@ -87,25 +87,3 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error
}
return season, nil
}
func IsSeasonNameUnique(ctx context.Context, tx bun.Tx, name string) (bool, error) {
count, err := tx.NewSelect().
Model((*Season)(nil)).
Where("name = ?", name).
Count(ctx)
if err != nil {
return false, errors.Wrap(err, "tx.NewSelect")
}
return count == 0, nil
}
func IsSeasonShortNameUnique(ctx context.Context, tx bun.Tx, shortname string) (bool, error) {
count, err := tx.NewSelect().
Model((*Season)(nil)).
Where("short_name = ?", shortname).
Count(ctx)
if err != nil {
return false, errors.Wrap(err, "tx.NewSelect")
}
return count == 0, nil
}

118
internal/db/txhelpers.go Normal file
View File

@@ -0,0 +1,118 @@
package db
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/throw"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// TxFunc is a function that runs within a database transaction
type (
TxFunc func(ctx context.Context, tx bun.Tx) (bool, error)
TxFuncSilent func(ctx context.Context, tx bun.Tx) error
)
var timeout = 15 * time.Second
// WithReadTx executes a read-only transaction with automatic rollback
// Returns true if successful, false if error was thrown to client
func WithReadTx(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
conn *bun.DB,
fn TxFunc,
) bool {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
ok, err := withTx(ctx, conn, fn, false)
if err != nil {
throw.InternalServiceError(s, w, r, "Database error", err)
}
return ok
}
// WithTxFailSilently executes a transaction with automatic rollback
// Returns true if successful, false if error occured.
// Does not throw any errors to the client.
func WithTxFailSilently(
ctx context.Context,
conn *bun.DB,
fn TxFuncSilent,
) error {
fnc := func(ctx context.Context, tx bun.Tx) (bool, error) {
err := fn(ctx, tx)
return err != nil, err
}
_, err := withTx(ctx, conn, fnc, true)
return err
}
// WithWriteTx executes a write transaction with automatic rollback on error
// Commits only if fn returns nil. Returns true if successful.
func WithWriteTx(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
conn *bun.DB,
fn TxFunc,
) bool {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
ok, err := withTx(ctx, conn, fn, true)
if err != nil {
throw.InternalServiceError(s, w, r, "Database error", err)
}
return ok
}
// WithNotifyTx executes a transaction with notification-based error handling
// Uses notifyInternalServiceError instead of throwInternalServiceError
func WithNotifyTx(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
conn *bun.DB,
fn TxFunc,
) bool {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
ok, err := withTx(ctx, conn, fn, true)
if err != nil {
notify.InternalServiceError(s, w, r, "Database error", err)
}
return ok
}
// withTx executes a transaction with automatic rollback on error
func withTx(
ctx context.Context,
conn *bun.DB,
fn TxFunc,
write bool,
) (bool, error) {
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return false, errors.Wrap(err, "conn.BeginTx")
}
defer func() { _ = tx.Rollback() }()
ok, err := fn(ctx, tx)
if err != nil || !ok {
return false, err
}
if write {
err = tx.Commit()
if err != nil {
return false, errors.Wrap(err, "tx.Commit")
}
} else {
_ = tx.Commit()
}
return true, nil
}

View File

@@ -112,19 +112,6 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User
return user, nil
}
// IsUsernameUnique checks if the given username is unique (not already taken)
// Returns true if the username is available, false if it's taken
func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) {
count, err := tx.NewSelect().
Model((*User)(nil)).
Where("username = ?", username).
Count(ctx)
if err != nil {
return false, errors.Wrap(err, "tx.NewSelect")
}
return count == 0, nil
}
// GetRoles loads all the roles for this user
func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
if u == nil {

View File

@@ -3,7 +3,6 @@ package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
@@ -14,23 +13,17 @@ import (
func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
var users *db.Users
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
users, err = db.GetUsers(ctx, tx, nil)
if err != nil {
return false, errors.Wrap(err, "db.GetUsers")
}
return true, nil
}); !ok {
return
}
defer func() { _ = tx.Rollback() }()
users, err := db.GetUsers(ctx, tx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetUsers"))
return
}
_ = tx.Commit()
renderSafely(page.AdminDashboard(users), s, r, w)
})
}

View File

@@ -3,7 +3,6 @@ package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
@@ -15,30 +14,21 @@ import (
// AdminUsersList shows all users
func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "DB Transaction failed", errors.Wrap(err, "conn.BeginTx"))
var users *db.Users
pageOpts := pageOptsFromForm(s, w, r)
if pageOpts == nil {
return
}
defer func() { _ = tx.Rollback() }()
// Get all users
pageOpts, err := pageOptsFromForm(r)
if err != nil {
throwBadRequest(s, w, r, "invalid form data", err)
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
users, err = db.GetUsers(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.GetUsers")
}
return true, nil
}); !ok {
return
}
users, err := db.GetUsers(ctx, tx, pageOpts)
if err != nil {
throwInternalServiceError(s, w, r, "Failed to load users", errors.Wrap(err, "db.GetUsers"))
return
}
_ = tx.Commit()
renderSafely(admin.UserList(users), s, r, w)
})
}

View File

@@ -3,7 +3,6 @@ package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws"
@@ -15,11 +14,12 @@ import (
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/pkg/oauth"
)
func Callback(
server *hws.Server,
s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB,
cfg *config.Config,
@@ -31,26 +31,9 @@ func Callback(
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
if exceeded {
err := errors.Errorf(
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
)
err := track.Error(attempts)
store.ClearRedirectTrack(r, "/callback")
throwError(
server,
w,
r,
http.StatusBadRequest,
"OAuth callback failed: Too many redirect attempts. Please try logging in again.",
err,
"warn",
)
throw.BadRequest(s, w, r, "Too many redirects. Please try logging in again.", err)
return
}
@@ -64,12 +47,12 @@ func Callback(
if err != nil {
if vsErr, ok := err.(*verifyStateError); ok {
if vsErr.IsCookieError() {
throwUnauthorized(server, w, r, "OAuth session not found or expired", err)
throw.Unauthorized(s, w, r, "OAuth session not found or expired", err)
} else {
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
}
} else {
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
}
return
}
@@ -77,20 +60,17 @@ func Callback(
switch data {
case "login":
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(server, w, r, "DB Transaction failed to start", err)
var redirect func()
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
if err != nil {
throw.InternalServiceError(s, w, r, "OAuth login failed", err)
return false, nil
}
return true, nil
}); !ok {
return
}
defer tx.Rollback()
redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
if err != nil {
throwInternalServiceError(server, w, r, "OAuth login failed", err)
return
}
tx.Commit()
redirect()
return
}

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/view/page"
)
@@ -21,7 +22,7 @@ func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
// Get technical details if applicable
var details string
if showDetails && hwsError.Error != nil {
details = formatErrorDetails(hwsError.Error)
details = notify.FormatErrorDetails(hwsError.Error)
}
// Render appropriate template

View File

@@ -2,152 +2,15 @@ package handlers
import (
"encoding/json"
"fmt"
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/throw"
"github.com/a-h/templ"
"github.com/pkg/errors"
)
// throwError is a generic helper that all throw* functions use internally
func throwError(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
statusCode int,
msg string,
err error,
level hws.ErrorLevel,
) {
s.ThrowError(w, r, hws.HWSError{
StatusCode: statusCode,
Message: msg,
Error: err,
Level: level,
RenderErrorPage: true, // throw* family always renders error pages
})
}
// throwInternalServiceError handles 500 errors (server failures)
func throwInternalServiceError(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusInternalServerError, msg, err, hws.ErrorERROR)
}
// throwServiceUnavailable handles 503 errors
func throwServiceUnavailable(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusServiceUnavailable, msg, err, hws.ErrorERROR)
}
// throwBadRequest handles 400 errors (malformed requests)
func throwBadRequest(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusBadRequest, msg, err, hws.ErrorDEBUG)
}
// throwForbidden handles 403 errors (normal permission denials)
func throwForbidden(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorDEBUG)
}
// throwForbiddenSecurity handles 403 errors for security events (uses WARN level)
func throwForbiddenSecurity(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorWARN)
}
// throwUnauthorized handles 401 errors (not authenticated)
func throwUnauthorized(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorDEBUG)
}
// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level)
func throwUnauthorizedSecurity(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorWARN)
}
// throwNotFound handles 404 errors
func throwNotFound(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
path string,
) {
msg := fmt.Sprintf("The requested resource was not found: %s", path)
err := errors.New("Resource not found")
throwError(s, w, r, http.StatusNotFound, msg, err, hws.ErrorDEBUG)
}
// ErrorDetails contains structured error information for WebSocket error modals
type ErrorDetails struct {
Code int `json:"code"`
Stacktrace string `json:"stacktrace"`
}
// formatErrorDetails extracts and formats error details from wrapped errors
func formatErrorDetails(err error) string {
if err == nil {
return ""
}
// Use %+v format to get stack trace from github.com/pkg/errors
return fmt.Sprintf("%+v", err)
}
// SerializeErrorDetails creates a JSON string with code and stacktrace
// This is exported so it can be used when creating error notifications
func SerializeErrorDetails(code int, err error) string {
details := ErrorDetails{
Code: code,
Stacktrace: formatErrorDetails(err),
}
jsonData, jsonErr := json.Marshal(details)
if jsonErr != nil {
// Fallback if JSON encoding fails
return fmt.Sprintf(`{"code":%d,"stacktrace":"Failed to serialize error"}`, code)
}
return string(jsonData)
}
// parseErrorDetails extracts code and stacktrace from JSON Details field
// Returns (code, stacktrace). If parsing fails, returns (500, original details string)
func parseErrorDetails(details string) (int, string) {
@@ -155,7 +18,7 @@ func parseErrorDetails(details string) (int, string) {
return 500, ""
}
var errDetails ErrorDetails
var errDetails notify.ErrorDetails
err := json.Unmarshal([]byte(details), &errDetails)
if err != nil {
// Not JSON or malformed - treat as plain stacktrace with default code
@@ -168,6 +31,6 @@ func parseErrorDetails(details string) (int, string) {
func renderSafely(page templ.Component, s *hws.Server, r *http.Request, w http.ResponseWriter) {
err := page.Render(r.Context(), w)
if err != nil {
throwInternalServiceError(s, w, r, "Failed to render page", errors.Wrap(err, "page."))
throw.InternalServiceError(s, w, r, "Failed to render page", errors.Wrap(err, "page."))
}
}

View File

@@ -3,6 +3,7 @@ package handlers
import (
"net/http"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/view/page"
"git.haelnorr.com/h/golib/hws"
@@ -14,7 +15,7 @@ func Index(s *hws.Server) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
throwNotFound(s, w, r, r.URL.Path)
throw.NotFound(s, w, r, r.URL.Path)
}
renderSafely(page.Index(), s, r, w)
},

View File

@@ -0,0 +1,49 @@
package handlers
import (
"context"
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/validation"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// IsUnique creates a handler that checks field uniqueness
// Returns 200 OK if unique, 409 Conflict if not unique
func IsUnique(
s *hws.Server,
conn *bun.DB,
model any,
field string,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
getter, err := validation.ParseForm(r)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
value := getter.String(field).TrimSpace().Required().Value
if !getter.Validate() {
w.WriteHeader(http.StatusBadRequest)
return
}
unique := false
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
unique, err = db.IsUnique(ctx, tx, model, field, value)
if err != nil {
return false, errors.Wrap(err, "db.IsUnique")
}
return true, nil
}); !ok {
return
}
if unique {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusConflict)
}
})
}

View File

@@ -1,45 +0,0 @@
package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/store"
"github.com/uptrace/bun"
)
func IsUsernameUnique(
server *hws.Server,
conn *bun.DB,
cfg *config.Config,
store *store.Store,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("username")
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(server, w, r, "Database transaction failed", err)
return
}
defer tx.Rollback()
unique, err := db.IsUsernameUnique(ctx, tx, username)
if err != nil {
throwInternalServiceError(server, w, r, "Database query failed", err)
return
}
tx.Commit()
if !unique {
w.WriteHeader(http.StatusConflict)
} else {
w.WriteHeader(http.StatusOK)
}
},
)
}

View File

@@ -11,7 +11,9 @@ import (
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/pkg/oauth"
)
@@ -31,7 +33,7 @@ func Login(
if r.Method == "POST" {
if err != nil {
notifyServiceUnavailable(s, r, "Login currently unavailable", err)
notify.ServiceUnavailable(s, w, r, "Login currently unavailable", err)
w.WriteHeader(http.StatusOK)
return
}
@@ -40,46 +42,29 @@ func Login(
}
if err != nil {
throwServiceUnavailable(s, w, r, "Login currently unavailable", err)
throw.ServiceUnavailable(s, w, r, "Login currently unavailable", err)
return
}
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
if exceeded {
err := errors.Errorf(
"login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
)
err = track.Error(attempts)
st.ClearRedirectTrack(r, "/login")
throwError(
s,
w,
r,
http.StatusBadRequest,
"Login failed: Too many redirect attempts. Please clear your browser cookies and try again.",
err,
"warn",
)
throw.BadRequest(s, w, r, "Too many redirects. Please clear your browser cookies and try again", err)
return
}
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
if err != nil {
throwInternalServiceError(s, w, r, "Failed to generate state token", err)
throw.InternalServiceError(s, w, r, "Failed to generate state token", err)
return
}
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
link, err := discordAPI.GetOAuthLink(state)
if err != nil {
throwInternalServiceError(s, w, r, "An error occurred trying to generate the login link", err)
throw.InternalServiceError(s, w, r, "An error occurred trying to generate the login link", err)
return
}
st.ClearRedirectTrack(r, "/login")

View File

@@ -3,12 +3,12 @@ package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/throw"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
@@ -21,42 +21,31 @@ func Logout(
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
return
}
defer func() { _ = tx.Rollback() }()
user := db.CurrentUser(r.Context())
if user == nil {
// JIC - should be impossible to get here if route is protected by LoginReq
w.Header().Set("HX-Redirect", "/")
return
}
token, err := user.DeleteDiscordTokens(ctx, tx)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
return
}
if token != nil {
err = discordAPI.RevokeToken(token.Convert())
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
token, err := user.DeleteDiscordTokens(ctx, tx)
if err != nil {
throwInternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
return
return false, errors.Wrap(err, "user.DeleteDiscordTokens")
}
}
err = auth.Logout(tx, w, r)
if err != nil {
throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout"))
return
}
err = tx.Commit()
if err != nil {
throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "tx.Commit"))
if token != nil {
err = discordAPI.RevokeToken(token.Convert())
if err != nil {
throw.InternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
return false, nil
}
}
err = auth.Logout(tx, w, r)
if err != nil {
throw.InternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout"))
return false, nil
}
return true, nil
}); !ok {
return
}
w.Header().Set("HX-Redirect", "/")

View File

@@ -4,12 +4,13 @@ import (
"context"
"fmt"
"net/http"
"strings"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/validation"
"git.haelnorr.com/h/oslstats/internal/view/page"
"git.haelnorr.com/h/timefmt"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
@@ -31,248 +32,62 @@ func NewSeasonSubmit(
conn *bun.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Parse form data
err := r.ParseForm()
if err != nil {
err = notifyWarn(s, r, "Invalid Form", "Please check your input and try again.", nil)
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
getter, ok := validation.ParseFormOrNotify(s, w, r)
if !ok {
return
}
name := getter.String("name").
TrimSpace().Required().
MaxLength(20).MinLength(5).Value
shortName := getter.String("short_name").
TrimSpace().ToUpper().Required().
MaxLength(6).MinLength(2).Value
format := timefmt.NewBuilder().
DayNumeric2().Slash().
MonthNumeric2().Slash().
Year4().Build()
startDate := getter.Time("start_date", format).Required().Value
if !getter.ValidateAndNotify(s, w, r) {
return
}
// Get form values
name := strings.TrimSpace(r.FormValue("name"))
shortName := strings.TrimSpace(strings.ToUpper(r.FormValue("short_name")))
startDateStr := r.FormValue("start_date")
// Validate required fields
if name == "" || shortName == "" || startDateStr == "" {
err = notifyWarn(s, r, "Missing Fields", "All fields are required.", nil)
nameUnique := false
shortNameUnique := false
var season *db.Season
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
nameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "name", name)
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
return false, errors.Wrap(err, "db.IsSeasonNameUnique")
}
return
}
// Validate field lengths
if len(name) > 20 {
err = notifyWarn(s, r, "Invalid Name", "Season name must be 20 characters or less.", nil)
shortNameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "short_name", shortName)
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
return false, errors.Wrap(err, "db.IsSeasonShortNameUnique")
}
return
}
if len(shortName) > 6 {
err = notifyWarn(s, r, "Invalid Short Name", "Short name must be 6 characters or less.", nil)
if !nameUnique || !shortNameUnique {
return true, nil
}
season, err = db.NewSeason(ctx, tx, name, shortName, startDate)
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
// Validate short name is alphanumeric only
if !isAlphanumeric(shortName) {
err = notifyWarn(s, r, "Invalid Short Name", "Short name must contain only letters and numbers.", nil)
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
// Parse start date (DD/MM/YYYY format)
startDate, err := time.Parse("02/01/2006", startDateStr)
if err != nil {
err = notifyWarn(s, r, "Invalid Date", "Please provide a valid start date in DD/MM/YYYY format.", nil)
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
// Begin database transaction
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
defer tx.Rollback()
// Double-check uniqueness (race condition protection)
nameUnique, err := db.IsSeasonNameUnique(ctx, tx, name)
if err != nil {
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonNameUnique"))
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
return false, errors.Wrap(err, "db.NewSeason")
}
return true, nil
}); !ok {
return
}
if !nameUnique {
err = notifyWarn(s, r, "Duplicate Name", "This season name is already taken.", nil)
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
shortNameUnique, err := db.IsSeasonShortNameUnique(ctx, tx, shortName)
if err != nil {
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonShortNameUnique"))
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
notify.Warn(s, w, r, "Duplicate Name", "This season name is already taken.", nil)
return
}
if !shortNameUnique {
err = notifyWarn(s, r, "Duplicate Short Name", "This short name is already taken.", nil)
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
notify.Warn(s, w, r, "Duplicate Short Name", "This short name is already taken.", nil)
return
}
// Create the season
season, err := db.NewSeason(ctx, tx, name, shortName, startDate)
if err != nil {
err = notifyInternalServiceError(s, r, "Failed to create season", errors.Wrap(err, "db.NewSeason"))
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
// Commit transaction
err = tx.Commit()
if err != nil {
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "tx.Commit"))
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
// Send success notification
err = notifySuccess(s, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil)
if err != nil {
// Log but don't fail the request
s.LogError(hws.HWSError{
StatusCode: http.StatusInternalServerError,
Message: "Failed to send success notification",
Error: err,
})
}
// Redirect to the season detail page
notify.Success(s, w, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil)
w.Header().Set("HX-Redirect", fmt.Sprintf("/seasons/%s", season.ShortName))
w.WriteHeader(http.StatusOK)
})
}
// Helper function to validate alphanumeric strings
func isAlphanumeric(s string) bool {
for _, r := range s {
if ((r < 'A') || (r > 'Z')) && ((r < '0') || (r > '9')) {
return false
}
}
return true
}
func IsSeasonNameUnique(
s *hws.Server,
conn *bun.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
defer tx.Rollback()
// Trim whitespace for consistency
name := strings.TrimSpace(r.FormValue("name"))
unique, err := db.IsSeasonNameUnique(ctx, tx, name)
if err != nil {
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonNameUnique"))
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
tx.Commit()
if !unique {
w.WriteHeader(http.StatusConflict)
return
}
w.WriteHeader(http.StatusOK)
})
}
func IsSeasonShortNameUnique(
s *hws.Server,
conn *bun.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
defer tx.Rollback()
// Get short name and convert to uppercase for consistency
shortname := strings.ToUpper(strings.TrimSpace(r.FormValue("short_name")))
unique, err := db.IsSeasonShortNameUnique(ctx, tx, shortname)
if err != nil {
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonShortNameUnique"))
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
}
return
}
tx.Commit()
if !unique {
w.WriteHeader(http.StatusConflict)
return
}
w.WriteHeader(http.StatusOK)
})
}

View File

@@ -1,55 +0,0 @@
package handlers
import (
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/notify"
"github.com/pkg/errors"
)
func notifyClient(
s *hws.Server,
r *http.Request,
level notify.Level,
title, message, details string,
action any,
) error {
subCookie, err := r.Cookie("ws_sub_id")
if err != nil {
return errors.Wrap(err, "r.Cookie")
}
subID := notify.Target(subCookie.Value)
nt := notify.Notification{
Target: subID,
Title: title,
Message: message,
Details: details,
Action: action,
Level: level,
}
s.NotifySub(nt)
return nil
}
func notifyInternalServiceError(s *hws.Server, r *http.Request, msg string, err error) error {
return notifyClient(s, r, notify.LevelError, "Internal Service Error", msg,
SerializeErrorDetails(http.StatusInternalServerError, err), nil)
}
func notifyServiceUnavailable(s *hws.Server, r *http.Request, msg string, err error) error {
return notifyClient(s, r, notify.LevelError, "Service Unavailable", msg,
SerializeErrorDetails(http.StatusServiceUnavailable, err), nil)
}
func notifyWarn(s *hws.Server, r *http.Request, title, msg string, action any) error {
return notifyClient(s, r, notify.LevelWarn, title, msg, "", action)
}
func notifyInfo(s *hws.Server, r *http.Request, title, msg string, action any) error {
return notifyClient(s, r, notify.LevelInfo, title, msg, "", action)
}
func notifySuccess(s *hws.Server, r *http.Request, title, msg string, action any) error {
return notifyClient(s, r, notify.LevelSuccess, title, msg, "", action)
}

View File

@@ -10,6 +10,7 @@ import (
"git.haelnorr.com/h/golib/notify"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/view/component/popup"
"github.com/coder/websocket"
@@ -22,7 +23,7 @@ func NotificationWS(
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") != "websocket" {
throwNotFound(s, w, r, r.URL.Path)
throw.NotFound(s, w, r, r.URL.Path)
return
}
nc, err := setupClient(s, w, r)

View File

@@ -6,6 +6,7 @@ import (
"github.com/pkg/errors"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/view/page"
)
@@ -16,12 +17,6 @@ func NotifyTester(s *hws.Server) http.Handler {
func(w http.ResponseWriter, r *http.Request) {
testErr := errors.New("This is a stack trace. No really i swear. Just pretend ok? Thanks")
if r.Method == "GET" {
// page, _ := ErrorPage(hws.HWSError{
// StatusCode: http.StatusTeapot,
// Message: "This error has been rendered as a test",
// Error: testErr,
// })
// page.Render(r.Context(), w)
renderSafely(page.Test(), s, r, w)
} else {
_ = r.ParseForm()
@@ -30,19 +25,15 @@ func NotifyTester(s *hws.Server) http.Handler {
level := r.Form.Get("type")
message := r.Form.Get("message")
var err error
switch level {
case "success":
err = notifySuccess(s, r, title, message, nil)
notify.Success(s, w, r, title, message, nil)
case "info":
err = notifyInfo(s, r, title, message, nil)
notify.Info(s, w, r, title, message, nil)
case "warn":
err = notifyWarn(s, r, title, message, nil)
notify.Warn(s, w, r, title, message, nil)
case "error":
err = notifyInternalServiceError(s, r, message, testErr)
}
if err != nil {
throwInternalServiceError(s, w, r, "Error notifying client", err)
notify.InternalServiceError(s, w, r, message, testErr)
}
}
},

View File

@@ -1,68 +0,0 @@
package handlers
import (
"net/http"
"strconv"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func pageOptsFromForm(r *http.Request) (*db.PageOpts, error) {
var pageNum, perPage int
var order bun.Order
var orderBy string
var err error
if pageStr := r.FormValue("page"); pageStr != "" {
pageNum, err = strconv.Atoi(pageStr)
if err != nil {
return nil, errors.Wrap(err, "invalid page number")
}
}
if perPageStr := r.FormValue("per_page"); perPageStr != "" {
perPage, err = strconv.Atoi(perPageStr)
if err != nil {
return nil, errors.Wrap(err, "invalid per_page number")
}
}
order = bun.Order(r.FormValue("order"))
orderBy = r.FormValue("order_by")
pageOpts := &db.PageOpts{
Page: pageNum,
PerPage: perPage,
Order: order,
OrderBy: orderBy,
}
return pageOpts, nil
}
func pageOptsFromQuery(r *http.Request) (*db.PageOpts, error) {
var pageNum, perPage int
var order bun.Order
var orderBy string
var err error
if pageStr := r.URL.Query().Get("page"); pageStr != "" {
pageNum, err = strconv.Atoi(pageStr)
if err != nil {
return nil, errors.Wrap(err, "invalid page number")
}
}
if perPageStr := r.URL.Query().Get("per_page"); perPageStr != "" {
perPage, err = strconv.Atoi(perPageStr)
if err != nil {
return nil, errors.Wrap(err, "invalid per_page number")
}
}
order = bun.Order(r.URL.Query().Get("order"))
orderBy = r.URL.Query().Get("order_by")
pageOpts := &db.PageOpts{
Page: pageNum,
PerPage: perPage,
Order: order,
OrderBy: orderBy,
}
return pageOpts, nil
}

View File

@@ -0,0 +1,45 @@
package handlers
import (
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/validation"
"github.com/uptrace/bun"
)
// pageOptsFromForm calls r.ParseForm and gets the pageOpts from the formdata.
// It renders a Bad Request error page on fail
// PageOpts will be nil on fail
func pageOptsFromForm(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
getter, ok := validation.ParseFormOrError(s, w, r)
if !ok {
return nil
}
return getPageOpts(s, w, r, getter)
}
// pageOptsFromQuery gets the pageOpts from the request query and renders a Bad Request error page on fail
// PageOpts will be nil on fail
func pageOptsFromQuery(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
return getPageOpts(s, w, r, validation.NewQueryGetter(r))
}
func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *db.PageOpts {
page := g.Int("page").Min(1).Value
perPage := g.Int("per_page").Min(1).Max(100).Value
order := g.String("order").TrimSpace().ToUpper().AllowedValues([]string{"ASC", "DESC"}).Value
orderBy := g.String("order_by").TrimSpace().ToLower().Value
valid := g.ValidateAndError(s, w, r)
if !valid {
return nil
}
pageOpts := &db.PageOpts{
Page: page,
PerPage: perPage,
Order: bun.Order(order),
OrderBy: orderBy,
}
return pageOpts
}

View File

@@ -7,6 +7,7 @@ import (
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/roles"
"git.haelnorr.com/h/oslstats/internal/throw"
"github.com/uptrace/bun"
)
@@ -17,7 +18,7 @@ func PermTester(s *hws.Server, conn *bun.DB) http.Handler {
isAdmin, err := user.HasRole(r.Context(), tx, roles.Admin)
tx.Rollback()
if err != nil {
throwInternalServiceError(s, w, r, "Error", err)
throw.InternalServiceError(s, w, r, "Error", err)
}
_, _ = w.Write([]byte(strconv.FormatBool(isAdmin)))
})

View File

@@ -3,7 +3,6 @@ package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws"
@@ -14,6 +13,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/view/page"
)
@@ -29,27 +29,9 @@ func Register(
attempts, exceeded, track := store.TrackRedirect(r, "/register", 3)
if exceeded {
err := errors.Errorf(
"registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t",
attempts,
track.IP,
track.UserAgent,
track.Path,
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
cfg.HWSAuth.SSL,
)
err := track.Error(attempts)
store.ClearRedirectTrack(r, "/register")
throwError(
s,
w,
r,
http.StatusBadRequest,
"Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.",
err,
"warn",
)
throw.BadRequest(s, w, r, "Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again", err)
return
}
@@ -65,68 +47,45 @@ func Register(
}
store.ClearRedirectTrack(r, "/register")
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "Database transaction failed", err)
return
}
defer tx.Rollback()
method := r.Method
if method == "GET" {
tx.Commit()
if r.Method == "GET" {
renderSafely(page.Register(details.DiscordUser.Username), s, r, w)
return
}
if method == "POST" {
username := r.FormValue("username")
user, err := registerUser(ctx, tx, username, details)
username := r.FormValue("username")
unique := false
var user *db.User
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username)
if err != nil {
err = notifyInternalServiceError(s, r, "Registration failed", err)
if err != nil {
throwInternalServiceError(s, w, r, "Registration failed", err)
}
return false, errors.Wrap(err, "db.IsUsernameUnique")
}
if !unique {
return true, nil
}
user, err = db.CreateUser(ctx, tx, username, details.DiscordUser)
if err != nil {
return false, errors.Wrap(err, "db.CreateUser")
}
err = user.UpdateDiscordToken(ctx, tx, details.Token)
if err != nil {
return false, errors.Wrap(err, "db.UpdateDiscordToken")
}
return true, nil
}); !ok {
return
}
if !unique {
w.WriteHeader(http.StatusConflict)
} else {
err = auth.Login(w, r, user, true)
if err != nil {
throw.InternalServiceError(s, w, r, "Login failed", err)
return
}
tx.Commit()
if user == nil {
w.WriteHeader(http.StatusConflict)
} else {
err = auth.Login(w, r, user, true)
if err != nil {
throwInternalServiceError(s, w, r, "Login failed", err)
return
}
pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom)
}
return
pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom)
}
},
)
}
func registerUser(
ctx context.Context,
tx bun.Tx,
username string,
details *store.RegistrationSession,
) (*db.User, error) {
unique, err := db.IsUsernameUnique(ctx, tx, username)
if err != nil {
return nil, errors.Wrap(err, "db.IsUsernameUnique")
}
if !unique {
return nil, nil
}
user, err := db.CreateUser(ctx, tx, username, details.DiscordUser)
if err != nil {
return nil, errors.Wrap(err, "db.CreateUser")
}
err = user.UpdateDiscordToken(ctx, tx, details.Token)
if err != nil {
return nil, errors.Wrap(err, "db.UpdateDiscordToken")
}
return user, nil
}

View File

@@ -3,10 +3,10 @@ package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/view/page"
"github.com/pkg/errors"
"github.com/uptrace/bun"
@@ -17,23 +17,20 @@ func SeasonPage(
conn *bun.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
return
}
defer tx.Rollback()
seasonStr := r.PathValue("season_short_name")
season, err := db.GetSeason(ctx, tx, seasonStr)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetSeason"))
var season *db.Season
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
season, err = db.GetSeason(ctx, tx, seasonStr)
if err != nil {
return false, errors.Wrap(err, "db.GetSeason")
}
return true, nil
}); !ok {
return
}
tx.Commit()
if season == nil {
throwNotFound(s, w, r, r.URL.Path)
throw.NotFound(s, w, r, r.URL.Path)
return
}
renderSafely(page.SeasonPage(season), s, r, w)

View File

@@ -3,7 +3,6 @@ package handlers
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/db"
@@ -17,25 +16,21 @@ func SeasonsPage(
conn *bun.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
pageOpts := pageOptsFromQuery(s, w, r)
if pageOpts == nil {
return
}
defer tx.Rollback()
pageOpts, err := pageOptsFromQuery(r)
if err != nil {
throwBadRequest(s, w, r, "invalid query", err)
var seasons *db.SeasonList
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.ListSeasons")
}
return true, nil
}); !ok {
return
}
seasons, err := db.ListSeasons(ctx, tx, pageOpts)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.ListSeasons"))
return
}
tx.Commit()
renderSafely(page.SeasonsPage(seasons), s, r, w)
})
}
@@ -45,37 +40,21 @@ func SeasonsList(
conn *bun.DB,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
// Parse form values
if err := r.ParseForm(); err != nil {
throwBadRequest(s, w, r, "Invalid form data", err)
pageOpts := pageOptsFromForm(s, w, r)
if pageOpts == nil {
return
}
pageOpts, err := pageOptsFromForm(r)
if err != nil {
throwBadRequest(s, w, r, "invalid form data", err)
var seasons *db.SeasonList
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
if err != nil {
return false, errors.Wrap(err, "db.ListSeasons")
}
return true, nil
}); !ok {
return
}
// Database query
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
return
}
defer tx.Rollback()
seasons, err := db.ListSeasons(ctx, tx, pageOpts)
if err != nil {
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.ListSeasons"))
return
}
tx.Commit()
// Return only the list component (hx-push-url handles URL update client-side)
renderSafely(page.SeasonsList(seasons), s, r, w)
})
}

View File

@@ -6,6 +6,7 @@ import (
"strings"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/throw"
)
// StaticFS handles requests for static files, without allowing access to the
@@ -16,7 +17,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler {
if err != nil {
// If we can't create the file server, return a handler that always errors
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err)
throw.InternalServiceError(server, w, r, "An error occurred trying to load the file system", err)
})
}

63
internal/notify/notify.go Normal file
View File

@@ -0,0 +1,63 @@
// Package notify provides utility functions for sending notifications to clients
package notify
import (
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/notify"
"git.haelnorr.com/h/oslstats/internal/throw"
"github.com/pkg/errors"
)
func notifyClient(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
level notify.Level,
title, message, details string,
action any,
) {
subCookie, err := r.Cookie("ws_sub_id")
if err != nil {
throw.InternalServiceError(s, w, r, "Notification failed. Do you have cookies enabled?", errors.Wrap(err, "r.Cookie"))
return
}
subID := notify.Target(subCookie.Value)
nt := notify.Notification{
Target: subID,
Title: title,
Message: message,
Details: details,
Action: action,
Level: level,
}
s.NotifySub(nt)
}
// InternalServiceError notifies with error level
func InternalServiceError(s *hws.Server, w http.ResponseWriter, r *http.Request, msg string, err error) {
notifyClient(s, w, r, notify.LevelError, "Internal Service Error", msg,
SerializeErrorDetails(http.StatusInternalServerError, err), nil)
}
// ServiceUnavailable notifies with error level
func ServiceUnavailable(s *hws.Server, w http.ResponseWriter, r *http.Request, msg string, err error) {
notifyClient(s, w, r, notify.LevelError, "Service Unavailable", msg,
SerializeErrorDetails(http.StatusServiceUnavailable, err), nil)
}
// Warn notifies with warn level
func Warn(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) {
notifyClient(s, w, r, notify.LevelWarn, title, msg, "", action)
}
// Info notifies with info level
func Info(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) {
notifyClient(s, w, r, notify.LevelInfo, title, msg, "", action)
}
// Success notifies with success level
func Success(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) {
notifyClient(s, w, r, notify.LevelSuccess, title, msg, "", action)
}

53
internal/notify/util.go Normal file
View File

@@ -0,0 +1,53 @@
package notify
import (
"encoding/json"
"fmt"
)
// ErrorDetails contains structured error information for WebSocket error modals
type ErrorDetails struct {
Code int `json:"code"`
Stacktrace string `json:"stacktrace"`
}
// SerializeErrorDetails creates a JSON string with code and stacktrace
// This is exported so it can be used when creating error notifications
func SerializeErrorDetails(code int, err error) string {
details := ErrorDetails{
Code: code,
Stacktrace: FormatErrorDetails(err),
}
jsonData, jsonErr := json.Marshal(details)
if jsonErr != nil {
// Fallback if JSON encoding fails
return fmt.Sprintf(`{"code":%d,"stacktrace":"Failed to serialize error"}`, code)
}
return string(jsonData)
}
// ParseErrorDetails extracts code and stacktrace from JSON Details field
// Returns (code, stacktrace). If parsing fails, returns (500, original details string)
func ParseErrorDetails(details string) (int, string) {
if details == "" {
return 500, ""
}
var errDetails ErrorDetails
err := json.Unmarshal([]byte(details), &errDetails)
if err != nil {
// Not JSON or malformed - treat as plain stacktrace with default code
return 500, details
}
return errDetails.Code, errDetails.Stacktrace
}
// FormatErrorDetails extracts and formats error details from wrapped errors
func FormatErrorDetails(err error) string {
if err == nil {
return ""
}
// Use %+v format to get stack trace from github.com/pkg/errors
return fmt.Sprintf("%+v", err)
}

View File

@@ -3,7 +3,6 @@ package rbac
import (
"context"
"net/http"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/contexts"
@@ -11,6 +10,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
// LoadPermissionsMiddleware loads user permissions into context after authentication
@@ -26,46 +26,28 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
return
}
// Start transaction for loading permissions
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
tx, err := c.conn.BeginTx(ctx, nil)
if err != nil {
// Log but don't block - permission checks will fail gracefully
var roles_ []*db.Role
var perms []*db.Permission
if err := db.WithTxFailSilently(r.Context(), c.conn, func(ctx context.Context, tx bun.Tx) error {
var err error
roles_, err = user.GetRoles(ctx, tx)
if err != nil {
return errors.Wrap(err, "user.GetRoles")
}
perms, err = user.GetPermissions(ctx, tx)
if err != nil {
return errors.Wrap(err, "user.GetPermissions")
}
return nil
}); err != nil {
c.s.LogError(hws.HWSError{
Message: "Failed to start database transaction",
Error: errors.Wrap(err, "c.conn.BeginTx"),
Message: "Database error",
Error: err,
Level: hws.ErrorERROR,
})
next.ServeHTTP(w, r)
return
}
defer func() { _ = tx.Rollback() }()
// Load user's roles_ and permissions
roles_, err := user.GetRoles(ctx, tx)
if err != nil {
c.s.LogError(hws.HWSError{
Message: "Failed to get user roles",
Error: errors.Wrap(err, "user.GetRoles"),
Level: hws.ErrorERROR,
})
next.ServeHTTP(w, r)
return
}
perms, err := user.GetPermissions(ctx, tx)
if err != nil {
c.s.LogError(hws.HWSError{
Message: "Failed to get user permissions",
Error: errors.Wrap(err, "user.GetPermissions"),
Level: hws.ErrorERROR,
})
next.ServeHTTP(w, r)
return
}
_ = tx.Commit() // read only transaction
// Build permission cache
cache := &contexts.PermissionCache{
@@ -88,7 +70,7 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
}
// Add cache to context (type-safe)
ctx = context.WithValue(ctx, contexts.PermissionCacheKey, cache)
ctx := context.WithValue(r.Context(), contexts.PermissionCacheKey, cache)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@@ -5,6 +5,8 @@ import (
"net/http"
"strings"
"time"
"github.com/pkg/errors"
)
// getClientIP extracts the client IP address, checking X-Forwarded-For first
@@ -93,3 +95,14 @@ func (s *Store) ClearRedirectTrack(r *http.Request, path string) {
key := redirectKey(ip, userAgent, path)
s.redirectTracks.Delete(key)
}
func (t *RedirectTrack) Error(attempts int) error {
return errors.Errorf(
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
attempts,
t.IP,
t.UserAgent,
t.Path,
t.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
)
}

118
internal/throw/throw.go Normal file
View File

@@ -0,0 +1,118 @@
// Package throw provides utility functions for throwing HTTP errors that render an error page
package throw
import (
"fmt"
"net/http"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
)
// throwError is a generic helper that all throw* functions use internally
func throwError(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
statusCode int,
msg string,
err error,
level hws.ErrorLevel,
) {
s.ThrowError(w, r, hws.HWSError{
StatusCode: statusCode,
Message: msg,
Error: err,
Level: level,
RenderErrorPage: true, // throw* family always renders error pages
})
}
// InternalServiceError handles 500 errors (server failures)
func InternalServiceError(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusInternalServerError, msg, err, hws.ErrorERROR)
}
// ServiceUnavailable handles 503 errors
func ServiceUnavailable(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusServiceUnavailable, msg, err, hws.ErrorERROR)
}
// BadRequest handles 400 errors (malformed requests)
func BadRequest(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusBadRequest, msg, err, hws.ErrorDEBUG)
}
// Forbidden handles 403 errors (normal permission denials)
func Forbidden(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorDEBUG)
}
// ForbiddenSecurity handles 403 errors for security events (uses WARN level)
func ForbiddenSecurity(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorWARN)
}
// Unauthorized handles 401 errors (not authenticated)
func Unauthorized(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorDEBUG)
}
// UnauthorizedSecurity handles 401 errors for security events (uses WARN level)
func UnauthorizedSecurity(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorWARN)
}
// NotFound handles 404 errors
func NotFound(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
path string,
) {
msg := fmt.Sprintf("The requested resource was not found: %s", path)
err := errors.New("Resource not found")
throwError(s, w, r, http.StatusNotFound, msg, err, hws.ErrorDEBUG)
}

View File

@@ -0,0 +1,101 @@
package validation
import (
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/timefmt"
"github.com/pkg/errors"
)
// FormGetter wraps http.Request to get form values
type FormGetter struct {
r *http.Request
checks []*ValidationRule
}
func NewFormGetter(r *http.Request) *FormGetter {
return &FormGetter{r: r, checks: []*ValidationRule{}}
}
func (f *FormGetter) Get(key string) string {
return f.r.FormValue(key)
}
func (f *FormGetter) getChecks() []*ValidationRule {
return f.checks
}
func (f *FormGetter) AddCheck(check *ValidationRule) {
f.checks = append(f.checks, check)
}
func (f *FormGetter) ValidateChecks() []*ValidationRule {
return validate(f)
}
func (f *FormGetter) String(key string) *StringField {
return newStringField(key, f)
}
func (f *FormGetter) Int(key string) *IntField {
return newIntField(key, f)
}
func (f *FormGetter) Time(key string, format *timefmt.Format) *TimeField {
return newTimeField(key, format, f)
}
func ParseForm(r *http.Request) (*FormGetter, error) {
err := r.ParseForm()
if err != nil {
return nil, errors.Wrap(err, "r.ParseForm")
}
return NewFormGetter(r), nil
}
// ParseFormOrNotify attempts to parse the form data and notifies the user on fail
func ParseFormOrNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) {
getter, err := ParseForm(r)
if err != nil {
notify.Warn(s, w, r, "Invalid Form", "Please check your input and try again.", nil)
return nil, false
}
return getter, true
}
// ParseFormOrError attempts to parse the form data and renders an error page on fail
func ParseFormOrError(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) {
getter, err := ParseForm(r)
if err != nil {
throw.BadRequest(s, w, r, "Invalid form data", err)
return nil, false
}
return getter, true
}
func (f *FormGetter) Validate() bool {
return len(validate(f)) == 0
}
// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check
// Returns true if all checks passed
func (f *FormGetter) ValidateAndNotify(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
) bool {
return validateAndNotify(s, w, r, f)
}
// ValidateAndError runs the provided validation checks and renders an error page with all the error messages
// Returns true if all checks passed
func (f *FormGetter) ValidateAndError(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
) bool {
return validateAndError(s, w, r, f)
}

View File

@@ -0,0 +1,71 @@
package validation
import (
"fmt"
"strconv"
)
type IntField struct {
Field
Value int
}
func newIntField(key string, g Getter) *IntField {
raw := g.Get(key)
var val int
if raw != "" {
var err error
val, err = strconv.Atoi(raw)
if err != nil {
g.AddCheck(newFailedCheck(
"Value is not a number",
fmt.Sprintf("%s must be an integer: %s provided", key, raw),
))
}
}
return &IntField{
Value: val,
Field: newField(key, g),
}
}
// Required enforces a non-zero value
func (i *IntField) Required() *IntField {
if i.Value == 0 {
i.getter.AddCheck(newFailedCheck(
"Value cannot be 0",
fmt.Sprintf("%s is required", i.Key),
))
}
return i
}
// Optional will skip all validations if value is empty
func (i *IntField) Optional() *IntField {
if i.Value == 0 {
i.optional = true
}
return i
}
// Max enforces a maxmium value
func (i *IntField) Max(max int) *IntField {
if i.Value > max && !i.optional {
i.getter.AddCheck(newFailedCheck(
"Value too large",
fmt.Sprintf("%s is too large, max %v", i.Key, max),
))
}
return i
}
// Min enforces a minimum value
func (i *IntField) Min(min int) *IntField {
if i.Value < min && !i.optional {
i.getter.AddCheck(newFailedCheck(
"Value too small",
fmt.Sprintf("%s is too small, min %v", i.Key, min),
))
}
return i
}

View File

@@ -0,0 +1,70 @@
package validation
import (
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/timefmt"
)
// QueryGetter wraps http.Request to get query values
type QueryGetter struct {
r *http.Request
checks []*ValidationRule
}
func NewQueryGetter(r *http.Request) *QueryGetter {
return &QueryGetter{r: r, checks: []*ValidationRule{}}
}
func (q *QueryGetter) Get(key string) string {
return q.r.URL.Query().Get(key)
}
func (q *QueryGetter) getChecks() []*ValidationRule {
return q.checks
}
func (q *QueryGetter) AddCheck(check *ValidationRule) {
q.checks = append(q.checks, check)
}
func (q *QueryGetter) ValidateChecks() []*ValidationRule {
return validate(q)
}
func (q *QueryGetter) String(key string) *StringField {
return newStringField(key, q)
}
func (q *QueryGetter) Int(key string) *IntField {
return newIntField(key, q)
}
func (q *QueryGetter) Time(key string, format *timefmt.Format) *TimeField {
return newTimeField(key, format, q)
}
func (q *QueryGetter) Validate() bool {
return len(validate(q)) == 0
}
// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check
// Returns true if all checks passed
func (q *QueryGetter) ValidateAndNotify(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
) bool {
return validateAndNotify(s, w, r, q)
}
// ValidateAndError runs the provided validation checks and renders an error page with all the error messages
// Returns true if all checks passed
func (q *QueryGetter) ValidateAndError(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
) bool {
return validateAndError(s, w, r, q)
}

View File

@@ -0,0 +1,106 @@
package validation
import (
"fmt"
"slices"
"strings"
"unicode"
)
type StringField struct {
Field
Value string
}
func newStringField(key string, g Getter) *StringField {
return &StringField{
Value: g.Get(key),
Field: newField(key, g),
}
}
// Required enforces a non empty string
func (s *StringField) Required() *StringField {
if s.Value == "" {
s.getter.AddCheck(newFailedCheck(
"Field not provided",
fmt.Sprintf("%s is required", s.Key),
))
}
return s
}
// Optional will skip all validations if value is empty
func (s *StringField) Optional() *StringField {
if s.Value == "" {
s.optional = true
}
return s
}
// MaxLength enforces a maximum string length
func (s *StringField) MaxLength(length int) *StringField {
if len(s.Value) > length && !s.optional {
s.getter.AddCheck(newFailedCheck(
"Input too long",
fmt.Sprintf("%s is too long, max %v chars", s.Key, length),
))
}
return s
}
// MinLength enforces a minimum string length
func (s *StringField) MinLength(length int) *StringField {
if len(s.Value) < length && !s.optional {
s.getter.AddCheck(newFailedCheck(
"Input too short",
fmt.Sprintf("%s is too short, min %v chars", s.Key, length),
))
}
return s
}
// AlphaNumeric enforces the string contains only letters and numbers
func (s *StringField) AlphaNumeric() *StringField {
if s.optional {
return s
}
for _, r := range s.Value {
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
s.getter.AddCheck(newFailedCheck(
"Invalid characters",
fmt.Sprintf("%s must contain only letters and numbers.", s.Key),
))
return s
}
}
return s
}
func (s *StringField) AllowedValues(allowed []string) *StringField {
if !slices.Contains(allowed, s.Value) && !s.optional {
s.getter.AddCheck(newFailedCheck(
"Value not allowed",
fmt.Sprintf("%s must be one of: %s", s.Key, strings.Join(allowed, ",")),
))
}
return s
}
// ToUpper transforms the string to uppercase
func (s *StringField) ToUpper() *StringField {
s.Value = strings.ToUpper(s.Value)
return s
}
// ToLower transforms the string to lowercase
func (s *StringField) ToLower() *StringField {
s.Value = strings.ToLower(s.Value)
return s
}
// TrimSpace removes leading and trailing whitespace
func (s *StringField) TrimSpace() *StringField {
s.Value = strings.TrimSpace(s.Value)
return s
}

View File

@@ -0,0 +1,50 @@
package validation
import (
"fmt"
"time"
"git.haelnorr.com/h/timefmt"
)
type TimeField struct {
Field
Value time.Time
}
func newTimeField(key string, format *timefmt.Format, g Getter) *TimeField {
raw := g.Get(key)
var startDate time.Time
if raw != "" {
var err error
startDate, err = format.Parse(raw)
if err != nil {
g.AddCheck(newFailedCheck(
"Invalid date/time format",
fmt.Sprintf("%s should be in format %s", key, format.LDML()),
))
}
}
return &TimeField{
Value: startDate,
Field: newField(key, g),
}
}
func (t *TimeField) Required() *TimeField {
if t.Value.IsZero() {
t.getter.AddCheck(newFailedCheck(
"Date/Time not provided",
fmt.Sprintf("%s must be provided", t.Key),
))
}
return t
}
// Optional will skip all validations if value is empty
func (t *TimeField) Optional() *TimeField {
if t.Value.IsZero() {
t.optional = true
}
return t
}

View File

@@ -0,0 +1,93 @@
// Package validation provides utilities for parsing and validating request data
package validation
import (
"errors"
"fmt"
"net/http"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/oslstats/internal/notify"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/timefmt"
)
type ValidationRule struct {
Condition bool // Condition on which to fail validation
Title string // Title for warning message
Message string // Warning message
}
// Getter abstracts getting values from either form or query
type Getter interface {
Get(key string) string
AddCheck(check *ValidationRule)
String(key string) *StringField
Int(key string) *IntField
Time(key string, format *timefmt.Format) *TimeField
ValidateChecks() []*ValidationRule
Validate() bool
ValidateAndNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) bool
ValidateAndError(s *hws.Server, w http.ResponseWriter, r *http.Request) bool
getChecks() []*ValidationRule
}
type Field struct {
Key string
optional bool
getter Getter
}
func newField(key string, g Getter) Field {
return Field{
Key: key,
getter: g,
}
}
func newFailedCheck(title, message string) *ValidationRule {
return &ValidationRule{
Condition: true,
Title: title,
Message: message,
}
}
func validate(g Getter) []*ValidationRule {
failed := []*ValidationRule{}
for _, check := range g.getChecks() {
if check != nil && check.Condition {
failed = append(failed, check)
}
}
return failed
}
func validateAndNotify(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
g Getter,
) bool {
failedChecks := g.ValidateChecks()
for _, check := range failedChecks {
notify.Warn(s, w, r, check.Title, check.Message, nil)
}
return len(failedChecks) == 0
}
func validateAndError(
s *hws.Server,
w http.ResponseWriter,
r *http.Request,
g Getter,
) bool {
failedChecks := g.ValidateChecks()
var err error
for _, check := range failedChecks {
err_ := fmt.Errorf("%s: %s", check.Title, check.Message)
err = errors.Join(err, err_)
}
throw.BadRequest(s, w, r, "Invalid form data", err)
return len(failedChecks) == 0
}