Files
oslstats/internal/server/middleware.go
2026-02-14 19:48:59 +11:00

170 lines
4.4 KiB
Go

package server
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/contexts"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/store"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func addMiddleware(
server *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
cfg *config.Config,
perms *rbac.Checker,
discordAPI *discord.APIClient,
store *store.Store,
conn *db.DB,
) error {
err := server.AddMiddleware(
auth.Authenticate(tokenRefresh(auth, discordAPI, store)),
rbac.LoadPreviewRoleMiddleware(server, conn),
perms.LoadPermissionsMiddleware(),
devMode(cfg),
)
if err != nil {
return errors.Wrap(err, "server.AddMiddleware")
}
return nil
}
func devMode(cfg *config.Config) hws.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if cfg.Flags.DevMode {
devInfo := contexts.DevInfo{
WebsocketBase: "ws://" + cfg.HWS.Host + ":" + strconv.FormatUint(cfg.HWS.Port, 10),
HTMXLog: true,
}
ctx := context.WithValue(r.Context(), contexts.DevModeKey, devInfo)
req := r.WithContext(ctx)
next.ServeHTTP(w, req)
return
}
next.ServeHTTP(w, r)
},
)
}
}
func tokenRefresh(
auth *hwsauth.Authenticator[*db.User, bun.Tx],
discordAPI *discord.APIClient,
store *store.Store,
) func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
return func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
success, err := refreshToken(ctx, store, discordAPI, user, tx)
if err != nil {
return false, &hws.HWSError{
Error: errors.Wrap(err, "refreshToken"),
Message: "Error refreshing discord token",
Level: hws.ErrorERROR,
RenderErrorPage: true,
StatusCode: http.StatusInternalServerError,
}
}
if !success {
err = auth.Logout(tx, w, r)
if err != nil {
return false, &hws.HWSError{
Error: errors.Wrap(err, "auth.Logout"),
Message: "Logout failed",
Level: hws.ErrorERROR,
RenderErrorPage: true,
StatusCode: http.StatusInternalServerError,
}
}
return false, nil
}
return true, nil
}
}
func refreshToken(
ctx context.Context,
store *store.Store,
discordAPI *discord.APIClient,
user *db.User,
tx bun.Tx,
) (bool, error) {
token := store.CheckToken(user)
if token != nil {
return true, nil
}
// Get the token
token, err := user.GetDiscordToken(ctx, tx)
if err != nil {
return false, errors.Wrap(err, "user.GetDiscordToken")
}
tokenstatus, err := tokenStatus(token)
if err != nil {
return false, errors.Wrap(err, "tokenStatus")
}
switch tokenstatus {
case "revoked":
return false, nil
case "expired", "expiring":
newtoken, err := discordAPI.RefreshToken(token.Convert())
if err != nil {
return false, errors.Wrap(err, "discordAPI.RefreshToken")
}
err = user.UpdateDiscordToken(ctx, tx, newtoken)
if err != nil {
return false, errors.Wrap(err, "user.UpdateDiscordToken")
}
err = store.NewTokenCheck(user, token)
if err != nil {
return false, errors.Wrap(err, "store.NewTokenCheck")
}
return true, nil
case "valid":
err = store.NewTokenCheck(user, token)
if err != nil {
return false, errors.Wrap(err, "store.NewTokenCheck")
}
return true, nil
default:
return false, errors.New("unexpected error occured validating discord token for user")
}
}
func tokenStatus(token *db.DiscordToken) (string, error) {
now := time.Now().Unix()
dayfromnow := now + int64(24*time.Hour/time.Second)
oauthtoken := token.Convert()
session, err := discord.NewOAuthSession(oauthtoken)
if err != nil {
return "", errors.Wrap(err, "discord.NewOAuthSession")
}
_, err = session.GetUser()
if err != nil {
if !strings.Contains(err.Error(), "HTTP 401") {
// Error not related to token status
return "", errors.Wrap(err, "session.GetUser")
}
// Token not valid
if token.ExpiresAt < now {
return "expired", nil
}
return "revoked", nil
}
if token.ExpiresAt < dayfromnow {
return "expiring", nil
}
return "valid", nil
}