package db import ( "context" "database/sql" "time" "git.haelnorr.com/h/oslstats/internal/discord" "github.com/pkg/errors" "github.com/uptrace/bun" ) type DiscordToken struct { bun.BaseModel `bun:"table:discord_tokens,alias:dt"` DiscordID string `bun:"discord_id,pk,notnull"` AccessToken string `bun:"access_token,notnull"` RefreshToken string `bun:"refresh_token,notnull"` ExpiresAt int64 `bun:"expires_at,notnull"` Scope string `bun:"scope,notnull"` TokenType string `bun:"token_type,notnull"` } // UpdateDiscordToken adds the provided discord token to the database. // If the user already has a token stored, it will replace that token instead. func (u *User) UpdateDiscordToken(ctx context.Context, tx bun.Tx, token *discord.Token) error { if token == nil { return errors.New("token cannot be nil") } expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix() discordToken := &DiscordToken{ DiscordID: u.DiscordID, AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, ExpiresAt: expiresAt, Scope: token.Scope, TokenType: token.TokenType, } _, err := tx.NewInsert(). Model(discordToken). On("CONFLICT (discord_id) DO UPDATE"). Set("access_token = EXCLUDED.access_token"). Set("refresh_token = EXCLUDED.refresh_token"). Set("expires_at = EXCLUDED.expires_at"). Exec(ctx) if err != nil { return errors.Wrap(err, "tx.NewInsert") } return nil } // DeleteDiscordTokens deletes a users discord OAuth tokens from the database. // It returns the DiscordToken so that it can be revoked via the discord API func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { token, err := u.GetDiscordToken(ctx, tx) if err != nil { return nil, errors.Wrap(err, "user.GetDiscordToken") } if token == nil { return nil, nil } _, err = tx.NewDelete(). Model((*DiscordToken)(nil)). Where("discord_id = ?", u.DiscordID). Exec(ctx) if err != nil { return nil, errors.Wrap(err, "tx.NewDelete") } return token, nil } // GetDiscordToken retrieves the users discord token from the database func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, error) { token := new(DiscordToken) err := tx.NewSelect(). Model(token). Where("discord_id = ?", u.DiscordID). Limit(1). Scan(ctx) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, errors.Wrap(err, "tx.NewSelect") } return token, nil } // Convert reverts the token back into a *discord.Token func (t *DiscordToken) Convert() *discord.Token { expiresIn := t.ExpiresAt - time.Now().Unix() expiresIn = max(expiresIn, 0) token := &discord.Token{ AccessToken: t.AccessToken, RefreshToken: t.RefreshToken, ExpiresIn: int(expiresIn), Scope: t.Scope, TokenType: t.TokenType, } return token }