diff --git a/cmd/oslstats/db.go b/cmd/oslstats/db.go
index d4106eb..5ea9d10 100644
--- a/cmd/oslstats/db.go
+++ b/cmd/oslstats/db.go
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
+ "time"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
@@ -17,10 +18,15 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s",
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL)
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
+
+ sqldb.SetMaxOpenConns(25)
+ sqldb.SetMaxIdleConns(10)
+ sqldb.SetConnMaxLifetime(5 * time.Minute)
+ sqldb.SetConnMaxIdleTime(5 * time.Minute)
+
conn = bun.NewDB(sqldb, pgdialect.New())
close = sqldb.Close
- // Simple table creation for backward compatibility
err = loadModels(ctx, conn)
if err != nil {
return nil, nil, errors.Wrap(err, "loadModels")
diff --git a/cmd/oslstats/main.go b/cmd/oslstats/main.go
index 005dc74..1cdb585 100644
--- a/cmd/oslstats/main.go
+++ b/cmd/oslstats/main.go
@@ -5,6 +5,7 @@ import (
"fmt"
"os"
+ "git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/oslstats/internal/config"
"github.com/pkg/errors"
)
@@ -23,6 +24,12 @@ func main() {
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config"))
os.Exit(1)
}
+ // Setup the logger
+ logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to init logger"))
+ os.Exit(1)
+ }
// Handle utility flags
if flags.EnvDoc || flags.ShowEnv {
@@ -38,8 +45,7 @@ func main() {
// Handle migration file creation (doesn't need DB connection)
if flags.MigrateCreate != "" {
if err := createMigration(flags.MigrateCreate); err != nil {
- fmt.Fprintf(os.Stderr, "Error creating migration: %v\n", err)
- os.Exit(1)
+ logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "createMigration"))).Msg("Error creating migration")
}
return
}
@@ -52,8 +58,7 @@ func main() {
// Setup database connection
conn, close, err := setupBun(ctx, cfg)
if err != nil {
- fmt.Fprintf(os.Stderr, "Error setting up database: %v\n", err)
- os.Exit(1)
+ logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "setupBun"))).Msg("Error setting up database")
}
defer close()
@@ -71,15 +76,13 @@ func main() {
}
if err != nil {
- fmt.Fprintf(os.Stderr, "Error: %v\n", err)
- os.Exit(1)
+ logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "dbFlags"))).Msg("Error migrating database")
}
return
}
// Normal server startup
- if err := run(ctx, os.Stdout, cfg); err != nil {
- fmt.Fprintf(os.Stderr, "%s\n", err)
- os.Exit(1)
+ if err := run(ctx, logger, cfg); err != nil {
+ logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "run"))).Msg("Error starting server")
}
}
diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go
index 28b4696..fc22b64 100644
--- a/cmd/oslstats/routes.go
+++ b/cmd/oslstats/routes.go
@@ -38,8 +38,8 @@ func addRoutes(
},
{
Path: "/login",
- Method: hws.MethodGET,
- Handler: auth.LogoutReq(handlers.Login(s, cfg, store, discordAPI)),
+ Methods: []hws.Method{hws.MethodGET, hws.MethodPOST},
+ Handler: auth.LogoutReq(handlers.Login(s, conn, cfg, store, discordAPI)),
},
{
Path: "/auth/callback",
diff --git a/cmd/oslstats/run.go b/cmd/oslstats/run.go
index 83fc280..acb46f4 100644
--- a/cmd/oslstats/run.go
+++ b/cmd/oslstats/run.go
@@ -2,7 +2,7 @@ package main
import (
"context"
- "io"
+ "fmt"
"os"
"os/signal"
"sync"
@@ -18,16 +18,10 @@ import (
)
// Initializes and runs the server
-func run(ctx context.Context, w io.Writer, cfg *config.Config) error {
+func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
- // Setup the logger
- logger, err := hlog.NewLogger(cfg.HLOG, w)
- if err != nil {
- return errors.Wrap(err, "hlog.NewLogger")
- }
-
// Setup the database connection
logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database")
@@ -78,7 +72,7 @@ func run(ctx context.Context, w io.Writer, cfg *config.Config) error {
logger.Info().Msg("Shut down requested, waiting 60 seconds...")
err := httpServer.Shutdown(shutdownCtx)
if err != nil {
- logger.Error().Err(err).Msg("Graceful shutdown failed")
+ logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "httpServer.Shutdown"))).Msg("Graceful shutdown failed")
}
})
wg.Wait()
diff --git a/go.mod b/go.mod
index 6e20442..414ad68 100644
--- a/go.mod
+++ b/go.mod
@@ -6,8 +6,8 @@ require (
git.haelnorr.com/h/golib/env v0.9.1
git.haelnorr.com/h/golib/ezconf v0.1.1
git.haelnorr.com/h/golib/hlog v0.10.4
- git.haelnorr.com/h/golib/hws v0.4.0
- git.haelnorr.com/h/golib/hwsauth v0.5.2
+ git.haelnorr.com/h/golib/hws v0.4.3
+ git.haelnorr.com/h/golib/hwsauth v0.5.3
git.haelnorr.com/h/golib/notify v0.1.0
github.com/a-h/templ v0.3.977
github.com/coder/websocket v1.8.14
diff --git a/go.sum b/go.sum
index cb5707d..536c99a 100644
--- a/go.sum
+++ b/go.sum
@@ -6,10 +6,12 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI
git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8=
git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ=
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
-git.haelnorr.com/h/golib/hws v0.4.0 h1:T2JfRz4zpgsNXj0Vyfzxdf/60Tee/7H30osFmr5jDh0=
-git.haelnorr.com/h/golib/hws v0.4.0/go.mod h1:UqB83p9lGjidDkk0pWRqxxOFrCkg8t+9J6uGtBOjNLo=
+git.haelnorr.com/h/golib/hws v0.4.3 h1:rpqe0Dcbm3b5XZ/Bfy0LUhph6RR7+bmANrSA/W81l0A=
+git.haelnorr.com/h/golib/hws v0.4.3/go.mod h1:UqB83p9lGjidDkk0pWRqxxOFrCkg8t+9J6uGtBOjNLo=
git.haelnorr.com/h/golib/hwsauth v0.5.2 h1:K4McXMEHtI5o4fAL3AZrmaMkwORNqSTV3MM6BExNKag=
git.haelnorr.com/h/golib/hwsauth v0.5.2/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ=
+git.haelnorr.com/h/golib/hwsauth v0.5.3 h1:Vgw8khDQZJRCc3m7z9QlbL9CYPyFB9JXUC3+omKRZPc=
+git.haelnorr.com/h/golib/hwsauth v0.5.3/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
diff --git a/internal/discord/api.go b/internal/discord/api.go
index aa490b5..fb8c96e 100644
--- a/internal/discord/api.go
+++ b/internal/discord/api.go
@@ -10,26 +10,6 @@ import (
"github.com/pkg/errors"
)
-type OAuthSession struct {
- *discordgo.Session
-}
-
-func NewOAuthSession(token *Token) (*OAuthSession, error) {
- session, err := discordgo.New("Bearer " + token.AccessToken)
- if err != nil {
- return nil, errors.Wrap(err, "discordgo.New")
- }
- return &OAuthSession{Session: session}, nil
-}
-
-func (s *OAuthSession) GetUser() (*discordgo.User, error) {
- user, err := s.User("@me")
- if err != nil {
- return nil, errors.Wrap(err, "s.User")
- }
- return user, nil
-}
-
// APIClient is an HTTP client wrapper that handles Discord API rate limits
type APIClient struct {
cfg *Config
@@ -38,6 +18,7 @@ type APIClient struct {
mu sync.RWMutex
buckets map[string]*RateLimitState
trustedHost string
+ bot *BotSession
}
// NewAPIClient creates a new Discord API client with rate limit handling
@@ -51,11 +32,20 @@ func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APICli
if trustedhost == "" {
return nil, errors.New("trustedhost cannot be empty")
}
+ bot, err := newBotSession(cfg)
+ if err != nil {
+ return nil, errors.Wrap(err, "newBotSession")
+ }
return &APIClient{
client: &http.Client{Timeout: 30 * time.Second},
logger: logger,
buckets: make(map[string]*RateLimitState),
cfg: cfg,
trustedHost: trustedhost,
+ bot: bot,
}, nil
}
+
+func (api *APIClient) Ping() (*discordgo.Application, error) {
+ return api.bot.Application("@me")
+}
diff --git a/internal/discord/bot.go b/internal/discord/bot.go
new file mode 100644
index 0000000..cc96280
--- /dev/null
+++ b/internal/discord/bot.go
@@ -0,0 +1,22 @@
+package discord
+
+import (
+ "github.com/bwmarrin/discordgo"
+ "github.com/pkg/errors"
+)
+
+type BotSession struct {
+ *discordgo.Session
+}
+
+func newBotSession(cfg *Config) (*BotSession, error) {
+ session, err := discordgo.New("Bot " + cfg.BotToken)
+ if err != nil {
+ return nil, errors.Wrap(err, "discordgo.New")
+ }
+ return &BotSession{Session: session}, nil
+}
+
+func (api *APIClient) Bot() *BotSession {
+ return api.bot
+}
diff --git a/internal/discord/config.go b/internal/discord/config.go
index cc086ff..11af109 100644
--- a/internal/discord/config.go
+++ b/internal/discord/config.go
@@ -12,6 +12,7 @@ type Config struct {
ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required)
OAuthScopes string // Authorisation scopes for OAuth
RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required)
+ BotToken string // ENV DISCORD_BOT_TOKEN: Token for the discord bot (required)
}
func ConfigFromEnv() (any, error) {
@@ -20,6 +21,7 @@ func ConfigFromEnv() (any, error) {
ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""),
OAuthScopes: getOAuthScopes(),
RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""),
+ BotToken: env.String("DISCORD_BOT_TOKEN", ""),
}
// Check required fields
@@ -32,6 +34,9 @@ func ConfigFromEnv() (any, error) {
if cfg.RedirectPath == "" {
return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH")
}
+ if cfg.BotToken == "" {
+ return nil, errors.New("Envar not set: DISCORD_BOT_TOKEN")
+ }
return cfg, nil
}
diff --git a/internal/discord/oauth.go b/internal/discord/oauth.go
index faff919..5bb8d04 100644
--- a/internal/discord/oauth.go
+++ b/internal/discord/oauth.go
@@ -8,9 +8,30 @@ import (
"net/url"
"strings"
+ "github.com/bwmarrin/discordgo"
"github.com/pkg/errors"
)
+type OAuthSession struct {
+ *discordgo.Session
+}
+
+func NewOAuthSession(token *Token) (*OAuthSession, error) {
+ session, err := discordgo.New("Bearer " + token.AccessToken)
+ if err != nil {
+ return nil, errors.Wrap(err, "discordgo.New")
+ }
+ return &OAuthSession{Session: session}, nil
+}
+
+func (s *OAuthSession) GetUser() (*discordgo.User, error) {
+ user, err := s.User("@me")
+ if err != nil {
+ return nil, errors.Wrap(err, "s.User")
+ }
+ return user, nil
+}
+
// Token represents a response from the Discord OAuth API after a successful authorization request
type Token struct {
AccessToken string `json:"access_token"`
diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go
index bad79d6..8438a13 100644
--- a/internal/handlers/errors.go
+++ b/internal/handlers/errors.go
@@ -42,6 +42,17 @@ func throwInternalServiceError(
throwError(s, w, r, http.StatusInternalServerError, msg, err, "error")
}
+// throwServiceUnavailable handles 503 errors
+func throwServiceUnavailable(
+ s *hws.Server,
+ w http.ResponseWriter,
+ r *http.Request,
+ msg string,
+ err error,
+) {
+ throwError(s, w, r, http.StatusServiceUnavailable, msg, err, "error")
+}
+
// throwBadRequest handles 400 errors (malformed requests)
func throwBadRequest(
s *hws.Server,
diff --git a/internal/handlers/login.go b/internal/handlers/login.go
index d76fcb6..1a635f6 100644
--- a/internal/handlers/login.go
+++ b/internal/handlers/login.go
@@ -1,11 +1,13 @@
package handlers
import (
+ stderrors "errors"
"net/http"
"git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws"
"github.com/pkg/errors"
+ "github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/discord"
@@ -13,16 +15,34 @@ import (
"git.haelnorr.com/h/oslstats/pkg/oauth"
)
-func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler {
+func Login(
+ s *hws.Server,
+ conn *bun.DB,
+ cfg *config.Config,
+ st *store.Store,
+ discordAPI *discord.APIClient,
+) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
- // TODO: check DB is connected
- // check discord API is working
+ errDB := conn.Ping()
+ _, errDisc := discordAPI.Ping()
+ err := stderrors.Join(errors.Wrap(errDB, "conn.Ping"), errors.Wrap(errDisc, "discordAPI.Ping"))
+ err = errors.Wrap(err, "login error")
+
if r.Method == "POST" {
- // if either fail, notify the client that login is unavailable right now
- // otherwise proceed redirect to GET method
+ if err != nil {
+ notifyServiceUnavailable(s, r, "Login currently unavailable", err)
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+ w.Header().Set("HX-Redirect", "/login")
+ return
+ }
+
+ if err != nil {
+ throwServiceUnavailable(s, w, r, "Login currently unavailable", err)
+ return
}
- // if either fail and method is GET, show service not available page
cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost)
attempts, exceeded, track := st.TrackRedirect(r, "/login", 5)
@@ -39,7 +59,7 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *
st.ClearRedirectTrack(r, "/login")
throwError(
- server,
+ s,
w,
r,
http.StatusBadRequest,
@@ -52,14 +72,14 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
if err != nil {
- throwInternalServiceError(server, w, r, "Failed to generate state token", err)
+ throwInternalServiceError(s, w, r, "Failed to generate state token", err)
return
}
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
link, err := discordAPI.GetOAuthLink(state)
if err != nil {
- throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err)
+ throwInternalServiceError(s, w, r, "An error occurred trying to generate the login link", err)
return
}
st.ClearRedirectTrack(r, "/login")
diff --git a/internal/handlers/notifications.go b/internal/handlers/notifications.go
index 047ff62..ee0af94 100644
--- a/internal/handlers/notifications.go
+++ b/internal/handlers/notifications.go
@@ -37,6 +37,11 @@ func notifyInternalServiceError(s *hws.Server, r *http.Request, msg string, err
SerializeErrorDetails(http.StatusInternalServerError, err), nil)
}
+func notifyServiceUnavailable(s *hws.Server, r *http.Request, msg string, err error) error {
+ return notifyClient(s, r, notify.LevelError, "Service Unavailable", msg,
+ SerializeErrorDetails(http.StatusServiceUnavailable, err), nil)
+}
+
func notifyWarn(s *hws.Server, r *http.Request, title, msg string, action any) error {
return notifyClient(s, r, notify.LevelWarn, title, msg, "", action)
}
diff --git a/internal/view/component/nav/navbarright.templ b/internal/view/component/nav/navbarright.templ
index 88a5531..fa27c16 100644
--- a/internal/view/component/nav/navbarright.templ
+++ b/internal/view/component/nav/navbarright.templ
@@ -81,13 +81,14 @@ templ navRight() {
} else {
-
Login
-
+
}