added discord token check to auth

This commit is contained in:
2026-02-07 18:02:12 +11:00
parent 697bef80e9
commit 7125683e6a
9 changed files with 183 additions and 13 deletions

1
.gitignore vendored
View File

@@ -8,6 +8,7 @@ tmp/
static/css/output.css static/css/output.css
internal/view/**/*_templ.go internal/view/**/*_templ.go
internal/view/**/*_templ.txt internal/view/**/*_templ.txt
cmd/test/*
# Database backups (compressed) # Database backups (compressed)
backups/*.sql.gz backups/*.sql.gz

View File

@@ -13,7 +13,7 @@ import (
) )
func setupAuth( func setupAuth(
config *hwsauth.Config, cfg *hwsauth.Config,
logger *hlog.Logger, logger *hlog.Logger,
conn *bun.DB, conn *bun.DB,
server *hws.Server, server *hws.Server,
@@ -24,7 +24,7 @@ func setupAuth(
return tx, err return tx, err
} }
auth, err := hwsauth.NewAuthenticator( auth, err := hwsauth.NewAuthenticator(
config, cfg,
db.GetUserByID, db.GetUserByID,
server, server,
beginTx, beginTx,
@@ -36,11 +36,12 @@ func setupAuth(
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator") return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
} }
auth.IgnorePaths(ignoredPaths...) err = auth.IgnorePaths(ignoredPaths...)
if err != nil {
return nil, errors.Wrap(err, "auth.IgnorePaths")
}
db.CurrentUser = auth.CurrentModel db.CurrentUser = auth.CurrentModel
return auth, nil return auth, nil
} }
// TODO: make a new getuser function that wraps db.GetUserByID and does OAuth refresh

View File

@@ -40,7 +40,8 @@ func setupHTTPServer(
"/ws/notifications", "/ws/notifications",
} }
auth, err := setupAuth(cfg.HWSAuth, logger, bun, httpServer, ignoredPaths) auth, err := setupAuth(
cfg.HWSAuth, logger, bun, httpServer, ignoredPaths)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "setupAuth") return nil, errors.Wrap(err, "setupAuth")
} }
@@ -74,7 +75,7 @@ func setupHTTPServer(
return nil, errors.Wrap(err, "addRoutes") return nil, errors.Wrap(err, "addRoutes")
} }
err = addMiddleware(httpServer, auth, cfg, perms) err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "addMiddleware") return nil, errors.Wrap(err, "addMiddleware")
} }

View File

@@ -38,7 +38,7 @@ func main() {
} }
return return
} }
//
// Setup the logger // Setup the logger
logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout) logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout)
if err != nil { if err != nil {

View File

@@ -4,13 +4,17 @@ import (
"context" "context"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time"
"git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/contexts" "git.haelnorr.com/h/oslstats/internal/contexts"
"git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/rbac" "git.haelnorr.com/h/oslstats/internal/rbac"
"git.haelnorr.com/h/oslstats/internal/store"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -21,9 +25,11 @@ func addMiddleware(
auth *hwsauth.Authenticator[*db.User, bun.Tx], auth *hwsauth.Authenticator[*db.User, bun.Tx],
cfg *config.Config, cfg *config.Config,
perms *rbac.Checker, perms *rbac.Checker,
discordAPI *discord.APIClient,
store *store.Store,
) error { ) error {
err := server.AddMiddleware( err := server.AddMiddleware(
auth.Authenticate(), auth.Authenticate(tokenRefresh(auth, discordAPI, store)),
perms.LoadPermissionsMiddleware(), perms.LoadPermissionsMiddleware(),
devMode(cfg), devMode(cfg),
) )
@@ -51,3 +57,111 @@ func devMode(cfg *config.Config) hws.Middleware {
) )
} }
} }
func tokenRefresh(
auth *hwsauth.Authenticator[*db.User, bun.Tx],
discordAPI *discord.APIClient,
store *store.Store,
) func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
return func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
success, err := refreshToken(ctx, store, discordAPI, user, tx)
if err != nil {
return false, &hws.HWSError{
Error: errors.Wrap(err, "refreshToken"),
Message: "Error refreshing discord token",
Level: hws.ErrorERROR,
RenderErrorPage: true,
StatusCode: http.StatusInternalServerError,
}
}
if !success {
err = auth.Logout(tx, w, r)
if err != nil {
return false, &hws.HWSError{
Error: errors.Wrap(err, "auth.Logout"),
Message: "Logout failed",
Level: hws.ErrorERROR,
RenderErrorPage: true,
StatusCode: http.StatusInternalServerError,
}
}
return false, nil
}
return true, nil
}
}
func refreshToken(
ctx context.Context,
store *store.Store,
discordAPI *discord.APIClient,
user *db.User,
tx bun.Tx,
) (bool, error) {
token := store.CheckToken(user)
if token != nil {
return true, nil
}
// Get the token
token, err := user.GetDiscordToken(ctx, tx)
if err != nil {
return false, errors.Wrap(err, "user.GetDiscordToken")
}
tokenstatus, err := tokenStatus(token)
if err != nil {
return false, errors.Wrap(err, "tokenStatus")
}
switch tokenstatus {
case "revoked":
return false, nil
case "expired", "expiring":
newtoken, err := discordAPI.RefreshToken(token.Convert())
if err != nil {
return false, errors.Wrap(err, "discordAPI.RefreshToken")
}
err = user.UpdateDiscordToken(ctx, tx, newtoken)
if err != nil {
return false, errors.Wrap(err, "user.UpdateDiscordToken")
}
err = store.NewTokenCheck(user, token)
if err != nil {
return false, errors.Wrap(err, "store.NewTokenCheck")
}
return true, nil
case "valid":
err = store.NewTokenCheck(user, token)
if err != nil {
return false, errors.Wrap(err, "store.NewTokenCheck")
}
return true, nil
default:
return false, errors.New("unexpected error occured validating discord token for user")
}
}
func tokenStatus(token *db.DiscordToken) (string, error) {
now := time.Now().Unix()
dayfromnow := now + int64(24*time.Hour/time.Second)
oauthtoken := token.Convert()
session, err := discord.NewOAuthSession(oauthtoken)
if err != nil {
return "", errors.Wrap(err, "discord.NewOAuthSession")
}
_, err = session.GetUser()
if err != nil {
if !strings.Contains(err.Error(), "HTTP 401") {
// Error not related to token status
return "", errors.Wrap(err, "session.GetUser")
}
// Token not valid
if token.ExpiresAt < now {
return "expired", nil
}
return "revoked", nil
}
if token.ExpiresAt < dayfromnow {
return "expiring", nil
}
return "valid", nil
}

2
go.mod
View File

@@ -7,7 +7,7 @@ require (
git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/ezconf v0.1.1
git.haelnorr.com/h/golib/hlog v0.10.4 git.haelnorr.com/h/golib/hlog v0.10.4
git.haelnorr.com/h/golib/hws v0.5.0 git.haelnorr.com/h/golib/hws v0.5.0
git.haelnorr.com/h/golib/hwsauth v0.5.5 git.haelnorr.com/h/golib/hwsauth v0.6.1
git.haelnorr.com/h/golib/notify v0.1.0 git.haelnorr.com/h/golib/notify v0.1.0
github.com/a-h/templ v0.3.977 github.com/a-h/templ v0.3.977
github.com/coder/websocket v1.8.14 github.com/coder/websocket v1.8.14

6
go.sum
View File

@@ -8,8 +8,10 @@ git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4V
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c= git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM= git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
git.haelnorr.com/h/golib/hwsauth v0.5.5 h1:w1qssktq0zYo5cC/xa44h/ZE5G5r7rIsJ4QQWq2Jeoo= git.haelnorr.com/h/golib/hwsauth v0.6.0 h1:qaRsgfcD7M6xCasfXwP6Ww9RM4TwDqYMFK2YtO6nt6c=
git.haelnorr.com/h/golib/hwsauth v0.5.5/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA= git.haelnorr.com/h/golib/hwsauth v0.6.0/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
git.haelnorr.com/h/golib/hwsauth v0.6.1 h1:3BiM6hwuYDjgfu02hshvUtr592DnWi9Epj//3N13ti0=
git.haelnorr.com/h/golib/hwsauth v0.6.1/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10= git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=

View File

@@ -1,3 +1,4 @@
// Package store provides a session store for caching data
package store package store
import ( import (
@@ -22,6 +23,7 @@ type Store struct {
sessions sync.Map // key: string, value: *RegistrationSession sessions sync.Map // key: string, value: *RegistrationSession
redirectTracks sync.Map // key: string, value: *RedirectTrack redirectTracks sync.Map // key: string, value: *RedirectTrack
cleanup *time.Ticker cleanup *time.Ticker
tokenchecks sync.Map // key: int, value: *TokenCheck
} }
func NewStore() *Store { func NewStore() *Store {
@@ -42,6 +44,7 @@ func NewStore() *Store {
func (s *Store) Delete(id string) { func (s *Store) Delete(id string) {
s.sessions.Delete(id) s.sessions.Delete(id)
} }
func (s *Store) cleanupExpired() { func (s *Store) cleanupExpired() {
now := time.Now() now := time.Now()
@@ -62,10 +65,19 @@ func (s *Store) cleanupExpired() {
} }
return true return true
}) })
s.tokenchecks.Range(func(key, value any) bool {
check := value.(*TokenCheck)
if now.After(check.ExpiresAt) {
s.tokenchecks.Delete(key)
}
return true
})
} }
func generateID() string { func generateID() string {
b := make([]byte, 32) b := make([]byte, 32)
rand.Read(b) _, _ = rand.Read(b)
return base64.RawURLEncoding.EncodeToString(b) return base64.RawURLEncoding.EncodeToString(b)
} }

View File

@@ -0,0 +1,39 @@
package store
import (
"errors"
"time"
"git.haelnorr.com/h/oslstats/internal/db"
)
type TokenCheck struct {
Token *db.DiscordToken
ExpiresAt time.Time
}
func (s *Store) NewTokenCheck(user *db.User, token *db.DiscordToken) error {
if user == nil {
return errors.New("user cannot be nil")
}
if token == nil {
return errors.New("token cannot be nil")
}
s.tokenchecks.Store(user.ID, &TokenCheck{
Token: token,
ExpiresAt: time.Now().Add(5 * time.Minute),
})
return nil
}
func (s *Store) CheckToken(user *db.User) *db.DiscordToken {
res, ok := s.tokenchecks.Load(user.ID)
if !ok {
return nil
}
check := (res).(*TokenCheck)
check.ExpiresAt = time.Now().Add(5 * time.Minute)
s.tokenchecks.Delete(user.ID)
s.tokenchecks.Store(user.ID, check)
return check.Token
}