From fa3b8e3982bbe9b7be4843d586e8dfd0567f751b Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Sun, 8 Feb 2026 20:52:58 +1100 Subject: [PATCH] more refactors :) --- internal/db/delete.go | 2 +- internal/db/discordtokens.go | 6 ++-- internal/db/getbyfield.go | 2 +- internal/db/getlist.go | 2 +- internal/db/paginate.go | 2 +- internal/db/permission.go | 12 +++---- internal/db/txhelpers.go | 2 +- internal/db/userrole.go | 2 +- internal/rbac/protection_middleware.go | 43 +++++--------------------- internal/store/tokencheck.go | 5 ++- 10 files changed, 28 insertions(+), 50 deletions(-) diff --git a/internal/db/delete.go b/internal/db/delete.go index 2b6189d..2f3c826 100644 --- a/internal/db/delete.go +++ b/internal/db/delete.go @@ -31,7 +31,7 @@ func (d *deleter[T]) Where(query string, args ...any) *deleter[T] { func (d *deleter[T]) Delete(ctx context.Context) error { _, err := d.q.Exec(ctx) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil } } diff --git a/internal/db/discordtokens.go b/internal/db/discordtokens.go index 6c94f67..c0cd596 100644 --- a/internal/db/discordtokens.go +++ b/internal/db/discordtokens.go @@ -80,7 +80,7 @@ func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, e Limit(1). Scan(ctx) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, errors.Wrap(err, "tx.NewSelect") @@ -90,10 +90,12 @@ func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, e // Convert reverts the token back into a *discord.Token func (t *DiscordToken) Convert() *discord.Token { + expiresIn := t.ExpiresAt - time.Now().Unix() + expiresIn = max(expiresIn, 0) token := &discord.Token{ AccessToken: t.AccessToken, RefreshToken: t.RefreshToken, - ExpiresIn: int(t.ExpiresAt - time.Now().Unix()), + ExpiresIn: int(expiresIn), Scope: t.Scope, TokenType: t.TokenType, } diff --git a/internal/db/getbyfield.go b/internal/db/getbyfield.go index 39dd408..70767e2 100644 --- a/internal/db/getbyfield.go +++ b/internal/db/getbyfield.go @@ -35,7 +35,7 @@ func (g *fieldgetter[T]) get(ctx context.Context) (*T, error) { Where("? = ?", bun.Ident(g.field), g.value). Scan(ctx, model) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, errors.Wrap(err, "bun.SelectQuery.Scan") diff --git a/internal/db/getlist.go b/internal/db/getlist.go index 434db0b..e309c34 100644 --- a/internal/db/getlist.go +++ b/internal/db/getlist.go @@ -59,7 +59,7 @@ func (l *listgetter[T]) GetAll(ctx context.Context) (*List[T], error) { } l.q, l.pageOpts = setPageOpts(l.q, l.pageOpts, l.defaults, total) err = l.q.Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil && errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "query.Scan") } list := &List[T]{ diff --git a/internal/db/paginate.go b/internal/db/paginate.go index fbd422a..bae2952 100644 --- a/internal/db/paginate.go +++ b/internal/db/paginate.go @@ -30,7 +30,7 @@ func setPageOpts(q *bun.SelectQuery, p, d *PageOpts, totalitems int) (*bun.Selec p.PerPage = d.PerPage } maxpage := p.TotalPages(totalitems) - if p.Page > maxpage { + if p.Page > maxpage && maxpage > 0 { p.Page = maxpage } if p.Order == "" { diff --git a/internal/db/permission.go b/internal/db/permission.go index a95904e..15c1697 100644 --- a/internal/db/permission.go +++ b/internal/db/permission.go @@ -42,7 +42,7 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis Limit(1). Scan(ctx) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, errors.Wrap(err, "tx.NewSelect") @@ -64,7 +64,7 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err Limit(1). Scan(ctx) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, errors.Wrap(err, "tx.NewSelect") @@ -78,13 +78,13 @@ func GetPermissionsByResource(ctx context.Context, tx bun.Tx, resource string) ( return nil, errors.New("resource cannot be empty") } - var perms []*Permission + perms := []*Permission{} err := tx.NewSelect(). Model(&perms). Where("resource = ?", resource). Order("action ASC"). Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "tx.NewSelect") } return perms, nil @@ -101,7 +101,7 @@ func GetPermissionsByIDs(ctx context.Context, tx bun.Tx, ids []int) ([]*Permissi Model(&perms). Where("id IN (?)", bun.In(ids)). Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil && errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "tx.NewSelect") } return perms, nil @@ -114,7 +114,7 @@ func ListAllPermissions(ctx context.Context, tx bun.Tx) ([]*Permission, error) { Model(&perms). Order("resource ASC", "action ASC"). Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil && errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "tx.NewSelect") } return perms, nil diff --git a/internal/db/txhelpers.go b/internal/db/txhelpers.go index 3d9a99f..ab3d3ab 100644 --- a/internal/db/txhelpers.go +++ b/internal/db/txhelpers.go @@ -48,7 +48,7 @@ func WithTxFailSilently( ) error { fnc := func(ctx context.Context, tx bun.Tx) (bool, error) { err := fn(ctx, tx) - return err != nil, err + return err == nil, err } _, err := withTx(ctx, conn, fnc, true) return err diff --git a/internal/db/userrole.go b/internal/db/userrole.go index 27bdc58..984e1b2 100644 --- a/internal/db/userrole.go +++ b/internal/db/userrole.go @@ -76,7 +76,7 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b Where("u.id = ?", userID). Scan(ctx) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return false, nil } return false, errors.Wrap(err, "tx.NewSelect") diff --git a/internal/rbac/protection_middleware.go b/internal/rbac/protection_middleware.go index b6eaa88..3280985 100644 --- a/internal/rbac/protection_middleware.go +++ b/internal/rbac/protection_middleware.go @@ -8,10 +8,12 @@ import ( "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/permissions" "git.haelnorr.com/h/oslstats/internal/roles" + "git.haelnorr.com/h/oslstats/internal/throw" + "github.com/pkg/errors" ) // RequirePermission creates middleware that requires a specific permission -func (c *Checker) RequirePermission(server *hws.Server, permission permissions.Permission) func(http.Handler) http.Handler { +func (c *Checker) RequirePermission(s *hws.Server, permission permissions.Permission) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := db.CurrentUser(r.Context()) @@ -24,26 +26,12 @@ func (c *Checker) RequirePermission(server *hws.Server, permission permissions.P has, err := c.UserHasPermission(r.Context(), user, permission) if err != nil { - // Log error and return 500 - server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "Permission check failed", - Error: err, - Level: hws.ErrorERROR, - RenderErrorPage: true, - }) + throw.InternalServiceError(s, w, r, "Permission check failed", errors.Wrap(err, "c.UserHasPermission")) return } if !has { - // User lacks permission - return 403 - server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusForbidden, - Message: "You don't have permission to access this resource", - Error: nil, - Level: hws.ErrorDEBUG, - RenderErrorPage: true, - }) + throw.Forbidden(s, w, r, "You don't have permission to access this resource", errors.New("invalid permissions")) return } @@ -53,7 +41,7 @@ func (c *Checker) RequirePermission(server *hws.Server, permission permissions.P } // RequireRole creates middleware that requires a specific role -func (c *Checker) RequireRole(server *hws.Server, role roles.Role) func(http.Handler) http.Handler { +func (c *Checker) RequireRole(s *hws.Server, role roles.Role) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := db.CurrentUser(r.Context()) @@ -66,27 +54,12 @@ func (c *Checker) RequireRole(server *hws.Server, role roles.Role) func(http.Han has, err := c.UserHasRole(r.Context(), user, role) if err != nil { - // Log error and return 500 - hwserr := hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "Role check failed", - Error: err, - Level: hws.ErrorERROR, - RenderErrorPage: true, - } - server.ThrowError(w, r, hwserr) + throw.InternalServiceError(s, w, r, "Role check failed", errors.Wrap(err, "c.UserHasRole")) return } if !has { - // User lacks role - return 403 - server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusForbidden, - Message: "You don't have the required role to access this resource", - Error: nil, - Level: hws.ErrorDEBUG, - RenderErrorPage: true, - }) + throw.Forbidden(s, w, r, "You don't have the required role to access this resource", errors.New("missing role")) return } diff --git a/internal/store/tokencheck.go b/internal/store/tokencheck.go index b25de79..0eda4ab 100644 --- a/internal/store/tokencheck.go +++ b/internal/store/tokencheck.go @@ -34,6 +34,9 @@ func (s *Store) CheckToken(user *db.User) *db.DiscordToken { check := (res).(*TokenCheck) check.ExpiresAt = time.Now().Add(5 * time.Minute) s.tokenchecks.Delete(user.ID) - s.tokenchecks.Store(user.ID, check) + s.tokenchecks.Store(user.ID, &TokenCheck{ + Token: check.Token, + ExpiresAt: time.Now().Add(5 * time.Minute), + }) return check.Token }