fixed db issues

This commit is contained in:
2026-02-05 20:07:37 +11:00
parent 4c31c24069
commit 697bef80e9
9 changed files with 140 additions and 61 deletions

View File

@@ -59,11 +59,9 @@ type AuditLogFilters struct {
// GetAuditLogs retrieves audit logs with optional filters and pagination // GetAuditLogs retrieves audit logs with optional filters and pagination
func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) { func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) {
pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderDesc, "created_at")
query := tx.NewSelect(). query := tx.NewSelect().
Model((*AuditLog)(nil)). Model((*AuditLog)(nil)).
Relation("User"). Relation("User")
OrderBy(pageOpts.OrderBy, pageOpts.Order)
// Apply filters if provided // Apply filters if provided
if filters != nil { if filters != nil {
@@ -88,11 +86,9 @@ func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *A
} }
// Get paginated results // Get paginated results
query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderDesc, "created_at")
logs := new([]*AuditLog) logs := new([]*AuditLog)
err = query. err = query.Scan(ctx, &logs)
Offset(pageOpts.PerPage*(pageOpts.Page-1)).
Limit(pageOpts.PerPage).
Scan(ctx, &logs)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "query.Scan") return nil, errors.Wrap(err, "query.Scan")
} }

View File

@@ -2,6 +2,7 @@ package db
import ( import (
"context" "context"
"database/sql"
"time" "time"
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
@@ -57,6 +58,9 @@ func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToke
if err != nil { if err != nil {
return nil, errors.Wrap(err, "user.GetDiscordToken") return nil, errors.Wrap(err, "user.GetDiscordToken")
} }
if token == nil {
return nil, nil
}
_, err = tx.NewDelete(). _, err = tx.NewDelete().
Model((*DiscordToken)(nil)). Model((*DiscordToken)(nil)).
Where("discord_id = ?", u.DiscordID). Where("discord_id = ?", u.DiscordID).
@@ -76,6 +80,9 @@ func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, e
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
return token, nil return token, nil

View File

@@ -1,6 +1,10 @@
package db package db
import "github.com/uptrace/bun" import (
"strings"
"github.com/uptrace/bun"
)
type PageOpts struct { type PageOpts struct {
Page int Page int
@@ -15,7 +19,7 @@ type OrderOpts struct {
Label string Label string
} }
func setDefaultPageOpts(p *PageOpts, page, perpage int, order bun.Order, orderby string) *PageOpts { func setPageOpts(q *bun.SelectQuery, p *PageOpts, page, perpage int, order bun.Order, orderby string) (*bun.SelectQuery, *PageOpts) {
if p == nil { if p == nil {
p = new(PageOpts) p = new(PageOpts)
} }
@@ -31,7 +35,46 @@ func setDefaultPageOpts(p *PageOpts, page, perpage int, order bun.Order, orderby
if p.OrderBy == "" { if p.OrderBy == "" {
p.OrderBy = orderby p.OrderBy = orderby
} }
return p p.OrderBy = sanitiseOrderBy(p.OrderBy)
q = q.OrderBy(p.OrderBy, p.Order).
Limit(p.PerPage).
Offset(p.PerPage * (p.Page - 1))
return q, p
}
func sanitiseOrderBy(orderby string) string {
result := strings.ToLower(orderby)
var builder strings.Builder
for _, r := range result {
if isValidChar(r) {
builder.WriteRune(r)
}
}
sanitized := builder.String()
if sanitized == "" {
return "_"
}
if !isValidFirstChar(rune(sanitized[0])) {
sanitized = "_" + sanitized
}
if len(sanitized) > 63 {
sanitized = sanitized[:63]
}
return sanitized
}
func isValidChar(r rune) bool {
return (r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_'
}
func isValidFirstChar(r rune) bool {
return (r >= 'a' && r <= 'z') || r == '_'
} }
// TotalPages calculates the total number of pages // TotalPages calculates the total number of pages

View File

@@ -37,7 +37,10 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis
Where("name = ?", name). Where("name = ?", name).
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil && err != sql.ErrNoRows { if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
return perm, nil return perm, nil
@@ -56,7 +59,10 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err
Where("id = ?", id). Where("id = ?", id).
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil && err != sql.ErrNoRows { if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
return perm, nil return perm, nil
@@ -115,9 +121,22 @@ func CreatePermission(ctx context.Context, tx bun.Tx, perm *Permission) error {
if perm == nil { if perm == nil {
return errors.New("permission cannot be nil") return errors.New("permission cannot be nil")
} }
if perm.Name == "" {
return errors.New("name cannot be empty")
}
if perm.DisplayName == "" {
return errors.New("display name cannot be empty")
}
if perm.Resource == "" {
return errors.New("resource cannot be empty")
}
if perm.Action == "" {
return errors.New("action cannot be empty")
}
_, err := tx.NewInsert(). _, err := tx.NewInsert().
Model(perm). Model(perm).
Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewInsert") return errors.Wrap(err, "tx.NewInsert")

View File

@@ -46,7 +46,10 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro
Where("name = ?", name). Where("name = ?", name).
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil && err != sql.ErrNoRows { if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
return role, nil return role, nil
@@ -65,7 +68,10 @@ func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
Where("id = ?", id). Where("id = ?", id).
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil && err != sql.ErrNoRows { if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
return role, nil return role, nil
@@ -84,7 +90,10 @@ func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, erro
Relation("Permissions"). Relation("Permissions").
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil && err != sql.ErrNoRows { if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
return role, nil return role, nil
@@ -112,6 +121,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
_, err := tx.NewInsert(). _, err := tx.NewInsert().
Model(role). Model(role).
Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "tx.NewInsert") return errors.Wrap(err, "tx.NewInsert")

View File

@@ -42,6 +42,7 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim
} }
_, err := tx.NewInsert(). _, err := tx.NewInsert().
Model(season). Model(season).
Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.NewInsert") return nil, errors.Wrap(err, "tx.NewInsert")
@@ -50,22 +51,18 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim
} }
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) { func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) {
pageOpts = setDefaultPageOpts(pageOpts, 1, 10, bun.OrderDesc, "start_date")
seasons := new([]*Season) seasons := new([]*Season)
err := tx.NewSelect(). query := tx.NewSelect().
Model(seasons). Model(seasons)
OrderBy(pageOpts.OrderBy, pageOpts.Order).
Offset(pageOpts.PerPage * (pageOpts.Page - 1)). total, err := query.Count(ctx)
Limit(pageOpts.PerPage).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
total, err := tx.NewSelect().
Model(seasons).
Count(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "query.Count")
}
query, pageOpts = setPageOpts(query, pageOpts, 1, 10, bun.OrderDesc, "start_date")
err = query.Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "query.Scan")
} }
sl := &SeasonList{ sl := &SeasonList{
Seasons: *seasons, Seasons: *seasons,
@@ -82,7 +79,10 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error
Where("short_name = ?", strings.ToUpper(shortname)). Where("short_name = ?", strings.ToUpper(shortname)).
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil && err != sql.ErrNoRows { if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
} }
return season, nil return season, nil

View File

@@ -3,7 +3,6 @@ package db
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"time" "time"
"git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/golib/hwsauth"
@@ -50,6 +49,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
_, err := tx.NewInsert(). _, err := tx.NewInsert().
Model(user). Model(user).
Returning("id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.NewInsert") return nil, errors.Wrap(err, "tx.NewInsert")
@@ -61,7 +61,6 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
// GetUserByID queries the database for a user matching the given ID // GetUserByID queries the database for a user matching the given ID
// Returns nil, nil if no user is found // Returns nil, nil if no user is found
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
fmt.Printf("user id requested: %v", id)
user := new(User) user := new(User)
err := tx.NewSelect(). err := tx.NewSelect().
Model(user). Model(user).
@@ -69,7 +68,7 @@ func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if err.Error() == "sql: no rows in result set" { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
@@ -87,10 +86,10 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User,
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if err.Error() == "sql: no rows in result set" { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, errors.Wrap(err, "tx.Select") return nil, errors.Wrap(err, "tx.NewSelect")
} }
return user, nil return user, nil
} }
@@ -105,7 +104,7 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User
Limit(1). Limit(1).
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if err.Error() == "sql: no rows in result set" { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "tx.NewSelect")
@@ -201,22 +200,17 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
} }
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) { func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) {
pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderAsc, "id")
users := new([]*User) users := new([]*User)
err := tx.NewSelect(). query := tx.NewSelect().
Model(users). Model(users)
OrderBy(pageOpts.OrderBy, pageOpts.Order). total, err := query.Count(ctx)
Limit(pageOpts.PerPage).
Offset(pageOpts.PerPage * (pageOpts.Page - 1)).
Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "tx.NewSelect")
}
total, err := tx.NewSelect().
Model(users).
Count(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.NewSelect") return nil, errors.Wrap(err, "query.Count")
}
query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderAsc, "id")
err = query.Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "query.Scan")
} }
list := &Users{ list := &Users{
Users: *users, Users: *users,

View File

@@ -2,6 +2,7 @@ package db
import ( import (
"context" "context"
"database/sql"
"git.haelnorr.com/h/oslstats/internal/roles" "git.haelnorr.com/h/oslstats/internal/roles"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -72,9 +73,12 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b
err := tx.NewSelect(). err := tx.NewSelect().
Model(user). Model(user).
Relation("Roles"). Relation("Roles").
Where("u.id = ? ", userID). Where("u.id = ?", userID).
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, errors.Wrap(err, "tx.NewSelect") return false, errors.Wrap(err, "tx.NewSelect")
} }
for _, role := range user.Roles { for _, role := range user.Roles {

View File

@@ -14,7 +14,7 @@ import (
) )
func Logout( func Logout(
server *hws.Server, s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *bun.DB, conn *bun.DB,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
@@ -26,10 +26,10 @@ func Logout(
tx, err := conn.BeginTx(ctx, nil) tx, err := conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx")) throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
return return
} }
defer tx.Rollback() defer func() { _ = tx.Rollback() }()
user := db.CurrentUser(r.Context()) user := db.CurrentUser(r.Context())
if user == nil { if user == nil {
@@ -39,20 +39,26 @@ func Logout(
} }
token, err := user.DeleteDiscordTokens(ctx, tx) token, err := user.DeleteDiscordTokens(ctx, tx)
if err != nil { if err != nil {
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens")) throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
return return
} }
err = discordAPI.RevokeToken(token.Convert()) if token != nil {
if err != nil { err = discordAPI.RevokeToken(token.Convert())
throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken")) if err != nil {
return throwInternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
return
}
} }
err = auth.Logout(tx, w, r) err = auth.Logout(tx, w, r)
if err != nil { if err != nil {
throwInternalServiceError(server, w, r, "Logout failed", err) throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout"))
return
}
err = tx.Commit()
if err != nil {
throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "tx.Commit"))
return return
} }
tx.Commit()
w.Header().Set("HX-Redirect", "/") w.Header().Set("HX-Redirect", "/")
}, },
) )