Files
oslstats/internal/db/getbyfield.go
2026-02-08 20:52:58 +11:00

84 lines
1.6 KiB
Go

package db
import (
"context"
"database/sql"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type ListFilter struct {
filters []Filter
}
func NewListFilter() *ListFilter {
return &ListFilter{[]Filter{}}
}
func (f *ListFilter) Add(field string, value any) {
f.filters = append(f.filters, Filter{field, value})
}
type fieldgetter[T any] struct {
q *bun.SelectQuery
field string
value any
}
func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
if g.field == "id" && (g.value).(int) < 1 {
return nil, errors.New("invalid id")
}
model := new(T)
err := g.q.
Where("? = ?", bun.Ident(g.field), g.value).
Scan(ctx, model)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
}
return model, nil
}
func (g *fieldgetter[T]) GetFirst(ctx context.Context) (*T, error) {
g.q = g.q.Limit(1)
return g.get(ctx)
}
func (g *fieldgetter[T]) GetAll(ctx context.Context) (*T, error) {
return g.get(ctx)
}
func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] {
g.q = g.q.Relation(name, apply...)
return g
}
func (g *fieldgetter[T]) Join(join string, args ...any) *fieldgetter[T] {
g.q = g.q.Join(join, args...)
return g
}
// GetByField retrieves a single record by field name
func GetByField[T any](
tx bun.Tx,
field string,
value any,
) *fieldgetter[T] {
return &fieldgetter[T]{
tx.NewSelect().Model((*T)(nil)),
field,
value,
}
}
func GetByID[T any](
tx bun.Tx,
id int,
) *fieldgetter[T] {
return GetByField[T](tx, "id", id)
}