refactored for maintainability
This commit is contained in:
101
internal/validation/forms.go
Normal file
101
internal/validation/forms.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/timefmt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// FormGetter wraps http.Request to get form values
|
||||
type FormGetter struct {
|
||||
r *http.Request
|
||||
checks []*ValidationRule
|
||||
}
|
||||
|
||||
func NewFormGetter(r *http.Request) *FormGetter {
|
||||
return &FormGetter{r: r, checks: []*ValidationRule{}}
|
||||
}
|
||||
|
||||
func (f *FormGetter) Get(key string) string {
|
||||
return f.r.FormValue(key)
|
||||
}
|
||||
|
||||
func (f *FormGetter) getChecks() []*ValidationRule {
|
||||
return f.checks
|
||||
}
|
||||
|
||||
func (f *FormGetter) AddCheck(check *ValidationRule) {
|
||||
f.checks = append(f.checks, check)
|
||||
}
|
||||
|
||||
func (f *FormGetter) ValidateChecks() []*ValidationRule {
|
||||
return validate(f)
|
||||
}
|
||||
|
||||
func (f *FormGetter) String(key string) *StringField {
|
||||
return newStringField(key, f)
|
||||
}
|
||||
|
||||
func (f *FormGetter) Int(key string) *IntField {
|
||||
return newIntField(key, f)
|
||||
}
|
||||
|
||||
func (f *FormGetter) Time(key string, format *timefmt.Format) *TimeField {
|
||||
return newTimeField(key, format, f)
|
||||
}
|
||||
|
||||
func ParseForm(r *http.Request) (*FormGetter, error) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "r.ParseForm")
|
||||
}
|
||||
return NewFormGetter(r), nil
|
||||
}
|
||||
|
||||
// ParseFormOrNotify attempts to parse the form data and notifies the user on fail
|
||||
func ParseFormOrNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) {
|
||||
getter, err := ParseForm(r)
|
||||
if err != nil {
|
||||
notify.Warn(s, w, r, "Invalid Form", "Please check your input and try again.", nil)
|
||||
return nil, false
|
||||
}
|
||||
return getter, true
|
||||
}
|
||||
|
||||
// ParseFormOrError attempts to parse the form data and renders an error page on fail
|
||||
func ParseFormOrError(s *hws.Server, w http.ResponseWriter, r *http.Request) (*FormGetter, bool) {
|
||||
getter, err := ParseForm(r)
|
||||
if err != nil {
|
||||
throw.BadRequest(s, w, r, "Invalid form data", err)
|
||||
return nil, false
|
||||
}
|
||||
return getter, true
|
||||
}
|
||||
|
||||
func (f *FormGetter) Validate() bool {
|
||||
return len(validate(f)) == 0
|
||||
}
|
||||
|
||||
// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check
|
||||
// Returns true if all checks passed
|
||||
func (f *FormGetter) ValidateAndNotify(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) bool {
|
||||
return validateAndNotify(s, w, r, f)
|
||||
}
|
||||
|
||||
// ValidateAndError runs the provided validation checks and renders an error page with all the error messages
|
||||
// Returns true if all checks passed
|
||||
func (f *FormGetter) ValidateAndError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) bool {
|
||||
return validateAndError(s, w, r, f)
|
||||
}
|
||||
71
internal/validation/integerfield.go
Normal file
71
internal/validation/integerfield.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type IntField struct {
|
||||
Field
|
||||
Value int
|
||||
}
|
||||
|
||||
func newIntField(key string, g Getter) *IntField {
|
||||
raw := g.Get(key)
|
||||
var val int
|
||||
if raw != "" {
|
||||
var err error
|
||||
val, err = strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
g.AddCheck(newFailedCheck(
|
||||
"Value is not a number",
|
||||
fmt.Sprintf("%s must be an integer: %s provided", key, raw),
|
||||
))
|
||||
}
|
||||
}
|
||||
return &IntField{
|
||||
Value: val,
|
||||
Field: newField(key, g),
|
||||
}
|
||||
}
|
||||
|
||||
// Required enforces a non-zero value
|
||||
func (i *IntField) Required() *IntField {
|
||||
if i.Value == 0 {
|
||||
i.getter.AddCheck(newFailedCheck(
|
||||
"Value cannot be 0",
|
||||
fmt.Sprintf("%s is required", i.Key),
|
||||
))
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// Optional will skip all validations if value is empty
|
||||
func (i *IntField) Optional() *IntField {
|
||||
if i.Value == 0 {
|
||||
i.optional = true
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// Max enforces a maxmium value
|
||||
func (i *IntField) Max(max int) *IntField {
|
||||
if i.Value > max && !i.optional {
|
||||
i.getter.AddCheck(newFailedCheck(
|
||||
"Value too large",
|
||||
fmt.Sprintf("%s is too large, max %v", i.Key, max),
|
||||
))
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// Min enforces a minimum value
|
||||
func (i *IntField) Min(min int) *IntField {
|
||||
if i.Value < min && !i.optional {
|
||||
i.getter.AddCheck(newFailedCheck(
|
||||
"Value too small",
|
||||
fmt.Sprintf("%s is too small, min %v", i.Key, min),
|
||||
))
|
||||
}
|
||||
return i
|
||||
}
|
||||
70
internal/validation/querys.go
Normal file
70
internal/validation/querys.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/timefmt"
|
||||
)
|
||||
|
||||
// QueryGetter wraps http.Request to get query values
|
||||
type QueryGetter struct {
|
||||
r *http.Request
|
||||
checks []*ValidationRule
|
||||
}
|
||||
|
||||
func NewQueryGetter(r *http.Request) *QueryGetter {
|
||||
return &QueryGetter{r: r, checks: []*ValidationRule{}}
|
||||
}
|
||||
|
||||
func (q *QueryGetter) Get(key string) string {
|
||||
return q.r.URL.Query().Get(key)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) getChecks() []*ValidationRule {
|
||||
return q.checks
|
||||
}
|
||||
|
||||
func (q *QueryGetter) AddCheck(check *ValidationRule) {
|
||||
q.checks = append(q.checks, check)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) ValidateChecks() []*ValidationRule {
|
||||
return validate(q)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) String(key string) *StringField {
|
||||
return newStringField(key, q)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) Int(key string) *IntField {
|
||||
return newIntField(key, q)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) Time(key string, format *timefmt.Format) *TimeField {
|
||||
return newTimeField(key, format, q)
|
||||
}
|
||||
|
||||
func (q *QueryGetter) Validate() bool {
|
||||
return len(validate(q)) == 0
|
||||
}
|
||||
|
||||
// ValidateAndNotify runs the provided validation checks and sends a notification for each failed check
|
||||
// Returns true if all checks passed
|
||||
func (q *QueryGetter) ValidateAndNotify(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) bool {
|
||||
return validateAndNotify(s, w, r, q)
|
||||
}
|
||||
|
||||
// ValidateAndError runs the provided validation checks and renders an error page with all the error messages
|
||||
// Returns true if all checks passed
|
||||
func (q *QueryGetter) ValidateAndError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) bool {
|
||||
return validateAndError(s, w, r, q)
|
||||
}
|
||||
106
internal/validation/stringfield.go
Normal file
106
internal/validation/stringfield.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type StringField struct {
|
||||
Field
|
||||
Value string
|
||||
}
|
||||
|
||||
func newStringField(key string, g Getter) *StringField {
|
||||
return &StringField{
|
||||
Value: g.Get(key),
|
||||
Field: newField(key, g),
|
||||
}
|
||||
}
|
||||
|
||||
// Required enforces a non empty string
|
||||
func (s *StringField) Required() *StringField {
|
||||
if s.Value == "" {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Field not provided",
|
||||
fmt.Sprintf("%s is required", s.Key),
|
||||
))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Optional will skip all validations if value is empty
|
||||
func (s *StringField) Optional() *StringField {
|
||||
if s.Value == "" {
|
||||
s.optional = true
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// MaxLength enforces a maximum string length
|
||||
func (s *StringField) MaxLength(length int) *StringField {
|
||||
if len(s.Value) > length && !s.optional {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Input too long",
|
||||
fmt.Sprintf("%s is too long, max %v chars", s.Key, length),
|
||||
))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// MinLength enforces a minimum string length
|
||||
func (s *StringField) MinLength(length int) *StringField {
|
||||
if len(s.Value) < length && !s.optional {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Input too short",
|
||||
fmt.Sprintf("%s is too short, min %v chars", s.Key, length),
|
||||
))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// AlphaNumeric enforces the string contains only letters and numbers
|
||||
func (s *StringField) AlphaNumeric() *StringField {
|
||||
if s.optional {
|
||||
return s
|
||||
}
|
||||
for _, r := range s.Value {
|
||||
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Invalid characters",
|
||||
fmt.Sprintf("%s must contain only letters and numbers.", s.Key),
|
||||
))
|
||||
return s
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *StringField) AllowedValues(allowed []string) *StringField {
|
||||
if !slices.Contains(allowed, s.Value) && !s.optional {
|
||||
s.getter.AddCheck(newFailedCheck(
|
||||
"Value not allowed",
|
||||
fmt.Sprintf("%s must be one of: %s", s.Key, strings.Join(allowed, ",")),
|
||||
))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ToUpper transforms the string to uppercase
|
||||
func (s *StringField) ToUpper() *StringField {
|
||||
s.Value = strings.ToUpper(s.Value)
|
||||
return s
|
||||
}
|
||||
|
||||
// ToLower transforms the string to lowercase
|
||||
func (s *StringField) ToLower() *StringField {
|
||||
s.Value = strings.ToLower(s.Value)
|
||||
return s
|
||||
}
|
||||
|
||||
// TrimSpace removes leading and trailing whitespace
|
||||
func (s *StringField) TrimSpace() *StringField {
|
||||
s.Value = strings.TrimSpace(s.Value)
|
||||
return s
|
||||
}
|
||||
50
internal/validation/timefield.go
Normal file
50
internal/validation/timefield.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/timefmt"
|
||||
)
|
||||
|
||||
type TimeField struct {
|
||||
Field
|
||||
Value time.Time
|
||||
}
|
||||
|
||||
func newTimeField(key string, format *timefmt.Format, g Getter) *TimeField {
|
||||
raw := g.Get(key)
|
||||
var startDate time.Time
|
||||
if raw != "" {
|
||||
var err error
|
||||
startDate, err = format.Parse(raw)
|
||||
if err != nil {
|
||||
g.AddCheck(newFailedCheck(
|
||||
"Invalid date/time format",
|
||||
fmt.Sprintf("%s should be in format %s", key, format.LDML()),
|
||||
))
|
||||
}
|
||||
}
|
||||
return &TimeField{
|
||||
Value: startDate,
|
||||
Field: newField(key, g),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TimeField) Required() *TimeField {
|
||||
if t.Value.IsZero() {
|
||||
t.getter.AddCheck(newFailedCheck(
|
||||
"Date/Time not provided",
|
||||
fmt.Sprintf("%s must be provided", t.Key),
|
||||
))
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// Optional will skip all validations if value is empty
|
||||
func (t *TimeField) Optional() *TimeField {
|
||||
if t.Value.IsZero() {
|
||||
t.optional = true
|
||||
}
|
||||
return t
|
||||
}
|
||||
93
internal/validation/validation.go
Normal file
93
internal/validation/validation.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Package validation provides utilities for parsing and validating request data
|
||||
package validation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/notify"
|
||||
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||
"git.haelnorr.com/h/timefmt"
|
||||
)
|
||||
|
||||
type ValidationRule struct {
|
||||
Condition bool // Condition on which to fail validation
|
||||
Title string // Title for warning message
|
||||
Message string // Warning message
|
||||
}
|
||||
|
||||
// Getter abstracts getting values from either form or query
|
||||
type Getter interface {
|
||||
Get(key string) string
|
||||
AddCheck(check *ValidationRule)
|
||||
String(key string) *StringField
|
||||
Int(key string) *IntField
|
||||
Time(key string, format *timefmt.Format) *TimeField
|
||||
ValidateChecks() []*ValidationRule
|
||||
Validate() bool
|
||||
ValidateAndNotify(s *hws.Server, w http.ResponseWriter, r *http.Request) bool
|
||||
ValidateAndError(s *hws.Server, w http.ResponseWriter, r *http.Request) bool
|
||||
|
||||
getChecks() []*ValidationRule
|
||||
}
|
||||
type Field struct {
|
||||
Key string
|
||||
optional bool
|
||||
getter Getter
|
||||
}
|
||||
|
||||
func newField(key string, g Getter) Field {
|
||||
return Field{
|
||||
Key: key,
|
||||
getter: g,
|
||||
}
|
||||
}
|
||||
|
||||
func newFailedCheck(title, message string) *ValidationRule {
|
||||
return &ValidationRule{
|
||||
Condition: true,
|
||||
Title: title,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
func validate(g Getter) []*ValidationRule {
|
||||
failed := []*ValidationRule{}
|
||||
for _, check := range g.getChecks() {
|
||||
if check != nil && check.Condition {
|
||||
failed = append(failed, check)
|
||||
}
|
||||
}
|
||||
return failed
|
||||
}
|
||||
|
||||
func validateAndNotify(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
g Getter,
|
||||
) bool {
|
||||
failedChecks := g.ValidateChecks()
|
||||
for _, check := range failedChecks {
|
||||
notify.Warn(s, w, r, check.Title, check.Message, nil)
|
||||
}
|
||||
return len(failedChecks) == 0
|
||||
}
|
||||
|
||||
func validateAndError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
g Getter,
|
||||
) bool {
|
||||
failedChecks := g.ValidateChecks()
|
||||
var err error
|
||||
for _, check := range failedChecks {
|
||||
err_ := fmt.Errorf("%s: %s", check.Title, check.Message)
|
||||
err = errors.Join(err, err_)
|
||||
}
|
||||
throw.BadRequest(s, w, r, "Invalid form data", err)
|
||||
return len(failedChecks) == 0
|
||||
}
|
||||
Reference in New Issue
Block a user