diff --git a/cmd/oslstats/auth.go b/cmd/oslstats/auth.go index acb3421..d2b6ff5 100644 --- a/cmd/oslstats/auth.go +++ b/cmd/oslstats/auth.go @@ -30,6 +30,7 @@ func setupAuth( beginTx, logger, handlers.ErrorPage, + conn.DB, ) if err != nil { return nil, errors.Wrap(err, "hwsauth.NewAuthenticator") diff --git a/cmd/oslstats/db.go b/cmd/oslstats/db.go index 6718d30..922b928 100644 --- a/cmd/oslstats/db.go +++ b/cmd/oslstats/db.go @@ -31,6 +31,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func func loadModels(ctx context.Context, conn *bun.DB, resetDB bool) error { models := []any{ (*db.User)(nil), + (*db.DiscordToken)(nil), } for _, model := range models { diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index e7161b7..e2942c5 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -44,17 +44,30 @@ func addRoutes( { Path: "/auth/callback", Method: hws.MethodGET, - Handler: auth.LogoutReq(handlers.Callback(server, conn, cfg, store, discordAPI)), + Handler: auth.LogoutReq(handlers.Callback(server, auth, conn, cfg, store, discordAPI)), }, { Path: "/register", Method: hws.MethodGET, - Handler: auth.LogoutReq(handlers.Register(server, conn, cfg, store)), + Handler: auth.LogoutReq(handlers.Register(server, auth, conn, cfg, store)), + }, + { + Path: "/register", + Method: hws.MethodPOST, + Handler: auth.LogoutReq(handlers.Register(server, auth, conn, cfg, store)), + }, + } + + htmxRoutes := []hws.Route{ + { + Path: "/htmx/isusernameunique", + Method: hws.MethodPOST, + Handler: handlers.IsUsernameUnique(server, conn, cfg, store), }, } // Register the routes with the server - err := server.AddRoutes(routes...) + err := server.AddRoutes(append(routes, htmxRoutes...)...) if err != nil { return errors.Wrap(err, "server.AddRoutes") } diff --git a/go.mod b/go.mod index 0be7c08..385f76a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/hlog v0.10.4 git.haelnorr.com/h/golib/hws v0.3.0 - git.haelnorr.com/h/golib/hwsauth v0.4.0 + git.haelnorr.com/h/golib/hwsauth v0.5.0 github.com/a-h/templ v0.3.977 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 @@ -22,7 +22,7 @@ require ( ) require ( - git.haelnorr.com/h/golib/cookies v0.9.0 // indirect + git.haelnorr.com/h/golib/cookies v0.9.0 git.haelnorr.com/h/golib/jwt v0.10.1 // indirect github.com/bwmarrin/discordgo v0.29.0 github.com/go-logr/logr v1.4.3 // indirect diff --git a/go.sum b/go.sum index c0f475e..795badc 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4V git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo= git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= -git.haelnorr.com/h/golib/hwsauth v0.4.0 h1:femjTuiaE8ye4BgC1xH1r6rC7PAhuhMmhcn1FBFZLN0= -git.haelnorr.com/h/golib/hwsauth v0.4.0/go.mod h1:aHY2u3b+dhoymszd/keii5HX9ZWpHU3v8gQqvTb/yKc= +git.haelnorr.com/h/golib/hwsauth v0.5.0 h1:RAr7cdMe2aden50n7d9m5R4josZZ8ikNfWGMAEGnJbo= +git.haelnorr.com/h/golib/hwsauth v0.5.0/go.mod h1:NOonrVU/lX8lzuV77eDEiTwBjn7RrzYVcSdXUJWeHmQ= git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= diff --git a/internal/db/discord_tokens.go b/internal/db/discord_tokens.go index 7f6c2c3..ae3f6fc 100644 --- a/internal/db/discord_tokens.go +++ b/internal/db/discord_tokens.go @@ -5,7 +5,6 @@ import ( "time" "git.haelnorr.com/h/oslstats/internal/discord" - "github.com/bwmarrin/discordgo" "github.com/pkg/errors" "github.com/uptrace/bun" ) @@ -19,10 +18,7 @@ type DiscordToken struct { ExpiresAt int64 `bun:"expires_at,notnull"` } -func UpdateDiscordToken(ctx context.Context, db *bun.DB, user *discordgo.User, token *discord.Token) error { - if db == nil { - return errors.New("db cannot be nil") - } +func UpdateDiscordToken(ctx context.Context, tx bun.Tx, user *User, token *discord.Token) error { if user == nil { return errors.New("user cannot be nil") } @@ -32,13 +28,13 @@ func UpdateDiscordToken(ctx context.Context, db *bun.DB, user *discordgo.User, t expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix() discordToken := &DiscordToken{ - DiscordID: user.ID, + DiscordID: user.DiscordID, AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, ExpiresAt: expiresAt, } - _, err := db.NewInsert(). + _, err := tx.NewInsert(). Model(discordToken). On("CONFLICT (discord_id) DO UPDATE"). Set("access_token = EXCLUDED.access_token"). @@ -46,5 +42,8 @@ func UpdateDiscordToken(ctx context.Context, db *bun.DB, user *discordgo.User, t Set("expires_at = EXCLUDED.expires_at"). Exec(ctx) - return err + if err != nil { + return errors.Wrap(err, "tx.NewInsert") + } + return nil } diff --git a/internal/db/user.go b/internal/db/user.go index 866fb44..40c5ca8 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -2,6 +2,7 @@ package db import ( "context" + "fmt" "time" "git.haelnorr.com/h/golib/hwsauth" @@ -63,6 +64,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di // GetUserByID queries the database for a user matching the given ID // Returns nil, nil if no user is found func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) { + fmt.Printf("user id requested: %v", id) user := new(User) err := tx.NewSelect(). Model(user). diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go index 47239c5..9093352 100644 --- a/internal/handlers/callback.go +++ b/internal/handlers/callback.go @@ -5,7 +5,9 @@ import ( "net/http" "time" + "git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/hwsauth" "github.com/pkg/errors" "github.com/uptrace/bun" @@ -16,7 +18,14 @@ import ( "git.haelnorr.com/h/oslstats/pkg/oauth" ) -func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *store.Store, discordAPI *discord.APIClient) http.Handler { +func Callback( + server *hws.Server, + auth *hwsauth.Authenticator[*db.User, bun.Tx], + conn *bun.DB, + cfg *config.Config, + store *store.Store, + discordAPI *discord.APIClient, +) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { // Track callback redirect attempts @@ -86,7 +95,7 @@ func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *store return } defer tx.Rollback() - redirect, err := login(ctx, tx, cfg, w, r, code, store, discordAPI) + redirect, err := login(ctx, auth, tx, cfg, w, r, code, store, discordAPI) if err != nil { throwInternalServiceError(server, w, r, "OAuth login failed", err) return @@ -152,6 +161,7 @@ func verifyState( func login( ctx context.Context, + auth *hwsauth.Authenticator[*db.User, bun.Tx], tx bun.Tx, cfg *config.Config, w http.ResponseWriter, @@ -194,7 +204,11 @@ func login( }) redirect = "/register" } else { - // TODO: log them in + err := auth.Login(w, r, user, true) + if err != nil { + return nil, errors.Wrap(err, "auth.Login") + } + redirect = cookies.CheckPageFrom(w, r) } return func() { http.Redirect(w, r, redirect, http.StatusSeeOther) diff --git a/internal/handlers/login.go b/internal/handlers/login.go index e46312a..db44584 100644 --- a/internal/handlers/login.go +++ b/internal/handlers/login.go @@ -3,6 +3,7 @@ package handlers import ( "net/http" + "git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/hws" "github.com/pkg/errors" @@ -15,6 +16,7 @@ import ( func Login(server *hws.Server, cfg *config.Config, st *store.Store, discordAPI *discord.APIClient) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + cookies.SetPageFrom(w, r, cfg.HWSAuth.TrustedHost) // Track login redirect attempts attempts, exceeded, track := st.TrackRedirect(r, "/login", 5) diff --git a/internal/handlers/register.go b/internal/handlers/register.go index d70691f..bfeeccd 100644 --- a/internal/handlers/register.go +++ b/internal/handlers/register.go @@ -5,7 +5,9 @@ import ( "net/http" "time" + "git.haelnorr.com/h/golib/cookies" "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/golib/hwsauth" "github.com/pkg/errors" "github.com/uptrace/bun" @@ -17,6 +19,7 @@ import ( func Register( server *hws.Server, + auth *hwsauth.Authenticator[*db.User, bun.Tx], conn *bun.DB, cfg *config.Config, store *store.Store, @@ -56,7 +59,6 @@ func Register( return } details, ok := store.GetRegistrationSession(sessionCookie.Value) - ok = false if !ok { http.Redirect(w, r, "/login", http.StatusSeeOther) return @@ -73,20 +75,27 @@ func Register( defer tx.Rollback() method := r.Method if method == "GET" { - unique, err := db.IsUsernameUnique(ctx, tx, details.DiscordUser.Username) - if err != nil { - throwInternalServiceError(server, w, r, "Database query failed", err) - return - } tx.Commit() - page.Register(details.DiscordUser.Username, unique).Render(r.Context(), w) + page.Register(details.DiscordUser.Username).Render(r.Context(), w) return } if method == "POST" { - // TODO: register the user - - // get the form data - // + username := r.FormValue("username") + user, err := registerUser(ctx, tx, username, details) + if err != nil { + throwInternalServiceError(server, w, r, "Registration failed", err) + } + tx.Commit() + if user == nil { + w.WriteHeader(http.StatusConflict) + } else { + err = auth.Login(w, r, user, true) + if err != nil { + throwInternalServiceError(server, w, r, "Login failed", err) + } + pageFrom := cookies.CheckPageFrom(w, r) + w.Header().Set("HX-Redirect", pageFrom) + } return } }, @@ -124,3 +133,27 @@ func IsUsernameUnique( }, ) } + +func registerUser( + ctx context.Context, + tx bun.Tx, + username string, + details *store.RegistrationSession, +) (*db.User, error) { + unique, err := db.IsUsernameUnique(ctx, tx, username) + if err != nil { + return nil, errors.Wrap(err, "db.IsUsernameUnique") + } + if !unique { + return nil, nil + } + user, err := db.CreateUser(ctx, tx, username, details.DiscordUser) + if err != nil { + return nil, errors.Wrap(err, "db.CreateUser") + } + err = db.UpdateDiscordToken(ctx, tx, user, details.Token) + if err != nil { + return nil, errors.Wrap(err, "db.UpdateDiscordToken") + } + return user, nil +} diff --git a/internal/view/component/form/register.templ b/internal/view/component/form/register.templ index e4c49f1..8f81b63 100644 --- a/internal/view/component/form/register.templ +++ b/internal/view/component/form/register.templ @@ -1,25 +1,42 @@ package form -templ RegisterForm(username, registerError string) { - {{ usernameErr := "Username is taken" }} +templ RegisterForm(username string) {