tweaks
This commit is contained in:
@@ -10,26 +10,6 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type OAuthSession struct {
|
||||
*discordgo.Session
|
||||
}
|
||||
|
||||
func NewOAuthSession(token *Token) (*OAuthSession, error) {
|
||||
session, err := discordgo.New("Bearer " + token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "discordgo.New")
|
||||
}
|
||||
return &OAuthSession{Session: session}, nil
|
||||
}
|
||||
|
||||
func (s *OAuthSession) GetUser() (*discordgo.User, error) {
|
||||
user, err := s.User("@me")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "s.User")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// APIClient is an HTTP client wrapper that handles Discord API rate limits
|
||||
type APIClient struct {
|
||||
cfg *Config
|
||||
@@ -38,6 +18,7 @@ type APIClient struct {
|
||||
mu sync.RWMutex
|
||||
buckets map[string]*RateLimitState
|
||||
trustedHost string
|
||||
bot *BotSession
|
||||
}
|
||||
|
||||
// NewAPIClient creates a new Discord API client with rate limit handling
|
||||
@@ -51,11 +32,20 @@ func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APICli
|
||||
if trustedhost == "" {
|
||||
return nil, errors.New("trustedhost cannot be empty")
|
||||
}
|
||||
bot, err := newBotSession(cfg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "newBotSession")
|
||||
}
|
||||
return &APIClient{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
logger: logger,
|
||||
buckets: make(map[string]*RateLimitState),
|
||||
cfg: cfg,
|
||||
trustedHost: trustedhost,
|
||||
bot: bot,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (api *APIClient) Ping() (*discordgo.Application, error) {
|
||||
return api.bot.Application("@me")
|
||||
}
|
||||
|
||||
22
internal/discord/bot.go
Normal file
22
internal/discord/bot.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type BotSession struct {
|
||||
*discordgo.Session
|
||||
}
|
||||
|
||||
func newBotSession(cfg *Config) (*BotSession, error) {
|
||||
session, err := discordgo.New("Bot " + cfg.BotToken)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "discordgo.New")
|
||||
}
|
||||
return &BotSession{Session: session}, nil
|
||||
}
|
||||
|
||||
func (api *APIClient) Bot() *BotSession {
|
||||
return api.bot
|
||||
}
|
||||
@@ -12,6 +12,7 @@ type Config struct {
|
||||
ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required)
|
||||
OAuthScopes string // Authorisation scopes for OAuth
|
||||
RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required)
|
||||
BotToken string // ENV DISCORD_BOT_TOKEN: Token for the discord bot (required)
|
||||
}
|
||||
|
||||
func ConfigFromEnv() (any, error) {
|
||||
@@ -20,6 +21,7 @@ func ConfigFromEnv() (any, error) {
|
||||
ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""),
|
||||
OAuthScopes: getOAuthScopes(),
|
||||
RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""),
|
||||
BotToken: env.String("DISCORD_BOT_TOKEN", ""),
|
||||
}
|
||||
|
||||
// Check required fields
|
||||
@@ -32,6 +34,9 @@ func ConfigFromEnv() (any, error) {
|
||||
if cfg.RedirectPath == "" {
|
||||
return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH")
|
||||
}
|
||||
if cfg.BotToken == "" {
|
||||
return nil, errors.New("Envar not set: DISCORD_BOT_TOKEN")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
@@ -8,9 +8,30 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type OAuthSession struct {
|
||||
*discordgo.Session
|
||||
}
|
||||
|
||||
func NewOAuthSession(token *Token) (*OAuthSession, error) {
|
||||
session, err := discordgo.New("Bearer " + token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "discordgo.New")
|
||||
}
|
||||
return &OAuthSession{Session: session}, nil
|
||||
}
|
||||
|
||||
func (s *OAuthSession) GetUser() (*discordgo.User, error) {
|
||||
user, err := s.User("@me")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "s.User")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Token represents a response from the Discord OAuth API after a successful authorization request
|
||||
type Token struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
|
||||
@@ -42,6 +42,17 @@ func throwInternalServiceError(
|
||||
throwError(s, w, r, http.StatusInternalServerError, msg, err, "error")
|
||||
}
|
||||
|
||||
// throwServiceUnavailable handles 503 errors
|
||||
func throwServiceUnavailable(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusServiceUnavailable, msg, err, "error")
|
||||
}
|
||||
|
||||
// throwBadRequest handles 400 errors (malformed requests)
|
||||
func throwBadRequest(
|
||||
s *hws.Server,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
@@ -13,16 +15,34 @@ import (
|
||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||
)
|
||||
|
||||
func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler {
|
||||
func Login(
|
||||
s *hws.Server,
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
st *store.Store,
|
||||
discordAPI *discord.APIClient,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: check DB is connected
|
||||
// check discord API is working
|
||||
errDB := conn.Ping()
|
||||
_, errDisc := discordAPI.Ping()
|
||||
err := stderrors.Join(errors.Wrap(errDB, "conn.Ping"), errors.Wrap(errDisc, "discordAPI.Ping"))
|
||||
err = errors.Wrap(err, "login error")
|
||||
|
||||
if r.Method == "POST" {
|
||||
// if either fail, notify the client that login is unavailable right now
|
||||
// otherwise proceed redirect to GET method
|
||||
if err != nil {
|
||||
notifyServiceUnavailable(s, r, "Login currently unavailable", err)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
w.Header().Set("HX-Redirect", "/login")
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
throwServiceUnavailable(s, w, r, "Login currently unavailable", err)
|
||||
return
|
||||
}
|
||||
// if either fail and method is GET, show service not available page
|
||||
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
|
||||
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
|
||||
|
||||
@@ -39,7 +59,7 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *
|
||||
st.ClearRedirectTrack(r, "/login")
|
||||
|
||||
throwError(
|
||||
server,
|
||||
s,
|
||||
w,
|
||||
r,
|
||||
http.StatusBadRequest,
|
||||
@@ -52,14 +72,14 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *
|
||||
|
||||
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Failed to generate state token", err)
|
||||
throwInternalServiceError(s, w, r, "Failed to generate state token", err)
|
||||
return
|
||||
}
|
||||
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
|
||||
|
||||
link, err := discordAPI.GetOAuthLink(state)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err)
|
||||
throwInternalServiceError(s, w, r, "An error occurred trying to generate the login link", err)
|
||||
return
|
||||
}
|
||||
st.ClearRedirectTrack(r, "/login")
|
||||
|
||||
@@ -37,6 +37,11 @@ func notifyInternalServiceError(s *hws.Server, r *http.Request, msg string, err
|
||||
SerializeErrorDetails(http.StatusInternalServerError, err), nil)
|
||||
}
|
||||
|
||||
func notifyServiceUnavailable(s *hws.Server, r *http.Request, msg string, err error) error {
|
||||
return notifyClient(s, r, notify.LevelError, "Service Unavailable", msg,
|
||||
SerializeErrorDetails(http.StatusServiceUnavailable, err), nil)
|
||||
}
|
||||
|
||||
func notifyWarn(s *hws.Server, r *http.Request, title, msg string, action any) error {
|
||||
return notifyClient(s, r, notify.LevelWarn, title, msg, "", action)
|
||||
}
|
||||
|
||||
@@ -81,13 +81,14 @@ templ navRight() {
|
||||
</div>
|
||||
</div>
|
||||
} else {
|
||||
<a
|
||||
class="hidden rounded-lg px-4 py-2 sm:block
|
||||
<button
|
||||
class="hidden rounded-lg px-4 py-2 sm:block hover:cursor-pointer
|
||||
bg-green hover:bg-green/75 text-mantle transition"
|
||||
href="/login"
|
||||
hx-post="/login"
|
||||
hx-swap="none"
|
||||
>
|
||||
Login
|
||||
</a>
|
||||
</button>
|
||||
}
|
||||
</div>
|
||||
<button
|
||||
|
||||
Reference in New Issue
Block a user