diff --git a/internal/db/auditlog.go b/internal/db/auditlog.go index 2fe7a93..16b9bc3 100644 --- a/internal/db/auditlog.go +++ b/internal/db/auditlog.go @@ -59,11 +59,9 @@ type AuditLogFilters struct { // GetAuditLogs retrieves audit logs with optional filters and pagination func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) { - pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderDesc, "created_at") query := tx.NewSelect(). Model((*AuditLog)(nil)). - Relation("User"). - OrderBy(pageOpts.OrderBy, pageOpts.Order) + Relation("User") // Apply filters if provided if filters != nil { @@ -88,11 +86,9 @@ func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *A } // Get paginated results + query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderDesc, "created_at") logs := new([]*AuditLog) - err = query. - Offset(pageOpts.PerPage*(pageOpts.Page-1)). - Limit(pageOpts.PerPage). - Scan(ctx, &logs) + err = query.Scan(ctx, &logs) if err != nil && err != sql.ErrNoRows { return nil, errors.Wrap(err, "query.Scan") } diff --git a/internal/db/discordtokens.go b/internal/db/discordtokens.go index bf70d5b..6c94f67 100644 --- a/internal/db/discordtokens.go +++ b/internal/db/discordtokens.go @@ -2,6 +2,7 @@ package db import ( "context" + "database/sql" "time" "git.haelnorr.com/h/oslstats/internal/discord" @@ -57,6 +58,9 @@ func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToke if err != nil { return nil, errors.Wrap(err, "user.GetDiscordToken") } + if token == nil { + return nil, nil + } _, err = tx.NewDelete(). Model((*DiscordToken)(nil)). Where("discord_id = ?", u.DiscordID). @@ -76,6 +80,9 @@ func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, e Limit(1). Scan(ctx) if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } return nil, errors.Wrap(err, "tx.NewSelect") } return token, nil diff --git a/internal/db/paginate.go b/internal/db/paginate.go index 27cc3a8..706065e 100644 --- a/internal/db/paginate.go +++ b/internal/db/paginate.go @@ -1,6 +1,10 @@ package db -import "github.com/uptrace/bun" +import ( + "strings" + + "github.com/uptrace/bun" +) type PageOpts struct { Page int @@ -15,7 +19,7 @@ type OrderOpts struct { Label string } -func setDefaultPageOpts(p *PageOpts, page, perpage int, order bun.Order, orderby string) *PageOpts { +func setPageOpts(q *bun.SelectQuery, p *PageOpts, page, perpage int, order bun.Order, orderby string) (*bun.SelectQuery, *PageOpts) { if p == nil { p = new(PageOpts) } @@ -31,7 +35,46 @@ func setDefaultPageOpts(p *PageOpts, page, perpage int, order bun.Order, orderby if p.OrderBy == "" { p.OrderBy = orderby } - return p + p.OrderBy = sanitiseOrderBy(p.OrderBy) + q = q.OrderBy(p.OrderBy, p.Order). + Limit(p.PerPage). + Offset(p.PerPage * (p.Page - 1)) + return q, p +} + +func sanitiseOrderBy(orderby string) string { + result := strings.ToLower(orderby) + var builder strings.Builder + for _, r := range result { + if isValidChar(r) { + builder.WriteRune(r) + } + } + sanitized := builder.String() + + if sanitized == "" { + return "_" + } + + if !isValidFirstChar(rune(sanitized[0])) { + sanitized = "_" + sanitized + } + + if len(sanitized) > 63 { + sanitized = sanitized[:63] + } + + return sanitized +} + +func isValidChar(r rune) bool { + return (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '_' +} + +func isValidFirstChar(r rune) bool { + return (r >= 'a' && r <= 'z') || r == '_' } // TotalPages calculates the total number of pages diff --git a/internal/db/permission.go b/internal/db/permission.go index 8443127..e36f05d 100644 --- a/internal/db/permission.go +++ b/internal/db/permission.go @@ -37,7 +37,10 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis Where("name = ?", name). Limit(1). Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } return nil, errors.Wrap(err, "tx.NewSelect") } return perm, nil @@ -56,7 +59,10 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err Where("id = ?", id). Limit(1). Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } return nil, errors.Wrap(err, "tx.NewSelect") } return perm, nil @@ -115,9 +121,22 @@ func CreatePermission(ctx context.Context, tx bun.Tx, perm *Permission) error { if perm == nil { return errors.New("permission cannot be nil") } + if perm.Name == "" { + return errors.New("name cannot be empty") + } + if perm.DisplayName == "" { + return errors.New("display name cannot be empty") + } + if perm.Resource == "" { + return errors.New("resource cannot be empty") + } + if perm.Action == "" { + return errors.New("action cannot be empty") + } _, err := tx.NewInsert(). Model(perm). + Returning("id"). Exec(ctx) if err != nil { return errors.Wrap(err, "tx.NewInsert") diff --git a/internal/db/role.go b/internal/db/role.go index ed60194..181e573 100644 --- a/internal/db/role.go +++ b/internal/db/role.go @@ -46,7 +46,10 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro Where("name = ?", name). Limit(1). Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } return nil, errors.Wrap(err, "tx.NewSelect") } return role, nil @@ -65,7 +68,10 @@ func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) { Where("id = ?", id). Limit(1). Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } return nil, errors.Wrap(err, "tx.NewSelect") } return role, nil @@ -84,7 +90,10 @@ func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, erro Relation("Permissions"). Limit(1). Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } return nil, errors.Wrap(err, "tx.NewSelect") } return role, nil @@ -112,6 +121,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error { _, err := tx.NewInsert(). Model(role). + Returning("id"). Exec(ctx) if err != nil { return errors.Wrap(err, "tx.NewInsert") diff --git a/internal/db/season.go b/internal/db/season.go index 69dd267..f08c569 100644 --- a/internal/db/season.go +++ b/internal/db/season.go @@ -42,6 +42,7 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim } _, err := tx.NewInsert(). Model(season). + Returning("id"). Exec(ctx) if err != nil { return nil, errors.Wrap(err, "tx.NewInsert") @@ -50,22 +51,18 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim } func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) { - pageOpts = setDefaultPageOpts(pageOpts, 1, 10, bun.OrderDesc, "start_date") seasons := new([]*Season) - err := tx.NewSelect(). - Model(seasons). - OrderBy(pageOpts.OrderBy, pageOpts.Order). - Offset(pageOpts.PerPage * (pageOpts.Page - 1)). - Limit(pageOpts.PerPage). - Scan(ctx) - if err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "tx.NewSelect") - } - total, err := tx.NewSelect(). - Model(seasons). - Count(ctx) + query := tx.NewSelect(). + Model(seasons) + + total, err := query.Count(ctx) if err != nil { - return nil, errors.Wrap(err, "tx.NewSelect") + return nil, errors.Wrap(err, "query.Count") + } + query, pageOpts = setPageOpts(query, pageOpts, 1, 10, bun.OrderDesc, "start_date") + err = query.Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "query.Scan") } sl := &SeasonList{ Seasons: *seasons, @@ -82,7 +79,10 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error Where("short_name = ?", strings.ToUpper(shortname)). Limit(1). Scan(ctx) - if err != nil && err != sql.ErrNoRows { + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } return nil, errors.Wrap(err, "tx.NewSelect") } return season, nil diff --git a/internal/db/user.go b/internal/db/user.go index 093af91..c8a1457 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -3,7 +3,6 @@ package db import ( "context" "database/sql" - "fmt" "time" "git.haelnorr.com/h/golib/hwsauth" @@ -50,6 +49,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di _, err := tx.NewInsert(). Model(user). + Returning("id"). Exec(ctx) if err != nil { return nil, errors.Wrap(err, "tx.NewInsert") @@ -61,7 +61,6 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di // GetUserByID queries the database for a user matching the given ID // Returns nil, nil if no user is found func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { - fmt.Printf("user id requested: %v", id) user := new(User) err := tx.NewSelect(). Model(user). @@ -69,7 +68,7 @@ func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { Limit(1). Scan(ctx) if err != nil { - if err.Error() == "sql: no rows in result set" { + if err == sql.ErrNoRows { return nil, nil } return nil, errors.Wrap(err, "tx.NewSelect") @@ -87,10 +86,10 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, Limit(1). Scan(ctx) if err != nil { - if err.Error() == "sql: no rows in result set" { + if err == sql.ErrNoRows { return nil, nil } - return nil, errors.Wrap(err, "tx.Select") + return nil, errors.Wrap(err, "tx.NewSelect") } return user, nil } @@ -105,7 +104,7 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User Limit(1). Scan(ctx) if err != nil { - if err.Error() == "sql: no rows in result set" { + if err == sql.ErrNoRows { return nil, nil } return nil, errors.Wrap(err, "tx.NewSelect") @@ -201,22 +200,17 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) { } func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) { - pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderAsc, "id") users := new([]*User) - err := tx.NewSelect(). - Model(users). - OrderBy(pageOpts.OrderBy, pageOpts.Order). - Limit(pageOpts.PerPage). - Offset(pageOpts.PerPage * (pageOpts.Page - 1)). - Scan(ctx) - if err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "tx.NewSelect") - } - total, err := tx.NewSelect(). - Model(users). - Count(ctx) + query := tx.NewSelect(). + Model(users) + total, err := query.Count(ctx) if err != nil { - return nil, errors.Wrap(err, "tx.NewSelect") + return nil, errors.Wrap(err, "query.Count") + } + query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderAsc, "id") + err = query.Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, errors.Wrap(err, "query.Scan") } list := &Users{ Users: *users, diff --git a/internal/db/userrole.go b/internal/db/userrole.go index da538c8..27bdc58 100644 --- a/internal/db/userrole.go +++ b/internal/db/userrole.go @@ -2,6 +2,7 @@ package db import ( "context" + "database/sql" "git.haelnorr.com/h/oslstats/internal/roles" "github.com/pkg/errors" @@ -72,9 +73,12 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b err := tx.NewSelect(). Model(user). Relation("Roles"). - Where("u.id = ? ", userID). + Where("u.id = ?", userID). Scan(ctx) if err != nil { + if err == sql.ErrNoRows { + return false, nil + } return false, errors.Wrap(err, "tx.NewSelect") } for _, role := range user.Roles { diff --git a/internal/handlers/logout.go b/internal/handlers/logout.go index 02b9a9d..8db5a56 100644 --- a/internal/handlers/logout.go +++ b/internal/handlers/logout.go @@ -14,7 +14,7 @@ import ( ) func Logout( - server *hws.Server, + s *hws.Server, auth *hwsauth.Authenticator[*db.User, bun.Tx], conn *bun.DB, discordAPI *discord.APIClient, @@ -26,10 +26,10 @@ func Logout( tx, err := conn.BeginTx(ctx, nil) if err != nil { - throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) + throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) return } - defer tx.Rollback() + defer func() { _ = tx.Rollback() }() user := db.CurrentUser(r.Context()) if user == nil { @@ -39,20 +39,26 @@ func Logout( } token, err := user.DeleteDiscordTokens(ctx, tx) if err != nil { - throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens")) + throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens")) return } - err = discordAPI.RevokeToken(token.Convert()) - if err != nil { - throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken")) - return + if token != nil { + err = discordAPI.RevokeToken(token.Convert()) + if err != nil { + throwInternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken")) + return + } } err = auth.Logout(tx, w, r) if err != nil { - throwInternalServiceError(server, w, r, "Logout failed", err) + throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout")) + return + } + err = tx.Commit() + if err != nil { + throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "tx.Commit")) return } - tx.Commit() w.Header().Set("HX-Redirect", "/") }, )