151 lines
3.3 KiB
Go
151 lines
3.3 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
type listgetter[T any] struct {
|
|
q *bun.SelectQuery
|
|
items *[]*T
|
|
}
|
|
|
|
type List[T any] struct {
|
|
Items []*T
|
|
Total int
|
|
PageOpts PageOpts
|
|
}
|
|
|
|
type Filter struct {
|
|
Field string
|
|
Value any
|
|
Comparator Comparator
|
|
}
|
|
|
|
type Comparator string
|
|
|
|
const (
|
|
Equal Comparator = "="
|
|
Less Comparator = "<"
|
|
LessEqual Comparator = "<="
|
|
Greater Comparator = ">"
|
|
GreaterEqual Comparator = ">="
|
|
In Comparator = "IN"
|
|
)
|
|
|
|
type ListFilter struct {
|
|
filters []Filter
|
|
}
|
|
|
|
func NewListFilter() *ListFilter {
|
|
return &ListFilter{[]Filter{}}
|
|
}
|
|
|
|
func (f *ListFilter) Equals(field string, value any) {
|
|
f.filters = append(f.filters, Filter{field, value, Equal})
|
|
}
|
|
|
|
func (f *ListFilter) LessThan(field string, value any) {
|
|
f.filters = append(f.filters, Filter{field, value, Less})
|
|
}
|
|
|
|
func (f *ListFilter) LessEqualThan(field string, value any) {
|
|
f.filters = append(f.filters, Filter{field, value, LessEqual})
|
|
}
|
|
|
|
func (f *ListFilter) GreaterThan(field string, value any) {
|
|
f.filters = append(f.filters, Filter{field, value, Greater})
|
|
}
|
|
|
|
func (f *ListFilter) GreaterEqualThan(field string, value any) {
|
|
f.filters = append(f.filters, Filter{field, value, GreaterEqual})
|
|
}
|
|
|
|
func (f *ListFilter) In(field string, values any) {
|
|
f.filters = append(f.filters, Filter{field, values, In})
|
|
}
|
|
|
|
func GetList[T any](tx bun.Tx) *listgetter[T] {
|
|
l := &listgetter[T]{
|
|
items: new([]*T),
|
|
}
|
|
l.q = tx.NewSelect().
|
|
Model(l.items)
|
|
return l
|
|
}
|
|
|
|
func (l *listgetter[T]) String() string {
|
|
return l.q.String()
|
|
}
|
|
|
|
func (l *listgetter[T]) Join(join string, args ...any) *listgetter[T] {
|
|
l.q = l.q.Join(join, args...)
|
|
return l
|
|
}
|
|
|
|
func (l *listgetter[T]) Where(query string, args ...any) *listgetter[T] {
|
|
l.q = l.q.Where(query, args...)
|
|
return l
|
|
}
|
|
|
|
func (l *listgetter[T]) Order(orders ...string) *listgetter[T] {
|
|
l.q = l.q.Order(orders...)
|
|
return l
|
|
}
|
|
|
|
func (l *listgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *listgetter[T] {
|
|
l.q = l.q.Relation(name, apply...)
|
|
return l
|
|
}
|
|
|
|
func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] {
|
|
for _, filter := range filters {
|
|
if filter.Comparator == In {
|
|
l.q = l.q.Where("? IN (?)", bun.Ident(filter.Field), bun.In(filter.Value))
|
|
} else {
|
|
l.q = l.q.Where("? ? ?", bun.Ident(filter.Field), bun.Safe(filter.Comparator), filter.Value)
|
|
}
|
|
}
|
|
return l
|
|
}
|
|
|
|
func (l *listgetter[T]) GetPaged(ctx context.Context, pageOpts, defaults *PageOpts) (*List[T], error) {
|
|
if defaults == nil {
|
|
return nil, errors.New("default pageopts is nil")
|
|
}
|
|
total, err := l.q.Count(ctx)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "query.Count")
|
|
}
|
|
l.q, pageOpts = setPageOpts(l.q, pageOpts, defaults, total)
|
|
err = l.q.Scan(ctx)
|
|
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
|
return nil, errors.Wrap(err, "query.Scan")
|
|
}
|
|
list := &List[T]{
|
|
Items: *l.items,
|
|
Total: total,
|
|
PageOpts: *pageOpts,
|
|
}
|
|
return list, nil
|
|
}
|
|
|
|
func (l *listgetter[T]) Count(ctx context.Context) (int, error) {
|
|
count, err := l.q.Count(ctx)
|
|
if err != nil {
|
|
return 0, errors.Wrap(err, "query.Count")
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
func (l *listgetter[T]) GetAll(ctx context.Context) ([]*T, error) {
|
|
err := l.q.Scan(ctx)
|
|
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
|
return nil, errors.Wrap(err, "query.Scan")
|
|
}
|
|
return *l.items, nil
|
|
}
|