slapid and player now links when registering

This commit is contained in:
2026-02-17 18:33:22 +11:00
parent e50f855206
commit 103da78f0b
11 changed files with 137 additions and 28 deletions

View File

@@ -17,6 +17,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/embedfs" "git.haelnorr.com/h/oslstats/internal/embedfs"
"git.haelnorr.com/h/oslstats/internal/server" "git.haelnorr.com/h/oslstats/internal/server"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/pkg/slapshotapi"
) )
// Initializes and runs the server // Initializes and runs the server
@@ -47,8 +48,15 @@ func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
return errors.Wrap(err, "discord.NewAPIClient") return errors.Wrap(err, "discord.NewAPIClient")
} }
// Setup Slapshot API
logger.Debug().Msg("Setting up Slapshot API client")
slapAPI, err := slapshotapi.NewSlapAPIClient(cfg.Slapshot)
if err != nil {
return errors.Wrap(err, "slapshotapi.NewSlapAPIClient")
}
logger.Debug().Msg("Setting up HTTP server") logger.Debug().Msg("Setting up HTTP server")
httpServer, err := server.Setup(staticFS, cfg, logger, conn, store, discordAPI) httpServer, err := server.Setup(staticFS, cfg, logger, conn, store, discordAPI, slapAPI)
if err != nil { if err != nil {
return errors.Wrap(err, "setupHttpServer") return errors.Wrap(err, "setupHttpServer")
} }

View File

@@ -3,7 +3,6 @@ package db
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -78,7 +77,6 @@ func (a *AuditLogFilter) UserIDs(ids []int) *AuditLogFilter {
} }
func (a *AuditLogFilter) Actions(actions []string) *AuditLogFilter { func (a *AuditLogFilter) Actions(actions []string) *AuditLogFilter {
fmt.Println(actions)
if len(actions) > 0 { if len(actions) > 0 {
a.In("al.action", actions) a.In("al.action", actions)
} }

View File

@@ -3,7 +3,6 @@ package db
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -110,7 +109,6 @@ func (l *listgetter[T]) Filter(filters ...Filter) *listgetter[T] {
l.q = l.q.Where("? ? ?", bun.Ident(filter.Field), bun.Safe(filter.Comparator), filter.Value) l.q = l.q.Where("? ? ?", bun.Ident(filter.Field), bun.Safe(filter.Comparator), filter.Value)
} }
} }
fmt.Println(l.q.String())
return l return l
} }

View File

@@ -11,13 +11,15 @@ type Player struct {
bun.BaseModel `bun:"table:players,alias:p"` bun.BaseModel `bun:"table:players,alias:p"`
ID int `bun:"id,pk,autoincrement" json:"id"` ID int `bun:"id,pk,autoincrement" json:"id"`
SlapID *string `bun:"slap_id,unique" json:"slap_id"` SlapID *uint32 `bun:"slap_id,unique" json:"slap_id"`
DiscordID string `bun:"discord_id,unique,notnull" json:"discord_id"` DiscordID string `bun:"discord_id,unique,notnull" json:"discord_id"`
UserID *int `bun:"user_id,unique" json:"user_id"` UserID *int `bun:"user_id,unique" json:"user_id"`
User *User `bun:"rel:belongs-to,join:user_id=id" json:"-"` User *User `bun:"rel:belongs-to,join:user_id=id" json:"-"`
} }
// NewPlayer creates a new player in the database. If there is an existing user with the same
// discordID, it will automatically link that user to the player
func NewPlayer(ctx context.Context, tx bun.Tx, discordID string, audit *AuditMeta) (*Player, error) { func NewPlayer(ctx context.Context, tx bun.Tx, discordID string, audit *AuditMeta) (*Player, error) {
player := &Player{DiscordID: discordID} player := &Player{DiscordID: discordID}
user, err := GetUserByDiscordID(ctx, tx, discordID) user, err := GetUserByDiscordID(ctx, tx, discordID)
@@ -35,11 +37,46 @@ func NewPlayer(ctx context.Context, tx bun.Tx, discordID string, audit *AuditMet
return player, nil return player, nil
} }
// ConnectPlayer links the user to an existing player, or creates a new player to link if not found
// Populates User.Player on success
func (u *User) ConnectPlayer(ctx context.Context, tx bun.Tx, audit *AuditMeta) error {
player, err := GetByField[Player](tx, "p.discord_id", u.DiscordID).
Relation("User").Get(ctx)
if err != nil {
if !IsBadRequest(err) {
// Unexpected error occured
return errors.Wrap(err, "GetByField")
}
// Player doesn't exist, create a new one
player, err = NewPlayer(ctx, tx, u.DiscordID, audit)
if err != nil {
return errors.Wrap(err, "NewPlayer")
}
// New player should automatically get linked to the user
u.Player = player
return nil
}
// Player was found
if player.UserID != nil {
if player.UserID == &u.ID {
return nil
}
return errors.New("player with that discord_id already linked to a user")
}
player.UserID = &u.ID
err = UpdateByID(tx, player.ID, player).Column("user_id").Exec(ctx)
if err != nil {
return errors.Wrap(err, "UpdateByID")
}
u.Player = player
return nil
}
func GetPlayer(ctx context.Context, tx bun.Tx, playerID int) (*Player, error) { func GetPlayer(ctx context.Context, tx bun.Tx, playerID int) (*Player, error) {
return GetByID[Player](tx, playerID).Relation("User").Get(ctx) return GetByID[Player](tx, playerID).Relation("User").Get(ctx)
} }
func UpdatePlayerSlapID(ctx context.Context, tx bun.Tx, playerID int, slapID string, audit *AuditMeta) error { func UpdatePlayerSlapID(ctx context.Context, tx bun.Tx, playerID int, slapID uint32, audit *AuditMeta) error {
player, err := GetPlayer(ctx, tx, playerID) player, err := GetPlayer(ctx, tx, playerID)
if err != nil { if err != nil {
return errors.Wrap(err, "GetPlayer") return errors.Wrap(err, "GetPlayer")

View File

@@ -56,7 +56,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
// GetUserByID queries the database for a user matching the given ID // GetUserByID queries the database for a user matching the given ID
// Returns a BadRequestNotFound error if no user is found // Returns a BadRequestNotFound error if no user is found
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
return GetByID[User](tx, id).Get(ctx) return GetByID[User](tx, id).Relation("Player").Get(ctx)
} }
// GetUserByUsername queries the database for a user matching the given username // GetUserByUsername queries the database for a user matching the given username
@@ -65,7 +65,7 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User,
if username == "" { if username == "" {
return nil, errors.New("username not provided") return nil, errors.New("username not provided")
} }
return GetByField[User](tx, "username", username).Get(ctx) return GetByField[User](tx, "username", username).Relation("Player").Get(ctx)
} }
// GetUserByDiscordID queries the database for a user matching the given discord id // GetUserByDiscordID queries the database for a user matching the given discord id
@@ -74,7 +74,7 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User
if discordID == "" { if discordID == "" {
return nil, errors.New("discord_id not provided") return nil, errors.New("discord_id not provided")
} }
return GetByField[User](tx, "discord_id", discordID).Get(ctx) return GetByField[User](tx, "u.discord_id", discordID).Relation("Player").Get(ctx)
} }
// GetRoles loads all the roles for this user // GetRoles loads all the roles for this user
@@ -142,7 +142,7 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) { func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*List[User], error) {
defaults := &PageOpts{1, 50, bun.OrderAsc, "id"} defaults := &PageOpts{1, 50, bun.OrderAsc, "id"}
return GetList[User](tx).GetPaged(ctx, pageOpts, defaults) return GetList[User](tx).Relation("Player").GetPaged(ctx, pageOpts, defaults)
} }
// GetUsersWithRoles queries the database for users with their roles preloaded // GetUsersWithRoles queries the database for users with their roles preloaded

View File

@@ -4,6 +4,8 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var ErrNoSteam error = errors.New("steam connection not found")
func (s *OAuthSession) GetSteamID() (string, error) { func (s *OAuthSession) GetSteamID() (string, error) {
connections, err := s.UserConnections() connections, err := s.UserConnections()
if err != nil { if err != nil {
@@ -14,5 +16,5 @@ func (s *OAuthSession) GetSteamID() (string, error) {
return conn.ID, nil return conn.ID, nil
} }
} }
return "", errors.New("steam connection not found") return "", ErrNoSteam
} }

View File

@@ -2,7 +2,6 @@ package handlers
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"strconv" "strconv"
@@ -93,7 +92,6 @@ func UpdateFixtures(
} }
var valid bool var valid bool
fixtures, valid = updateFixtures(fixtures, updates) fixtures, valid = updateFixtures(fixtures, updates)
fmt.Println(len(fixtures))
if !valid { if !valid {
notify.Warn(s, w, r, "Invalid game weeks", "A game week is missing or has no games", nil) notify.Warn(s, w, r, "Invalid game weeks", "A game week is missing or has no games", nil)
return false, nil return false, nil

View File

@@ -12,16 +12,20 @@ import (
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/respond" "git.haelnorr.com/h/oslstats/internal/respond"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/throw" "git.haelnorr.com/h/oslstats/internal/throw"
authview "git.haelnorr.com/h/oslstats/internal/view/authview" authview "git.haelnorr.com/h/oslstats/internal/view/authview"
"git.haelnorr.com/h/oslstats/pkg/slapshotapi"
) )
func Register( func Register(
s *hws.Server, s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *db.DB, conn *db.DB,
slapAPI *slapshotapi.SlapAPI,
cfg *config.Config, cfg *config.Config,
store *store.Store, store *store.Store,
) http.Handler { ) http.Handler {
@@ -56,6 +60,7 @@ func Register(
username := r.FormValue("username") username := r.FormValue("username")
unique := false unique := false
var user *db.User var user *db.User
audit := db.NewAudit(r.RemoteAddr, r.UserAgent(), user)
if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) { if ok := conn.WithNotifyTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username) unique, err = db.IsUnique(ctx, tx, (*db.User)(nil), "username", username)
if err != nil { if err != nil {
@@ -64,19 +69,13 @@ func Register(
if !unique { if !unique {
return true, nil return true, nil
} }
user, err = db.CreateUser(ctx, tx, username, details.DiscordUser, db.NewAuditFromRequest(r)) user, err = registerUser(ctx, tx, username, details, cfg.RBAC, audit)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.CreateUser") return false, errors.Wrap(err, "registerUser")
} }
err = user.UpdateDiscordToken(ctx, tx, details.Token) err = connectSlapID(ctx, tx, user, details.Token, slapAPI, audit)
if err != nil { if err != nil {
return false, errors.Wrap(err, "db.UpdateDiscordToken") return false, errors.Wrap(err, "connectSlapID")
}
if shouldGrantAdmin(user, cfg.RBAC) {
err := ensureUserHasAdminRole(ctx, tx, user)
if err != nil {
return false, errors.Wrap(err, "ensureUserHasAdminRole")
}
} }
return true, nil return true, nil
}); !ok { }); !ok {
@@ -96,3 +95,62 @@ func Register(
}, },
) )
} }
func registerUser(ctx context.Context, tx bun.Tx,
username string, details *store.RegistrationSession,
rbac *rbac.Config, audit *db.AuditMeta,
) (*db.User, error) {
// Register the user
user, err := db.CreateUser(ctx, tx, username, details.DiscordUser, audit)
if err != nil {
return nil, errors.Wrap(err, "db.CreateUser")
}
err = user.UpdateDiscordToken(ctx, tx, details.Token)
if err != nil {
return nil, errors.Wrap(err, "db.UpdateDiscordToken")
}
err = user.ConnectPlayer(ctx, tx, audit)
if err != nil {
return nil, errors.Wrap(err, "db.ConnectPlayer")
}
// Check if they should be an admin
if shouldGrantAdmin(user, rbac) {
err := ensureUserHasAdminRole(ctx, tx, user)
if err != nil {
return nil, errors.Wrap(err, "ensureUserHasAdminRole")
}
}
return user, nil
}
func connectSlapID(ctx context.Context, tx bun.Tx, user *db.User,
token *discord.Token, slapAPI *slapshotapi.SlapAPI, audit *db.AuditMeta,
) error {
// Attempt to setup their player/slapID from steam connection
// If fails due to no steam connection or no slapID, fail silently and proceed with registration
session, err := discord.NewOAuthSession(token)
if err != nil {
return errors.Wrap(err, "discord.NewOAuthSession")
}
steamID, err := session.GetSteamID()
if err != nil {
if err == discord.ErrNoSteam {
return nil
}
return errors.Wrap(err, "session.GetSteamID")
}
slapID, err := slapAPI.GetSlapID(ctx, steamID)
if err != nil {
if err == slapshotapi.ErrNoSlapID {
return nil
}
return errors.Wrap(err, "slapAPI.GetSlapID")
}
// slapID exists, we can update their player connection
err = db.UpdatePlayerSlapID(ctx, tx, user.Player.ID, slapID, audit)
if err != nil {
return errors.Wrap(err, "db.UpdatePlayerSlapID")
}
return nil
}

View File

@@ -15,6 +15,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/permissions" "git.haelnorr.com/h/oslstats/internal/permissions"
"git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/pkg/slapshotapi"
) )
func addRoutes( func addRoutes(
@@ -25,6 +26,7 @@ func addRoutes(
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
store *store.Store, store *store.Store,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
slapAPI *slapshotapi.SlapAPI,
perms *rbac.Checker, perms *rbac.Checker,
) error { ) error {
// Create the routes // Create the routes
@@ -55,7 +57,7 @@ func addRoutes(
{ {
Path: "/register", Path: "/register",
Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
Handler: auth.LogoutReq(handlers.Register(s, auth, conn, cfg, store)), Handler: auth.LogoutReq(handlers.Register(s, auth, conn, slapAPI, cfg, store)),
}, },
{ {
Path: "/logout", Path: "/logout",

View File

@@ -15,6 +15,7 @@ import (
"git.haelnorr.com/h/oslstats/internal/handlers" "git.haelnorr.com/h/oslstats/internal/handlers"
"git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/store" "git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/pkg/slapshotapi"
) )
func Setup( func Setup(
@@ -24,6 +25,7 @@ func Setup(
conn *db.DB, conn *db.DB,
store *store.Store, store *store.Store,
discordAPI *discord.APIClient, discordAPI *discord.APIClient,
slapAPI *slapshotapi.SlapAPI,
) (server *hws.Server, err error) { ) (server *hws.Server, err error) {
if staticFS == nil { if staticFS == nil {
return nil, errors.New("No filesystem provided") return nil, errors.New("No filesystem provided")
@@ -67,7 +69,7 @@ func Setup(
return nil, errors.Wrap(err, "rbac.NewChecker") return nil, errors.Wrap(err, "rbac.NewChecker")
} }
err = addRoutes(httpServer, &fs, cfg, conn, auth, store, discordAPI, perms) err = addRoutes(httpServer, &fs, cfg, conn, auth, store, discordAPI, slapAPI, perms)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "addRoutes") return nil, errors.Wrap(err, "addRoutes")
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -30,7 +31,9 @@ type idresp struct {
ID uint32 `json:"id"` ID uint32 `json:"id"`
} }
// GetSlapID returns the slapshot ID of the steam user var ErrNoSlapID error = errors.New("slapID not found")
// GetSlapID returns the slapshot ID of the steam user.
func (c *SlapAPI) GetSlapID( func (c *SlapAPI) GetSlapID(
ctx context.Context, ctx context.Context,
steamid string, steamid string,
@@ -38,7 +41,10 @@ func (c *SlapAPI) GetSlapID(
endpoint := getEndpointSteamID(steamid) endpoint := getEndpointSteamID(steamid)
data, err := c.request(ctx, endpoint) data, err := c.request(ctx, endpoint)
if err != nil { if err != nil {
return 0, errors.Wrap(err, "slapapiReq") if strings.Contains(err.Error(), "404") {
return 0, ErrNoSlapID
}
return 0, errors.Wrap(err, "c.request")
} }
resp := idresp{} resp := idresp{}
err = json.Unmarshal(data, &resp) err = json.Unmarshal(data, &resp)