70 lines
1.3 KiB
Go
70 lines
1.3 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
type fieldgetter[T any] struct {
|
|
q *bun.SelectQuery
|
|
field string
|
|
value any
|
|
model *T
|
|
}
|
|
|
|
func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
|
|
if g.field == "id" && (g.value).(int) < 1 {
|
|
return nil, errors.New("invalid id")
|
|
}
|
|
err := g.q.
|
|
Where("? = ?", bun.Ident(g.field), g.value).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
|
|
}
|
|
return g.model, nil
|
|
}
|
|
|
|
func (g *fieldgetter[T]) Get(ctx context.Context) (*T, error) {
|
|
g.q = g.q.Limit(1)
|
|
return g.get(ctx)
|
|
}
|
|
|
|
func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] {
|
|
g.q = g.q.Relation(name, apply...)
|
|
return g
|
|
}
|
|
|
|
func (g *fieldgetter[T]) Join(join string, args ...any) *fieldgetter[T] {
|
|
g.q = g.q.Join(join, args...)
|
|
return g
|
|
}
|
|
|
|
// GetByField retrieves a single record by field name
|
|
func GetByField[T any](
|
|
tx bun.Tx,
|
|
field string,
|
|
value any,
|
|
) *fieldgetter[T] {
|
|
model := new(T)
|
|
return &fieldgetter[T]{
|
|
tx.NewSelect().Model(model),
|
|
field,
|
|
value,
|
|
model,
|
|
}
|
|
}
|
|
|
|
func GetByID[T any](
|
|
tx bun.Tx,
|
|
id int,
|
|
) *fieldgetter[T] {
|
|
return GetByField[T](tx, "id", id)
|
|
}
|