diff --git a/.gitignore b/.gitignore index 72d2520..41d3e99 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ tmp/ static/css/output.css internal/view/**/*_templ.go internal/view/**/*_templ.txt +cmd/test/* # Database backups (compressed) backups/*.sql.gz diff --git a/cmd/oslstats/auth.go b/cmd/oslstats/auth.go index d2b6ff5..781d40a 100644 --- a/cmd/oslstats/auth.go +++ b/cmd/oslstats/auth.go @@ -13,7 +13,7 @@ import ( ) func setupAuth( - config *hwsauth.Config, + cfg *hwsauth.Config, logger *hlog.Logger, conn *bun.DB, server *hws.Server, @@ -24,7 +24,7 @@ func setupAuth( return tx, err } auth, err := hwsauth.NewAuthenticator( - config, + cfg, db.GetUserByID, server, beginTx, @@ -36,11 +36,12 @@ func setupAuth( return nil, errors.Wrap(err, "hwsauth.NewAuthenticator") } - auth.IgnorePaths(ignoredPaths...) + err = auth.IgnorePaths(ignoredPaths...) + if err != nil { + return nil, errors.Wrap(err, "auth.IgnorePaths") + } db.CurrentUser = auth.CurrentModel return auth, nil } - -// TODO: make a new getuser function that wraps db.GetUserByID and does OAuth refresh diff --git a/cmd/oslstats/httpserver.go b/cmd/oslstats/httpserver.go index b0528b4..46acaac 100644 --- a/cmd/oslstats/httpserver.go +++ b/cmd/oslstats/httpserver.go @@ -40,7 +40,8 @@ func setupHTTPServer( "/ws/notifications", } - auth, err := setupAuth(cfg.HWSAuth, logger, bun, httpServer, ignoredPaths) + auth, err := setupAuth( + cfg.HWSAuth, logger, bun, httpServer, ignoredPaths) if err != nil { return nil, errors.Wrap(err, "setupAuth") } @@ -74,7 +75,7 @@ func setupHTTPServer( return nil, errors.Wrap(err, "addRoutes") } - err = addMiddleware(httpServer, auth, cfg, perms) + err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store) if err != nil { return nil, errors.Wrap(err, "addMiddleware") } diff --git a/cmd/oslstats/main.go b/cmd/oslstats/main.go index cdef7fa..51871c8 100644 --- a/cmd/oslstats/main.go +++ b/cmd/oslstats/main.go @@ -38,7 +38,7 @@ func main() { } return } - // + // Setup the logger logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout) if err != nil { diff --git a/cmd/oslstats/middleware.go b/cmd/oslstats/middleware.go index 6af763b..fdcd73e 100644 --- a/cmd/oslstats/middleware.go +++ b/cmd/oslstats/middleware.go @@ -4,13 +4,17 @@ import ( "context" "net/http" "strconv" + "strings" + "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/contexts" "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" "git.haelnorr.com/h/oslstats/internal/rbac" + "git.haelnorr.com/h/oslstats/internal/store" "github.com/pkg/errors" "github.com/uptrace/bun" @@ -21,9 +25,11 @@ func addMiddleware( auth *hwsauth.Authenticator[*db.User, bun.Tx], cfg *config.Config, perms *rbac.Checker, + discordAPI *discord.APIClient, + store *store.Store, ) error { err := server.AddMiddleware( - auth.Authenticate(), + auth.Authenticate(tokenRefresh(auth, discordAPI, store)), perms.LoadPermissionsMiddleware(), devMode(cfg), ) @@ -51,3 +57,111 @@ func devMode(cfg *config.Config) hws.Middleware { ) } } + +func tokenRefresh( + auth *hwsauth.Authenticator[*db.User, bun.Tx], + discordAPI *discord.APIClient, + store *store.Store, +) func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) { + return func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) { + success, err := refreshToken(ctx, store, discordAPI, user, tx) + if err != nil { + return false, &hws.HWSError{ + Error: errors.Wrap(err, "refreshToken"), + Message: "Error refreshing discord token", + Level: hws.ErrorERROR, + RenderErrorPage: true, + StatusCode: http.StatusInternalServerError, + } + } + if !success { + err = auth.Logout(tx, w, r) + if err != nil { + return false, &hws.HWSError{ + Error: errors.Wrap(err, "auth.Logout"), + Message: "Logout failed", + Level: hws.ErrorERROR, + RenderErrorPage: true, + StatusCode: http.StatusInternalServerError, + } + } + return false, nil + } + return true, nil + } +} + +func refreshToken( + ctx context.Context, + store *store.Store, + discordAPI *discord.APIClient, + user *db.User, + tx bun.Tx, +) (bool, error) { + token := store.CheckToken(user) + if token != nil { + return true, nil + } + // Get the token + token, err := user.GetDiscordToken(ctx, tx) + if err != nil { + return false, errors.Wrap(err, "user.GetDiscordToken") + } + + tokenstatus, err := tokenStatus(token) + if err != nil { + return false, errors.Wrap(err, "tokenStatus") + } + switch tokenstatus { + case "revoked": + return false, nil + case "expired", "expiring": + newtoken, err := discordAPI.RefreshToken(token.Convert()) + if err != nil { + return false, errors.Wrap(err, "discordAPI.RefreshToken") + } + err = user.UpdateDiscordToken(ctx, tx, newtoken) + if err != nil { + return false, errors.Wrap(err, "user.UpdateDiscordToken") + } + err = store.NewTokenCheck(user, token) + if err != nil { + return false, errors.Wrap(err, "store.NewTokenCheck") + } + return true, nil + case "valid": + err = store.NewTokenCheck(user, token) + if err != nil { + return false, errors.Wrap(err, "store.NewTokenCheck") + } + return true, nil + default: + return false, errors.New("unexpected error occured validating discord token for user") + } +} + +func tokenStatus(token *db.DiscordToken) (string, error) { + now := time.Now().Unix() + dayfromnow := now + int64(24*time.Hour/time.Second) + oauthtoken := token.Convert() + session, err := discord.NewOAuthSession(oauthtoken) + if err != nil { + return "", errors.Wrap(err, "discord.NewOAuthSession") + } + _, err = session.GetUser() + if err != nil { + if !strings.Contains(err.Error(), "HTTP 401") { + // Error not related to token status + return "", errors.Wrap(err, "session.GetUser") + } + // Token not valid + if token.ExpiresAt < now { + return "expired", nil + } + return "revoked", nil + } + if token.ExpiresAt < dayfromnow { + return "expiring", nil + } + return "valid", nil +} diff --git a/go.mod b/go.mod index 91d57dd..57684ab 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/hlog v0.10.4 git.haelnorr.com/h/golib/hws v0.5.0 - git.haelnorr.com/h/golib/hwsauth v0.5.5 + git.haelnorr.com/h/golib/hwsauth v0.6.1 git.haelnorr.com/h/golib/notify v0.1.0 github.com/a-h/templ v0.3.977 github.com/coder/websocket v1.8.14 diff --git a/go.sum b/go.sum index 3721fe7..35b9693 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,10 @@ git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4V git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c= git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM= -git.haelnorr.com/h/golib/hwsauth v0.5.5 h1:w1qssktq0zYo5cC/xa44h/ZE5G5r7rIsJ4QQWq2Jeoo= -git.haelnorr.com/h/golib/hwsauth v0.5.5/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA= +git.haelnorr.com/h/golib/hwsauth v0.6.0 h1:qaRsgfcD7M6xCasfXwP6Ww9RM4TwDqYMFK2YtO6nt6c= +git.haelnorr.com/h/golib/hwsauth v0.6.0/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA= +git.haelnorr.com/h/golib/hwsauth v0.6.1 h1:3BiM6hwuYDjgfu02hshvUtr592DnWi9Epj//3N13ti0= +git.haelnorr.com/h/golib/hwsauth v0.6.1/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10= diff --git a/internal/store/store.go b/internal/store/store.go index 5620e58..7fd61c2 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1,3 +1,4 @@ +// Package store provides a session store for caching data package store import ( @@ -22,6 +23,7 @@ type Store struct { sessions sync.Map // key: string, value: *RegistrationSession redirectTracks sync.Map // key: string, value: *RedirectTrack cleanup *time.Ticker + tokenchecks sync.Map // key: int, value: *TokenCheck } func NewStore() *Store { @@ -42,6 +44,7 @@ func NewStore() *Store { func (s *Store) Delete(id string) { s.sessions.Delete(id) } + func (s *Store) cleanupExpired() { now := time.Now() @@ -62,10 +65,19 @@ func (s *Store) cleanupExpired() { } return true }) + + s.tokenchecks.Range(func(key, value any) bool { + check := value.(*TokenCheck) + if now.After(check.ExpiresAt) { + s.tokenchecks.Delete(key) + } + return true + }) } + func generateID() string { b := make([]byte, 32) - rand.Read(b) + _, _ = rand.Read(b) return base64.RawURLEncoding.EncodeToString(b) } diff --git a/internal/store/tokencheck.go b/internal/store/tokencheck.go new file mode 100644 index 0000000..b25de79 --- /dev/null +++ b/internal/store/tokencheck.go @@ -0,0 +1,39 @@ +package store + +import ( + "errors" + "time" + + "git.haelnorr.com/h/oslstats/internal/db" +) + +type TokenCheck struct { + Token *db.DiscordToken + ExpiresAt time.Time +} + +func (s *Store) NewTokenCheck(user *db.User, token *db.DiscordToken) error { + if user == nil { + return errors.New("user cannot be nil") + } + if token == nil { + return errors.New("token cannot be nil") + } + s.tokenchecks.Store(user.ID, &TokenCheck{ + Token: token, + ExpiresAt: time.Now().Add(5 * time.Minute), + }) + return nil +} + +func (s *Store) CheckToken(user *db.User) *db.DiscordToken { + res, ok := s.tokenchecks.Load(user.ID) + if !ok { + return nil + } + check := (res).(*TokenCheck) + check.ExpiresAt = time.Now().Add(5 * time.Minute) + s.tokenchecks.Delete(user.ID) + s.tokenchecks.Store(user.ID, check) + return check.Token +}