more refactors :)
This commit is contained in:
@@ -31,7 +31,7 @@ func (d *deleter[T]) Where(query string, args ...any) *deleter[T] {
|
|||||||
func (d *deleter[T]) Delete(ctx context.Context) error {
|
func (d *deleter[T]) Delete(ctx context.Context) error {
|
||||||
_, err := d.q.Exec(ctx)
|
_, err := d.q.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, e
|
|||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
@@ -90,10 +90,12 @@ func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, e
|
|||||||
|
|
||||||
// Convert reverts the token back into a *discord.Token
|
// Convert reverts the token back into a *discord.Token
|
||||||
func (t *DiscordToken) Convert() *discord.Token {
|
func (t *DiscordToken) Convert() *discord.Token {
|
||||||
|
expiresIn := t.ExpiresAt - time.Now().Unix()
|
||||||
|
expiresIn = max(expiresIn, 0)
|
||||||
token := &discord.Token{
|
token := &discord.Token{
|
||||||
AccessToken: t.AccessToken,
|
AccessToken: t.AccessToken,
|
||||||
RefreshToken: t.RefreshToken,
|
RefreshToken: t.RefreshToken,
|
||||||
ExpiresIn: int(t.ExpiresAt - time.Now().Unix()),
|
ExpiresIn: int(expiresIn),
|
||||||
Scope: t.Scope,
|
Scope: t.Scope,
|
||||||
TokenType: t.TokenType,
|
TokenType: t.TokenType,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) {
|
|||||||
Where("? = ?", bun.Ident(g.field), g.value).
|
Where("? = ?", bun.Ident(g.field), g.value).
|
||||||
Scan(ctx, model)
|
Scan(ctx, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
|
return nil, errors.Wrap(err, "bun.SelectQuery.Scan")
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func (l *listgetter[T]) GetAll(ctx context.Context) (*List[T], error) {
|
|||||||
}
|
}
|
||||||
l.q, l.pageOpts = setPageOpts(l.q, l.pageOpts, l.defaults, total)
|
l.q, l.pageOpts = setPageOpts(l.q, l.pageOpts, l.defaults, total)
|
||||||
err = l.q.Scan(ctx)
|
err = l.q.Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, errors.Wrap(err, "query.Scan")
|
return nil, errors.Wrap(err, "query.Scan")
|
||||||
}
|
}
|
||||||
list := &List[T]{
|
list := &List[T]{
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.Selec
|
|||||||
p.PerPage = d.PerPage
|
p.PerPage = d.PerPage
|
||||||
}
|
}
|
||||||
maxpage := p.TotalPages(totalitems)
|
maxpage := p.TotalPages(totalitems)
|
||||||
if p.Page > maxpage {
|
if p.Page > maxpage && maxpage > 0 {
|
||||||
p.Page = maxpage
|
p.Page = maxpage
|
||||||
}
|
}
|
||||||
if p.Order == "" {
|
if p.Order == "" {
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis
|
|||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
@@ -64,7 +64,7 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err
|
|||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
@@ -78,13 +78,13 @@ func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) (
|
|||||||
return nil, errors.New("resource cannot be empty")
|
return nil, errors.New("resource cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
var perms []*Permission
|
perms := []*Permission{}
|
||||||
err := tx.NewSelect().
|
err := tx.NewSelect().
|
||||||
Model(&perms).
|
Model(&perms).
|
||||||
Where("resource = ?", resource).
|
Where("resource = ?", resource).
|
||||||
Order("action ASC").
|
Order("action ASC").
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return perms, nil
|
return perms, nil
|
||||||
@@ -101,7 +101,7 @@ func GetPermissionsByIDs(ctx context.Context, tx bun.Tx, ids []int) ([]*Permissi
|
|||||||
Model(&perms).
|
Model(&perms).
|
||||||
Where("id IN (?)", bun.In(ids)).
|
Where("id IN (?)", bun.In(ids)).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return perms, nil
|
return perms, nil
|
||||||
@@ -114,7 +114,7 @@ func ListAllPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) {
|
|||||||
Model(&perms).
|
Model(&perms).
|
||||||
Order("resource ASC", "action ASC").
|
Order("resource ASC", "action ASC").
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return perms, nil
|
return perms, nil
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func WithTxFailSilently(
|
|||||||
) error {
|
) error {
|
||||||
fnc := func(ctx context.Context, tx bun.Tx) (bool, error) {
|
fnc := func(ctx context.Context, tx bun.Tx) (bool, error) {
|
||||||
err := fn(ctx, tx)
|
err := fn(ctx, tx)
|
||||||
return err != nil, err
|
return err == nil, err
|
||||||
}
|
}
|
||||||
_, err := withTx(ctx, conn, fnc, true)
|
_, err := withTx(ctx, conn, fnc, true)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b
|
|||||||
Where("u.id = ?", userID).
|
Where("u.id = ?", userID).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return false, errors.Wrap(err, "tx.NewSelect")
|
return false, errors.Wrap(err, "tx.NewSelect")
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ import (
|
|||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
"git.haelnorr.com/h/oslstats/internal/permissions"
|
"git.haelnorr.com/h/oslstats/internal/permissions"
|
||||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/throw"
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequirePermission creates middleware that requires a specific permission
|
// RequirePermission creates middleware that requires a specific permission
|
||||||
func (c *Checker) RequirePermission(server *hws.Server, permission permissions.Permission) func(http.Handler) http.Handler {
|
func (c *Checker) RequirePermission(s *hws.Server, permission permissions.Permission) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := db.CurrentUser(r.Context())
|
user := db.CurrentUser(r.Context())
|
||||||
@@ -24,26 +26,12 @@ func (c *Checker) RequirePermission(server *hws.Server, permission permissions.P
|
|||||||
|
|
||||||
has, err := c.UserHasPermission(r.Context(), user, permission)
|
has, err := c.UserHasPermission(r.Context(), user, permission)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log error and return 500
|
throw.InternalServiceError(s, w, r, "Permission check failed", errors.Wrap(err, "c.UserHasPermission"))
|
||||||
server.ThrowError(w, r, hws.HWSError{
|
|
||||||
StatusCode: http.StatusInternalServerError,
|
|
||||||
Message: "Permission check failed",
|
|
||||||
Error: err,
|
|
||||||
Level: hws.ErrorERROR,
|
|
||||||
RenderErrorPage: true,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !has {
|
if !has {
|
||||||
// User lacks permission - return 403
|
throw.Forbidden(s, w, r, "You don't have permission to access this resource", errors.New("invalid permissions"))
|
||||||
server.ThrowError(w, r, hws.HWSError{
|
|
||||||
StatusCode: http.StatusForbidden,
|
|
||||||
Message: "You don't have permission to access this resource",
|
|
||||||
Error: nil,
|
|
||||||
Level: hws.ErrorDEBUG,
|
|
||||||
RenderErrorPage: true,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +41,7 @@ func (c *Checker) RequirePermission(server *hws.Server, permission permissions.P
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RequireRole creates middleware that requires a specific role
|
// RequireRole creates middleware that requires a specific role
|
||||||
func (c *Checker) RequireRole(server *hws.Server, role roles.Role) func(http.Handler) http.Handler {
|
func (c *Checker) RequireRole(s *hws.Server, role roles.Role) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := db.CurrentUser(r.Context())
|
user := db.CurrentUser(r.Context())
|
||||||
@@ -66,27 +54,12 @@ func (c *Checker) RequireRole(server *hws.Server, role roles.Role) func(http.Han
|
|||||||
|
|
||||||
has, err := c.UserHasRole(r.Context(), user, role)
|
has, err := c.UserHasRole(r.Context(), user, role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log error and return 500
|
throw.InternalServiceError(s, w, r, "Role check failed", errors.Wrap(err, "c.UserHasRole"))
|
||||||
hwserr := hws.HWSError{
|
|
||||||
StatusCode: http.StatusInternalServerError,
|
|
||||||
Message: "Role check failed",
|
|
||||||
Error: err,
|
|
||||||
Level: hws.ErrorERROR,
|
|
||||||
RenderErrorPage: true,
|
|
||||||
}
|
|
||||||
server.ThrowError(w, r, hwserr)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !has {
|
if !has {
|
||||||
// User lacks role - return 403
|
throw.Forbidden(s, w, r, "You don't have the required role to access this resource", errors.New("missing role"))
|
||||||
server.ThrowError(w, r, hws.HWSError{
|
|
||||||
StatusCode: http.StatusForbidden,
|
|
||||||
Message: "You don't have the required role to access this resource",
|
|
||||||
Error: nil,
|
|
||||||
Level: hws.ErrorDEBUG,
|
|
||||||
RenderErrorPage: true,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,9 @@ func (s *Store) CheckToken(user *db.User) *db.DiscordToken {
|
|||||||
check := (res).(*TokenCheck)
|
check := (res).(*TokenCheck)
|
||||||
check.ExpiresAt = time.Now().Add(5 * time.Minute)
|
check.ExpiresAt = time.Now().Add(5 * time.Minute)
|
||||||
s.tokenchecks.Delete(user.ID)
|
s.tokenchecks.Delete(user.ID)
|
||||||
s.tokenchecks.Store(user.ID, check)
|
s.tokenchecks.Store(user.ID, &TokenCheck{
|
||||||
|
Token: check.Token,
|
||||||
|
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||||
|
})
|
||||||
return check.Token
|
return check.Token
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user