package db import ( "context" "database/sql" "github.com/pkg/errors" "github.com/uptrace/bun" ) type listgetter[T any] struct { q *bun.SelectQuery items *[]*T } type List[T any] struct { Items []*T Total int PageOpts PageOpts } type Filter struct { Field string Value any Comparator Comparator } type Comparator string const ( Equal Comparator = "=" Less Comparator = "<" LessEqual Comparator = "<=" Greater Comparator = ">" GreaterEqual Comparator = ">=" In Comparator = "IN" ) type ListFilter struct { filters []Filter } func NewListFilter() *ListFilter { return &ListFilter{[]Filter{}} } func (f *ListFilter) Equals(field string, value any) { f.filters = append(f.filters, Filter{field, value, Equal}) } func (f *ListFilter) LessThan(field string, value any) { f.filters = append(f.filters, Filter{field, value, Less}) } func (f *ListFilter) LessEqualThan(field string, value any) { f.filters = append(f.filters, Filter{field, value, LessEqual}) } func (f *ListFilter) GreaterThan(field string, value any) { f.filters = append(f.filters, Filter{field, value, Greater}) } func (f *ListFilter) GreaterEqualThan(field string, value any) { f.filters = append(f.filters, Filter{field, value, GreaterEqual}) } func (f *ListFilter) In(field string, values any) { f.filters = append(f.filters, Filter{field, values, In}) } func GetList[T any](tx bun.Tx) *listgetter[T] { l := &listgetter[T]{ items: new([]*T), } l.q = tx.NewSelect(). Model(l.items) return l } func (l *listgetter[T]) String() string { return l.q.String() } func (l *listgetter[T]) Join(join string, args ...any) *listgetter[T] { l.q = l.q.Join(join, args...) return l } func (l *listgetter[T]) Where(query string, args ...any) *listgetter[T] { l.q = l.q.Where(query, args...) return l } func (l *listgetter[T]) Order(orders ...string) *listgetter[T] { l.q = l.q.Order(orders...) return l } func (l *listgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *listgetter[T] { l.q = l.q.Relation(name, apply...) return l } func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] { for _, filter := range filters { if filter.Comparator == In { l.q = l.q.Where("? IN (?)", bun.Ident(filter.Field), bun.In(filter.Value)) } else { l.q = l.q.Where("? ? ?", bun.Ident(filter.Field), bun.Safe(filter.Comparator), filter.Value) } } return l } func (l *listgetter[T]) GetPaged(ctx context.Context, pageOpts, defaults *PageOpts) (*List[T], error) { if defaults == nil { return nil, errors.New("default pageopts is nil") } total, err := l.q.Count(ctx) if err != nil { return nil, errors.Wrap(err, "query.Count") } l.q, pageOpts = setPageOpts(l.q, pageOpts, defaults, total) err = l.q.Scan(ctx) if err != nil && errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "query.Scan") } list := &List[T]{ Items: *l.items, Total: total, PageOpts: *pageOpts, } return list, nil } func (l *listgetter[T]) Count(ctx context.Context) (int, error) { count, err := l.q.Count(ctx) if err != nil { return 0, errors.Wrap(err, "query.Count") } return count, nil } func (l *listgetter[T]) GetAll(ctx context.Context) ([]*T, error) { err := l.q.Scan(ctx) if err != nil && errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "query.Scan") } return *l.items, nil }