updated stuff
This commit is contained in:
@@ -1,15 +1,21 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"git.haelnorr.com/h/oslstats/internal/session"
|
||||
"git.haelnorr.com/h/oslstats/pkg/oauth"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func Callback(server *hws.Server, cfg *config.Config) http.Handler {
|
||||
func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *session.Store) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
state := r.URL.Query().Get("state")
|
||||
@@ -20,42 +26,141 @@ func Callback(server *hws.Server, cfg *config.Config) http.Handler {
|
||||
}
|
||||
data, err := verifyState(cfg.OAuth, w, r, state)
|
||||
if err != nil {
|
||||
err = server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "OAuth state verification failed",
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel("debug"),
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
// Check if this is a cookie error (401) or signature error (403)
|
||||
if vsErr, ok := err.(*verifyStateError); ok {
|
||||
if vsErr.IsCookieError() {
|
||||
// Cookie missing/expired - normal failed/expired session (DEBUG)
|
||||
throwUnauthorized(server, w, r, "OAuth session not found or expired", err)
|
||||
} else {
|
||||
// Signature verification failed - security violation (WARN)
|
||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
||||
}
|
||||
} else {
|
||||
// Unknown error type - treat as security issue
|
||||
throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
switch data {
|
||||
case "login":
|
||||
w.Write([]byte(code))
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "DB Transaction failed to start", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
redirect, err := login(ctx, tx, cfg, w, r, code, store)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "OAuth login failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
redirect()
|
||||
return
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func verifyState(cfg *oauth.Config, w http.ResponseWriter, r *http.Request, state string) (string, error) {
|
||||
// verifyStateError wraps an error with context about what went wrong
|
||||
type verifyStateError struct {
|
||||
err error
|
||||
cookieError bool // true if cookie missing/invalid, false if signature invalid
|
||||
}
|
||||
|
||||
func (e *verifyStateError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *verifyStateError) IsCookieError() bool {
|
||||
return e.cookieError
|
||||
}
|
||||
|
||||
func verifyState(
|
||||
cfg *oauth.Config,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
state string,
|
||||
) (string, error) {
|
||||
if r == nil {
|
||||
return "", errors.New("request cannot be nil")
|
||||
}
|
||||
if state == "" {
|
||||
return "", errors.New("state param field is empty")
|
||||
}
|
||||
|
||||
// Try to get the cookie
|
||||
uak, err := oauth.GetStateCookie(r)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "oauth.GetStateCookie")
|
||||
// Cookie missing or invalid - this is a 401 (not authenticated)
|
||||
return "", &verifyStateError{
|
||||
err: errors.Wrap(err, "oauth.GetStateCookie"),
|
||||
cookieError: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the state signature
|
||||
data, err := oauth.VerifyState(cfg, state, uak)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "oauth.VerifyState")
|
||||
// Signature verification failed - this is a 403 (security violation)
|
||||
return "", &verifyStateError{
|
||||
err: errors.Wrap(err, "oauth.VerifyState"),
|
||||
cookieError: false,
|
||||
}
|
||||
}
|
||||
|
||||
oauth.DeleteStateCookie(w)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func login(
|
||||
ctx context.Context,
|
||||
tx bun.Tx,
|
||||
cfg *config.Config,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
code string,
|
||||
store *session.Store,
|
||||
) (func(), error) {
|
||||
token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "discord.AuthorizeWithCode")
|
||||
}
|
||||
session, err := discord.NewOAuthSession(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "discord.NewOAuthSession")
|
||||
}
|
||||
discorduser, err := session.GetUser()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "session.GetUser")
|
||||
}
|
||||
|
||||
user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserByDiscordID")
|
||||
}
|
||||
var redirect string
|
||||
if user == nil {
|
||||
sessionID, err := store.CreateRegistrationSession(discorduser, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "store.CreateRegistrationSession")
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "registration_session",
|
||||
Path: "/",
|
||||
Value: sessionID,
|
||||
MaxAge: 300, // 5 minutes
|
||||
HttpOnly: true,
|
||||
Secure: cfg.HWSAuth.SSL,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
redirect = "/register"
|
||||
} else {
|
||||
// TODO: log them in
|
||||
}
|
||||
return func() {
|
||||
http.Redirect(w, r, redirect, http.StatusSeeOther)
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -5,24 +5,93 @@ import (
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func ErrorPage(
|
||||
errorCode int,
|
||||
) (hws.ErrorPage, error) {
|
||||
// func ErrorPage(
|
||||
// error hws.HWSError,
|
||||
// ) (hws.ErrorPage, error) {
|
||||
// messages := map[int]string{
|
||||
// 400: "The request you made was malformed or unexpected.",
|
||||
// 401: "You need to login to view this page.",
|
||||
// 403: "You do not have permission to view this page.",
|
||||
// 404: "The page or resource you have requested does not exist.",
|
||||
// 500: `An error occured on the server. Please try again, and if this
|
||||
// continues to happen contact an administrator.`,
|
||||
// 503: "The server is currently down for maintenance and should be back soon. =)",
|
||||
// }
|
||||
// msg, exists := messages[error.StatusCode]
|
||||
// if !exists {
|
||||
// return nil, errors.New("No valid message for the given code")
|
||||
// }
|
||||
// return page.Error(error.StatusCode, http.StatusText(error.StatusCode), msg), nil
|
||||
// }
|
||||
|
||||
func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) {
|
||||
// Determine if this status code should show technical details
|
||||
showDetails := shouldShowDetails(hwsError.StatusCode)
|
||||
|
||||
// Get the user-friendly message
|
||||
message := hwsError.Message
|
||||
if message == "" {
|
||||
// Fallback to default messages if no custom message provided
|
||||
message = getDefaultMessage(hwsError.StatusCode)
|
||||
}
|
||||
|
||||
// Get technical details if applicable
|
||||
var details string
|
||||
if showDetails && hwsError.Error != nil {
|
||||
details = hwsError.Error.Error()
|
||||
}
|
||||
|
||||
// Render appropriate template
|
||||
if details != "" {
|
||||
return page.ErrorWithDetails(
|
||||
hwsError.StatusCode,
|
||||
http.StatusText(hwsError.StatusCode),
|
||||
message,
|
||||
details,
|
||||
), nil
|
||||
}
|
||||
|
||||
return page.Error(
|
||||
hwsError.StatusCode,
|
||||
http.StatusText(hwsError.StatusCode),
|
||||
message,
|
||||
), nil
|
||||
}
|
||||
|
||||
// shouldShowDetails determines if a status code should display technical details
|
||||
func shouldShowDetails(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 400, 500, 503: // Bad Request, Internal Server Error, Service Unavailable
|
||||
return true
|
||||
case 401, 403, 404: // Unauthorized, Forbidden, Not Found
|
||||
return false
|
||||
default:
|
||||
// For unknown codes, show details for 5xx errors
|
||||
return statusCode >= 500
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultMessage provides fallback messages for status codes
|
||||
func getDefaultMessage(statusCode int) string {
|
||||
messages := map[int]string{
|
||||
400: "The request you made was malformed or unexpected.",
|
||||
401: "You need to login to view this page.",
|
||||
403: "You do not have permission to view this page.",
|
||||
404: "The page or resource you have requested does not exist.",
|
||||
500: `An error occured on the server. Please try again, and if this
|
||||
continues to happen contact an administrator.`,
|
||||
500: `An error occurred on the server. Please try again, and if this
|
||||
continues to happen contact an administrator.`,
|
||||
503: "The server is currently down for maintenance and should be back soon. =)",
|
||||
}
|
||||
msg, exists := messages[errorCode]
|
||||
|
||||
msg, exists := messages[statusCode]
|
||||
if !exists {
|
||||
return nil, errors.New("No valid message for the given code")
|
||||
if statusCode >= 500 {
|
||||
return "A server error occurred. Please try again later."
|
||||
}
|
||||
return "An error occurred while processing your request."
|
||||
}
|
||||
return page.Error(errorCode, http.StatusText(errorCode), msg), nil
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
109
internal/handlers/errors.go
Normal file
109
internal/handlers/errors.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// throwError is a generic helper that all throw* functions use internally
|
||||
func throwError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
statusCode int,
|
||||
msg string,
|
||||
err error,
|
||||
level string,
|
||||
) {
|
||||
err = s.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: statusCode,
|
||||
Message: msg,
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel(level),
|
||||
RenderErrorPage: true, // throw* family always renders error pages
|
||||
})
|
||||
if err != nil {
|
||||
s.ThrowFatal(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
// throwInternalServiceError handles 500 errors (server failures)
|
||||
func throwInternalServiceError(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusInternalServerError, msg, err, "error")
|
||||
}
|
||||
|
||||
// throwBadRequest handles 400 errors (malformed requests)
|
||||
func throwBadRequest(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusBadRequest, msg, err, "debug")
|
||||
}
|
||||
|
||||
// throwForbidden handles 403 errors (normal permission denials)
|
||||
func throwForbidden(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusForbidden, msg, err, "debug")
|
||||
}
|
||||
|
||||
// throwForbiddenSecurity handles 403 errors for security events (uses WARN level)
|
||||
func throwForbiddenSecurity(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusForbidden, msg, err, "warn")
|
||||
}
|
||||
|
||||
// throwUnauthorized handles 401 errors (not authenticated)
|
||||
func throwUnauthorized(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, "debug")
|
||||
}
|
||||
|
||||
// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level)
|
||||
func throwUnauthorizedSecurity(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
msg string,
|
||||
err error,
|
||||
) {
|
||||
throwError(s, w, r, http.StatusUnauthorized, msg, err, "warn")
|
||||
}
|
||||
|
||||
// throwNotFound handles 404 errors
|
||||
func throwNotFound(
|
||||
s *hws.Server,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
path string,
|
||||
) {
|
||||
msg := fmt.Sprintf("The requested resource was not found: %s", path)
|
||||
err := errors.New("Resource not found")
|
||||
throwError(s, w, r, http.StatusNotFound, msg, err, "debug")
|
||||
}
|
||||
@@ -14,34 +14,7 @@ func Index(server *hws.Server) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
page, err := ErrorPage(http.StatusNotFound)
|
||||
if err != nil {
|
||||
err = server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured trying to generate the error page",
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel("error"),
|
||||
RenderErrorPage: false,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
err = page.Render(r.Context(), w)
|
||||
if err != nil {
|
||||
err = server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured trying to render the error page",
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel("error"),
|
||||
RenderErrorPage: false,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
throwNotFound(server, w, r, r.URL.Path)
|
||||
}
|
||||
page.Index().Render(r.Context(), w)
|
||||
},
|
||||
|
||||
@@ -14,32 +14,14 @@ func Login(server *hws.Server, cfg *config.Config) http.Handler {
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
state, uak, err := oauth.GenerateState(cfg.OAuth, "login")
|
||||
if err != nil {
|
||||
err = server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "Failed to generate state token",
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel("error"),
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
throwInternalServiceError(server, w, r, "Failed to generate state token", err)
|
||||
return
|
||||
}
|
||||
oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL)
|
||||
|
||||
link, err := discord.GetOAuthLink(cfg.Discord, state, cfg.HWSAuth.TrustedHost)
|
||||
if err != nil {
|
||||
err = server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured trying to generate the login link",
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel("error"),
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, link, http.StatusSeeOther)
|
||||
|
||||
95
internal/handlers/register.go
Normal file
95
internal/handlers/register.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/oslstats/internal/config"
|
||||
"git.haelnorr.com/h/oslstats/internal/db"
|
||||
"git.haelnorr.com/h/oslstats/internal/session"
|
||||
"git.haelnorr.com/h/oslstats/internal/view/page"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func Register(
|
||||
server *hws.Server,
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
store *session.Store,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
sessionCookie, err := r.Cookie("registration_session")
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
details, ok := store.GetRegistrationSession(sessionCookie.Value)
|
||||
if !ok {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database transaction failed", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
method := r.Method
|
||||
if method == "GET" {
|
||||
unique, err := db.IsUsernameUnique(ctx, tx, details.DiscordUser.Username)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database query failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
page.Register(details.DiscordUser.Username, unique).Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
if method == "POST" {
|
||||
// TODO: register the user
|
||||
|
||||
// get the form data
|
||||
//
|
||||
return
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func IsUsernameUnique(
|
||||
server *hws.Server,
|
||||
conn *bun.DB,
|
||||
cfg *config.Config,
|
||||
store *session.Store,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
username := r.FormValue("username")
|
||||
// check if its unique
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database transaction failed", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
unique, err := db.IsUsernameUnique(ctx, tx, username)
|
||||
if err != nil {
|
||||
throwInternalServiceError(server, w, r, "Database query failed", err)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
if !unique {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -15,16 +15,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler {
|
||||
if err != nil {
|
||||
// If we can't create the file server, return a handler that always errors
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err = server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "An error occured trying to load the file system",
|
||||
Error: err,
|
||||
Level: hws.ErrorLevel("error"),
|
||||
RenderErrorPage: true,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user