package server import ( "context" "net/http" "strconv" "strings" "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/contexts" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/store" "github.com/pkg/errors" "github.com/uptrace/bun" ) func addMiddleware( server *hws.Server, auth *hwsauth.Authenticator[*db.User, bun.Tx], cfg *config.Config, perms *rbac.Checker, discordAPI *discord.APIClient, store *store.Store, conn *db.DB, ) error { err := server.AddMiddleware( auth.Authenticate(tokenRefresh(auth, discordAPI, store)), rbac.LoadPreviewRoleMiddleware(server, conn), perms.LoadPermissionsMiddleware(), devMode(cfg), ) if err != nil { return errors.Wrap(err, "server.AddMiddleware") } return nil } func devMode(cfg *config.Config) hws.Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !cfg.Flags.DevMode && !cfg.Flags.Staging { next.ServeHTTP(w, r) return } devInfo := contexts.DevInfo{} if cfg.Flags.DevMode { devInfo.WebsocketBase = "ws://" + cfg.HWS.Host + ":" + strconv.FormatUint(cfg.HWS.Port, 10) devInfo.HTMXLog = true } if cfg.Flags.Staging { devInfo.StagingBanner = true } ctx := context.WithValue(r.Context(), contexts.DevModeKey, devInfo) req := r.WithContext(ctx) next.ServeHTTP(w, req) }, ) } } func tokenRefresh( auth *hwsauth.Authenticator[*db.User, bun.Tx], discordAPI *discord.APIClient, store *store.Store, ) func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) { return func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) { success, err := refreshToken(ctx, store, discordAPI, user, tx) if err != nil { return false, &hws.HWSError{ Error: errors.Wrap(err, "refreshToken"), Message: "Error refreshing discord token", Level: hws.ErrorERROR, RenderErrorPage: true, StatusCode: http.StatusInternalServerError, } } if !success { err = auth.Logout(tx, w, r) if err != nil { return false, &hws.HWSError{ Error: errors.Wrap(err, "auth.Logout"), Message: "Logout failed", Level: hws.ErrorERROR, RenderErrorPage: true, StatusCode: http.StatusInternalServerError, } } return false, nil } return true, nil } } func refreshToken( ctx context.Context, store *store.Store, discordAPI *discord.APIClient, user *db.User, tx bun.Tx, ) (bool, error) { token := store.CheckToken(user) if token != nil { return true, nil } // Get the token token, err := user.GetDiscordToken(ctx, tx) if err != nil { return false, errors.Wrap(err, "user.GetDiscordToken") } tokenstatus, err := tokenStatus(token) if err != nil { return false, errors.Wrap(err, "tokenStatus") } switch tokenstatus { case "revoked": return false, nil case "expired", "expiring": newtoken, err := discordAPI.RefreshToken(token.Convert()) if err != nil { if strings.Contains(err.Error(), "invalid_grant") { return false, nil } return false, errors.Wrap(err, "discordAPI.RefreshToken") } err = user.UpdateDiscordToken(ctx, tx, newtoken) if err != nil { return false, errors.Wrap(err, "user.UpdateDiscordToken") } err = store.NewTokenCheck(user, token) if err != nil { return false, errors.Wrap(err, "store.NewTokenCheck") } return true, nil case "valid": err = store.NewTokenCheck(user, token) if err != nil { return false, errors.Wrap(err, "store.NewTokenCheck") } return true, nil default: return false, errors.New("unexpected error occured validating discord token for user") } } func tokenStatus(token *db.DiscordToken) (string, error) { now := time.Now().Unix() dayfromnow := now + int64(24*time.Hour/time.Second) oauthtoken := token.Convert() session, err := discord.NewOAuthSession(oauthtoken) if err != nil { return "", errors.Wrap(err, "discord.NewOAuthSession") } _, err = session.GetUser() if err != nil { if !strings.Contains(err.Error(), "HTTP 401") { // Error not related to token status return "", errors.Wrap(err, "session.GetUser") } // Token not valid if token.ExpiresAt < now { return "expired", nil } return "revoked", nil } if token.ExpiresAt < dayfromnow { return "expiring", nil } return "valid", nil }