215 lines
5.1 KiB
Go
215 lines
5.1 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.haelnorr.com/h/golib/cookies"
|
|
"git.haelnorr.com/h/golib/hws"
|
|
"git.haelnorr.com/h/golib/hwsauth"
|
|
"github.com/pkg/errors"
|
|
"github.com/uptrace/bun"
|
|
|
|
"git.haelnorr.com/h/oslstats/internal/config"
|
|
"git.haelnorr.com/h/oslstats/internal/db"
|
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
|
"git.haelnorr.com/h/oslstats/internal/store"
|
|
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
|
)
|
|
|
|
func Callback(
|
|
server *hws.Server,
|
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
|
conn *bun.DB,
|
|
cfg *config.Config,
|
|
store *store.Store,
|
|
discordAPI *discord.APIClient,
|
|
) http.Handler {
|
|
return http.HandlerFunc(
|
|
func(w http.ResponseWriter, r *http.Request) {
|
|
attempts, exceeded, track := store.TrackRedirect(r, "/callback", 5)
|
|
|
|
if exceeded {
|
|
err := errors.Errorf(
|
|
"callback redirect loop detected after %d attempts | ip=%s ua=%s path=%s first_seen=%s",
|
|
attempts,
|
|
track.IP,
|
|
track.UserAgent,
|
|
track.Path,
|
|
track.FirstSeen.Format("2006-01-02T15:04:05Z07:00"),
|
|
)
|
|
|
|
store.ClearRedirectTrack(r, "/callback")
|
|
|
|
throwError(
|
|
server,
|
|
w,
|
|
r,
|
|
http.StatusBadRequest,
|
|
"OAuth callback failed: Too many redirect attempts. Please try logging in again.",
|
|
err,
|
|
"warn",
|
|
)
|
|
return
|
|
}
|
|
|
|
state := r.URL.Query().Get("state")
|
|
code := r.URL.Query().Get("code")
|
|
if state == "" && code == "" {
|
|
http.Redirect(w, r, "/", http.StatusBadRequest)
|
|
return
|
|
}
|
|
data, err := verifyState(cfg.OAuth, w, r, state)
|
|
if err != nil {
|
|
if vsErr, ok := err.(*verifyStateError); ok {
|
|
if vsErr.IsCookieError() {
|
|
throwUnauthorized(server, w, r, "OAuth session not found or expired", err)
|
|
} else {
|
|
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
|
}
|
|
} else {
|
|
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
|
}
|
|
return
|
|
}
|
|
store.ClearRedirectTrack(r, "/callback")
|
|
|
|
switch data {
|
|
case "login":
|
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
|
defer cancel()
|
|
tx, err := conn.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
throwInternalServiceError(server, w, r, "DB Transaction failed to start", err)
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI)
|
|
if err != nil {
|
|
throwInternalServiceError(server, w, r, "OAuth login failed", err)
|
|
return
|
|
}
|
|
tx.Commit()
|
|
redirect()
|
|
return
|
|
}
|
|
},
|
|
)
|
|
}
|
|
|
|
type verifyStateError struct {
|
|
err error
|
|
cookieError bool
|
|
}
|
|
|
|
func (e *verifyStateError) Error() string {
|
|
return e.err.Error()
|
|
}
|
|
|
|
func (e *verifyStateError) IsCookieError() bool {
|
|
return e.cookieError
|
|
}
|
|
|
|
func verifyState(
|
|
cfg *oauth.Config,
|
|
w http.ResponseWriter,
|
|
r *http.Request,
|
|
state string,
|
|
) (string, error) {
|
|
if r == nil {
|
|
return "", errors.New("request cannot be nil")
|
|
}
|
|
if state == "" {
|
|
return "", errors.New("state param field is empty")
|
|
}
|
|
|
|
uak, err := oauth.GetStateCookie(r)
|
|
if err != nil {
|
|
return "", &verifyStateError{
|
|
err: errors.Wrap(err, "oauth.GetStateCookie"),
|
|
cookieError: true,
|
|
}
|
|
}
|
|
|
|
data, err := oauth.VerifyState(cfg, state, uak)
|
|
if err != nil {
|
|
return "", &verifyStateError{
|
|
err: errors.Wrap(err, "oauth.VerifyState"),
|
|
cookieError: false,
|
|
}
|
|
}
|
|
|
|
oauth.DeleteStateCookie(w)
|
|
return data, nil
|
|
}
|
|
|
|
func login(
|
|
ctx context.Context,
|
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
|
tx bun.Tx,
|
|
cfg *config.Config,
|
|
w http.ResponseWriter,
|
|
r *http.Request,
|
|
code string,
|
|
store *store.Store,
|
|
discordAPI *discord.APIClient,
|
|
) (func(), error) {
|
|
token, err := discordAPI.AuthorizeWithCode(code)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "discordAPI.AuthorizeWithCode")
|
|
}
|
|
session, err := discord.NewOAuthSession(token)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "discord.NewOAuthSession")
|
|
}
|
|
discorduser, err := session.GetUser()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "session.GetUser")
|
|
}
|
|
|
|
user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "db.GetUserByDiscordID")
|
|
}
|
|
var redirect string
|
|
if user == nil {
|
|
sessionID, err := store.CreateRegistrationSession(discorduser, token)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "store.CreateRegistrationSession")
|
|
}
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "registration_session",
|
|
Path: "/",
|
|
Value: sessionID,
|
|
MaxAge: 300, // 5 minutes
|
|
HttpOnly: true,
|
|
Secure: cfg.HWSAuth.SSL,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
redirect = "/register"
|
|
} else {
|
|
err = user.UpdateDiscordToken(ctx, tx, token)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "user.UpdateDiscordToken")
|
|
}
|
|
|
|
// Check if user should be granted admin role (environment-based)
|
|
if shouldGrantAdmin(user, cfg.RBAC) {
|
|
err := ensureUserHasAdminRole(ctx, tx, user)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "ensureUserHasAdminRole")
|
|
}
|
|
}
|
|
|
|
err := auth.Login(w, r, user, true)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "auth.Login")
|
|
}
|
|
redirect = cookies.CheckPageFrom(w, r)
|
|
}
|
|
return func() {
|
|
http.Redirect(w, r, redirect, http.StatusSeeOther)
|
|
}, nil
|
|
}
|