From 96d534f045eaaf57e11f3ccff535381f097d237d Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Tue, 27 Jan 2026 19:14:12 +1100 Subject: [PATCH] tweaks --- cmd/oslstats/db.go | 8 +- cmd/oslstats/main.go | 21 +- cmd/oslstats/routes.go | 4 +- cmd/oslstats/run.go | 12 +- go.mod | 4 +- go.sum | 6 +- internal/discord/api.go | 30 +- internal/discord/bot.go | 22 ++ internal/discord/config.go | 5 + internal/discord/oauth.go | 21 ++ internal/handlers/errors.go | 11 + internal/handlers/login.go | 38 ++- internal/handlers/notifications.go | 5 + internal/view/component/nav/navbarright.templ | 9 +- pkg/embedfs/files/css/output.css | 256 ------------------ 15 files changed, 138 insertions(+), 314 deletions(-) create mode 100644 internal/discord/bot.go diff --git a/cmd/oslstats/db.go b/cmd/oslstats/db.go index d4106eb..5ea9d10 100644 --- a/cmd/oslstats/db.go +++ b/cmd/oslstats/db.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "time" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" @@ -17,10 +18,15 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s", cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DB, cfg.DB.SSL) sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) + + sqldb.SetMaxOpenConns(25) + sqldb.SetMaxIdleConns(10) + sqldb.SetConnMaxLifetime(5 * time.Minute) + sqldb.SetConnMaxIdleTime(5 * time.Minute) + conn = bun.NewDB(sqldb, pgdialect.New()) close = sqldb.Close - // Simple table creation for backward compatibility err = loadModels(ctx, conn) if err != nil { return nil, nil, errors.Wrap(err, "loadModels") diff --git a/cmd/oslstats/main.go b/cmd/oslstats/main.go index 005dc74..1cdb585 100644 --- a/cmd/oslstats/main.go +++ b/cmd/oslstats/main.go @@ -5,6 +5,7 @@ import ( "fmt" "os" + "git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/oslstats/internal/config" "github.com/pkg/errors" ) @@ -23,6 +24,12 @@ func main() { fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config")) os.Exit(1) } + // Setup the logger + logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to init logger")) + os.Exit(1) + } // Handle utility flags if flags.EnvDoc || flags.ShowEnv { @@ -38,8 +45,7 @@ func main() { // Handle migration file creation (doesn't need DB connection) if flags.MigrateCreate != "" { if err := createMigration(flags.MigrateCreate); err != nil { - fmt.Fprintf(os.Stderr, "Error creating migration: %v\n", err) - os.Exit(1) + logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "createMigration"))).Msg("Error creating migration") } return } @@ -52,8 +58,7 @@ func main() { // Setup database connection conn, close, err := setupBun(ctx, cfg) if err != nil { - fmt.Fprintf(os.Stderr, "Error setting up database: %v\n", err) - os.Exit(1) + logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "setupBun"))).Msg("Error setting up database") } defer close() @@ -71,15 +76,13 @@ func main() { } if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) + logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "dbFlags"))).Msg("Error migrating database") } return } // Normal server startup - if err := run(ctx, os.Stdout, cfg); err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err) - os.Exit(1) + if err := run(ctx, logger, cfg); err != nil { + logger.Fatal().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "run"))).Msg("Error starting server") } } diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index 28b4696..fc22b64 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -38,8 +38,8 @@ func addRoutes( }, { Path: "/login", - Method: hws.MethodGET, - Handler: auth.LogoutReq(handlers.Login(s, cfg, store, discordAPI)), + Methods: []hws.Method{hws.MethodGET, hws.MethodPOST}, + Handler: auth.LogoutReq(handlers.Login(s, conn, cfg, store, discordAPI)), }, { Path: "/auth/callback", diff --git a/cmd/oslstats/run.go b/cmd/oslstats/run.go index 83fc280..acb46f4 100644 --- a/cmd/oslstats/run.go +++ b/cmd/oslstats/run.go @@ -2,7 +2,7 @@ package main import ( "context" - "io" + "fmt" "os" "os/signal" "sync" @@ -18,16 +18,10 @@ import ( ) // Initializes and runs the server -func run(ctx context.Context, w io.Writer, cfg *config.Config) error { +func run(ctx context.Context, logger *hlog.Logger, cfg *config.Config) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() - // Setup the logger - logger, err := hlog.NewLogger(cfg.HLOG, w) - if err != nil { - return errors.Wrap(err, "hlog.NewLogger") - } - // Setup the database connection logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Connecting to database") @@ -78,7 +72,7 @@ func run(ctx context.Context, w io.Writer, cfg *config.Config) error { logger.Info().Msg("Shut down requested, waiting 60 seconds...") err := httpServer.Shutdown(shutdownCtx) if err != nil { - logger.Error().Err(err).Msg("Graceful shutdown failed") + logger.Error().Err(err).Str("stacktrace", fmt.Sprintf("%+v", errors.Wrap(err, "httpServer.Shutdown"))).Msg("Graceful shutdown failed") } }) wg.Wait() diff --git a/go.mod b/go.mod index 6e20442..414ad68 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,8 @@ require ( git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/hlog v0.10.4 - git.haelnorr.com/h/golib/hws v0.4.0 - git.haelnorr.com/h/golib/hwsauth v0.5.2 + git.haelnorr.com/h/golib/hws v0.4.3 + git.haelnorr.com/h/golib/hwsauth v0.5.3 git.haelnorr.com/h/golib/notify v0.1.0 github.com/a-h/templ v0.3.977 github.com/coder/websocket v1.8.14 diff --git a/go.sum b/go.sum index cb5707d..536c99a 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,12 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= -git.haelnorr.com/h/golib/hws v0.4.0 h1:T2JfRz4zpgsNXj0Vyfzxdf/60Tee/7H30osFmr5jDh0= -git.haelnorr.com/h/golib/hws v0.4.0/go.mod h1:UqB83p9lGjidDkk0pWRqxxOFrCkg8t+9J6uGtBOjNLo= +git.haelnorr.com/h/golib/hws v0.4.3 h1:rpqe0Dcbm3b5XZ/Bfy0LUhph6RR7+bmANrSA/W81l0A= +git.haelnorr.com/h/golib/hws v0.4.3/go.mod h1:UqB83p9lGjidDkk0pWRqxxOFrCkg8t+9J6uGtBOjNLo= git.haelnorr.com/h/golib/hwsauth v0.5.2 h1:K4McXMEHtI5o4fAL3AZrmaMkwORNqSTV3MM6BExNKag= git.haelnorr.com/h/golib/hwsauth v0.5.2/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ= +git.haelnorr.com/h/golib/hwsauth v0.5.3 h1:Vgw8khDQZJRCc3m7z9QlbL9CYPyFB9JXUC3+omKRZPc= +git.haelnorr.com/h/golib/hwsauth v0.5.3/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10= diff --git a/internal/discord/api.go b/internal/discord/api.go index aa490b5..fb8c96e 100644 --- a/internal/discord/api.go +++ b/internal/discord/api.go @@ -10,26 +10,6 @@ import ( "github.com/pkg/errors" ) -type OAuthSession struct { - *discordgo.Session -} - -func NewOAuthSession(token *Token) (*OAuthSession, error) { - session, err := discordgo.New("Bearer " + token.AccessToken) - if err != nil { - return nil, errors.Wrap(err, "discordgo.New") - } - return &OAuthSession{Session: session}, nil -} - -func (s *OAuthSession) GetUser() (*discordgo.User, error) { - user, err := s.User("@me") - if err != nil { - return nil, errors.Wrap(err, "s.User") - } - return user, nil -} - // APIClient is an HTTP client wrapper that handles Discord API rate limits type APIClient struct { cfg *Config @@ -38,6 +18,7 @@ type APIClient struct { mu sync.RWMutex buckets map[string]*RateLimitState trustedHost string + bot *BotSession } // NewAPIClient creates a new Discord API client with rate limit handling @@ -51,11 +32,20 @@ func NewAPIClient(cfg *Config, logger *hlog.Logger, trustedhost string) (*APICli if trustedhost == "" { return nil, errors.New("trustedhost cannot be empty") } + bot, err := newBotSession(cfg) + if err != nil { + return nil, errors.Wrap(err, "newBotSession") + } return &APIClient{ client: &http.Client{Timeout: 30 * time.Second}, logger: logger, buckets: make(map[string]*RateLimitState), cfg: cfg, trustedHost: trustedhost, + bot: bot, }, nil } + +func (api *APIClient) Ping() (*discordgo.Application, error) { + return api.bot.Application("@me") +} diff --git a/internal/discord/bot.go b/internal/discord/bot.go new file mode 100644 index 0000000..cc96280 --- /dev/null +++ b/internal/discord/bot.go @@ -0,0 +1,22 @@ +package discord + +import ( + "github.com/bwmarrin/discordgo" + "github.com/pkg/errors" +) + +type BotSession struct { + *discordgo.Session +} + +func newBotSession(cfg *Config) (*BotSession, error) { + session, err := discordgo.New("Bot " + cfg.BotToken) + if err != nil { + return nil, errors.Wrap(err, "discordgo.New") + } + return &BotSession{Session: session}, nil +} + +func (api *APIClient) Bot() *BotSession { + return api.bot +} diff --git a/internal/discord/config.go b/internal/discord/config.go index cc086ff..11af109 100644 --- a/internal/discord/config.go +++ b/internal/discord/config.go @@ -12,6 +12,7 @@ type Config struct { ClientSecret string // ENV DISCORD_CLIENT_SECRET: Discord application client secret (required) OAuthScopes string // Authorisation scopes for OAuth RedirectPath string // ENV DISCORD_REDIRECT_PATH: Path for the OAuth redirect handler (required) + BotToken string // ENV DISCORD_BOT_TOKEN: Token for the discord bot (required) } func ConfigFromEnv() (any, error) { @@ -20,6 +21,7 @@ func ConfigFromEnv() (any, error) { ClientSecret: env.String("DISCORD_CLIENT_SECRET", ""), OAuthScopes: getOAuthScopes(), RedirectPath: env.String("DISCORD_REDIRECT_PATH", ""), + BotToken: env.String("DISCORD_BOT_TOKEN", ""), } // Check required fields @@ -32,6 +34,9 @@ func ConfigFromEnv() (any, error) { if cfg.RedirectPath == "" { return nil, errors.New("Envar not set: DISCORD_REDIRECT_PATH") } + if cfg.BotToken == "" { + return nil, errors.New("Envar not set: DISCORD_BOT_TOKEN") + } return cfg, nil } diff --git a/internal/discord/oauth.go b/internal/discord/oauth.go index faff919..5bb8d04 100644 --- a/internal/discord/oauth.go +++ b/internal/discord/oauth.go @@ -8,9 +8,30 @@ import ( "net/url" "strings" + "github.com/bwmarrin/discordgo" "github.com/pkg/errors" ) +type OAuthSession struct { + *discordgo.Session +} + +func NewOAuthSession(token *Token) (*OAuthSession, error) { + session, err := discordgo.New("Bearer " + token.AccessToken) + if err != nil { + return nil, errors.Wrap(err, "discordgo.New") + } + return &OAuthSession{Session: session}, nil +} + +func (s *OAuthSession) GetUser() (*discordgo.User, error) { + user, err := s.User("@me") + if err != nil { + return nil, errors.Wrap(err, "s.User") + } + return user, nil +} + // Token represents a response from the Discord OAuth API after a successful authorization request type Token struct { AccessToken string `json:"access_token"` diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go index bad79d6..8438a13 100644 --- a/internal/handlers/errors.go +++ b/internal/handlers/errors.go @@ -42,6 +42,17 @@ func throwInternalServiceError( throwError(s, w, r, http.StatusInternalServerError, msg, err, "error") } +// throwServiceUnavailable handles 503 errors +func throwServiceUnavailable( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusServiceUnavailable, msg, err, "error") +} + // throwBadRequest handles 400 errors (malformed requests) func throwBadRequest( s *hws.Server, diff --git a/internal/handlers/login.go b/internal/handlers/login.go index d76fcb6..1a635f6 100644 --- a/internal/handlers/login.go +++ b/internal/handlers/login.go @@ -1,11 +1,13 @@ package handlers import ( + stderrors "errors" "net/http" "git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/hws" "github.com/pkg/errors" + "github.com/uptrace/bun" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/discord" @@ -13,16 +15,34 @@ import ( "git.haelnorr.com/h/oslstats/pkg/oauth" ) -func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler { +func Login( + s *hws.Server, + conn *bun.DB, + cfg *config.Config, + st *store.Store, + discordAPI *discord.APIClient, +) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - // TODO: check DB is connected - // check discord API is working + errDB := conn.Ping() + _, errDisc := discordAPI.Ping() + err := stderrors.Join(errors.Wrap(errDB, "conn.Ping"), errors.Wrap(errDisc, "discordAPI.Ping")) + err = errors.Wrap(err, "login error") + if r.Method == "POST" { - // if either fail, notify the client that login is unavailable right now - // otherwise proceed redirect to GET method + if err != nil { + notifyServiceUnavailable(s, r, "Login currently unavailable", err) + w.WriteHeader(http.StatusOK) + return + } + w.Header().Set("HX-Redirect", "/login") + return + } + + if err != nil { + throwServiceUnavailable(s, w, r, "Login currently unavailable", err) + return } - // if either fail and method is GET, show service not available page cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost) attempts, exceeded, track := st.TrackRedirect(r, "/login", 5) @@ -39,7 +59,7 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI * st.ClearRedirectTrack(r, "/login") throwError( - server, + s, w, r, http.StatusBadRequest, @@ -52,14 +72,14 @@ func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI * state, uak, err := oauth.GenerateState(cfg.OAuth, "login") if err != nil { - throwInternalServiceError(server, w, r, "Failed to generate state token", err) + throwInternalServiceError(s, w, r, "Failed to generate state token", err) return } oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL) link, err := discordAPI.GetOAuthLink(state) if err != nil { - throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err) + throwInternalServiceError(s, w, r, "An error occurred trying to generate the login link", err) return } st.ClearRedirectTrack(r, "/login") diff --git a/internal/handlers/notifications.go b/internal/handlers/notifications.go index 047ff62..ee0af94 100644 --- a/internal/handlers/notifications.go +++ b/internal/handlers/notifications.go @@ -37,6 +37,11 @@ func notifyInternalServiceError(s *hws.Server, r *http.Request, msg string, err SerializeErrorDetails(http.StatusInternalServerError, err), nil) } +func notifyServiceUnavailable(s *hws.Server, r *http.Request, msg string, err error) error { + return notifyClient(s, r, notify.LevelError, "Service Unavailable", msg, + SerializeErrorDetails(http.StatusServiceUnavailable, err), nil) +} + func notifyWarn(s *hws.Server, r *http.Request, title, msg string, action any) error { return notifyClient(s, r, notify.LevelWarn, title, msg, "", action) } diff --git a/internal/view/component/nav/navbarright.templ b/internal/view/component/nav/navbarright.templ index 88a5531..fa27c16 100644 --- a/internal/view/component/nav/navbarright.templ +++ b/internal/view/component/nav/navbarright.templ @@ -81,13 +81,14 @@ templ navRight() { } else { - + }