added discord token check to auth
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,6 +8,7 @@ tmp/
|
||||
static/css/output.css
|
||||
internal/view/**/*_templ.go
|
||||
internal/view/**/*_templ.txt
|
||||
cmd/test/*
|
||||
|
||||
# Database backups (compressed)
|
||||
backups/*.sql.gz
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func setupAuth(
|
||||
config *hwsauth.Config,
|
||||
cfg *hwsauth.Config,
|
||||
logger *hlog.Logger,
|
||||
conn *bun.DB,
|
||||
server *hws.Server,
|
||||
@@ -24,7 +24,7 @@ func setupAuth(
|
||||
return tx, err
|
||||
}
|
||||
auth, err := hwsauth.NewAuthenticator(
|
||||
config,
|
||||
cfg,
|
||||
db.GetUserByID,
|
||||
server,
|
||||
beginTx,
|
||||
@@ -36,11 +36,12 @@ func setupAuth(
|
||||
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
|
||||
}
|
||||
|
||||
auth.IgnorePaths(ignoredPaths...)
|
||||
err = auth.IgnorePaths(ignoredPaths...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "auth.IgnorePaths")
|
||||
}
|
||||
|
||||
db.CurrentUser = auth.CurrentModel
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// TODO: make a new getuser function that wraps db.GetUserByID and does OAuth refresh
|
||||
|
||||
@@ -40,7 +40,8 @@ func setupHTTPServer(
|
||||
"/ws/notifications",
|
||||
}
|
||||
|
||||
auth, err := setupAuth(cfg.HWSAuth, logger, bun, httpServer, ignoredPaths)
|
||||
auth, err := setupAuth(
|
||||
cfg.HWSAuth, logger, bun, httpServer, ignoredPaths)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "setupAuth")
|
||||
}
|
||||
@@ -74,7 +75,7 @@ func setupHTTPServer(
|
||||
return nil, errors.Wrap(err, "addRoutes")
|
||||
}
|
||||
|
||||
err = addMiddleware(httpServer, auth, cfg, perms)
|
||||
err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "addMiddleware")
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func main() {
|
||||
}
|
||||
return
|
||||
}
|
||||
//
|
||||
|
||||
// Setup the logger
|
||||
logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,13 +4,17 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/contexts"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/rbac"
|
||||
"git.haelnorr.com/h/oslstats/internal/store"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
@@ -21,9 +25,11 @@ func addMiddleware(
|
||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||
cfg *config.Config,
|
||||
perms *rbac.Checker,
|
||||
discordAPI *discord.APIClient,
|
||||
store *store.Store,
|
||||
) error {
|
||||
err := server.AddMiddleware(
|
||||
auth.Authenticate(),
|
||||
auth.Authenticate(tokenRefresh(auth, discordAPI, store)),
|
||||
perms.LoadPermissionsMiddleware(),
|
||||
devMode(cfg),
|
||||
)
|
||||
@@ -51,3 +57,111 @@ func devMode(cfg *config.Config) hws.Middleware {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func tokenRefresh(
|
||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||
discordAPI *discord.APIClient,
|
||||
store *store.Store,
|
||||
) func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
|
||||
return func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
|
||||
success, err := refreshToken(ctx, store, discordAPI, user, tx)
|
||||
if err != nil {
|
||||
return false, &hws.HWSError{
|
||||
Error: errors.Wrap(err, "refreshToken"),
|
||||
Message: "Error refreshing discord token",
|
||||
Level: hws.ErrorERROR,
|
||||
RenderErrorPage: true,
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
if !success {
|
||||
err = auth.Logout(tx, w, r)
|
||||
if err != nil {
|
||||
return false, &hws.HWSError{
|
||||
Error: errors.Wrap(err, "auth.Logout"),
|
||||
Message: "Logout failed",
|
||||
Level: hws.ErrorERROR,
|
||||
RenderErrorPage: true,
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
func refreshToken(
|
||||
ctx context.Context,
|
||||
store *store.Store,
|
||||
discordAPI *discord.APIClient,
|
||||
user *db.User,
|
||||
tx bun.Tx,
|
||||
) (bool, error) {
|
||||
token := store.CheckToken(user)
|
||||
if token != nil {
|
||||
return true, nil
|
||||
}
|
||||
// Get the token
|
||||
token, err := user.GetDiscordToken(ctx, tx)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "user.GetDiscordToken")
|
||||
}
|
||||
|
||||
tokenstatus, err := tokenStatus(token)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tokenStatus")
|
||||
}
|
||||
switch tokenstatus {
|
||||
case "revoked":
|
||||
return false, nil
|
||||
case "expired", "expiring":
|
||||
newtoken, err := discordAPI.RefreshToken(token.Convert())
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "discordAPI.RefreshToken")
|
||||
}
|
||||
err = user.UpdateDiscordToken(ctx, tx, newtoken)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "user.UpdateDiscordToken")
|
||||
}
|
||||
err = store.NewTokenCheck(user, token)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "store.NewTokenCheck")
|
||||
}
|
||||
return true, nil
|
||||
case "valid":
|
||||
err = store.NewTokenCheck(user, token)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "store.NewTokenCheck")
|
||||
}
|
||||
return true, nil
|
||||
default:
|
||||
return false, errors.New("unexpected error occured validating discord token for user")
|
||||
}
|
||||
}
|
||||
|
||||
func tokenStatus(token *db.DiscordToken) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
dayfromnow := now + int64(24*time.Hour/time.Second)
|
||||
oauthtoken := token.Convert()
|
||||
session, err := discord.NewOAuthSession(oauthtoken)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "discord.NewOAuthSession")
|
||||
}
|
||||
_, err = session.GetUser()
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "HTTP 401") {
|
||||
// Error not related to token status
|
||||
return "", errors.Wrap(err, "session.GetUser")
|
||||
}
|
||||
// Token not valid
|
||||
if token.ExpiresAt < now {
|
||||
return "expired", nil
|
||||
}
|
||||
return "revoked", nil
|
||||
}
|
||||
if token.ExpiresAt < dayfromnow {
|
||||
return "expiring", nil
|
||||
}
|
||||
return "valid", nil
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -7,7 +7,7 @@ require (
|
||||
git.haelnorr.com/h/golib/ezconf v0.1.1
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4
|
||||
git.haelnorr.com/h/golib/hws v0.5.0
|
||||
git.haelnorr.com/h/golib/hwsauth v0.5.5
|
||||
git.haelnorr.com/h/golib/hwsauth v0.6.1
|
||||
git.haelnorr.com/h/golib/notify v0.1.0
|
||||
github.com/a-h/templ v0.3.977
|
||||
github.com/coder/websocket v1.8.14
|
||||
|
||||
6
go.sum
6
go.sum
@@ -8,8 +8,10 @@ git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4V
|
||||
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
|
||||
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
|
||||
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.5.5 h1:w1qssktq0zYo5cC/xa44h/ZE5G5r7rIsJ4QQWq2Jeoo=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.5.5/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.6.0 h1:qaRsgfcD7M6xCasfXwP6Ww9RM4TwDqYMFK2YtO6nt6c=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.6.0/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.6.1 h1:3BiM6hwuYDjgfu02hshvUtr592DnWi9Epj//3N13ti0=
|
||||
git.haelnorr.com/h/golib/hwsauth v0.6.1/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
|
||||
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package store provides a session store for caching data
|
||||
package store
|
||||
|
||||
import (
|
||||
@@ -22,6 +23,7 @@ type Store struct {
|
||||
sessions sync.Map // key: string, value: *RegistrationSession
|
||||
redirectTracks sync.Map // key: string, value: *RedirectTrack
|
||||
cleanup *time.Ticker
|
||||
tokenchecks sync.Map // key: int, value: *TokenCheck
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
@@ -42,6 +44,7 @@ func NewStore() *Store {
|
||||
func (s *Store) Delete(id string) {
|
||||
s.sessions.Delete(id)
|
||||
}
|
||||
|
||||
func (s *Store) cleanupExpired() {
|
||||
now := time.Now()
|
||||
|
||||
@@ -62,10 +65,19 @@ func (s *Store) cleanupExpired() {
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
s.tokenchecks.Range(func(key, value any) bool {
|
||||
check := value.(*TokenCheck)
|
||||
if now.After(check.ExpiresAt) {
|
||||
s.tokenchecks.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func generateID() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
_, _ = rand.Read(b)
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
|
||||
39
internal/store/tokencheck.go
Normal file
39
internal/store/tokencheck.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
)
|
||||
|
||||
type TokenCheck struct {
|
||||
Token *db.DiscordToken
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func (s *Store) NewTokenCheck(user *db.User, token *db.DiscordToken) error {
|
||||
if user == nil {
|
||||
return errors.New("user cannot be nil")
|
||||
}
|
||||
if token == nil {
|
||||
return errors.New("token cannot be nil")
|
||||
}
|
||||
s.tokenchecks.Store(user.ID, &TokenCheck{
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) CheckToken(user *db.User) *db.DiscordToken {
|
||||
res, ok := s.tokenchecks.Load(user.ID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
check := (res).(*TokenCheck)
|
||||
check.ExpiresAt = time.Now().Add(5 * time.Minute)
|
||||
s.tokenchecks.Delete(user.ID)
|
||||
s.tokenchecks.Store(user.ID, check)
|
||||
return check.Token
|
||||
}
|
||||
Reference in New Issue
Block a user