Compare commits

..

2 Commits

Author SHA1 Message Date
9362448f22 added slapapi 2026-02-17 08:12:07 +11:00
f8090aa0cc added players 2026-02-16 21:31:02 +11:00
31 changed files with 595 additions and 85 deletions

1
go.mod
View File

@@ -16,6 +16,7 @@ require (
github.com/uptrace/bun v1.2.16 github.com/uptrace/bun v1.2.16
github.com/uptrace/bun/dialect/pgdialect v1.2.16 github.com/uptrace/bun/dialect/pgdialect v1.2.16
github.com/uptrace/bun/driver/pgdriver v1.2.16 github.com/uptrace/bun/driver/pgdriver v1.2.16
golang.org/x/time v0.14.0
) )
require ( require (

2
go.sum
View File

@@ -90,6 +90,8 @@ golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -10,6 +10,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/pkg/oauth" "git.haelnorr.com/h/oslstats/pkg/oauth"
"git.haelnorr.com/h/oslstats/pkg/slapshotapi"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -22,6 +23,7 @@ type Config struct {
Discord *discord.Config Discord *discord.Config
OAuth *oauth.Config OAuth *oauth.Config
RBAC *rbac.Config RBAC *rbac.Config
Slapshot *slapshotapi.Config
Flags *Flags Flags *Flags
} }
@@ -42,6 +44,7 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
discord.NewEZConfIntegration(), discord.NewEZConfIntegration(),
oauth.NewEZConfIntegration(), oauth.NewEZConfIntegration(),
rbac.NewEZConfIntegration(), rbac.NewEZConfIntegration(),
slapshotapi.NewEZConfIntegration(),
) )
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "loader.RegisterIntegrations") return nil, nil, errors.Wrap(err, "loader.RegisterIntegrations")
@@ -93,6 +96,11 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
return nil, nil, errors.New("RBAC Config not loaded") return nil, nil, errors.New("RBAC Config not loaded")
} }
slapcfg, ok := loader.GetConfig("slapshotapi")
if !ok {
return nil, nil, errors.New("SlapshotAPI Config not loaded")
}
config := &Config{ config := &Config{
DB: dbcfg.(*db.Config), DB: dbcfg.(*db.Config),
HWS: hwscfg.(*hws.Config), HWS: hwscfg.(*hws.Config),
@@ -101,6 +109,7 @@ func GetConfig(flags *Flags) (*Config, *ezconf.ConfigLoader, error) {
Discord: discordcfg.(*discord.Config), Discord: discordcfg.(*discord.Config),
OAuth: oauthcfg.(*oauth.Config), OAuth: oauthcfg.(*oauth.Config),
RBAC: rbaccfg.(*rbac.Config), RBAC: rbaccfg.(*rbac.Config),
Slapshot: slapcfg.(*slapshotapi.Config),
Flags: flags, Flags: flags,
} }

View File

@@ -7,15 +7,18 @@ import (
) )
type AuditMeta struct { type AuditMeta struct {
r *http.Request ipAddress string
userAgent string
u *User u *User
} }
func NewAudit(r *http.Request, u *User) *AuditMeta { func NewAudit(ipAdd, agent string, user *User) *AuditMeta {
if u == nil { return &AuditMeta{ipAdd, agent, user}
u = CurrentUser(r.Context()) }
}
return &AuditMeta{r, u} func NewAuditFromRequest(r *http.Request) *AuditMeta {
u := CurrentUser(r.Context())
return &AuditMeta{r.RemoteAddr, r.UserAgent(), u}
} }
// AuditInfo contains metadata for audit logging // AuditInfo contains metadata for audit logging
@@ -45,9 +48,44 @@ func extractTableName[T any]() string {
if bunTag != "" { if bunTag != "" {
// Parse tag: "table:seasons,alias:s" -> "seasons" // Parse tag: "table:seasons,alias:s" -> "seasons"
for part := range strings.SplitSeq(bunTag, ",") { for part := range strings.SplitSeq(bunTag, ",") {
part, _ := strings.CutPrefix(part, "table:") part, match := strings.CutPrefix(part, "table:")
if match {
return part return part
} }
return part
}
}
}
}
// Fallback: use struct name in lowercase + "s"
return strings.ToLower(t.Name()) + "s"
}
// extractTableName gets the bun table alias from a model type using reflection
// Example: Season with `bun:"table:seasons,alias:s"` returns "s"
func extractTableAlias[T any]() string {
var model T
t := reflect.TypeOf(model)
// Handle pointer types
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
// Look for bun.BaseModel field with table tag
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Type.Name() == "BaseModel" {
bunTag := field.Tag.Get("bun")
if bunTag != "" {
// Parse tag: "table:seasons,alias:s" -> "seasons"
for part := range strings.SplitSeq(bunTag, ",") {
part, match := strings.CutPrefix(part, "alias:")
if match {
return part
}
}
} }
} }
} }

View File

@@ -49,9 +49,6 @@ func log(
if meta.u == nil { if meta.u == nil {
return errors.New("user cannot be nil for audit logging") return errors.New("user cannot be nil for audit logging")
} }
if meta.r == nil {
return errors.New("request cannot be nil for audit logging")
}
// Convert resourceID to string // Convert resourceID to string
var resourceIDStr *string var resourceIDStr *string
@@ -70,18 +67,14 @@ func log(
detailsJSON = jsonBytes detailsJSON = jsonBytes
} }
// Extract IP and User-Agent from request
ipAddress := meta.r.RemoteAddr
userAgent := meta.r.UserAgent()
log := &AuditLog{ log := &AuditLog{
UserID: meta.u.ID, UserID: meta.u.ID,
Action: info.Action, Action: info.Action,
ResourceType: info.ResourceType, ResourceType: info.ResourceType,
ResourceID: resourceIDStr, ResourceID: resourceIDStr,
Details: detailsJSON, Details: detailsJSON,
IPAddress: ipAddress, IPAddress: meta.ipAddress,
UserAgent: userAgent, UserAgent: meta.userAgent,
Result: result, Result: result,
ErrorMessage: errorMessage, ErrorMessage: errorMessage,
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),

View File

@@ -37,6 +37,10 @@ func (g *fieldgetter[T]) Get(ctx context.Context) (*T, error) {
return g.get(ctx) return g.get(ctx)
} }
func (g *fieldgetter[T]) String() string {
return g.q.String()
}
func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] { func (g *fieldgetter[T]) Relation(name string, apply ...func(*bun.SelectQuery) *bun.SelectQuery) *fieldgetter[T] {
g.q = g.q.Relation(name, apply...) g.q = g.q.Relation(name, apply...)
return g return g
@@ -66,5 +70,6 @@ func GetByID[T any](
tx bun.Tx, tx bun.Tx,
id int, id int,
) *fieldgetter[T] { ) *fieldgetter[T] {
return GetByField[T](tx, "id", id) prefix := extractTableAlias[T]()
return GetByField[T](tx, prefix+".id", id)
} }

View File

@@ -0,0 +1,37 @@
package migrations
import (
"context"
"git.haelnorr.com/h/oslstats/internal/db"
"github.com/uptrace/bun"
)
func init() {
Migrations.MustRegister(
// UP migration
func(ctx context.Context, conn *bun.DB) error {
// Add your migration code here
_, err := conn.NewCreateTable().
Model((*db.Player)(nil)).
IfNotExists().
Exec(ctx)
if err != nil {
return err
}
return nil
},
// DOWN migration
func(ctx context.Context, conn *bun.DB) error {
// Add your rollback code here
_, err := conn.NewDropTable().
Model((*db.Player)(nil)).
IfExists().
Exec(ctx)
if err != nil {
return err
}
return nil
},
)
}

54
internal/db/player.go Normal file
View File

@@ -0,0 +1,54 @@
package db
import (
"context"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
type Player struct {
bun.BaseModel `bun:"table:players,alias:p"`
ID int `bun:"id,pk,autoincrement" json:"id"`
SlapID *string `bun:"slap_id,unique" json:"slap_id"`
DiscordID string `bun:"discord_id,unique,notnull" json:"discord_id"`
UserID *int `bun:"user_id,unique" json:"user_id"`
User *User `bun:"rel:belongs-to,join:user_id=id" json:"-"`
}
func NewPlayer(ctx context.Context, tx bun.Tx, discordID string, audit *AuditMeta) (*Player, error) {
player := &Player{DiscordID: discordID}
user, err := GetUserByDiscordID(ctx, tx, discordID)
if err != nil && !IsBadRequest(err) {
return nil, errors.Wrap(err, "GetUserByDiscordID")
}
if user != nil {
player.UserID = &user.ID
}
err = Insert(tx, player).
WithAudit(audit, nil).Exec(ctx)
if err != nil {
return nil, errors.Wrap(err, "Insert")
}
return player, nil
}
func GetPlayer(ctx context.Context, tx bun.Tx, playerID int) (*Player, error) {
return GetByID[Player](tx, playerID).Relation("User").Get(ctx)
}
func UpdatePlayerSlapID(ctx context.Context, tx bun.Tx, playerID int, slapID string, audit *AuditMeta) error {
player, err := GetPlayer(ctx, tx, playerID)
if err != nil {
return errors.Wrap(err, "GetPlayer")
}
player.SlapID = &slapID
err = UpdateByID(tx, player.ID, player).Column("slap_id").
WithAudit(audit, nil).Exec(ctx)
if err != nil {
return errors.Wrap(err, "UpdateByID")
}
return nil
}

View File

@@ -33,6 +33,7 @@ func (db *DB) RegisterModels() []any {
(*Permission)(nil), (*Permission)(nil),
(*AuditLog)(nil), (*AuditLog)(nil),
(*Fixture)(nil), (*Fixture)(nil),
(*Player)(nil),
} }
db.RegisterModel(models...) db.RegisterModel(models...)
return models return models

View File

@@ -21,6 +21,7 @@ type User struct {
DiscordID string `bun:"discord_id,unique" json:"discord_id"` DiscordID string `bun:"discord_id,unique" json:"discord_id"`
Roles []*Role `bun:"m2m:user_roles,join:User=Role" json:"-"` Roles []*Role `bun:"m2m:user_roles,join:User=Role" json:"-"`
Player *Player `bun:"rel:has-one,join:id=user_id"`
} }
func (u *User) GetID() int { func (u *User) GetID() int {

View File

@@ -1,3 +1,4 @@
// Package discord provides utilities for interacting with the discord API
package discord package discord
import ( import (

View File

@@ -19,19 +19,19 @@ type RateLimitState struct {
// Do executes an HTTP request with automatic rate limit handling // Do executes an HTTP request with automatic rate limit handling
// It will wait if rate limits are about to be exceeded and retry once if a 429 is received // It will wait if rate limits are about to be exceeded and retry once if a 429 is received
func (c *APIClient) Do(req *http.Request) (*http.Response, error) { func (api *APIClient) Do(req *http.Request) (*http.Response, error) {
if req == nil { if req == nil {
return nil, errors.New("request cannot be nil") return nil, errors.New("request cannot be nil")
} }
// Step 1: Check if we need to wait before making request // Step 1: Check if we need to wait before making request
bucket := c.getBucketFromRequest(req) bucket := api.getBucketFromRequest(req)
if err := c.waitIfNeeded(bucket); err != nil { if err := api.waitIfNeeded(bucket); err != nil {
return nil, err return nil, err
} }
// Step 2: Execute request // Step 2: Execute request
resp, err := c.client.Do(req) resp, err := api.client.Do(req)
if err != nil { if err != nil {
// Check if it's a network timeout // Check if it's a network timeout
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
@@ -41,17 +41,17 @@ func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
} }
// Step 3: Update rate limit state from response headers // Step 3: Update rate limit state from response headers
c.updateRateLimit(resp.Header) api.updateRateLimit(resp.Header)
// Step 4: Handle 429 (rate limited) // Step 4: Handle 429 (rate limited)
if resp.StatusCode == http.StatusTooManyRequests { if resp.StatusCode == http.StatusTooManyRequests {
resp.Body.Close() // Close original response resp.Body.Close() // Close original response
retryAfter := c.parseRetryAfter(resp.Header) retryAfter := api.parseRetryAfter(resp.Header)
// No Retry-After header, can't retry safely // No Retry-After header, can't retry safely
if retryAfter == 0 { if retryAfter == 0 {
c.logger.Warn(). api.logger.Warn().
Str("bucket", bucket). Str("bucket", bucket).
Str("method", req.Method). Str("method", req.Method).
Str("path", req.URL.Path). Str("path", req.URL.Path).
@@ -61,7 +61,7 @@ func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
// Retry-After exceeds 30 second cap // Retry-After exceeds 30 second cap
if retryAfter > 30*time.Second { if retryAfter > 30*time.Second {
c.logger.Warn(). api.logger.Warn().
Str("bucket", bucket). Str("bucket", bucket).
Str("method", req.Method). Str("method", req.Method).
Str("path", req.URL.Path). Str("path", req.URL.Path).
@@ -74,7 +74,7 @@ func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
} }
// Wait and retry // Wait and retry
c.logger.Warn(). api.logger.Warn().
Str("bucket", bucket). Str("bucket", bucket).
Str("method", req.Method). Str("method", req.Method).
Str("path", req.URL.Path). Str("path", req.URL.Path).
@@ -84,7 +84,7 @@ func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
time.Sleep(retryAfter) time.Sleep(retryAfter)
// Retry the request // Retry the request
resp, err = c.client.Do(req) resp, err = api.client.Do(req)
if err != nil { if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil, errors.Wrap(err, "retry request timed out") return nil, errors.Wrap(err, "retry request timed out")
@@ -93,12 +93,12 @@ func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
} }
// Update rate limit again after retry // Update rate limit again after retry
c.updateRateLimit(resp.Header) api.updateRateLimit(resp.Header)
// If STILL rate limited after retry, return error // If STILL rate limited after retry, return error
if resp.StatusCode == http.StatusTooManyRequests { if resp.StatusCode == http.StatusTooManyRequests {
resp.Body.Close() resp.Body.Close()
c.logger.Error(). api.logger.Error().
Str("bucket", bucket). Str("bucket", bucket).
Str("method", req.Method). Str("method", req.Method).
Str("path", req.URL.Path). Str("path", req.URL.Path).
@@ -115,15 +115,15 @@ func (c *APIClient) Do(req *http.Request) (*http.Response, error) {
// getBucketFromRequest extracts or generates bucket ID from request // getBucketFromRequest extracts or generates bucket ID from request
// For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers // For Discord, the bucket is typically METHOD:path until we get the actual bucket from headers
func (c *APIClient) getBucketFromRequest(req *http.Request) string { func (api *APIClient) getBucketFromRequest(req *http.Request) string {
return req.Method + ":" + req.URL.Path return req.Method + ":" + req.URL.Path
} }
// waitIfNeeded checks if we need to delay before request to avoid hitting rate limits // waitIfNeeded checks if we need to delay before request to avoid hitting rate limits
func (c *APIClient) waitIfNeeded(bucket string) error { func (api *APIClient) waitIfNeeded(bucket string) error {
c.mu.RLock() api.mu.RLock()
state, exists := c.buckets[bucket] state, exists := api.buckets[bucket]
c.mu.RUnlock() api.mu.RUnlock()
if !exists { if !exists {
return nil // No state yet, proceed return nil // No state yet, proceed
@@ -138,7 +138,7 @@ func (c *APIClient) waitIfNeeded(bucket string) error {
waitDuration += 100 * time.Millisecond waitDuration += 100 * time.Millisecond
if waitDuration > 0 { if waitDuration > 0 {
c.logger.Debug(). api.logger.Debug().
Str("bucket", bucket). Str("bucket", bucket).
Dur("wait_duration", waitDuration). Dur("wait_duration", waitDuration).
Msg("Proactively waiting for rate limit reset") Msg("Proactively waiting for rate limit reset")
@@ -150,16 +150,16 @@ func (c *APIClient) waitIfNeeded(bucket string) error {
} }
// updateRateLimit parses response headers and updates bucket state // updateRateLimit parses response headers and updates bucket state
func (c *APIClient) updateRateLimit(headers http.Header) { func (api *APIClient) updateRateLimit(headers http.Header) {
bucket := headers.Get("X-RateLimit-Bucket") bucket := headers.Get("X-RateLimit-Bucket")
if bucket == "" { if bucket == "" {
return // No bucket info, can't track return // No bucket info, can't track
} }
// Parse headers // Parse headers
limit := c.parseInt(headers.Get("X-RateLimit-Limit")) limit := api.parseInt(headers.Get("X-RateLimit-Limit"))
remaining := c.parseInt(headers.Get("X-RateLimit-Remaining")) remaining := api.parseInt(headers.Get("X-RateLimit-Remaining"))
resetAfter := c.parseFloat(headers.Get("X-RateLimit-Reset-After")) resetAfter := api.parseFloat(headers.Get("X-RateLimit-Reset-After"))
state := &RateLimitState{ state := &RateLimitState{
Bucket: bucket, Bucket: bucket,
@@ -168,12 +168,12 @@ func (c *APIClient) updateRateLimit(headers http.Header) {
Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))), Reset: time.Now().Add(time.Duration(resetAfter * float64(time.Second))),
} }
c.mu.Lock() api.mu.Lock()
c.buckets[bucket] = state api.buckets[bucket] = state
c.mu.Unlock() api.mu.Unlock()
// Log rate limit state for debugging // Log rate limit state for debugging
c.logger.Debug(). api.logger.Debug().
Str("bucket", bucket). Str("bucket", bucket).
Int("remaining", remaining). Int("remaining", remaining).
Int("limit", limit). Int("limit", limit).
@@ -182,14 +182,14 @@ func (c *APIClient) updateRateLimit(headers http.Header) {
} }
// parseRetryAfter extracts retry delay from Retry-After header // parseRetryAfter extracts retry delay from Retry-After header
func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration { func (api *APIClient) parseRetryAfter(headers http.Header) time.Duration {
retryAfter := headers.Get("Retry-After") retryAfter := headers.Get("Retry-After")
if retryAfter == "" { if retryAfter == "" {
return 0 return 0
} }
// Discord returns seconds as float // Discord returns seconds as float
seconds := c.parseFloat(retryAfter) seconds := api.parseFloat(retryAfter)
if seconds <= 0 { if seconds <= 0 {
return 0 return 0
} }
@@ -198,7 +198,7 @@ func (c *APIClient) parseRetryAfter(headers http.Header) time.Duration {
} }
// parseInt parses an integer from a header value, returns 0 on error // parseInt parses an integer from a header value, returns 0 on error
func (c *APIClient) parseInt(s string) int { func (api *APIClient) parseInt(s string) int {
if s == "" { if s == "" {
return 0 return 0
} }
@@ -207,7 +207,7 @@ func (c *APIClient) parseInt(s string) int {
} }
// parseFloat parses a float from a header value, returns 0 on error // parseFloat parses a float from a header value, returns 0 on error
func (c *APIClient) parseFloat(s string) float64 { func (api *APIClient) parseFloat(s string) float64 {
if s == "" { if s == "" {
return 0 return 0
} }

View File

@@ -0,0 +1,18 @@
package discord
import (
"github.com/pkg/errors"
)
func (s *OAuthSession) GetSteamID() (string, error) {
connections, err := s.UserConnections()
if err != nil {
return "", errors.Wrap(err, "s.UserConnections")
}
for _, conn := range connections {
if conn.Type == "steam" {
return conn.ID, nil
}
}
return "", errors.New("steam connection not found")
}

View File

@@ -84,7 +84,7 @@ func AdminRoleCreate(s *hws.Server, conn *db.DB) http.Handler {
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
} }
err := db.CreateRole(ctx, tx, newRole, db.NewAudit(r, nil)) err := db.CreateRole(ctx, tx, newRole, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.CreateRole") return false, errors.Wrap(err, "db.CreateRole")
} }
@@ -196,7 +196,7 @@ func AdminRoleDelete(s *hws.Server, conn *db.DB) http.Handler {
} }
// Delete the role with audit logging // Delete the role with audit logging
err = db.DeleteRole(ctx, tx, roleID, db.NewAudit(r, nil)) err = db.DeleteRole(ctx, tx, roleID, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
if db.IsBadRequest(err) { if db.IsBadRequest(err) {
respond.NotFound(w, err) respond.NotFound(w, err)
@@ -320,7 +320,7 @@ func AdminRolePermissionsUpdate(s *hws.Server, conn *db.DB) http.Handler {
} }
return false, errors.Wrap(err, "db.GetRoleByID") return false, errors.Wrap(err, "db.GetRoleByID")
} }
err = role.UpdatePermissions(ctx, tx, permissionIDs, db.NewAudit(r, nil)) err = role.UpdatePermissions(ctx, tx, permissionIDs, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
return false, errors.Wrap(err, "role.UpdatePermissions") return false, errors.Wrap(err, "role.UpdatePermissions")
} }

View File

@@ -36,7 +36,7 @@ func GenerateFixtures(
var league *db.League var league *db.League
var fixtures []*db.Fixture var fixtures []*db.Fixture
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
_, err := db.NewRound(ctx, tx, seasonShortName, leagueShortName, round, db.NewAudit(r, nil)) _, err := db.NewRound(ctx, tx, seasonShortName, leagueShortName, round, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
if db.IsBadRequest(err) { if db.IsBadRequest(err) {
respond.BadRequest(w, errors.Wrap(err, "db.NewRound")) respond.BadRequest(w, errors.Wrap(err, "db.NewRound"))
@@ -98,7 +98,7 @@ func UpdateFixtures(
notify.Warn(s, w, r, "Invalid game weeks", "A game week is missing or has no games", nil) notify.Warn(s, w, r, "Invalid game weeks", "A game week is missing or has no games", nil)
return false, nil return false, nil
} }
err = db.UpdateFixtureGameWeeks(ctx, tx, fixtures, db.NewAudit(r, nil)) err = db.UpdateFixtureGameWeeks(ctx, tx, fixtures, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
if db.IsBadRequest(err) { if db.IsBadRequest(err) {
respond.BadRequest(w, errors.Wrap(err, "db.UpdateFixtureGameWeeks")) respond.BadRequest(w, errors.Wrap(err, "db.UpdateFixtureGameWeeks"))
@@ -125,7 +125,7 @@ func DeleteFixture(
return return
} }
if !conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if !conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
err := db.DeleteFixture(ctx, tx, fixtureID, db.NewAudit(r, nil)) err := db.DeleteFixture(ctx, tx, fixtureID, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
if db.IsBadRequest(err) { if db.IsBadRequest(err) {
respond.NotFound(w, errors.Wrap(err, "db.DeleteFixture")) respond.NotFound(w, errors.Wrap(err, "db.DeleteFixture"))

View File

@@ -61,7 +61,7 @@ func NewLeagueSubmit(
if !nameUnique || !shortNameUnique { if !nameUnique || !shortNameUnique {
return true, nil return true, nil
} }
league, err = db.NewLeague(ctx, tx, name, shortname, description, db.NewAudit(r, nil)) league, err = db.NewLeague(ctx, tx, name, shortname, description, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.NewLeague") return false, errors.Wrap(err, "db.NewLeague")
} }

View File

@@ -64,7 +64,7 @@ func Register(
if !unique { if !unique {
return true, nil return true, nil
} }
user, err = db.CreateUser(ctx, tx, username, details.DiscordUser, db.NewAudit(r, nil)) user, err = db.CreateUser(ctx, tx, username, details.DiscordUser, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.CreateUser") return false, errors.Wrap(err, "db.CreateUser")
} }

View File

@@ -86,7 +86,7 @@ func SeasonEditSubmit(
} }
return false, errors.Wrap(err, "db.GetSeason") return false, errors.Wrap(err, "db.GetSeason")
} }
err = season.Update(ctx, tx, version, start, end, finalsStart, finalsEnd, db.NewAudit(r, nil)) err = season.Update(ctx, tx, version, start, end, finalsStart, finalsEnd, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
return false, errors.Wrap(err, "season.Update") return false, errors.Wrap(err, "season.Update")
} }

View File

@@ -18,8 +18,8 @@ func SeasonLeagueAddTeam(
conn *db.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seasonStr := r.PathValue("season_short_name") seasonShortName := r.PathValue("season_short_name")
leagueStr := r.PathValue("league_short_name") leagueShortName := r.PathValue("league_short_name")
getter, ok := validation.ParseFormOrNotify(s, w, r) getter, ok := validation.ParseFormOrNotify(s, w, r)
if !ok { if !ok {
@@ -36,7 +36,7 @@ func SeasonLeagueAddTeam(
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
var err error var err error
team, season, league, err = db.NewTeamParticipation(ctx, tx, seasonStr, leagueStr, teamID, db.NewAudit(r, nil)) team, season, league, err = db.NewTeamParticipation(ctx, tx, seasonShortName, leagueShortName, teamID, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
if db.IsBadRequest(err) { if db.IsBadRequest(err) {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)

View File

@@ -92,7 +92,7 @@ func SeasonLeagueDeleteFixtures(
var league *db.League var league *db.League
var fixtures []*db.Fixture var fixtures []*db.Fixture
if !conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if !conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
err := db.DeleteAllFixtures(ctx, tx, seasonShortName, leagueShortName, db.NewAudit(r, nil)) err := db.DeleteAllFixtures(ctx, tx, seasonShortName, leagueShortName, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
if db.IsBadRequest(err) { if db.IsBadRequest(err) {
respond.BadRequest(w, errors.Wrap(err, "db.DeleteAllFixtures")) respond.BadRequest(w, errors.Wrap(err, "db.DeleteAllFixtures"))

View File

@@ -19,13 +19,13 @@ func SeasonAddLeague(
conn *db.DB, conn *db.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seasonStr := r.PathValue("season_short_name") seasonShortName := r.PathValue("season_short_name")
leagueStr := r.PathValue("league_short_name") leagueShortName := r.PathValue("league_short_name")
var season *db.Season var season *db.Season
var allLeagues []*db.League var allLeagues []*db.League
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
err := db.NewSeasonLeague(ctx, tx, seasonStr, leagueStr, db.NewAudit(r, nil)) err := db.NewSeasonLeague(ctx, tx, seasonShortName, leagueShortName, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
if db.IsBadRequest(err) { if db.IsBadRequest(err) {
respond.BadRequest(w, err) respond.BadRequest(w, err)
@@ -35,7 +35,7 @@ func SeasonAddLeague(
} }
// Reload season with updated leagues // Reload season with updated leagues
season, err = db.GetSeason(ctx, tx, seasonStr) season, err = db.GetSeason(ctx, tx, seasonShortName)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.GetSeason") return false, errors.Wrap(err, "db.GetSeason")
} }
@@ -75,7 +75,7 @@ func SeasonRemoveLeague(
} }
return false, errors.Wrap(err, "db.GetSeason") return false, errors.Wrap(err, "db.GetSeason")
} }
err = season.RemoveLeague(ctx, tx, leagueStr, db.NewAudit(r, nil)) err = season.RemoveLeague(ctx, tx, leagueStr, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
if db.IsBadRequest(err) { if db.IsBadRequest(err) {
respond.BadRequest(w, err) respond.BadRequest(w, err)

View File

@@ -66,7 +66,7 @@ func NewSeasonSubmit(
if !nameUnique || !shortNameUnique { if !nameUnique || !shortNameUnique {
return true, nil return true, nil
} }
season, err = db.NewSeason(ctx, tx, name, version, shortname, start, db.NewAudit(r, nil)) season, err = db.NewSeason(ctx, tx, name, version, shortname, start, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.NewSeason") return false, errors.Wrap(err, "db.NewSeason")
} }

View File

@@ -71,7 +71,7 @@ func NewTeamSubmit(
if !nameUnique || !shortNameComboUnique { if !nameUnique || !shortNameComboUnique {
return true, nil return true, nil
} }
_, err = db.NewTeam(ctx, tx, name, shortName, altShortName, color, db.NewAudit(r, nil)) _, err = db.NewTeam(ctx, tx, name, shortName, altShortName, color, db.NewAuditFromRequest(r))
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.NewTeam") return false, errors.Wrap(err, "db.NewTeam")
} }

37
pkg/slapshotapi/client.go Normal file
View File

@@ -0,0 +1,37 @@
package slapshotapi
import (
"net/http"
"sync"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
type SlapAPI struct {
client *http.Client
ratelimiter *rate.Limiter
mu sync.Mutex
maxTokens int
key string
env string
}
func NewSlapAPIClient(cfg *Config) (*SlapAPI, error) {
if cfg == nil {
return nil, errors.New("config cannot be nil")
}
if cfg.Environment != "api" && cfg.Environment != "staging" {
return nil, errors.New("invalid env specified, must be 'api' or 'staging'")
}
rl := rate.NewLimiter(rate.Inf, 10)
client := &SlapAPI{
client: http.DefaultClient,
ratelimiter: rl,
mu: sync.Mutex{},
maxTokens: 10,
key: cfg.Key,
env: cfg.Environment,
}
return client, nil
}

23
pkg/slapshotapi/config.go Normal file
View File

@@ -0,0 +1,23 @@
// Package slapshotapi provides utilities for interacting with the slapshot public API
package slapshotapi
import (
"git.haelnorr.com/h/golib/env"
"github.com/pkg/errors"
)
type Config struct {
Environment string // ENV SLAPSHOT_ENVIRONMENT: API environment to connect to (default: staging)
Key string // ENV SLAPSHOT_API_KEY: API Key for authorisation with the API (required)
}
func ConfigFromEnv() (any, error) {
cfg := &Config{
Environment: env.String("SLAPSHOT_ENVIRONMENT", "staging"),
Key: env.String("SLAPSHOT_API_KEY", ""),
}
if cfg.Key == "" {
return nil, errors.New("Envar not set: SLAPSHOT_API_KEY")
}
return cfg, nil
}

35
pkg/slapshotapi/enums.go Normal file
View File

@@ -0,0 +1,35 @@
package slapshotapi
const (
RegionEUWest = "eu-west"
RegionNAEast = "na-east"
RegionNACentral = "na-central"
RegionNAWest = "na-west"
RegionOCEEast = "oce-east"
ArenaSlapstadium = "Slapstadium"
ArenaSlapville = "Slapville"
ArenaSlapstadiumMini = "Slapstadium_mini"
ArenaTableHockey = "Table_Hockey"
ArenaColosseum = "Colosseum"
ArenaSlapvilleJumbo = "Slapville_Jumbo"
ArenaSlapstation = "Slapstation"
ArenaSlapstadiumXL = "Slapstadium_XL"
ArenaIsland = "Island"
ArenaObstacles = "Obstacles"
ArenaObstaclesXL = "Obstacles_XL"
EndReasonEndOfReg = "EndOfRegulation"
EndReasonOvertime = "Overtime"
EndReasonHomeTeamLeft = "HomeTeamLeft"
EndReasonAwayTeamLeft = "AwayTeamLeft"
EndReasonMercy = "MercyRule"
EndReasonTie = "Tie"
EndReasonForfeit = "Forfeit"
EndReasonCancelled = "Cancelled"
EndReasonUnknown = "Unknown"
GameModeHockey = "hockey"
GameModeDodgePuck = "dodgepuck"
GameModeTag = "tag"
)

41
pkg/slapshotapi/ezconf.go Normal file
View File

@@ -0,0 +1,41 @@
package slapshotapi
import (
"runtime"
"strings"
)
// EZConfIntegration provides integration with ezconf for automatic configuration
type EZConfIntegration struct {
configFunc func() (any, error)
name string
}
// PackagePath returns the path to the config package for source parsing
func (e EZConfIntegration) PackagePath() string {
_, filename, _, _ := runtime.Caller(0)
// Return directory of this file
return filename[:len(filename)-len("/ezconf.go")]
}
// ConfigFunc returns the ConfigFromEnv function for ezconf
func (e EZConfIntegration) ConfigFunc() func() (any, error) {
return func() (any, error) {
return e.configFunc()
}
}
// Name returns the name to use when registering with ezconf
func (e EZConfIntegration) Name() string {
return strings.ToLower(e.name)
}
// GroupName returns the display name for grouping environment variables
func (e EZConfIntegration) GroupName() string {
return e.name
}
// NewEZConfIntegration creates a new EZConf integration helper
func NewEZConfIntegration() EZConfIntegration {
return EZConfIntegration{name: "SlapshotAPI", configFunc: ConfigFromEnv}
}

View File

@@ -0,0 +1,59 @@
package slapshotapi
import (
"context"
"net/http"
"strconv"
"time"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
func (c *SlapAPI) do(ctx context.Context, req *http.Request) (*http.Response, error) {
for {
err := c.ratelimiter.Wait(ctx)
if err != nil {
return nil, errors.Wrap(err, "c.ratelimiter.Wait")
}
resp, err := c.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "c.client.Do")
}
if resp.StatusCode == http.StatusTooManyRequests {
resetAfter := 30 * time.Second
err := resp.Body.Close()
if err != nil {
return nil, errors.Wrap(err, "resp.Body.Close")
}
if resetAfter > 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(resetAfter):
continue
}
}
}
c.updateLimiterFromHeaders(resp.Header)
return resp, nil
}
}
func (c *SlapAPI) updateLimiterFromHeaders(h http.Header) {
c.mu.Lock()
defer c.mu.Unlock()
limit, err1 := strconv.Atoi(h.Get("RateLimit-Limit"))
window, err2 := strconv.Atoi(h.Get("RateLimit-Window"))
if err1 != nil || err2 != nil || limit <= 0 || window <= 0 {
return
}
if limit != c.maxTokens || time.Duration(window) != time.Duration(float64(window)/float64(limit))*time.Second {
c.maxTokens = limit
c.ratelimiter.SetBurst(limit)
c.ratelimiter.SetLimit(rate.Every(time.Duration(window) / time.Duration(limit)))
}
}

View File

@@ -0,0 +1,62 @@
package slapshotapi
import (
"context"
"encoding/json"
"fmt"
"github.com/pkg/errors"
)
type endpointMatchmaking struct {
regions []string
}
func getEndpointMatchmaking(regions []string) *endpointMatchmaking {
return &endpointMatchmaking{
regions: regions,
}
}
func (ep *endpointMatchmaking) path() string {
path := "/api/public/matchmaking%s"
filters := ""
if len(ep.regions) > 0 {
filters = "?regions="
for i, region := range ep.regions {
filters = filters + region
if i+1 != len(ep.regions) {
filters = filters + ","
}
}
}
return fmt.Sprintf(path, filters)
}
func (ep *endpointMatchmaking) method() string {
return "GET"
}
type matchmakingresp struct {
Playlists PubsQueue `json:"playlists"`
}
type PubsQueue struct {
InQueue uint16 `json:"in_queue"`
InMatch uint16 `json:"in_match"`
}
// GetQueueStatus gets the number of players in public matchmaking
func (c *SlapAPI) GetQueueStatus(
ctx context.Context,
regions []string,
) (*PubsQueue, error) {
endpoint := getEndpointMatchmaking(regions)
data, err := c.request(ctx, endpoint)
if err != nil {
return nil, errors.Wrap(err, "slapapiReq")
}
resp := matchmakingresp{}
json.Unmarshal(data, &resp)
return &resp.Playlists, nil
}

View File

@@ -0,0 +1,44 @@
package slapshotapi
import (
"context"
"fmt"
"io"
"net/http"
"github.com/pkg/errors"
)
type endpoint interface {
path() string
method() string
}
func (c *SlapAPI) request(
ctx context.Context,
ep endpoint,
) ([]byte, error) {
baseurl := fmt.Sprintf("https://%s.slapshot.gg%s", c.env, ep.path())
req, err := http.NewRequest(ep.method(), baseurl, nil)
if err != nil {
return nil, errors.Wrap(err, "http.NewRequest")
}
req.Header.Add("accept", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", c.key))
res, err := c.do(ctx, req)
if err != nil {
return nil, errors.Wrap(err, "http.DefaultClient.Do")
}
if res.StatusCode != 200 {
return nil, errors.New(fmt.Sprintf("Error making request: %v", res.StatusCode))
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "io.ReadAll")
}
err = res.Body.Close()
if err != nil {
return nil, errors.Wrap(err, "resp.Body.Close")
}
return body, nil
}

49
pkg/slapshotapi/slapid.go Normal file
View File

@@ -0,0 +1,49 @@
package slapshotapi
import (
"context"
"encoding/json"
"fmt"
"github.com/pkg/errors"
)
type endpointSteamID struct {
steamID string
}
func getEndpointSteamID(steamID string) *endpointSteamID {
return &endpointSteamID{
steamID: steamID,
}
}
func (ep *endpointSteamID) path() string {
return fmt.Sprintf("/api/public/players/steam/%s", ep.steamID)
}
func (ep *endpointSteamID) method() string {
return "GET"
}
type idresp struct {
ID uint32 `json:"id"`
}
// GetSlapID returns the slapshot ID of the steam user
func (c *SlapAPI) GetSlapID(
ctx context.Context,
steamid string,
) (uint32, error) {
endpoint := getEndpointSteamID(steamid)
data, err := c.request(ctx, endpoint)
if err != nil {
return 0, errors.Wrap(err, "slapapiReq")
}
resp := idresp{}
err = json.Unmarshal(data, &resp)
if err != nil {
return 0, errors.Wrap(err, "json.Unmarshal")
}
return resp.ID, nil
}