league #1
@@ -101,17 +101,17 @@ func addRoutes(
|
|||||||
{
|
{
|
||||||
Path: "/htmx/isusernameunique",
|
Path: "/htmx/isusernameunique",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: handlers.IsUsernameUnique(s, conn, cfg, store),
|
Handler: handlers.IsUnique(s, conn, (*db.User)(nil), "username"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/htmx/isseasonnameunique",
|
Path: "/htmx/isseasonnameunique",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: handlers.IsSeasonNameUnique(s, conn),
|
Handler: handlers.IsUnique(s, conn, (*db.Season)(nil), "name"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Path: "/htmx/isseasonshortnameunique",
|
Path: "/htmx/isseasonshortnameunique",
|
||||||
Method: hws.MethodPOST,
|
Method: hws.MethodPOST,
|
||||||
Handler: handlers.IsSeasonShortNameUnique(s, conn),
|
Handler: handlers.IsUnique(s, conn, (*db.Season)(nil), "short_name"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module git.haelnorr.com/h/oslstats
|
module git.haelnorr.com/h/oslstats
|
||||||
|
|
||||||
go 1.25.5
|
go 1.25.6
|
||||||
|
|
||||||
require (
|
require (
|
||||||
git.haelnorr.com/h/golib/env v0.9.1
|
git.haelnorr.com/h/golib/env v0.9.1
|
||||||
@@ -27,6 +27,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
git.haelnorr.com/h/golib/cookies v0.9.0
|
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||||
git.haelnorr.com/h/golib/jwt v0.10.1 // indirect
|
git.haelnorr.com/h/golib/jwt v0.10.1 // indirect
|
||||||
|
git.haelnorr.com/h/timefmt v0.1.0
|
||||||
github.com/bwmarrin/discordgo v0.29.0
|
github.com/bwmarrin/discordgo v0.29.0
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -16,6 +16,8 @@ git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7Kv
|
|||||||
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||||
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||||
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
|
git.haelnorr.com/h/golib/notify v0.1.0/go.mod h1:ARqaRmCYb8LMURhDM75sG+qX+YpqXmUVeAtacwjHjBc=
|
||||||
|
git.haelnorr.com/h/timefmt v0.1.0 h1:ULDkWEtFIV+FkkoV0q9n62Spj+HDdtFL9QeAdGIEp+o=
|
||||||
|
git.haelnorr.com/h/timefmt v0.1.0/go.mod h1:12gXXYLP4w9Fa9ZkbZWdvKV6RyZEzwAm9mN+WB3oXpw=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=
|
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=
|
||||||
|
|||||||
@@ -105,35 +105,33 @@ func (l *Logger) log(
|
|||||||
|
|
||||||
// GetRecentLogs retrieves recent audit logs with pagination
|
// GetRecentLogs retrieves recent audit logs with pagination
|
||||||
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
func (l *Logger) GetRecentLogs(ctx context.Context, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
||||||
tx, err := l.conn.BeginTx(ctx, nil)
|
var logs *db.AuditLogs
|
||||||
if err != nil {
|
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||||
return nil, errors.Wrap(err, "conn.BeginTx")
|
var err error
|
||||||
|
logs, err = db.GetAuditLogs(ctx, tx, pageOpts, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "db.GetAuditLogs")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "db.WithTxFailSilently")
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
|
||||||
|
|
||||||
logs, err := db.GetAuditLogs(ctx, tx, pageOpts, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = tx.Commit() // read only transaction
|
|
||||||
return logs, nil
|
return logs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLogsByUser retrieves audit logs for a specific user
|
// GetLogsByUser retrieves audit logs for a specific user
|
||||||
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
func (l *Logger) GetLogsByUser(ctx context.Context, userID int, pageOpts *db.PageOpts) (*db.AuditLogs, error) {
|
||||||
tx, err := l.conn.BeginTx(ctx, nil)
|
var logs *db.AuditLogs
|
||||||
if err != nil {
|
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||||
return nil, errors.Wrap(err, "conn.BeginTx")
|
var err error
|
||||||
|
logs, err = db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "db.GetAuditLogsByUser")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "db.WithTxFailSilently")
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
|
||||||
|
|
||||||
logs, err := db.GetAuditLogsByUser(ctx, tx, userID, pageOpts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = tx.Commit() // read only transaction
|
|
||||||
return logs, nil
|
return logs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,20 +143,16 @@ func (l *Logger) CleanupOldLogs(ctx context.Context, daysToKeep int) (int, error
|
|||||||
|
|
||||||
cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix()
|
cutoffTime := time.Now().AddDate(0, 0, -daysToKeep).Unix()
|
||||||
|
|
||||||
tx, err := l.conn.BeginTx(ctx, nil)
|
var count int
|
||||||
if err != nil {
|
if err := db.WithTxFailSilently(ctx, l.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||||
return 0, errors.Wrap(err, "conn.BeginTx")
|
var err error
|
||||||
}
|
count, err = db.CleanupOldAuditLogs(ctx, tx, cutoffTime)
|
||||||
defer func() { _ = tx.Rollback() }()
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "db.CleanupOldAuditLogs")
|
||||||
count, err := db.CleanupOldAuditLogs(ctx, tx, cutoffTime)
|
}
|
||||||
if err != nil {
|
return nil
|
||||||
return 0, err
|
}); err != nil {
|
||||||
}
|
return 0, errors.Wrap(err, "db.WithTxFailSilently")
|
||||||
|
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
return 0, errors.Wrap(err, "tx.Commit")
|
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|||||||
19
internal/db/isunique.go
Normal file
19
internal/db/isunique.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsUnique(ctx context.Context, tx bun.Tx, model any, field, value string) (bool, error) {
|
||||||
|
count, err := tx.NewSelect().
|
||||||
|
Model(model).
|
||||||
|
Where("? = ?", bun.Ident(field), value).
|
||||||
|
Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "tx.NewSelect")
|
||||||
|
}
|
||||||
|
return count == 0, nil
|
||||||
|
}
|
||||||
@@ -87,25 +87,3 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error
|
|||||||
}
|
}
|
||||||
return season, nil
|
return season, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsSeasonNameUnique(ctx context.Context, tx bun.Tx, name string) (bool, error) {
|
|
||||||
count, err := tx.NewSelect().
|
|
||||||
Model((*Season)(nil)).
|
|
||||||
Where("name = ?", name).
|
|
||||||
Count(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "tx.NewSelect")
|
|
||||||
}
|
|
||||||
return count == 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsSeasonShortNameUnique(ctx context.Context, tx bun.Tx, shortname string) (bool, error) {
|
|
||||||
count, err := tx.NewSelect().
|
|
||||||
Model((*Season)(nil)).
|
|
||||||
Where("short_name = ?", shortname).
|
|
||||||
Count(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "tx.NewSelect")
|
|
||||||
}
|
|
||||||
return count == 0, nil
|
|
||||||
}
|
|
||||||
|
|||||||
118
internal/db/txhelpers.go
Normal file
118
internal/db/txhelpers.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TxFunc is a function that runs within a database transaction
|
||||||
|
type (
|
||||||
|
TxFunc func(ctx context.Context, tx bun.Tx) (bool, error)
|
||||||
|
TxFuncSilent func(ctx context.Context, tx bun.Tx) error
|
||||||
|
)
|
||||||
|
|
||||||
|
var timeout = 15 * time.Second
|
||||||
|
|
||||||
|
// WithReadTx executes a read-only transaction with automatic rollback
|
||||||
|
// Returns true if successful, false if error was thrown to client
|
||||||
|
func WithReadTx(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
conn *bun.DB,
|
||||||
|
fn TxFunc,
|
||||||
|
) bool {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
ok, err := withTx(ctx, conn, fn, false)
|
||||||
|
if err != nil {
|
||||||
|
throw.InternalServiceError(s, w, r, "Database error", err)
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTxFailSilently executes a transaction with automatic rollback
|
||||||
|
// Returns true if successful, false if error occured.
|
||||||
|
// Does not throw any errors to the client.
|
||||||
|
func WithTxFailSilently(
|
||||||
|
ctx context.Context,
|
||||||
|
conn *bun.DB,
|
||||||
|
fn TxFuncSilent,
|
||||||
|
) error {
|
||||||
|
fnc := func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
|
err := fn(ctx, tx)
|
||||||
|
return err != nil, err
|
||||||
|
}
|
||||||
|
_, err := withTx(ctx, conn, fnc, true)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithWriteTx executes a write transaction with automatic rollback on error
|
||||||
|
// Commits only if fn returns nil. Returns true if successful.
|
||||||
|
func WithWriteTx(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
conn *bun.DB,
|
||||||
|
fn TxFunc,
|
||||||
|
) bool {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
ok, err := withTx(ctx, conn, fn, true)
|
||||||
|
if err != nil {
|
||||||
|
throw.InternalServiceError(s, w, r, "Database error", err)
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithNotifyTx executes a transaction with notification-based error handling
|
||||||
|
// Uses notifyInternalServiceError instead of throwInternalServiceError
|
||||||
|
func WithNotifyTx(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
conn *bun.DB,
|
||||||
|
fn TxFunc,
|
||||||
|
) bool {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
ok, err := withTx(ctx, conn, fn, true)
|
||||||
|
if err != nil {
|
||||||
|
notify.InternalServiceError(s, w, r, "Database error", err)
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTx executes a transaction with automatic rollback on error
|
||||||
|
func withTx(
|
||||||
|
ctx context.Context,
|
||||||
|
conn *bun.DB,
|
||||||
|
fn TxFunc,
|
||||||
|
write bool,
|
||||||
|
) (bool, error) {
|
||||||
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "conn.BeginTx")
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
ok, err := fn(ctx, tx)
|
||||||
|
if err != nil || !ok {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if write {
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "tx.Commit")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_ = tx.Commit()
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
@@ -112,19 +112,6 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUsernameUnique checks if the given username is unique (not already taken)
|
|
||||||
// Returns true if the username is available, false if it's taken
|
|
||||||
func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) {
|
|
||||||
count, err := tx.NewSelect().
|
|
||||||
Model((*User)(nil)).
|
|
||||||
Where("username = ?", username).
|
|
||||||
Count(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "tx.NewSelect")
|
|
||||||
}
|
|
||||||
return count == 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRoles loads all the roles for this user
|
// GetRoles loads all the roles for this user
|
||||||
func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
|
func (u *User) GetRoles(ctx context.Context, tx bun.Tx) ([]*Role, error) {
|
||||||
if u == nil {
|
if u == nil {
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
@@ -14,23 +13,17 @@ import (
|
|||||||
|
|
||||||
func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminDashboard(s *hws.Server, conn *bun.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
var users *db.Users
|
||||||
defer cancel()
|
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
|
var err error
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
users, err = db.GetUsers(ctx, tx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
return false, errors.Wrap(err, "db.GetUsers")
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
|
||||||
|
|
||||||
users, err := db.GetUsers(ctx, tx, nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetUsers"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = tx.Commit()
|
|
||||||
|
|
||||||
renderSafely(page.AdminDashboard(users), s, r, w)
|
renderSafely(page.AdminDashboard(users), s, r, w)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
@@ -15,30 +14,21 @@ import (
|
|||||||
// AdminUsersList shows all users
|
// AdminUsersList shows all users
|
||||||
func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler {
|
func AdminUsersList(s *hws.Server, conn *bun.DB) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
var users *db.Users
|
||||||
defer cancel()
|
pageOpts := pageOptsFromForm(s, w, r)
|
||||||
|
if pageOpts == nil {
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "DB Transaction failed", errors.Wrap(err, "conn.BeginTx"))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
|
var err error
|
||||||
// Get all users
|
users, err = db.GetUsers(ctx, tx, pageOpts)
|
||||||
pageOpts, err := pageOptsFromForm(r)
|
if err != nil {
|
||||||
if err != nil {
|
return false, errors.Wrap(err, "db.GetUsers")
|
||||||
throwBadRequest(s, w, r, "invalid form data", err)
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
users, err := db.GetUsers(ctx, tx, pageOpts)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Failed to load users", errors.Wrap(err, "db.GetUsers"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = tx.Commit()
|
|
||||||
|
|
||||||
renderSafely(admin.UserList(users), s, r, w)
|
renderSafely(admin.UserList(users), s, r, w)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/cookies"
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
@@ -15,11 +14,12 @@ import (
|
|||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
"git.haelnorr.com/h/oslstats/internal/store"
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Callback(
|
func Callback(
|
||||||
server *hws.Server,
|
s *hws.Server,
|
||||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
@@ -31,26 +31,9 @@ func Callback(
|
|||||||
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
|
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
|
||||||
|
|
||||||
if exceeded {
|
if exceeded {
|
||||||
err := errors.Errorf(
|
err := track.Error(attempts)
|
||||||
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
|
||||||
attempts,
|
|
||||||
track.IP,
|
|
||||||
track.UserAgent,
|
|
||||||
track.Path,
|
|
||||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
|
||||||
)
|
|
||||||
|
|
||||||
store.ClearRedirectTrack(r, "/callback")
|
store.ClearRedirectTrack(r, "/callback")
|
||||||
|
throw.BadRequest(s, w, r, "Too many redirects. Please try logging in again.", err)
|
||||||
throwError(
|
|
||||||
server,
|
|
||||||
w,
|
|
||||||
r,
|
|
||||||
http.StatusBadRequest,
|
|
||||||
"OAuth callback failed: Too many redirect attempts. Please try logging in again.",
|
|
||||||
err,
|
|
||||||
"warn",
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,12 +47,12 @@ func Callback(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if vsErr, ok := err.(*verifyStateError); ok {
|
if vsErr, ok := err.(*verifyStateError); ok {
|
||||||
if vsErr.IsCookieError() {
|
if vsErr.IsCookieError() {
|
||||||
throwUnauthorized(server, w, r, "OAuth session not found or expired", err)
|
throw.Unauthorized(s, w, r, "OAuth session not found or expired", err)
|
||||||
} else {
|
} else {
|
||||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -77,20 +60,17 @@ func Callback(
|
|||||||
|
|
||||||
switch data {
|
switch data {
|
||||||
case "login":
|
case "login":
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
var redirect func()
|
||||||
defer cancel()
|
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(server, w, r, "DB Transaction failed to start", err)
|
throw.InternalServiceError(s, w, r, "OAuth login failed", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
|
||||||
redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(server, w, r, "OAuth login failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
redirect()
|
redirect()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
|
|||||||
// Get technical details if applicable
|
// Get technical details if applicable
|
||||||
var details string
|
var details string
|
||||||
if showDetails && hwsError.Error != nil {
|
if showDetails && hwsError.Error != nil {
|
||||||
details = formatErrorDetails(hwsError.Error)
|
details = notify.FormatErrorDetails(hwsError.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Render appropriate template
|
// Render appropriate template
|
||||||
|
|||||||
@@ -2,152 +2,15 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
"github.com/a-h/templ"
|
"github.com/a-h/templ"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// throwError is a generic helper that all throw* functions use internally
|
|
||||||
func throwError(
|
|
||||||
s *hws.Server,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
statusCode int,
|
|
||||||
msg string,
|
|
||||||
err error,
|
|
||||||
level hws.ErrorLevel,
|
|
||||||
) {
|
|
||||||
s.ThrowError(w, r, hws.HWSError{
|
|
||||||
StatusCode: statusCode,
|
|
||||||
Message: msg,
|
|
||||||
Error: err,
|
|
||||||
Level: level,
|
|
||||||
RenderErrorPage: true, // throw* family always renders error pages
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// throwInternalServiceError handles 500 errors (server failures)
|
|
||||||
func throwInternalServiceError(
|
|
||||||
s *hws.Server,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
msg string,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
throwError(s, w, r, http.StatusInternalServerError, msg, err, hws.ErrorERROR)
|
|
||||||
}
|
|
||||||
|
|
||||||
// throwServiceUnavailable handles 503 errors
|
|
||||||
func throwServiceUnavailable(
|
|
||||||
s *hws.Server,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
msg string,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
throwError(s, w, r, http.StatusServiceUnavailable, msg, err, hws.ErrorERROR)
|
|
||||||
}
|
|
||||||
|
|
||||||
// throwBadRequest handles 400 errors (malformed requests)
|
|
||||||
func throwBadRequest(
|
|
||||||
s *hws.Server,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
msg string,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
throwError(s, w, r, http.StatusBadRequest, msg, err, hws.ErrorDEBUG)
|
|
||||||
}
|
|
||||||
|
|
||||||
// throwForbidden handles 403 errors (normal permission denials)
|
|
||||||
func throwForbidden(
|
|
||||||
s *hws.Server,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
msg string,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorDEBUG)
|
|
||||||
}
|
|
||||||
|
|
||||||
// throwForbiddenSecurity handles 403 errors for security events (uses WARN level)
|
|
||||||
func throwForbiddenSecurity(
|
|
||||||
s *hws.Server,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
msg string,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorWARN)
|
|
||||||
}
|
|
||||||
|
|
||||||
// throwUnauthorized handles 401 errors (not authenticated)
|
|
||||||
func throwUnauthorized(
|
|
||||||
s *hws.Server,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
msg string,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorDEBUG)
|
|
||||||
}
|
|
||||||
|
|
||||||
// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level)
|
|
||||||
func throwUnauthorizedSecurity(
|
|
||||||
s *hws.Server,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
msg string,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorWARN)
|
|
||||||
}
|
|
||||||
|
|
||||||
// throwNotFound handles 404 errors
|
|
||||||
func throwNotFound(
|
|
||||||
s *hws.Server,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
path string,
|
|
||||||
) {
|
|
||||||
msg := fmt.Sprintf("The requested resource was not found: %s", path)
|
|
||||||
err := errors.New("Resource not found")
|
|
||||||
throwError(s, w, r, http.StatusNotFound, msg, err, hws.ErrorDEBUG)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrorDetails contains structured error information for WebSocket error modals
|
|
||||||
type ErrorDetails struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Stacktrace string `json:"stacktrace"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatErrorDetails extracts and formats error details from wrapped errors
|
|
||||||
func formatErrorDetails(err error) string {
|
|
||||||
if err == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
// Use %+v format to get stack trace from github.com/pkg/errors
|
|
||||||
return fmt.Sprintf("%+v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SerializeErrorDetails creates a JSON string with code and stacktrace
|
|
||||||
// This is exported so it can be used when creating error notifications
|
|
||||||
func SerializeErrorDetails(code int, err error) string {
|
|
||||||
details := ErrorDetails{
|
|
||||||
Code: code,
|
|
||||||
Stacktrace: formatErrorDetails(err),
|
|
||||||
}
|
|
||||||
jsonData, jsonErr := json.Marshal(details)
|
|
||||||
if jsonErr != nil {
|
|
||||||
// Fallback if JSON encoding fails
|
|
||||||
return fmt.Sprintf(`{"code":%d,"stacktrace":"Failed to serialize error"}`, code)
|
|
||||||
}
|
|
||||||
return string(jsonData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseErrorDetails extracts code and stacktrace from JSON Details field
|
// parseErrorDetails extracts code and stacktrace from JSON Details field
|
||||||
// Returns (code, stacktrace). If parsing fails, returns (500, original details string)
|
// Returns (code, stacktrace). If parsing fails, returns (500, original details string)
|
||||||
func parseErrorDetails(details string) (int, string) {
|
func parseErrorDetails(details string) (int, string) {
|
||||||
@@ -155,7 +18,7 @@ func parseErrorDetails(details string) (int, string) {
|
|||||||
return 500, ""
|
return 500, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
var errDetails ErrorDetails
|
var errDetails notify.ErrorDetails
|
||||||
err := json.Unmarshal([]byte(details), &errDetails)
|
err := json.Unmarshal([]byte(details), &errDetails)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Not JSON or malformed - treat as plain stacktrace with default code
|
// Not JSON or malformed - treat as plain stacktrace with default code
|
||||||
@@ -168,6 +31,6 @@ func parseErrorDetails(details string) (int, string) {
|
|||||||
func renderSafely(page templ.Component, s *hws.Server, r *http.Request, w http.ResponseWriter) {
|
func renderSafely(page templ.Component, s *hws.Server, r *http.Request, w http.ResponseWriter) {
|
||||||
err := page.Render(r.Context(), w)
|
err := page.Render(r.Context(), w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "Failed to render page", errors.Wrap(err, "page."))
|
throw.InternalServiceError(s, w, r, "Failed to render page", errors.Wrap(err, "page."))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
@@ -14,7 +15,7 @@ func Index(s *hws.Server) http.Handler {
|
|||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path != "/" {
|
if r.URL.Path != "/" {
|
||||||
throwNotFound(s, w, r, r.URL.Path)
|
throw.NotFound(s, w, r, r.URL.Path)
|
||||||
}
|
}
|
||||||
renderSafely(page.Index(), s, r, w)
|
renderSafely(page.Index(), s, r, w)
|
||||||
},
|
},
|
||||||
|
|||||||
49
internal/handlers/isunique.go
Normal file
49
internal/handlers/isunique.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsUnique creates a handler that checks field uniqueness
|
||||||
|
// Returns 200 OK if unique, 409 Conflict if not unique
|
||||||
|
func IsUnique(
|
||||||
|
s *hws.Server,
|
||||||
|
conn *bun.DB,
|
||||||
|
model any,
|
||||||
|
field string,
|
||||||
|
) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
getter, err := validation.ParseForm(r)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := getter.String(field).TrimSpace().Required().Value
|
||||||
|
if !getter.Validate() {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
unique := false
|
||||||
|
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
|
unique, err = db.IsUnique(ctx, tx, model, field, value)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "db.IsUnique")
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if unique {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusConflict)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/store"
|
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
|
||||||
|
|
||||||
func IsUsernameUnique(
|
|
||||||
server *hws.Server,
|
|
||||||
conn *bun.DB,
|
|
||||||
cfg *config.Config,
|
|
||||||
store *store.Store,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
username := r.FormValue("username")
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(server, w, r, "Database transaction failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
unique, err := db.IsUsernameUnique(ctx, tx, username)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(server, w, r, "Database query failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
if !unique {
|
|
||||||
w.WriteHeader(http.StatusConflict)
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -11,7 +11,9 @@ import (
|
|||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/store"
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,7 +33,7 @@ func Login(
|
|||||||
|
|
||||||
if r.Method == "POST" {
|
if r.Method == "POST" {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
notifyServiceUnavailable(s, r, "Login currently unavailable", err)
|
notify.ServiceUnavailable(s, w, r, "Login currently unavailable", err)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -40,46 +42,29 @@ func Login(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwServiceUnavailable(s, w, r, "Login currently unavailable", err)
|
throw.ServiceUnavailable(s, w, r, "Login currently unavailable", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
|
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
|
||||||
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
|
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
|
||||||
|
|
||||||
if exceeded {
|
if exceeded {
|
||||||
err := errors.Errorf(
|
err = track.Error(attempts)
|
||||||
"login redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
|
||||||
attempts,
|
|
||||||
track.IP,
|
|
||||||
track.UserAgent,
|
|
||||||
track.Path,
|
|
||||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
|
||||||
)
|
|
||||||
|
|
||||||
st.ClearRedirectTrack(r, "/login")
|
st.ClearRedirectTrack(r, "/login")
|
||||||
|
throw.BadRequest(s, w, r, "Too many redirects. Please clear your browser cookies and try again", err)
|
||||||
throwError(
|
|
||||||
s,
|
|
||||||
w,
|
|
||||||
r,
|
|
||||||
http.StatusBadRequest,
|
|
||||||
"Login failed: Too many redirect attempts. Please clear your browser cookies and try again.",
|
|
||||||
err,
|
|
||||||
"warn",
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
|
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "Failed to generate state token", err)
|
throw.InternalServiceError(s, w, r, "Failed to generate state token", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
|
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
|
||||||
|
|
||||||
link, err := discordAPI.GetOAuthLink(state)
|
link, err := discordAPI.GetOAuthLink(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "An error occurred trying to generate the login link", err)
|
throw.InternalServiceError(s, w, r, "An error occurred trying to generate the login link", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
st.ClearRedirectTrack(r, "/login")
|
st.ClearRedirectTrack(r, "/login")
|
||||||
|
|||||||
@@ -3,12 +3,12 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/golib/hwsauth"
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
@@ -21,42 +21,31 @@ func Logout(
|
|||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() { _ = tx.Rollback() }()
|
|
||||||
|
|
||||||
user := db.CurrentUser(r.Context())
|
user := db.CurrentUser(r.Context())
|
||||||
if user == nil {
|
if user == nil {
|
||||||
// JIC - should be impossible to get here if route is protected by LoginReq
|
// JIC - should be impossible to get here if route is protected by LoginReq
|
||||||
w.Header().Set("HX-Redirect", "/")
|
w.Header().Set("HX-Redirect", "/")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err := user.DeleteDiscordTokens(ctx, tx)
|
if ok := db.WithWriteTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
if err != nil {
|
token, err := user.DeleteDiscordTokens(ctx, tx)
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if token != nil {
|
|
||||||
err = discordAPI.RevokeToken(token.Convert())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
|
return false, errors.Wrap(err, "user.DeleteDiscordTokens")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
if token != nil {
|
||||||
err = auth.Logout(tx, w, r)
|
err = discordAPI.RevokeToken(token.Convert())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout"))
|
throw.InternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
|
||||||
return
|
return false, nil
|
||||||
}
|
}
|
||||||
err = tx.Commit()
|
}
|
||||||
if err != nil {
|
err = auth.Logout(tx, w, r)
|
||||||
throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "tx.Commit"))
|
if err != nil {
|
||||||
|
throw.InternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout"))
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("HX-Redirect", "/")
|
w.Header().Set("HX-Redirect", "/")
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||||
|
"git.haelnorr.com/h/timefmt"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
@@ -31,248 +32,62 @@ func NewSeasonSubmit(
|
|||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Parse form data
|
getter, ok := validation.ParseFormOrNotify(s, w, r)
|
||||||
err := r.ParseForm()
|
if !ok {
|
||||||
if err != nil {
|
return
|
||||||
err = notifyWarn(s, r, "Invalid Form", "Please check your input and try again.", nil)
|
}
|
||||||
if err != nil {
|
name := getter.String("name").
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
TrimSpace().Required().
|
||||||
}
|
MaxLength(20).MinLength(5).Value
|
||||||
|
shortName := getter.String("short_name").
|
||||||
|
TrimSpace().ToUpper().Required().
|
||||||
|
MaxLength(6).MinLength(2).Value
|
||||||
|
format := timefmt.NewBuilder().
|
||||||
|
DayNumeric2().Slash().
|
||||||
|
MonthNumeric2().Slash().
|
||||||
|
Year4().Build()
|
||||||
|
startDate := getter.Time("start_date", format).Required().Value
|
||||||
|
if !getter.ValidateAndNotify(s, w, r) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get form values
|
nameUnique := false
|
||||||
name := strings.TrimSpace(r.FormValue("name"))
|
shortNameUnique := false
|
||||||
shortName := strings.TrimSpace(strings.ToUpper(r.FormValue("short_name")))
|
var season *db.Season
|
||||||
startDateStr := r.FormValue("start_date")
|
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
|
var err error
|
||||||
// Validate required fields
|
nameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "name", name)
|
||||||
if name == "" || shortName == "" || startDateStr == "" {
|
|
||||||
err = notifyWarn(s, r, "Missing Fields", "All fields are required.", nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
return false, errors.Wrap(err, "db.IsSeasonNameUnique")
|
||||||
}
|
}
|
||||||
return
|
shortNameUnique, err = db.IsUnique(ctx, tx, (*db.Season)(nil), "short_name", shortName)
|
||||||
}
|
|
||||||
|
|
||||||
// Validate field lengths
|
|
||||||
if len(name) > 20 {
|
|
||||||
err = notifyWarn(s, r, "Invalid Name", "Season name must be 20 characters or less.", nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
return false, errors.Wrap(err, "db.IsSeasonShortNameUnique")
|
||||||
}
|
}
|
||||||
return
|
if !nameUnique || !shortNameUnique {
|
||||||
}
|
return true, nil
|
||||||
|
}
|
||||||
if len(shortName) > 6 {
|
season, err = db.NewSeason(ctx, tx, name, shortName, startDate)
|
||||||
err = notifyWarn(s, r, "Invalid Short Name", "Short name must be 6 characters or less.", nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
return false, errors.Wrap(err, "db.NewSeason")
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate short name is alphanumeric only
|
|
||||||
if !isAlphanumeric(shortName) {
|
|
||||||
err = notifyWarn(s, r, "Invalid Short Name", "Short name must contain only letters and numbers.", nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse start date (DD/MM/YYYY format)
|
|
||||||
startDate, err := time.Parse("02/01/2006", startDateStr)
|
|
||||||
if err != nil {
|
|
||||||
err = notifyWarn(s, r, "Invalid Date", "Please provide a valid start date in DD/MM/YYYY format.", nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Begin database transaction
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
// Double-check uniqueness (race condition protection)
|
|
||||||
nameUnique, err := db.IsSeasonNameUnique(ctx, tx, name)
|
|
||||||
if err != nil {
|
|
||||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonNameUnique"))
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !nameUnique {
|
if !nameUnique {
|
||||||
err = notifyWarn(s, r, "Duplicate Name", "This season name is already taken.", nil)
|
notify.Warn(s, w, r, "Duplicate Name", "This season name is already taken.", nil)
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
shortNameUnique, err := db.IsSeasonShortNameUnique(ctx, tx, shortName)
|
|
||||||
if err != nil {
|
|
||||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonShortNameUnique"))
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !shortNameUnique {
|
if !shortNameUnique {
|
||||||
err = notifyWarn(s, r, "Duplicate Short Name", "This short name is already taken.", nil)
|
notify.Warn(s, w, r, "Duplicate Short Name", "This short name is already taken.", nil)
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the season
|
notify.Success(s, w, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil)
|
||||||
season, err := db.NewSeason(ctx, tx, name, shortName, startDate)
|
|
||||||
if err != nil {
|
|
||||||
err = notifyInternalServiceError(s, r, "Failed to create season", errors.Wrap(err, "db.NewSeason"))
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commit transaction
|
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "tx.Commit"))
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send success notification
|
|
||||||
err = notifySuccess(s, r, "Season Created", fmt.Sprintf("Successfully created season: %s", name), nil)
|
|
||||||
if err != nil {
|
|
||||||
// Log but don't fail the request
|
|
||||||
s.LogError(hws.HWSError{
|
|
||||||
StatusCode: http.StatusInternalServerError,
|
|
||||||
Message: "Failed to send success notification",
|
|
||||||
Error: err,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Redirect to the season detail page
|
|
||||||
w.Header().Set("HX-Redirect", fmt.Sprintf("/seasons/%s", season.ShortName))
|
w.Header().Set("HX-Redirect", fmt.Sprintf("/seasons/%s", season.ShortName))
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to validate alphanumeric strings
|
|
||||||
func isAlphanumeric(s string) bool {
|
|
||||||
for _, r := range s {
|
|
||||||
if ((r < 'A') || (r > 'Z')) && ((r < '0') || (r > '9')) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsSeasonNameUnique(
|
|
||||||
s *hws.Server,
|
|
||||||
conn *bun.DB,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
err := r.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
// Trim whitespace for consistency
|
|
||||||
name := strings.TrimSpace(r.FormValue("name"))
|
|
||||||
|
|
||||||
unique, err := db.IsSeasonNameUnique(ctx, tx, name)
|
|
||||||
if err != nil {
|
|
||||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonNameUnique"))
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.Commit()
|
|
||||||
|
|
||||||
if !unique {
|
|
||||||
w.WriteHeader(http.StatusConflict)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsSeasonShortNameUnique(
|
|
||||||
s *hws.Server,
|
|
||||||
conn *bun.DB,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
err := r.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
// Get short name and convert to uppercase for consistency
|
|
||||||
shortname := strings.ToUpper(strings.TrimSpace(r.FormValue("short_name")))
|
|
||||||
|
|
||||||
unique, err := db.IsSeasonShortNameUnique(ctx, tx, shortname)
|
|
||||||
if err != nil {
|
|
||||||
err = notifyInternalServiceError(s, r, "Database error", errors.Wrap(err, "db.IsSeasonShortNameUnique"))
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.Commit()
|
|
||||||
|
|
||||||
if !unique {
|
|
||||||
w.WriteHeader(http.StatusConflict)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
|
||||||
"git.haelnorr.com/h/golib/notify"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
func notifyClient(
|
|
||||||
s *hws.Server,
|
|
||||||
r *http.Request,
|
|
||||||
level notify.Level,
|
|
||||||
title, message, details string,
|
|
||||||
action any,
|
|
||||||
) error {
|
|
||||||
subCookie, err := r.Cookie("ws_sub_id")
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "r.Cookie")
|
|
||||||
}
|
|
||||||
subID := notify.Target(subCookie.Value)
|
|
||||||
nt := notify.Notification{
|
|
||||||
Target: subID,
|
|
||||||
Title: title,
|
|
||||||
Message: message,
|
|
||||||
Details: details,
|
|
||||||
Action: action,
|
|
||||||
Level: level,
|
|
||||||
}
|
|
||||||
s.NotifySub(nt)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func notifyInternalServiceError(s *hws.Server, r *http.Request, msg string, err error) error {
|
|
||||||
return notifyClient(s, r, notify.LevelError, "Internal Service Error", msg,
|
|
||||||
SerializeErrorDetails(http.StatusInternalServerError, err), nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func notifyServiceUnavailable(s *hws.Server, r *http.Request, msg string, err error) error {
|
|
||||||
return notifyClient(s, r, notify.LevelError, "Service Unavailable", msg,
|
|
||||||
SerializeErrorDetails(http.StatusServiceUnavailable, err), nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func notifyWarn(s *hws.Server, r *http.Request, title, msg string, action any) error {
|
|
||||||
return notifyClient(s, r, notify.LevelWarn, title, msg, "", action)
|
|
||||||
}
|
|
||||||
|
|
||||||
func notifyInfo(s *hws.Server, r *http.Request, title, msg string, action any) error {
|
|
||||||
return notifyClient(s, r, notify.LevelInfo, title, msg, "", action)
|
|
||||||
}
|
|
||||||
|
|
||||||
func notifySuccess(s *hws.Server, r *http.Request, title, msg string, action any) error {
|
|
||||||
return notifyClient(s, r, notify.LevelSuccess, title, msg, "", action)
|
|
||||||
}
|
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"git.haelnorr.com/h/golib/notify"
|
"git.haelnorr.com/h/golib/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
"git.haelnorr.com/h/oslstats/internal/view/component/popup"
|
"git.haelnorr.com/h/oslstats/internal/view/component/popup"
|
||||||
|
|
||||||
"github.com/coder/websocket"
|
"github.com/coder/websocket"
|
||||||
@@ -22,7 +23,7 @@ func NotificationWS(
|
|||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Upgrade") != "websocket" {
|
if r.Header.Get("Upgrade") != "websocket" {
|
||||||
throwNotFound(s, w, r, r.URL.Path)
|
throw.NotFound(s, w, r, r.URL.Path)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
nc, err := setupClient(s, w, r)
|
nc, err := setupClient(s, w, r)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,12 +17,6 @@ func NotifyTester(s *hws.Server) http.Handler {
|
|||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
testErr := errors.New("This is a stack trace. No really i swear. Just pretend ok? Thanks")
|
testErr := errors.New("This is a stack trace. No really i swear. Just pretend ok? Thanks")
|
||||||
if r.Method == "GET" {
|
if r.Method == "GET" {
|
||||||
// page, _ := ErrorPage(hws.HWSError{
|
|
||||||
// StatusCode: http.StatusTeapot,
|
|
||||||
// Message: "This error has been rendered as a test",
|
|
||||||
// Error: testErr,
|
|
||||||
// })
|
|
||||||
// page.Render(r.Context(), w)
|
|
||||||
renderSafely(page.Test(), s, r, w)
|
renderSafely(page.Test(), s, r, w)
|
||||||
} else {
|
} else {
|
||||||
_ = r.ParseForm()
|
_ = r.ParseForm()
|
||||||
@@ -30,19 +25,15 @@ func NotifyTester(s *hws.Server) http.Handler {
|
|||||||
level := r.Form.Get("type")
|
level := r.Form.Get("type")
|
||||||
message := r.Form.Get("message")
|
message := r.Form.Get("message")
|
||||||
|
|
||||||
var err error
|
|
||||||
switch level {
|
switch level {
|
||||||
case "success":
|
case "success":
|
||||||
err = notifySuccess(s, r, title, message, nil)
|
notify.Success(s, w, r, title, message, nil)
|
||||||
case "info":
|
case "info":
|
||||||
err = notifyInfo(s, r, title, message, nil)
|
notify.Info(s, w, r, title, message, nil)
|
||||||
case "warn":
|
case "warn":
|
||||||
err = notifyWarn(s, r, title, message, nil)
|
notify.Warn(s, w, r, title, message, nil)
|
||||||
case "error":
|
case "error":
|
||||||
err = notifyInternalServiceError(s, r, message, testErr)
|
notify.InternalServiceError(s, w, r, message, testErr)
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Error notifying client", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
|
||||||
|
|
||||||
func pageOptsFromForm(r *http.Request) (*db.PageOpts, error) {
|
|
||||||
var pageNum, perPage int
|
|
||||||
var order bun.Order
|
|
||||||
var orderBy string
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if pageStr := r.FormValue("page"); pageStr != "" {
|
|
||||||
pageNum, err = strconv.Atoi(pageStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "invalid page number")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if perPageStr := r.FormValue("per_page"); perPageStr != "" {
|
|
||||||
perPage, err = strconv.Atoi(perPageStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "invalid per_page number")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
order = bun.Order(r.FormValue("order"))
|
|
||||||
orderBy = r.FormValue("order_by")
|
|
||||||
|
|
||||||
pageOpts := &db.PageOpts{
|
|
||||||
Page: pageNum,
|
|
||||||
PerPage: perPage,
|
|
||||||
Order: order,
|
|
||||||
OrderBy: orderBy,
|
|
||||||
}
|
|
||||||
return pageOpts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func pageOptsFromQuery(r *http.Request) (*db.PageOpts, error) {
|
|
||||||
var pageNum, perPage int
|
|
||||||
var order bun.Order
|
|
||||||
var orderBy string
|
|
||||||
var err error
|
|
||||||
if pageStr := r.URL.Query().Get("page"); pageStr != "" {
|
|
||||||
pageNum, err = strconv.Atoi(pageStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "invalid page number")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if perPageStr := r.URL.Query().Get("per_page"); perPageStr != "" {
|
|
||||||
perPage, err = strconv.Atoi(perPageStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "invalid per_page number")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
order = bun.Order(r.URL.Query().Get("order"))
|
|
||||||
orderBy = r.URL.Query().Get("order_by")
|
|
||||||
pageOpts := &db.PageOpts{
|
|
||||||
Page: pageNum,
|
|
||||||
PerPage: perPage,
|
|
||||||
Order: order,
|
|
||||||
OrderBy: orderBy,
|
|
||||||
}
|
|
||||||
return pageOpts, nil
|
|
||||||
}
|
|
||||||
45
internal/handlers/pageopt_helpers.go
Normal file
45
internal/handlers/pageopt_helpers.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/validation"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// pageOptsFromForm calls r.ParseForm and gets the pageOpts from the formdata.
|
||||||
|
// It renders a Bad Request error page on fail
|
||||||
|
// PageOpts will be nil on fail
|
||||||
|
func pageOptsFromForm(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
|
||||||
|
getter, ok := validation.ParseFormOrError(s, w, r)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return getPageOpts(s, w, r, getter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pageOptsFromQuery gets the pageOpts from the request query and renders a Bad Request error page on fail
|
||||||
|
// PageOpts will be nil on fail
|
||||||
|
func pageOptsFromQuery(s *hws.Server, w http.ResponseWriter, r *http.Request) *db.PageOpts {
|
||||||
|
return getPageOpts(s, w, r, validation.NewQueryGetter(r))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPageOpts(s *hws.Server, w http.ResponseWriter, r *http.Request, g validation.Getter) *db.PageOpts {
|
||||||
|
page := g.Int("page").Min(1).Value
|
||||||
|
perPage := g.Int("per_page").Min(1).Max(100).Value
|
||||||
|
order := g.String("order").TrimSpace().ToUpper().AllowedValues([]string{"ASC", "DESC"}).Value
|
||||||
|
orderBy := g.String("order_by").TrimSpace().ToLower().Value
|
||||||
|
valid := g.ValidateAndError(s, w, r)
|
||||||
|
if !valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pageOpts := &db.PageOpts{
|
||||||
|
Page: page,
|
||||||
|
PerPage: perPage,
|
||||||
|
Order: bun.Order(order),
|
||||||
|
OrderBy: orderBy,
|
||||||
|
}
|
||||||
|
return pageOpts
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,7 +18,7 @@ func PermTester(s *hws.Server, conn *bun.DB) http.Handler {
|
|||||||
isAdmin, err := user.HasRole(r.Context(), tx, roles.Admin)
|
isAdmin, err := user.HasRole(r.Context(), tx, roles.Admin)
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(s, w, r, "Error", err)
|
throw.InternalServiceError(s, w, r, "Error", err)
|
||||||
}
|
}
|
||||||
_, _ = w.Write([]byte(strconv.FormatBool(isAdmin)))
|
_, _ = w.Write([]byte(strconv.FormatBool(isAdmin)))
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/cookies"
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
@@ -14,6 +13,7 @@ import (
|
|||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/store"
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,27 +29,9 @@ func Register(
|
|||||||
attempts, exceeded, track := store.TrackRedirect(r, "/register", 3)
|
attempts, exceeded, track := store.TrackRedirect(r, "/register", 3)
|
||||||
|
|
||||||
if exceeded {
|
if exceeded {
|
||||||
err := errors.Errorf(
|
err := track.Error(attempts)
|
||||||
"registration redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s ssl=%t",
|
|
||||||
attempts,
|
|
||||||
track.IP,
|
|
||||||
track.UserAgent,
|
|
||||||
track.Path,
|
|
||||||
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
|
||||||
cfg.HWSAuth.SSL,
|
|
||||||
)
|
|
||||||
|
|
||||||
store.ClearRedirectTrack(r, "/register")
|
store.ClearRedirectTrack(r, "/register")
|
||||||
|
throw.BadRequest(s, w, r, "Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again", err)
|
||||||
throwError(
|
|
||||||
s,
|
|
||||||
w,
|
|
||||||
r,
|
|
||||||
http.StatusBadRequest,
|
|
||||||
"Registration failed: Cookies appear to be blocked or disabled. Please enable cookies in your browser and try again. If this problem persists, try a different browser or contact support.",
|
|
||||||
err,
|
|
||||||
"warn",
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,68 +47,45 @@ func Register(
|
|||||||
}
|
}
|
||||||
|
|
||||||
store.ClearRedirectTrack(r, "/register")
|
store.ClearRedirectTrack(r, "/register")
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
|
||||||
defer cancel()
|
if r.Method == "GET" {
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Database transaction failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
method := r.Method
|
|
||||||
if method == "GET" {
|
|
||||||
tx.Commit()
|
|
||||||
renderSafely(page.Register(details.DiscordUser.Username), s, r, w)
|
renderSafely(page.Register(details.DiscordUser.Username), s, r, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if method == "POST" {
|
username := r.FormValue("username")
|
||||||
username := r.FormValue("username")
|
unique := false
|
||||||
user, err := registerUser(ctx, tx, username, details)
|
var user *db.User
|
||||||
|
if ok := db.WithNotifyTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
|
unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = notifyInternalServiceError(s, r, "Registration failed", err)
|
return false, errors.Wrap(err, "db.IsUsernameUnique")
|
||||||
if err != nil {
|
}
|
||||||
throwInternalServiceError(s, w, r, "Registration failed", err)
|
if !unique {
|
||||||
}
|
return true, nil
|
||||||
|
}
|
||||||
|
user, err = db.CreateUser(ctx, tx, username, details.DiscordUser)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "db.CreateUser")
|
||||||
|
}
|
||||||
|
err = user.UpdateDiscordToken(ctx, tx, details.Token)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "db.UpdateDiscordToken")
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !unique {
|
||||||
|
w.WriteHeader(http.StatusConflict)
|
||||||
|
} else {
|
||||||
|
err = auth.Login(w, r, user, true)
|
||||||
|
if err != nil {
|
||||||
|
throw.InternalServiceError(s, w, r, "Login failed", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
pageFrom := cookies.CheckPageFrom(w, r)
|
||||||
if user == nil {
|
w.Header().Set("HX-Redirect", pageFrom)
|
||||||
w.WriteHeader(http.StatusConflict)
|
|
||||||
} else {
|
|
||||||
err = auth.Login(w, r, user, true)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Login failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
pageFrom := cookies.CheckPageFrom(w, r)
|
|
||||||
w.Header().Set("HX-Redirect", pageFrom)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerUser(
|
|
||||||
ctx context.Context,
|
|
||||||
tx bun.Tx,
|
|
||||||
username string,
|
|
||||||
details *store.RegistrationSession,
|
|
||||||
) (*db.User, error) {
|
|
||||||
unique, err := db.IsUsernameUnique(ctx, tx, username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "db.IsUsernameUnique")
|
|
||||||
}
|
|
||||||
if !unique {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
user, err := db.CreateUser(ctx, tx, username, details.DiscordUser)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "db.CreateUser")
|
|
||||||
}
|
|
||||||
err = user.UpdateDiscordToken(ctx, tx, details.Token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "db.UpdateDiscordToken")
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -17,23 +17,20 @@ func SeasonPage(
|
|||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
seasonStr := r.PathValue("season_short_name")
|
seasonStr := r.PathValue("season_short_name")
|
||||||
season, err := db.GetSeason(ctx, tx, seasonStr)
|
var season *db.Season
|
||||||
if err != nil {
|
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.GetSeason"))
|
var err error
|
||||||
|
season, err = db.GetSeason(ctx, tx, seasonStr)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "db.GetSeason")
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
|
||||||
if season == nil {
|
if season == nil {
|
||||||
throwNotFound(s, w, r, r.URL.Path)
|
throw.NotFound(s, w, r, r.URL.Path)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
renderSafely(page.SeasonPage(season), s, r, w)
|
renderSafely(page.SeasonPage(season), s, r, w)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
@@ -17,25 +16,21 @@ func SeasonsPage(
|
|||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
pageOpts := pageOptsFromQuery(s, w, r)
|
||||||
defer cancel()
|
if pageOpts == nil {
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
var seasons *db.SeasonList
|
||||||
pageOpts, err := pageOptsFromQuery(r)
|
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
if err != nil {
|
var err error
|
||||||
throwBadRequest(s, w, r, "invalid query", err)
|
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "db.ListSeasons")
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
seasons, err := db.ListSeasons(ctx, tx, pageOpts)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.ListSeasons"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
renderSafely(page.SeasonsPage(seasons), s, r, w)
|
renderSafely(page.SeasonsPage(seasons), s, r, w)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -45,37 +40,21 @@ func SeasonsList(
|
|||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
pageOpts := pageOptsFromForm(s, w, r)
|
||||||
defer cancel()
|
if pageOpts == nil {
|
||||||
|
|
||||||
// Parse form values
|
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
throwBadRequest(s, w, r, "Invalid form data", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var seasons *db.SeasonList
|
||||||
pageOpts, err := pageOptsFromForm(r)
|
if ok := db.WithReadTx(s, w, r, conn, func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
if err != nil {
|
var err error
|
||||||
throwBadRequest(s, w, r, "invalid form data", err)
|
seasons, err = db.ListSeasons(ctx, tx, pageOpts)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "db.ListSeasons")
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Database query
|
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
seasons, err := db.ListSeasons(ctx, tx, pageOpts)
|
|
||||||
if err != nil {
|
|
||||||
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "db.ListSeasons"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
|
|
||||||
// Return only the list component (hx-push-url handles URL update client-side)
|
|
||||||
renderSafely(page.SeasonsList(seasons), s, r, w)
|
renderSafely(page.SeasonsList(seasons), s, r, w)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StaticFS handles requests for static files, without allowing access to the
|
// StaticFS handles requests for static files, without allowing access to the
|
||||||
@@ -16,7 +17,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// If we can't create the file server, return a handler that always errors
|
// If we can't create the file server, return a handler that always errors
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err)
|
throw.InternalServiceError(server, w, r, "An error occurred trying to load the file system", err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
63
internal/notify/notify.go
Normal file
63
internal/notify/notify.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
// Package notify provides utility functions for sending notifications to clients
|
||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/notify"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func notifyClient(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
level notify.Level,
|
||||||
|
title, message, details string,
|
||||||
|
action any,
|
||||||
|
) {
|
||||||
|
subCookie, err := r.Cookie("ws_sub_id")
|
||||||
|
if err != nil {
|
||||||
|
throw.InternalServiceError(s, w, r, "Notification failed. Do you have cookies enabled?", errors.Wrap(err, "r.Cookie"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
subID := notify.Target(subCookie.Value)
|
||||||
|
nt := notify.Notification{
|
||||||
|
Target: subID,
|
||||||
|
Title: title,
|
||||||
|
Message: message,
|
||||||
|
Details: details,
|
||||||
|
Action: action,
|
||||||
|
Level: level,
|
||||||
|
}
|
||||||
|
s.NotifySub(nt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InternalServiceError notifies with error level
|
||||||
|
func InternalServiceError(s *hws.Server, w http.ResponseWriter, r *http.Request, msg string, err error) {
|
||||||
|
notifyClient(s, w, r, notify.LevelError, "Internal Service Error", msg,
|
||||||
|
SerializeErrorDetails(http.StatusInternalServerError, err), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServiceUnavailable notifies with error level
|
||||||
|
func ServiceUnavailable(s *hws.Server, w http.ResponseWriter, r *http.Request, msg string, err error) {
|
||||||
|
notifyClient(s, w, r, notify.LevelError, "Service Unavailable", msg,
|
||||||
|
SerializeErrorDetails(http.StatusServiceUnavailable, err), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn notifies with warn level
|
||||||
|
func Warn(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) {
|
||||||
|
notifyClient(s, w, r, notify.LevelWarn, title, msg, "", action)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info notifies with info level
|
||||||
|
func Info(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) {
|
||||||
|
notifyClient(s, w, r, notify.LevelInfo, title, msg, "", action)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success notifies with success level
|
||||||
|
func Success(s *hws.Server, w http.ResponseWriter, r *http.Request, title, msg string, action any) {
|
||||||
|
notifyClient(s, w, r, notify.LevelSuccess, title, msg, "", action)
|
||||||
|
}
|
||||||
53
internal/notify/util.go
Normal file
53
internal/notify/util.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorDetails contains structured error information for WebSocket error modals
|
||||||
|
type ErrorDetails struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Stacktrace string `json:"stacktrace"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerializeErrorDetails creates a JSON string with code and stacktrace
|
||||||
|
// This is exported so it can be used when creating error notifications
|
||||||
|
func SerializeErrorDetails(code int, err error) string {
|
||||||
|
details := ErrorDetails{
|
||||||
|
Code: code,
|
||||||
|
Stacktrace: FormatErrorDetails(err),
|
||||||
|
}
|
||||||
|
jsonData, jsonErr := json.Marshal(details)
|
||||||
|
if jsonErr != nil {
|
||||||
|
// Fallback if JSON encoding fails
|
||||||
|
return fmt.Sprintf(`{"code":%d,"stacktrace":"Failed to serialize error"}`, code)
|
||||||
|
}
|
||||||
|
return string(jsonData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseErrorDetails extracts code and stacktrace from JSON Details field
|
||||||
|
// Returns (code, stacktrace). If parsing fails, returns (500, original details string)
|
||||||
|
func ParseErrorDetails(details string) (int, string) {
|
||||||
|
if details == "" {
|
||||||
|
return 500, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var errDetails ErrorDetails
|
||||||
|
err := json.Unmarshal([]byte(details), &errDetails)
|
||||||
|
if err != nil {
|
||||||
|
// Not JSON or malformed - treat as plain stacktrace with default code
|
||||||
|
return 500, details
|
||||||
|
}
|
||||||
|
|
||||||
|
return errDetails.Code, errDetails.Stacktrace
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatErrorDetails extracts and formats error details from wrapped errors
|
||||||
|
func FormatErrorDetails(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Use %+v format to get stack trace from github.com/pkg/errors
|
||||||
|
return fmt.Sprintf("%+v", err)
|
||||||
|
}
|
||||||
@@ -3,7 +3,6 @@ package rbac
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/oslstats/internal/contexts"
|
"git.haelnorr.com/h/oslstats/internal/contexts"
|
||||||
@@ -11,6 +10,7 @@ import (
|
|||||||
"git.haelnorr.com/h/oslstats/internal/permissions"
|
"git.haelnorr.com/h/oslstats/internal/permissions"
|
||||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoadPermissionsMiddleware loads user permissions into context after authentication
|
// LoadPermissionsMiddleware loads user permissions into context after authentication
|
||||||
@@ -26,46 +26,28 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start transaction for loading permissions
|
var roles_ []*db.Role
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
var perms []*db.Permission
|
||||||
defer cancel()
|
if err := db.WithTxFailSilently(r.Context(), c.conn, func(ctx context.Context, tx bun.Tx) error {
|
||||||
tx, err := c.conn.BeginTx(ctx, nil)
|
var err error
|
||||||
if err != nil {
|
roles_, err = user.GetRoles(ctx, tx)
|
||||||
// Log but don't block - permission checks will fail gracefully
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "user.GetRoles")
|
||||||
|
}
|
||||||
|
perms, err = user.GetPermissions(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "user.GetPermissions")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
c.s.LogError(hws.HWSError{
|
c.s.LogError(hws.HWSError{
|
||||||
Message: "Failed to start database transaction",
|
Message: "Database error",
|
||||||
Error: errors.Wrap(err, "c.conn.BeginTx"),
|
Error: err,
|
||||||
Level: hws.ErrorERROR,
|
Level: hws.ErrorERROR,
|
||||||
})
|
})
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
|
||||||
|
|
||||||
// Load user's roles_ and permissions
|
|
||||||
roles_, err := user.GetRoles(ctx, tx)
|
|
||||||
if err != nil {
|
|
||||||
c.s.LogError(hws.HWSError{
|
|
||||||
Message: "Failed to get user roles",
|
|
||||||
Error: errors.Wrap(err, "user.GetRoles"),
|
|
||||||
Level: hws.ErrorERROR,
|
|
||||||
})
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
perms, err := user.GetPermissions(ctx, tx)
|
|
||||||
if err != nil {
|
|
||||||
c.s.LogError(hws.HWSError{
|
|
||||||
Message: "Failed to get user permissions",
|
|
||||||
Error: errors.Wrap(err, "user.GetPermissions"),
|
|
||||||
Level: hws.ErrorERROR,
|
|
||||||
})
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = tx.Commit() // read only transaction
|
|
||||||
|
|
||||||
// Build permission cache
|
// Build permission cache
|
||||||
cache := &contexts.PermissionCache{
|
cache := &contexts.PermissionCache{
|
||||||
@@ -88,7 +70,7 @@ func (c *Checker) LoadPermissionsMiddleware() hws.Middleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add cache to context (type-safe)
|
// Add cache to context (type-safe)
|
||||||
ctx = context.WithValue(ctx, contexts.PermissionCacheKey, cache)
|
ctx := context.WithValue(r.Context(), contexts.PermissionCacheKey, cache)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// getClientIP extracts the client IP address, checking X-Forwarded-For first
|
// getClientIP extracts the client IP address, checking X-Forwarded-For first
|
||||||
@@ -93,3 +95,14 @@ func (s *Store) ClearRedirectTrack(r *http.Request, path string) {
|
|||||||
key := redirectKey(ip, userAgent, path)
|
key := redirectKey(ip, userAgent, path)
|
||||||
s.redirectTracks.Delete(key)
|
s.redirectTracks.Delete(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *RedirectTrack) Error(attempts int) error {
|
||||||
|
return errors.Errorf(
|
||||||
|
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
||||||
|
attempts,
|
||||||
|
t.IP,
|
||||||
|
t.UserAgent,
|
||||||
|
t.Path,
|
||||||
|
t.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
118
internal/throw/throw.go
Normal file
118
internal/throw/throw.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
// Package throw provides utility functions for throwing HTTP errors that render an error page
|
||||||
|
package throw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// throwError is a generic helper that all throw* functions use internally
|
||||||
|
func throwError(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
statusCode int,
|
||||||
|
msg string,
|
||||||
|
err error,
|
||||||
|
level hws.ErrorLevel,
|
||||||
|
) {
|
||||||
|
s.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Message: msg,
|
||||||
|
Error: err,
|
||||||
|
Level: level,
|
||||||
|
RenderErrorPage: true, // throw* family always renders error pages
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// InternalServiceError handles 500 errors (server failures)
|
||||||
|
func InternalServiceError(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
msg string,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
throwError(s, w, r, http.StatusInternalServerError, msg, err, hws.ErrorERROR)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServiceUnavailable handles 503 errors
|
||||||
|
func ServiceUnavailable(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
msg string,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
throwError(s, w, r, http.StatusServiceUnavailable, msg, err, hws.ErrorERROR)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BadRequest handles 400 errors (malformed requests)
|
||||||
|
func BadRequest(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
msg string,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
throwError(s, w, r, http.StatusBadRequest, msg, err, hws.ErrorDEBUG)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forbidden handles 403 errors (normal permission denials)
|
||||||
|
func Forbidden(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
msg string,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorDEBUG)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForbiddenSecurity handles 403 errors for security events (uses WARN level)
|
||||||
|
func ForbiddenSecurity(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
msg string,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
throwError(s, w, r, http.StatusForbidden, msg, err, hws.ErrorWARN)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unauthorized handles 401 errors (not authenticated)
|
||||||
|
func Unauthorized(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
msg string,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorDEBUG)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnauthorizedSecurity handles 401 errors for security events (uses WARN level)
|
||||||
|
func UnauthorizedSecurity(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
msg string,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
throwError(s, w, r, http.StatusUnauthorized, msg, err, hws.ErrorWARN)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotFound handles 404 errors
|
||||||
|
func NotFound(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
path string,
|
||||||
|
) {
|
||||||
|
msg := fmt.Sprintf("The requested resource was not found: %s", path)
|
||||||
|
err := errors.New("Resource not found")
|
||||||
|
throwError(s, w, r, http.StatusNotFound, msg, err, hws.ErrorDEBUG)
|
||||||
|
}
|
||||||
101
internal/validation/forms.go
Normal file
101
internal/validation/forms.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
|
"git.haelnorr.com/h/timefmt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FormGetter wraps http.Request to get form values
|
||||||
|
type FormGetter struct {
|
||||||
|
r *http.Request
|
||||||
|
checks []*ValidationRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFormGetter(r *http.Request) *FormGetter {
|
||||||
|
return &FormGetter{r: r, checks: []*ValidationRule{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FormGetter) Get(key string) string {
|
||||||
|
return f.r.FormValue(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FormGetter) getChecks() []*ValidationRule {
|
||||||
|
return f.checks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FormGetter) AddCheck(check *ValidationRule) {
|
||||||
|
f.checks = append(f.checks, check)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FormGetter) ValidateChecks() []*ValidationRule {
|
||||||
|
return validate(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FormGetter) String(key string) *StringField {
|
||||||
|
return newStringField(key, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FormGetter) Int(key string) *IntField {
|
||||||
|
return newIntField(key, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FormGetter) Time(key string, format *timefmt.Format) *TimeField {
|
||||||
|
return newTimeField(key, format, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseForm(r *http.Request) (*FormGetter, error) {
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "r.ParseForm")
|
||||||
|
}
|
||||||
|
return NewFormGetter(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFormOrNotify attempts to parse the form data and notifies the user on fail
|
||||||
|
func ParseFormOrNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) {
|
||||||
|
getter, err := ParseForm(r)
|
||||||
|
if err != nil {
|
||||||
|
notify.Warn(s, w, r, "Invalid Form", "Please check your input and try again.", nil)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return getter, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFormOrError attempts to parse the form data and renders an error page on fail
|
||||||
|
func ParseFormOrError(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) {
|
||||||
|
getter, err := ParseForm(r)
|
||||||
|
if err != nil {
|
||||||
|
throw.BadRequest(s, w, r, "Invalid form data", err)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return getter, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FormGetter) Validate() bool {
|
||||||
|
return len(validate(f)) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check
|
||||||
|
// Returns true if all checks passed
|
||||||
|
func (f *FormGetter) ValidateAndNotify(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) bool {
|
||||||
|
return validateAndNotify(s, w, r, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndError runs the provided validation checks and renders an error page with all the error messages
|
||||||
|
// Returns true if all checks passed
|
||||||
|
func (f *FormGetter) ValidateAndError(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) bool {
|
||||||
|
return validateAndError(s, w, r, f)
|
||||||
|
}
|
||||||
71
internal/validation/integerfield.go
Normal file
71
internal/validation/integerfield.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IntField struct {
|
||||||
|
Field
|
||||||
|
Value int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIntField(key string, g Getter) *IntField {
|
||||||
|
raw := g.Get(key)
|
||||||
|
var val int
|
||||||
|
if raw != "" {
|
||||||
|
var err error
|
||||||
|
val, err = strconv.Atoi(raw)
|
||||||
|
if err != nil {
|
||||||
|
g.AddCheck(newFailedCheck(
|
||||||
|
"Value is not a number",
|
||||||
|
fmt.Sprintf("%s must be an integer: %s provided", key, raw),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &IntField{
|
||||||
|
Value: val,
|
||||||
|
Field: newField(key, g),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required enforces a non-zero value
|
||||||
|
func (i *IntField) Required() *IntField {
|
||||||
|
if i.Value == 0 {
|
||||||
|
i.getter.AddCheck(newFailedCheck(
|
||||||
|
"Value cannot be 0",
|
||||||
|
fmt.Sprintf("%s is required", i.Key),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional will skip all validations if value is empty
|
||||||
|
func (i *IntField) Optional() *IntField {
|
||||||
|
if i.Value == 0 {
|
||||||
|
i.optional = true
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max enforces a maxmium value
|
||||||
|
func (i *IntField) Max(max int) *IntField {
|
||||||
|
if i.Value > max && !i.optional {
|
||||||
|
i.getter.AddCheck(newFailedCheck(
|
||||||
|
"Value too large",
|
||||||
|
fmt.Sprintf("%s is too large, max %v", i.Key, max),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// Min enforces a minimum value
|
||||||
|
func (i *IntField) Min(min int) *IntField {
|
||||||
|
if i.Value < min && !i.optional {
|
||||||
|
i.getter.AddCheck(newFailedCheck(
|
||||||
|
"Value too small",
|
||||||
|
fmt.Sprintf("%s is too small, min %v", i.Key, min),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
70
internal/validation/querys.go
Normal file
70
internal/validation/querys.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/timefmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QueryGetter wraps http.Request to get query values
|
||||||
|
type QueryGetter struct {
|
||||||
|
r *http.Request
|
||||||
|
checks []*ValidationRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQueryGetter(r *http.Request) *QueryGetter {
|
||||||
|
return &QueryGetter{r: r, checks: []*ValidationRule{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryGetter) Get(key string) string {
|
||||||
|
return q.r.URL.Query().Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryGetter) getChecks() []*ValidationRule {
|
||||||
|
return q.checks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryGetter) AddCheck(check *ValidationRule) {
|
||||||
|
q.checks = append(q.checks, check)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryGetter) ValidateChecks() []*ValidationRule {
|
||||||
|
return validate(q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryGetter) String(key string) *StringField {
|
||||||
|
return newStringField(key, q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryGetter) Int(key string) *IntField {
|
||||||
|
return newIntField(key, q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryGetter) Time(key string, format *timefmt.Format) *TimeField {
|
||||||
|
return newTimeField(key, format, q)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryGetter) Validate() bool {
|
||||||
|
return len(validate(q)) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check
|
||||||
|
// Returns true if all checks passed
|
||||||
|
func (q *QueryGetter) ValidateAndNotify(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) bool {
|
||||||
|
return validateAndNotify(s, w, r, q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndError runs the provided validation checks and renders an error page with all the error messages
|
||||||
|
// Returns true if all checks passed
|
||||||
|
func (q *QueryGetter) ValidateAndError(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) bool {
|
||||||
|
return validateAndError(s, w, r, q)
|
||||||
|
}
|
||||||
106
internal/validation/stringfield.go
Normal file
106
internal/validation/stringfield.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StringField struct {
|
||||||
|
Field
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStringField(key string, g Getter) *StringField {
|
||||||
|
return &StringField{
|
||||||
|
Value: g.Get(key),
|
||||||
|
Field: newField(key, g),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Required enforces a non empty string
|
||||||
|
func (s *StringField) Required() *StringField {
|
||||||
|
if s.Value == "" {
|
||||||
|
s.getter.AddCheck(newFailedCheck(
|
||||||
|
"Field not provided",
|
||||||
|
fmt.Sprintf("%s is required", s.Key),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional will skip all validations if value is empty
|
||||||
|
func (s *StringField) Optional() *StringField {
|
||||||
|
if s.Value == "" {
|
||||||
|
s.optional = true
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxLength enforces a maximum string length
|
||||||
|
func (s *StringField) MaxLength(length int) *StringField {
|
||||||
|
if len(s.Value) > length && !s.optional {
|
||||||
|
s.getter.AddCheck(newFailedCheck(
|
||||||
|
"Input too long",
|
||||||
|
fmt.Sprintf("%s is too long, max %v chars", s.Key, length),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinLength enforces a minimum string length
|
||||||
|
func (s *StringField) MinLength(length int) *StringField {
|
||||||
|
if len(s.Value) < length && !s.optional {
|
||||||
|
s.getter.AddCheck(newFailedCheck(
|
||||||
|
"Input too short",
|
||||||
|
fmt.Sprintf("%s is too short, min %v chars", s.Key, length),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// AlphaNumeric enforces the string contains only letters and numbers
|
||||||
|
func (s *StringField) AlphaNumeric() *StringField {
|
||||||
|
if s.optional {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
for _, r := range s.Value {
|
||||||
|
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
|
||||||
|
s.getter.AddCheck(newFailedCheck(
|
||||||
|
"Invalid characters",
|
||||||
|
fmt.Sprintf("%s must contain only letters and numbers.", s.Key),
|
||||||
|
))
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StringField) AllowedValues(allowed []string) *StringField {
|
||||||
|
if !slices.Contains(allowed, s.Value) && !s.optional {
|
||||||
|
s.getter.AddCheck(newFailedCheck(
|
||||||
|
"Value not allowed",
|
||||||
|
fmt.Sprintf("%s must be one of: %s", s.Key, strings.Join(allowed, ",")),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToUpper transforms the string to uppercase
|
||||||
|
func (s *StringField) ToUpper() *StringField {
|
||||||
|
s.Value = strings.ToUpper(s.Value)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToLower transforms the string to lowercase
|
||||||
|
func (s *StringField) ToLower() *StringField {
|
||||||
|
s.Value = strings.ToLower(s.Value)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrimSpace removes leading and trailing whitespace
|
||||||
|
func (s *StringField) TrimSpace() *StringField {
|
||||||
|
s.Value = strings.TrimSpace(s.Value)
|
||||||
|
return s
|
||||||
|
}
|
||||||
50
internal/validation/timefield.go
Normal file
50
internal/validation/timefield.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/timefmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TimeField struct {
|
||||||
|
Field
|
||||||
|
Value time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTimeField(key string, format *timefmt.Format, g Getter) *TimeField {
|
||||||
|
raw := g.Get(key)
|
||||||
|
var startDate time.Time
|
||||||
|
if raw != "" {
|
||||||
|
var err error
|
||||||
|
startDate, err = format.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
g.AddCheck(newFailedCheck(
|
||||||
|
"Invalid date/time format",
|
||||||
|
fmt.Sprintf("%s should be in format %s", key, format.LDML()),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &TimeField{
|
||||||
|
Value: startDate,
|
||||||
|
Field: newField(key, g),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TimeField) Required() *TimeField {
|
||||||
|
if t.Value.IsZero() {
|
||||||
|
t.getter.AddCheck(newFailedCheck(
|
||||||
|
"Date/Time not provided",
|
||||||
|
fmt.Sprintf("%s must be provided", t.Key),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional will skip all validations if value is empty
|
||||||
|
func (t *TimeField) Optional() *TimeField {
|
||||||
|
if t.Value.IsZero() {
|
||||||
|
t.optional = true
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
93
internal/validation/validation.go
Normal file
93
internal/validation/validation.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
// Package validation provides utilities for parsing and validating request data
|
||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
|
"git.haelnorr.com/h/timefmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ValidationRule struct {
|
||||||
|
Condition bool // Condition on which to fail validation
|
||||||
|
Title string // Title for warning message
|
||||||
|
Message string // Warning message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter abstracts getting values from either form or query
|
||||||
|
type Getter interface {
|
||||||
|
Get(key string) string
|
||||||
|
AddCheck(check *ValidationRule)
|
||||||
|
String(key string) *StringField
|
||||||
|
Int(key string) *IntField
|
||||||
|
Time(key string, format *timefmt.Format) *TimeField
|
||||||
|
ValidateChecks() []*ValidationRule
|
||||||
|
Validate() bool
|
||||||
|
ValidateAndNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) bool
|
||||||
|
ValidateAndError(s *hws.Server, w http.ResponseWriter, r *http.Request) bool
|
||||||
|
|
||||||
|
getChecks() []*ValidationRule
|
||||||
|
}
|
||||||
|
type Field struct {
|
||||||
|
Key string
|
||||||
|
optional bool
|
||||||
|
getter Getter
|
||||||
|
}
|
||||||
|
|
||||||
|
func newField(key string, g Getter) Field {
|
||||||
|
return Field{
|
||||||
|
Key: key,
|
||||||
|
getter: g,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFailedCheck(title, message string) *ValidationRule {
|
||||||
|
return &ValidationRule{
|
||||||
|
Condition: true,
|
||||||
|
Title: title,
|
||||||
|
Message: message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validate(g Getter) []*ValidationRule {
|
||||||
|
failed := []*ValidationRule{}
|
||||||
|
for _, check := range g.getChecks() {
|
||||||
|
if check != nil && check.Condition {
|
||||||
|
failed = append(failed, check)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return failed
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAndNotify(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
g Getter,
|
||||||
|
) bool {
|
||||||
|
failedChecks := g.ValidateChecks()
|
||||||
|
for _, check := range failedChecks {
|
||||||
|
notify.Warn(s, w, r, check.Title, check.Message, nil)
|
||||||
|
}
|
||||||
|
return len(failedChecks) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAndError(
|
||||||
|
s *hws.Server,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
g Getter,
|
||||||
|
) bool {
|
||||||
|
failedChecks := g.ValidateChecks()
|
||||||
|
var err error
|
||||||
|
for _, check := range failedChecks {
|
||||||
|
err_ := fmt.Errorf("%s: %s", check.Title, check.Message)
|
||||||
|
err = errors.Join(err, err_)
|
||||||
|
}
|
||||||
|
throw.BadRequest(s, w, r, "Invalid form data", err)
|
||||||
|
return len(failedChecks) == 0
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user