170 lines
4.4 KiB
Go
170 lines
4.4 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.haelnorr.com/h/golib/hws"
|
|
"git.haelnorr.com/h/golib/hwsauth"
|
|
"git.haelnorr.com/h/oslstats/internal/config"
|
|
"git.haelnorr.com/h/oslstats/internal/contexts"
|
|
"git.haelnorr.com/h/oslstats/internal/db"
|
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
|
"git.haelnorr.com/h/oslstats/internal/rbac"
|
|
"git.haelnorr.com/h/oslstats/internal/store"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
func addMiddleware(
|
|
server *hws.Server,
|
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
|
cfg *config.Config,
|
|
perms *rbac.Checker,
|
|
discordAPI *discord.APIClient,
|
|
store *store.Store,
|
|
conn *db.DB,
|
|
) error {
|
|
err := server.AddMiddleware(
|
|
auth.Authenticate(tokenRefresh(auth, discordAPI, store)),
|
|
rbac.LoadPreviewRoleMiddleware(server, conn),
|
|
perms.LoadPermissionsMiddleware(),
|
|
devMode(cfg),
|
|
)
|
|
if err != nil {
|
|
return errors.Wrap(err, "server.AddMiddleware")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func devMode(cfg *config.Config) hws.Middleware {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if cfg.Flags.DevMode {
|
|
devInfo := contexts.DevInfo{
|
|
WebsocketBase: "ws://" + cfg.HWS.Host + ":" + strconv.FormatUint(cfg.HWS.Port, 10),
|
|
HTMXLog: true,
|
|
}
|
|
ctx := context.WithValue(r.Context(), contexts.DevModeKey, devInfo)
|
|
req := r.WithContext(ctx)
|
|
next.ServeHTTP(w, req)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
},
|
|
)
|
|
}
|
|
}
|
|
|
|
func tokenRefresh(
|
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
|
discordAPI *discord.APIClient,
|
|
store *store.Store,
|
|
) func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
|
|
return func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
|
|
success, err := refreshToken(ctx, store, discordAPI, user, tx)
|
|
if err != nil {
|
|
return false, &hws.HWSError{
|
|
Error: errors.Wrap(err, "refreshToken"),
|
|
Message: "Error refreshing discord token",
|
|
Level: hws.ErrorERROR,
|
|
RenderErrorPage: true,
|
|
StatusCode: http.StatusInternalServerError,
|
|
}
|
|
}
|
|
if !success {
|
|
err = auth.Logout(tx, w, r)
|
|
if err != nil {
|
|
return false, &hws.HWSError{
|
|
Error: errors.Wrap(err, "auth.Logout"),
|
|
Message: "Logout failed",
|
|
Level: hws.ErrorERROR,
|
|
RenderErrorPage: true,
|
|
StatusCode: http.StatusInternalServerError,
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
func refreshToken(
|
|
ctx context.Context,
|
|
store *store.Store,
|
|
discordAPI *discord.APIClient,
|
|
user *db.User,
|
|
tx bun.Tx,
|
|
) (bool, error) {
|
|
token := store.CheckToken(user)
|
|
if token != nil {
|
|
return true, nil
|
|
}
|
|
// Get the token
|
|
token, err := user.GetDiscordToken(ctx, tx)
|
|
if err != nil {
|
|
return false, errors.Wrap(err, "user.GetDiscordToken")
|
|
}
|
|
|
|
tokenstatus, err := tokenStatus(token)
|
|
if err != nil {
|
|
return false, errors.Wrap(err, "tokenStatus")
|
|
}
|
|
switch tokenstatus {
|
|
case "revoked":
|
|
return false, nil
|
|
case "expired", "expiring":
|
|
newtoken, err := discordAPI.RefreshToken(token.Convert())
|
|
if err != nil {
|
|
return false, errors.Wrap(err, "discordAPI.RefreshToken")
|
|
}
|
|
err = user.UpdateDiscordToken(ctx, tx, newtoken)
|
|
if err != nil {
|
|
return false, errors.Wrap(err, "user.UpdateDiscordToken")
|
|
}
|
|
err = store.NewTokenCheck(user, token)
|
|
if err != nil {
|
|
return false, errors.Wrap(err, "store.NewTokenCheck")
|
|
}
|
|
return true, nil
|
|
case "valid":
|
|
err = store.NewTokenCheck(user, token)
|
|
if err != nil {
|
|
return false, errors.Wrap(err, "store.NewTokenCheck")
|
|
}
|
|
return true, nil
|
|
default:
|
|
return false, errors.New("unexpected error occured validating discord token for user")
|
|
}
|
|
}
|
|
|
|
func tokenStatus(token *db.DiscordToken) (string, error) {
|
|
now := time.Now().Unix()
|
|
dayfromnow := now + int64(24*time.Hour/time.Second)
|
|
oauthtoken := token.Convert()
|
|
session, err := discord.NewOAuthSession(oauthtoken)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, "discord.NewOAuthSession")
|
|
}
|
|
_, err = session.GetUser()
|
|
if err != nil {
|
|
if !strings.Contains(err.Error(), "HTTP 401") {
|
|
// Error not related to token status
|
|
return "", errors.Wrap(err, "session.GetUser")
|
|
}
|
|
// Token not valid
|
|
if token.ExpiresAt < now {
|
|
return "expired", nil
|
|
}
|
|
return "revoked", nil
|
|
}
|
|
if token.ExpiresAt < dayfromnow {
|
|
return "expiring", nil
|
|
}
|
|
return "valid", nil
|
|
}
|