Files
oslstats/internal/handlers/callback.go
2026-02-15 12:27:36 +11:00

204 lines
5.2 KiB
Go

package handlers
import (
"context"
"net/http"
"git.haelnorr.com/h/golib/cookies"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"git.haelnorr.com/h/oslstats/internal/config"
"git.haelnorr.com/h/oslstats/internal/db"
"git.haelnorr.com/h/oslstats/internal/discord"
"git.haelnorr.com/h/oslstats/internal/store"
"git.haelnorr.com/h/oslstats/internal/throw"
"git.haelnorr.com/h/oslstats/internal/validation"
"git.haelnorr.com/h/oslstats/pkg/oauth"
)
func Callback(
s *hws.Server,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
conn *db.DB,
cfg *config.Config,
store *store.Store,
discordAPI *discord.APIClient,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
if exceeded {
err := track.Error(attempts)
store.ClearRedirectTrack(r, "/callback")
throw.BadRequest(s, w, r, "Too many redirects. Please try logging in again.", err)
return
}
getter := validation.NewQueryGetter(r)
state := getter.String("state").Required().Value
code := getter.String("code").Required().Value
if !getter.Validate() {
store.ClearRedirectTrack(r, "/callback")
apiErr := getter.String("error").Value
errDesc := getter.String("error_description").Value
if apiErr == "access_denied" {
throw.Unauthorized(s, w, r, "OAuth login failed or cancelled", errors.New(errDesc))
return
}
throw.BadRequest(s, w, r, "OAuth login failed", errors.New("state or code parameters missing"))
return
}
data, err := verifyState(cfg.OAuth, w, r, state)
if err != nil {
store.ClearRedirectTrack(r, "/callback")
if vsErr, ok := err.(*verifyStateError); ok {
if vsErr.IsCookieError() {
throw.Unauthorized(s, w, r, "OAuth session not found or expired", err)
} else {
throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
}
} else {
throw.ForbiddenSecurity(s, w, r, "OAuth state verification failed", err)
}
return
}
store.ClearRedirectTrack(r, "/callback")
switch data {
case "login":
var redirect func()
if ok := conn.WithWriteTx(s, w, r, func(ctx context.Context, tx bun.Tx) (bool, error) {
redirect, err = login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
if err != nil {
throw.InternalServiceError(s, w, r, "OAuth login failed", err)
return false, nil
}
return true, nil
}); !ok {
return
}
redirect()
return
}
},
)
}
type verifyStateError struct {
err error
cookieError bool
}
func (e *verifyStateError) Error() string {
return e.err.Error()
}
func (e *verifyStateError) IsCookieError() bool {
return e.cookieError
}
func verifyState(
cfg *oauth.Config,
w http.ResponseWriter,
r *http.Request,
state string,
) (string, error) {
if r == nil {
return "", errors.New("request cannot be nil")
}
if state == "" {
return "", errors.New("state param field is empty")
}
uak, err := oauth.GetStateCookie(r)
if err != nil {
return "", &verifyStateError{
err: errors.Wrap(err, "oauth.GetStateCookie"),
cookieError: true,
}
}
data, err := oauth.VerifyState(cfg, state, uak)
if err != nil {
return "", &verifyStateError{
err: errors.Wrap(err, "oauth.VerifyState"),
cookieError: false,
}
}
oauth.DeleteStateCookie(w)
return data, nil
}
func login(
ctx context.Context,
auth *hwsauth.Authenticator[*db.User, bun.Tx],
tx bun.Tx,
cfg *config.Config,
w http.ResponseWriter,
r *http.Request,
code string,
store *store.Store,
discordAPI *discord.APIClient,
) (func(), error) {
token, err := discordAPI.AuthorizeWithCode(code)
if err != nil {
return nil, errors.Wrap(err, "discordAPI.AuthorizeWithCode")
}
session, err := discord.NewOAuthSession(token)
if err != nil {
return nil, errors.Wrap(err, "discord.NewOAuthSession")
}
discorduser, err := session.GetUser()
if err != nil {
return nil, errors.Wrap(err, "session.GetUser")
}
user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID)
if err != nil && !db.IsBadRequest(err) {
return nil, errors.Wrap(err, "db.GetUserByDiscordID")
}
var redirect string
if user == nil {
sessionID, err := store.CreateRegistrationSession(discorduser, token)
if err != nil {
return nil, errors.Wrap(err, "store.CreateRegistrationSession")
}
http.SetCookie(w, &http.Cookie{
Name: "registration_session",
Path: "/",
Value: sessionID,
MaxAge: 300, // 5 minutes
HttpOnly: true,
Secure: cfg.HWSAuth.SSL,
SameSite: http.SameSiteLaxMode,
})
redirect = "/register"
} else {
err = user.UpdateDiscordToken(ctx, tx, token)
if err != nil {
return nil, errors.Wrap(err, "user.UpdateDiscordToken")
}
// Check if user should be granted admin role (environment-based)
if shouldGrantAdmin(user, cfg.RBAC) {
err := ensureUserHasAdminRole(ctx, tx, user)
if err != nil {
return nil, errors.Wrap(err, "ensureUserHasAdminRole")
}
}
err := auth.Login(w, r, user, true)
if err != nil {
return nil, errors.Wrap(err, "auth.Login")
}
redirect = cookies.CheckPageFrom(w, r)
}
return func() {
http.Redirect(w, r, redirect, http.StatusSeeOther)
}, nil
}