package db import ( "context" "database/sql" "github.com/pkg/errors" "github.com/uptrace/bun" ) type ListFilter struct { filters []Filter } func NewListFilter() *ListFilter { return &ListFilter{[]Filter{}} } func (f *ListFilter) Add(field string, value any) { f.filters = append(f.filters, Filter{field, value}) } type fieldgetter[T any] struct { q *bun.SelectQuery field string value any } func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) { if g.field == "id" && (g.value).(int) < 1 { return nil, errors.New("invalid id") } model := new(T) err := g.q. Where("? = ?", bun.Ident(g.field), g.value). Scan(ctx, model) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, errors.Wrap(err, "bun.SelectQuery.Scan") } return model, nil } func (g *fieldgetter[T]) GetFirst(ctx context.Context) (*T, error) { g.q = g.q.Limit(1) return g.get(ctx) } func (g *fieldgetter[T]) GetAll(ctx context.Context) (*T, error) { return g.get(ctx) } func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] { g.q = g.q.Relation(name, apply...) return g } func (g *fieldgetter[T]) Join(join string, args ...any) *fieldgetter[T] { g.q = g.q.Join(join, args...) return g } // GetByField retrieves a single record by field name func GetByField[T any]( tx bun.Tx, field string, value any, ) *fieldgetter[T] { return &fieldgetter[T]{ tx.NewSelect().Model((*T)(nil)), field, value, } } func GetByID[T any]( tx bun.Tx, id int, ) *fieldgetter[T] { return GetByField[T](tx, "id", id) }