From b810b750117ea1841d2ba31f9af12d347dca1ff2 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Fri, 23 Jan 2026 19:07:05 +1100 Subject: [PATCH] updated stuff --- .gitignore | 1 + Makefile | 3 + cmd/oslstats/auth.go | 5 +- cmd/oslstats/db.go | 2 +- cmd/oslstats/httpserver.go | 4 +- cmd/oslstats/main.go | 10 ++ cmd/oslstats/routes.go | 9 +- cmd/oslstats/run.go | 7 +- go.mod | 13 +- go.sum | 22 ++- internal/config/flags.go | 22 +-- internal/db/discord_tokens.go | 50 ++++++ internal/db/user.go | 103 ++++------- internal/discord/api.go | 26 +++ internal/discord/oauth.go | 158 +++++++++++++++- internal/handlers/callback.go | 133 ++++++++++++-- internal/handlers/errorpage.go | 87 ++++++++- internal/handlers/errors.go | 109 +++++++++++ internal/handlers/index.go | 29 +-- internal/handlers/login.go | 22 +-- internal/handlers/register.go | 95 ++++++++++ internal/handlers/static.go | 11 +- internal/session/newlogin.go | 46 +++++ internal/session/store.go | 46 +++++ internal/view/component/form/register.templ | 86 +++++++++ internal/view/component/nav/navbarright.templ | 4 +- internal/view/component/nav/sidenav.templ | 4 +- internal/view/page/error.templ | 80 ++++++--- internal/view/page/register.templ | 35 ++++ pkg/contexts/currentuser.go | 8 - pkg/contexts/keys.go | 4 +- pkg/embedfs/files/css/output.css | 170 ++++++++++++++++++ pkg/oauth/state_test.go | 4 +- 33 files changed, 1186 insertions(+), 222 deletions(-) create mode 100644 internal/db/discord_tokens.go create mode 100644 internal/discord/api.go create mode 100644 internal/handlers/errors.go create mode 100644 internal/handlers/register.go create mode 100644 internal/session/newlogin.go create mode 100644 internal/session/store.go create mode 100644 internal/view/component/form/register.templ create mode 100644 internal/view/page/register.templ delete mode 100644 pkg/contexts/currentuser.go diff --git a/.gitignore b/.gitignore index 33692cc..9d37f15 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.db* .logs/ server.log +keys/ bin/ tmp/ static/css/output.css diff --git a/Makefile b/Makefile index 07bab50..8f59266 100644 --- a/Makefile +++ b/Makefile @@ -34,3 +34,6 @@ showenv: make build ./bin/${BINARY_NAME} --showenv +migrate: + make build + ./bin/${BINARY_NAME}${SUFFIX} --migrate diff --git a/cmd/oslstats/auth.go b/cmd/oslstats/auth.go index 4f57359..acb3421 100644 --- a/cmd/oslstats/auth.go +++ b/cmd/oslstats/auth.go @@ -8,7 +8,6 @@ import ( "git.haelnorr.com/h/golib/hwsauth" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/handlers" - "git.haelnorr.com/h/oslstats/pkg/contexts" "github.com/pkg/errors" "github.com/uptrace/bun" ) @@ -38,7 +37,9 @@ func setupAuth( auth.IgnorePaths(ignoredPaths...) - contexts.CurrentUser = auth.CurrentModel + db.CurrentUser = auth.CurrentModel return auth, nil } + +// TODO: make a new getuser function that wraps db.GetUserByID and does OAuth refresh diff --git a/cmd/oslstats/db.go b/cmd/oslstats/db.go index 1df8f57..6718d30 100644 --- a/cmd/oslstats/db.go +++ b/cmd/oslstats/db.go @@ -20,7 +20,7 @@ func setupBun(ctx context.Context, cfg *config.Config) (conn *bun.DB, close func conn = bun.NewDB(sqldb, pgdialect.New()) close = sqldb.Close - err = loadModels(ctx, conn, cfg.Flags.ResetDB) + err = loadModels(ctx, conn, cfg.Flags.MigrateDB) if err != nil { return nil, nil, errors.Wrap(err, "loadModels") } diff --git a/cmd/oslstats/httpserver.go b/cmd/oslstats/httpserver.go index 9a50c78..8c0ff26 100644 --- a/cmd/oslstats/httpserver.go +++ b/cmd/oslstats/httpserver.go @@ -7,6 +7,7 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/session" "git.haelnorr.com/h/golib/hlog" "github.com/pkg/errors" @@ -18,6 +19,7 @@ func setupHttpServer( config *config.Config, logger *hlog.Logger, bun *bun.DB, + store *session.Store, ) (server *hws.Server, err error) { if staticFS == nil { return nil, errors.New("No filesystem provided") @@ -53,7 +55,7 @@ func setupHttpServer( return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths") } - err = addRoutes(httpServer, &fs, config, bun, auth) + err = addRoutes(httpServer, &fs, config, bun, auth, store) if err != nil { return nil, errors.Wrap(err, "addRoutes") } diff --git a/cmd/oslstats/main.go b/cmd/oslstats/main.go index e92d126..38e0eea 100644 --- a/cmd/oslstats/main.go +++ b/cmd/oslstats/main.go @@ -29,6 +29,16 @@ func main() { return } + if flags.MigrateDB { + _, closedb, err := setupBun(ctx, cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err) + os.Exit(1) + } + closedb() + return + } + if err := run(ctx, os.Stdout, cfg); err != nil { fmt.Fprintf(os.Stderr, "%s\n", err) os.Exit(1) diff --git a/cmd/oslstats/routes.go b/cmd/oslstats/routes.go index 993e135..e247234 100644 --- a/cmd/oslstats/routes.go +++ b/cmd/oslstats/routes.go @@ -8,6 +8,7 @@ import ( "git.haelnorr.com/h/oslstats/internal/config" "git.haelnorr.com/h/oslstats/internal/db" "git.haelnorr.com/h/oslstats/internal/handlers" + "git.haelnorr.com/h/oslstats/internal/session" "github.com/pkg/errors" "github.com/uptrace/bun" @@ -19,6 +20,7 @@ func addRoutes( cfg *config.Config, conn *bun.DB, auth *hwsauth.Authenticator[*db.User, bun.Tx], + store *session.Store, ) error { // Create the routes routes := []hws.Route{ @@ -40,7 +42,12 @@ func addRoutes( { Path: "/auth/callback", Method: hws.MethodGET, - Handler: auth.LogoutReq(handlers.Callback(server, cfg)), + Handler: auth.LogoutReq(handlers.Callback(server, conn, cfg, store)), + }, + { + Path: "/register", + Method: hws.MethodGET, + Handler: auth.LogoutReq(handlers.Register(server, conn, cfg, store)), }, } diff --git a/cmd/oslstats/run.go b/cmd/oslstats/run.go index 7906c26..5dd1922 100644 --- a/cmd/oslstats/run.go +++ b/cmd/oslstats/run.go @@ -10,6 +10,7 @@ import ( "git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/session" "git.haelnorr.com/h/oslstats/pkg/embedfs" "github.com/pkg/errors" ) @@ -41,8 +42,12 @@ func run(ctx context.Context, w io.Writer, config *config.Config) error { return errors.Wrap(err, "getStaticFiles") } + // Setup session store + logger.Debug().Msg("Setting up session store") + store := session.NewStore() + logger.Debug().Msg("Setting up HTTP server") - httpServer, err := setupHttpServer(&staticFS, config, logger, bun) + httpServer, err := setupHttpServer(&staticFS, config, logger, bun, store) if err != nil { return errors.Wrap(err, "setupHttpServer") } diff --git a/go.mod b/go.mod index 26e14e9..0be7c08 100644 --- a/go.mod +++ b/go.mod @@ -6,20 +6,25 @@ require ( git.haelnorr.com/h/golib/env v0.9.1 git.haelnorr.com/h/golib/ezconf v0.1.1 git.haelnorr.com/h/golib/hlog v0.10.4 - git.haelnorr.com/h/golib/hws v0.2.3 - git.haelnorr.com/h/golib/hwsauth v0.3.4 + git.haelnorr.com/h/golib/hws v0.3.0 + git.haelnorr.com/h/golib/hwsauth v0.4.0 github.com/a-h/templ v0.3.977 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 github.com/uptrace/bun v1.2.16 github.com/uptrace/bun/dialect/pgdialect v1.2.16 github.com/uptrace/bun/driver/pgdriver v1.2.16 - golang.org/x/crypto v0.45.0 +) + +require ( + github.com/gorilla/websocket v1.4.2 // indirect + golang.org/x/crypto v0.45.0 // indirect ) require ( git.haelnorr.com/h/golib/cookies v0.9.0 // indirect - git.haelnorr.com/h/golib/jwt v0.10.0 // indirect + git.haelnorr.com/h/golib/jwt v0.10.1 // indirect + github.com/bwmarrin/discordgo v0.29.0 github.com/go-logr/logr v1.4.3 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/google/uuid v1.6.0 // indirect diff --git a/go.sum b/go.sum index 78c8ff7..c0f475e 100644 --- a/go.sum +++ b/go.sum @@ -6,16 +6,18 @@ git.haelnorr.com/h/golib/ezconf v0.1.1 h1:4euTSDb9jvuQQkVq+x5gHoYPYyUZPWxoOSlWCI git.haelnorr.com/h/golib/ezconf v0.1.1/go.mod h1:rETDcjpcEyyeBgCiZSU617wc0XycwZSC5+IAOtXmwP8= git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4ViSRQ= git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc= -git.haelnorr.com/h/golib/hws v0.2.3 h1:gZQkBciXKh3jYw05vZncSR2lvIqi0H2MVfIWySySsmw= -git.haelnorr.com/h/golib/hws v0.2.3/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= -git.haelnorr.com/h/golib/hwsauth v0.3.4 h1:wwYBb6cQQ+x9hxmYuZBF4mVmCv/n4PjJV//e1+SgPOo= -git.haelnorr.com/h/golib/hwsauth v0.3.4/go.mod h1:LI7Qz68GPNIW8732Zwptb//ybjiFJOoXf4tgUuUEqHI= -git.haelnorr.com/h/golib/jwt v0.10.0 h1:8cI8mSnb8X+EmJtrBO/5UZwuBMtib0IE9dv85gkm94E= -git.haelnorr.com/h/golib/jwt v0.10.0/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= +git.haelnorr.com/h/golib/hws v0.3.0 h1:/YGzxd3sRR3DFU6qVZxpJMKV3W2wCONqZKYUDIercCo= +git.haelnorr.com/h/golib/hws v0.3.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo= +git.haelnorr.com/h/golib/hwsauth v0.4.0 h1:femjTuiaE8ye4BgC1xH1r6rC7PAhuhMmhcn1FBFZLN0= +git.haelnorr.com/h/golib/hwsauth v0.4.0/go.mod h1:aHY2u3b+dhoymszd/keii5HX9ZWpHU3v8gQqvTb/yKc= +git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI= +git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg= github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo= +github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= +github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -28,6 +30,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -66,13 +70,19 @@ go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8= diff --git a/internal/config/flags.go b/internal/config/flags.go index 7d89108..ba87131 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -5,16 +5,16 @@ import ( ) type Flags struct { - ResetDB bool - EnvDoc bool - ShowEnv bool - GenEnv string - EnvFile string + MigrateDB bool + EnvDoc bool + ShowEnv bool + GenEnv string + EnvFile string } func SetupFlags() *Flags { // Parse commandline args - resetDB := flag.Bool("resetdb", false, "Reset all the database tables with the updated models") + migrateDB := flag.Bool("migrate", false, "Reset all the database tables with the updated models") envDoc := flag.Bool("envdoc", false, "Print all environment variables and their documentation") showEnv := flag.Bool("showenv", false, "Print all environment variable values and their documentation") genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)") @@ -22,11 +22,11 @@ func SetupFlags() *Flags { flag.Parse() flags := &Flags{ - ResetDB: *resetDB, - EnvDoc: *envDoc, - ShowEnv: *showEnv, - GenEnv: *genEnv, - EnvFile: *envfile, + MigrateDB: *migrateDB, + EnvDoc: *envDoc, + ShowEnv: *showEnv, + GenEnv: *genEnv, + EnvFile: *envfile, } return flags } diff --git a/internal/db/discord_tokens.go b/internal/db/discord_tokens.go new file mode 100644 index 0000000..7f6c2c3 --- /dev/null +++ b/internal/db/discord_tokens.go @@ -0,0 +1,50 @@ +package db + +import ( + "context" + "time" + + "git.haelnorr.com/h/oslstats/internal/discord" + "github.com/bwmarrin/discordgo" + "github.com/pkg/errors" + "github.com/uptrace/bun" +) + +type DiscordToken struct { + bun.BaseModel `bun:"table:discord_tokens,alias:dt"` + + DiscordID string `bun:"discord_id,pk,notnull"` + AccessToken string `bun:"access_token,notnull"` + RefreshToken string `bun:"refresh_token,notnull"` + ExpiresAt int64 `bun:"expires_at,notnull"` +} + +func UpdateDiscordToken(ctx context.Context, db *bun.DB, user *discordgo.User, token *discord.Token) error { + if db == nil { + return errors.New("db cannot be nil") + } + if user == nil { + return errors.New("user cannot be nil") + } + if token == nil { + return errors.New("token cannot be nil") + } + expiresAt := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second).Unix() + + discordToken := &DiscordToken{ + DiscordID: user.ID, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + ExpiresAt: expiresAt, + } + + _, err := db.NewInsert(). + Model(discordToken). + On("CONFLICT (discord_id) DO UPDATE"). + Set("access_token = EXCLUDED.access_token"). + Set("refresh_token = EXCLUDED.refresh_token"). + Set("expires_at = EXCLUDED.expires_at"). + Exec(ctx) + + return err +} diff --git a/internal/db/user.go b/internal/db/user.go index 01764eb..866fb44 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -2,65 +2,29 @@ package db import ( "context" + "time" + "git.haelnorr.com/h/golib/hwsauth" + "github.com/bwmarrin/discordgo" "github.com/pkg/errors" "github.com/uptrace/bun" - "golang.org/x/crypto/bcrypt" ) +var CurrentUser hwsauth.ContextLoader[*User] + type User struct { bun.BaseModel `bun:"table:users,alias:u"` - ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key) - Username string `bun:"username,unique"` // Username (unique) - PasswordHash string `bun:"password_hash,nullzero"` // Bcrypt hashed password (not exported in JSON) - CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database - Bio string `bun:"bio"` // Short byline set by the user + ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key) + Username string `bun:"username,unique"` // Username (unique) + CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database + DiscordID string `bun:"discord_id,unique"` } func (user *User) GetID() int { return user.ID } -// Uses bcrypt to set the users password_hash from the given password -func (user *User) SetPassword(ctx context.Context, tx bun.Tx, password string) error { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "bcrypt.GenerateFromPassword") - } - newPassword := string(hashedPassword) - - _, err = tx.NewUpdate(). - Model(user). - Set("password_hash = ?", newPassword). - Where("id = ?", user.ID). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "tx.Update") - } - return nil -} - -// Uses bcrypt to check if the given password matches the users password_hash -func (user *User) CheckPassword(ctx context.Context, tx bun.Tx, password string) error { - var hashedPassword string - err := tx.NewSelect(). - Table("users"). - Column("password_hash"). - Where("id = ?", user.ID). - Limit(1). - Scan(ctx, &hashedPassword) - if err != nil { - return errors.Wrap(err, "tx.Select") - } - - err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) - if err != nil { - return errors.Wrap(err, "Username or password incorrect") - } - return nil -} - // Change the user's username func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername string) error { _, err := tx.NewUpdate(). @@ -75,35 +39,18 @@ func (user *User) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername str return nil } -// Change the user's bio -func (user *User) ChangeBio(ctx context.Context, tx bun.Tx, newBio string) error { - _, err := tx.NewUpdate(). - Model(user). - Set("bio = ?", newBio). - Where("id = ?", user.ID). - Exec(ctx) - if err != nil { - return errors.Wrap(err, "tx.Update") - } - user.Bio = newBio - return nil -} - // CreateUser creates a new user with the given username and password -func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*User, error) { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return nil, errors.Wrap(err, "bcrypt.GenerateFromPassword") +func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *discordgo.User) (*User, error) { + if discorduser == nil { + return nil, errors.New("user cannot be nil") } - user := &User{ - Username: username, - PasswordHash: string(hashedPassword), - CreatedAt: 0, // You may want to set this to time.Now().Unix() - Bio: "", + Username: username, + CreatedAt: time.Now().Unix(), + DiscordID: discorduser.ID, } - _, err = tx.NewInsert(). + _, err := tx.NewInsert(). Model(user). Exec(ctx) if err != nil { @@ -149,6 +96,24 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User, return user, nil } +// GetUserByDiscordID queries the database for a user matching the given discord id +// Returns nil, nil if no user is found +func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User, error) { + user := new(User) + err := tx.NewSelect(). + Model(user). + Where("discord_id = ?", discordID). + Limit(1). + Scan(ctx) + if err != nil { + if err.Error() == "sql: no rows in result set" { + return nil, nil + } + return nil, errors.Wrap(err, "tx.Select") + } + return user, nil +} + // IsUsernameUnique checks if the given username is unique (not already taken) // Returns true if the username is available, false if it's taken func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) { diff --git a/internal/discord/api.go b/internal/discord/api.go new file mode 100644 index 0000000..7c41257 --- /dev/null +++ b/internal/discord/api.go @@ -0,0 +1,26 @@ +package discord + +import ( + "github.com/bwmarrin/discordgo" + "github.com/pkg/errors" +) + +type OAuthSession struct { + *discordgo.Session +} + +func NewOAuthSession(token *Token) (*OAuthSession, error) { + session, err := discordgo.New("Bearer " + token.AccessToken) + if err != nil { + return nil, errors.Wrap(err, "discordgo.New") + } + return &OAuthSession{Session: session}, nil +} + +func (s *OAuthSession) GetUser() (*discordgo.User, error) { + user, err := s.User("@me") + if err != nil { + return nil, errors.Wrap(err, "s.User") + } + return user, nil +} diff --git a/internal/discord/oauth.go b/internal/discord/oauth.go index 591d0f9..73644ea 100644 --- a/internal/discord/oauth.go +++ b/internal/discord/oauth.go @@ -1,23 +1,28 @@ package discord import ( + "encoding/json" "fmt" + "io" + "net/http" "net/url" + "strings" "github.com/pkg/errors" ) type Token struct { - AccessToken string - TokenType string - ExpiresIn int - RefreshToken string - Scope string + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` } const oauthurl string = "https://discord.com/oauth2/authorize" +const apiurl string = "https://discord.com/api/v10" -func GetOAuthLink(cfg *Config, state string, trustedHost string) (string, error) { +func GetOAuthLink(cfg *Config, state, trustedHost string) (string, error) { if cfg == nil { return "", errors.New("cfg cannot be nil") } @@ -37,3 +42,144 @@ func GetOAuthLink(cfg *Config, state string, trustedHost string) (string, error) return fmt.Sprintf("%s?%s", oauthurl, values.Encode()), nil } + +func AuthorizeWithCode(cfg *Config, code, trustedHost string) (*Token, error) { + if code == "" { + return nil, errors.New("code cannot be empty") + } + if cfg == nil { + return nil, errors.New("config cannot be nil") + } + if trustedHost == "" { + return nil, errors.New("trustedHost cannot be empty") + } + // Prepare form data + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("redirect_uri", fmt.Sprintf("%s/%s", trustedHost, cfg.RedirectPath)) + // Create request + req, err := http.NewRequest( + "POST", + apiurl+"/oauth2/token", + strings.NewReader(data.Encode()), + ) + if err != nil { + return nil, errors.Wrap(err, "failed to create request") + } + // Set headers + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Set basic auth (client_id and client_secret) + req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) + // Execute request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + // Check status code + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body)) + } + // Parse JSON response + var tokenResp Token + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "failed to parse token response") + } + return &tokenResp, nil +} + +func RefreshToken(cfg *Config, token *Token) (*Token, error) { + if token == nil { + return nil, errors.New("token cannot be nil") + } + if cfg == nil { + return nil, errors.New("config cannot be nil") + } + // Prepare form data + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", token.RefreshToken) + // Create request + req, err := http.NewRequest( + "POST", + apiurl+"/oauth2/token", + strings.NewReader(data.Encode()), + ) + if err != nil { + return nil, errors.Wrap(err, "failed to create request") + } + // Set headers + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Set basic auth (client_id and client_secret) + req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) + // Execute request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } + // Check status code + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("discord API returned status %d: %s", resp.StatusCode, string(body)) + } + // Parse JSON response + var tokenResp Token + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, errors.Wrap(err, "failed to parse token response") + } + return &tokenResp, nil +} + +func RevokeToken(cfg *Config, token *Token) error { + if token == nil { + return errors.New("token cannot be nil") + } + if cfg == nil { + return errors.New("config cannot be nil") + } + // Prepare form data + data := url.Values{} + data.Set("token", token.AccessToken) + data.Set("token_type_hint", "access_token") + // Create request + req, err := http.NewRequest( + "POST", + apiurl+"/oauth2/token/revoke", + strings.NewReader(data.Encode()), + ) + if err != nil { + return errors.Wrap(err, "failed to create request") + } + // Set headers + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Set basic auth (client_id and client_secret) + req.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) + // Execute request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return errors.Wrap(err, "failed to execute request") + } + defer resp.Body.Close() + // Check status code + if resp.StatusCode != http.StatusOK { + return errors.Errorf("discord API returned status %d", resp.StatusCode) + } + return nil +} diff --git a/internal/handlers/callback.go b/internal/handlers/callback.go index 9d985c7..d0677aa 100644 --- a/internal/handlers/callback.go +++ b/internal/handlers/callback.go @@ -1,15 +1,21 @@ package handlers import ( + "context" "net/http" + "time" "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/discord" + "git.haelnorr.com/h/oslstats/internal/session" "git.haelnorr.com/h/oslstats/pkg/oauth" "github.com/pkg/errors" + "github.com/uptrace/bun" ) -func Callback(server *hws.Server, cfg *config.Config) http.Handler { +func Callback(server *hws.Server, conn *bun.DB, cfg *config.Config, store *session.Store) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { state := r.URL.Query().Get("state") @@ -20,42 +26,141 @@ func Callback(server *hws.Server, cfg *config.Config) http.Handler { } data, err := verifyState(cfg.OAuth, w, r, state) if err != nil { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusForbidden, - Message: "OAuth state verification failed", - Error: err, - Level: hws.ErrorLevel("debug"), - RenderErrorPage: true, - }) - if err != nil { - server.ThrowFatal(w, err) + // Check if this is a cookie error (401) or signature error (403) + if vsErr, ok := err.(*verifyStateError); ok { + if vsErr.IsCookieError() { + // Cookie missing/expired - normal failed/expired session (DEBUG) + throwUnauthorized(server, w, r, "OAuth session not found or expired", err) + } else { + // Signature verification failed - security violation (WARN) + throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) + } + } else { + // Unknown error type - treat as security issue + throwForbiddenSecurity(server, w, r, "OAuth state verification failed", err) } return } switch data { case "login": - w.Write([]byte(code)) + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "DB Transaction failed to start", err) + return + } + defer tx.Rollback() + redirect, err := login(ctx, tx, cfg, w, r, code, store) + if err != nil { + throwInternalServiceError(server, w, r, "OAuth login failed", err) + return + } + tx.Commit() + redirect() return } }, ) } -func verifyState(cfg *oauth.Config, w http.ResponseWriter, r *http.Request, state string) (string, error) { +// verifyStateError wraps an error with context about what went wrong +type verifyStateError struct { + err error + cookieError bool // true if cookie missing/invalid, false if signature invalid +} + +func (e *verifyStateError) Error() string { + return e.err.Error() +} + +func (e *verifyStateError) IsCookieError() bool { + return e.cookieError +} + +func verifyState( + cfg *oauth.Config, + w http.ResponseWriter, + r *http.Request, + state string, +) (string, error) { if r == nil { return "", errors.New("request cannot be nil") } if state == "" { return "", errors.New("state param field is empty") } + + // Try to get the cookie uak, err := oauth.GetStateCookie(r) if err != nil { - return "", errors.Wrap(err, "oauth.GetStateCookie") + // Cookie missing or invalid - this is a 401 (not authenticated) + return "", &verifyStateError{ + err: errors.Wrap(err, "oauth.GetStateCookie"), + cookieError: true, + } } + + // Verify the state signature data, err := oauth.VerifyState(cfg, state, uak) if err != nil { - return "", errors.Wrap(err, "oauth.VerifyState") + // Signature verification failed - this is a 403 (security violation) + return "", &verifyStateError{ + err: errors.Wrap(err, "oauth.VerifyState"), + cookieError: false, + } } + oauth.DeleteStateCookie(w) return data, nil } + +func login( + ctx context.Context, + tx bun.Tx, + cfg *config.Config, + w http.ResponseWriter, + r *http.Request, + code string, + store *session.Store, +) (func(), error) { + token, err := discord.AuthorizeWithCode(cfg.Discord, code, cfg.HWSAuth.TrustedHost) + if err != nil { + return nil, errors.Wrap(err, "discord.AuthorizeWithCode") + } + session, err := discord.NewOAuthSession(token) + if err != nil { + return nil, errors.Wrap(err, "discord.NewOAuthSession") + } + discorduser, err := session.GetUser() + if err != nil { + return nil, errors.Wrap(err, "session.GetUser") + } + + user, err := db.GetUserByDiscordID(ctx, tx, discorduser.ID) + if err != nil { + return nil, errors.Wrap(err, "db.GetUserByDiscordID") + } + var redirect string + if user == nil { + sessionID, err := store.CreateRegistrationSession(discorduser, token) + if err != nil { + return nil, errors.Wrap(err, "store.CreateRegistrationSession") + } + http.SetCookie(w, &http.Cookie{ + Name: "registration_session", + Path: "/", + Value: sessionID, + MaxAge: 300, // 5 minutes + HttpOnly: true, + Secure: cfg.HWSAuth.SSL, + SameSite: http.SameSiteLaxMode, + }) + redirect = "/register" + } else { + // TODO: log them in + } + return func() { + http.Redirect(w, r, redirect, http.StatusSeeOther) + }, nil +} diff --git a/internal/handlers/errorpage.go b/internal/handlers/errorpage.go index 2fa7699..d6e4a25 100644 --- a/internal/handlers/errorpage.go +++ b/internal/handlers/errorpage.go @@ -5,24 +5,93 @@ import ( "git.haelnorr.com/h/golib/hws" "git.haelnorr.com/h/oslstats/internal/view/page" - "github.com/pkg/errors" ) -func ErrorPage( - errorCode int, -) (hws.ErrorPage, error) { +// func ErrorPage( +// error hws.HWSError, +// ) (hws.ErrorPage, error) { +// messages := map[int]string{ +// 400: "The request you made was malformed or unexpected.", +// 401: "You need to login to view this page.", +// 403: "You do not have permission to view this page.", +// 404: "The page or resource you have requested does not exist.", +// 500: `An error occured on the server. Please try again, and if this +// continues to happen contact an administrator.`, +// 503: "The server is currently down for maintenance and should be back soon. =)", +// } +// msg, exists := messages[error.StatusCode] +// if !exists { +// return nil, errors.New("No valid message for the given code") +// } +// return page.Error(error.StatusCode, http.StatusText(error.StatusCode), msg), nil +// } + +func ErrorPage(hwsError hws.HWSError) (hws.ErrorPage, error) { + // Determine if this status code should show technical details + showDetails := shouldShowDetails(hwsError.StatusCode) + + // Get the user-friendly message + message := hwsError.Message + if message == "" { + // Fallback to default messages if no custom message provided + message = getDefaultMessage(hwsError.StatusCode) + } + + // Get technical details if applicable + var details string + if showDetails && hwsError.Error != nil { + details = hwsError.Error.Error() + } + + // Render appropriate template + if details != "" { + return page.ErrorWithDetails( + hwsError.StatusCode, + http.StatusText(hwsError.StatusCode), + message, + details, + ), nil + } + + return page.Error( + hwsError.StatusCode, + http.StatusText(hwsError.StatusCode), + message, + ), nil +} + +// shouldShowDetails determines if a status code should display technical details +func shouldShowDetails(statusCode int) bool { + switch statusCode { + case 400, 500, 503: // Bad Request, Internal Server Error, Service Unavailable + return true + case 401, 403, 404: // Unauthorized, Forbidden, Not Found + return false + default: + // For unknown codes, show details for 5xx errors + return statusCode >= 500 + } +} + +// getDefaultMessage provides fallback messages for status codes +func getDefaultMessage(statusCode int) string { messages := map[int]string{ 400: "The request you made was malformed or unexpected.", 401: "You need to login to view this page.", 403: "You do not have permission to view this page.", 404: "The page or resource you have requested does not exist.", - 500: `An error occured on the server. Please try again, and if this - continues to happen contact an administrator.`, + 500: `An error occurred on the server. Please try again, and if this + continues to happen contact an administrator.`, 503: "The server is currently down for maintenance and should be back soon. =)", } - msg, exists := messages[errorCode] + + msg, exists := messages[statusCode] if !exists { - return nil, errors.New("No valid message for the given code") + if statusCode >= 500 { + return "A server error occurred. Please try again later." + } + return "An error occurred while processing your request." } - return page.Error(errorCode, http.StatusText(errorCode), msg), nil + + return msg } diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go new file mode 100644 index 0000000..4e26611 --- /dev/null +++ b/internal/handlers/errors.go @@ -0,0 +1,109 @@ +package handlers + +import ( + "fmt" + "net/http" + + "git.haelnorr.com/h/golib/hws" + "github.com/pkg/errors" +) + +// throwError is a generic helper that all throw* functions use internally +func throwError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + statusCode int, + msg string, + err error, + level string, +) { + err = s.ThrowError(w, r, hws.HWSError{ + StatusCode: statusCode, + Message: msg, + Error: err, + Level: hws.ErrorLevel(level), + RenderErrorPage: true, // throw* family always renders error pages + }) + if err != nil { + s.ThrowFatal(w, err) + } +} + +// throwInternalServiceError handles 500 errors (server failures) +func throwInternalServiceError( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusInternalServerError, msg, err, "error") +} + +// throwBadRequest handles 400 errors (malformed requests) +func throwBadRequest( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusBadRequest, msg, err, "debug") +} + +// throwForbidden handles 403 errors (normal permission denials) +func throwForbidden( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusForbidden, msg, err, "debug") +} + +// throwForbiddenSecurity handles 403 errors for security events (uses WARN level) +func throwForbiddenSecurity( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusForbidden, msg, err, "warn") +} + +// throwUnauthorized handles 401 errors (not authenticated) +func throwUnauthorized( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusUnauthorized, msg, err, "debug") +} + +// throwUnauthorizedSecurity handles 401 errors for security events (uses WARN level) +func throwUnauthorizedSecurity( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + throwError(s, w, r, http.StatusUnauthorized, msg, err, "warn") +} + +// throwNotFound handles 404 errors +func throwNotFound( + s *hws.Server, + w http.ResponseWriter, + r *http.Request, + path string, +) { + msg := fmt.Sprintf("The requested resource was not found: %s", path) + err := errors.New("Resource not found") + throwError(s, w, r, http.StatusNotFound, msg, err, "debug") +} diff --git a/internal/handlers/index.go b/internal/handlers/index.go index 3274302..25e7ac5 100644 --- a/internal/handlers/index.go +++ b/internal/handlers/index.go @@ -14,34 +14,7 @@ func Index(server *hws.Server) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { - page, err := ErrorPage(http.StatusNotFound) - if err != nil { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to generate the error page", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: false, - }) - if err != nil { - server.ThrowFatal(w, err) - } - return - } - err = page.Render(r.Context(), w) - if err != nil { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to render the error page", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: false, - }) - if err != nil { - server.ThrowFatal(w, err) - } - return - } + throwNotFound(server, w, r, r.URL.Path) } page.Index().Render(r.Context(), w) }, diff --git a/internal/handlers/login.go b/internal/handlers/login.go index a64e9d3..3ff126f 100644 --- a/internal/handlers/login.go +++ b/internal/handlers/login.go @@ -14,32 +14,14 @@ func Login(server *hws.Server, cfg *config.Config) http.Handler { func(w http.ResponseWriter, r *http.Request) { state, uak, err := oauth.GenerateState(cfg.OAuth, "login") if err != nil { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "Failed to generate state token", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: true, - }) - if err != nil { - server.ThrowFatal(w, err) - } + throwInternalServiceError(server, w, r, "Failed to generate state token", err) return } oauth.SetStateCookie(w, uak, cfg.HWSAuth.SSL) link, err := discord.GetOAuthLink(cfg.Discord, state, cfg.HWSAuth.TrustedHost) if err != nil { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to generate the login link", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: true, - }) - if err != nil { - server.ThrowFatal(w, err) - } + throwInternalServiceError(server, w, r, "An error occurred trying to generate the login link", err) return } http.Redirect(w, r, link, http.StatusSeeOther) diff --git a/internal/handlers/register.go b/internal/handlers/register.go new file mode 100644 index 0000000..b871e45 --- /dev/null +++ b/internal/handlers/register.go @@ -0,0 +1,95 @@ +package handlers + +import ( + "context" + "net/http" + "time" + + "git.haelnorr.com/h/golib/hws" + "git.haelnorr.com/h/oslstats/internal/config" + "git.haelnorr.com/h/oslstats/internal/db" + "git.haelnorr.com/h/oslstats/internal/session" + "git.haelnorr.com/h/oslstats/internal/view/page" + "github.com/uptrace/bun" +) + +func Register( + server *hws.Server, + conn *bun.DB, + cfg *config.Config, + store *session.Store, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + sessionCookie, err := r.Cookie("registration_session") + if err != nil { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + details, ok := store.GetRegistrationSession(sessionCookie.Value) + if !ok { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database transaction failed", err) + return + } + defer tx.Rollback() + method := r.Method + if method == "GET" { + unique, err := db.IsUsernameUnique(ctx, tx, details.DiscordUser.Username) + if err != nil { + throwInternalServiceError(server, w, r, "Database query failed", err) + return + } + tx.Commit() + page.Register(details.DiscordUser.Username, unique).Render(r.Context(), w) + return + } + if method == "POST" { + // TODO: register the user + + // get the form data + // + return + } + }, + ) +} + +func IsUsernameUnique( + server *hws.Server, + conn *bun.DB, + cfg *config.Config, + store *session.Store, +) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + username := r.FormValue("username") + // check if its unique + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + throwInternalServiceError(server, w, r, "Database transaction failed", err) + return + } + defer tx.Rollback() + unique, err := db.IsUsernameUnique(ctx, tx, username) + if err != nil { + throwInternalServiceError(server, w, r, "Database query failed", err) + return + } + tx.Commit() + if !unique { + w.WriteHeader(http.StatusConflict) + } else { + w.WriteHeader(http.StatusOK) + } + }, + ) +} diff --git a/internal/handlers/static.go b/internal/handlers/static.go index fb444c9..3176129 100644 --- a/internal/handlers/static.go +++ b/internal/handlers/static.go @@ -15,16 +15,7 @@ func StaticFS(staticFS *http.FileSystem, server *hws.Server) http.Handler { if err != nil { // If we can't create the file server, return a handler that always errors return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err = server.ThrowError(w, r, hws.HWSError{ - StatusCode: http.StatusInternalServerError, - Message: "An error occured trying to load the file system", - Error: err, - Level: hws.ErrorLevel("error"), - RenderErrorPage: true, - }) - if err != nil { - server.ThrowFatal(w, err) - } + throwInternalServiceError(server, w, r, "An error occurred trying to load the file system", err) }) } diff --git a/internal/session/newlogin.go b/internal/session/newlogin.go new file mode 100644 index 0000000..d1c60ba --- /dev/null +++ b/internal/session/newlogin.go @@ -0,0 +1,46 @@ +package session + +import ( + "errors" + "time" + + "git.haelnorr.com/h/oslstats/internal/discord" + "github.com/bwmarrin/discordgo" +) + +type RegistrationSession struct { + DiscordUser *discordgo.User + Token *discord.Token + ExpiresAt time.Time +} + +func (s *Store) CreateRegistrationSession(user *discordgo.User, token *discord.Token) (string, error) { + if user == nil { + return "", errors.New("user cannot be nil") + } + if token == nil { + return "", errors.New("token cannot be nil") + } + id := generateID() + s.sessions.Store(id, &RegistrationSession{ + DiscordUser: user, + Token: token, + ExpiresAt: time.Now().Add(5 * time.Minute), + }) + return id, nil +} + +func (s *Store) GetRegistrationSession(id string) (*RegistrationSession, bool) { + val, ok := s.sessions.Load(id) + if !ok { + return nil, false + } + + session := val.(*RegistrationSession) + if time.Now().After(session.ExpiresAt) { + s.sessions.Delete(id) + return nil, false + } + + return session, true +} diff --git a/internal/session/store.go b/internal/session/store.go new file mode 100644 index 0000000..fe77e62 --- /dev/null +++ b/internal/session/store.go @@ -0,0 +1,46 @@ +package session + +import ( + "crypto/rand" + "encoding/base64" + "sync" + "time" +) + +type Store struct { + sessions sync.Map + cleanup *time.Ticker +} + +func NewStore() *Store { + s := &Store{ + cleanup: time.NewTicker(1 * time.Minute), + } + + // Background cleanup of expired sessions + go func() { + for range s.cleanup.C { + s.cleanupExpired() + } + }() + + return s +} + +func (s *Store) Delete(id string) { + s.sessions.Delete(id) +} +func (s *Store) cleanupExpired() { + s.sessions.Range(func(key, value any) bool { + session := value.(*RegistrationSession) + if time.Now().After(session.ExpiresAt) { + s.sessions.Delete(key) + } + return true + }) +} +func generateID() string { + b := make([]byte, 32) + rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} diff --git a/internal/view/component/form/register.templ b/internal/view/component/form/register.templ new file mode 100644 index 0000000..e4c49f1 --- /dev/null +++ b/internal/view/component/form/register.templ @@ -0,0 +1,86 @@ +package form + +templ RegisterForm(username, registerError string) { + {{ usernameErr := "Username is taken" }} +
+ +
+
+
+ +
+ +
+

+
+
+ +
+
+} diff --git a/internal/view/component/nav/navbarright.templ b/internal/view/component/nav/navbarright.templ index 5f1fc22..88a5531 100644 --- a/internal/view/component/nav/navbarright.templ +++ b/internal/view/component/nav/navbarright.templ @@ -1,6 +1,6 @@ package nav -import "git.haelnorr.com/h/oslstats/pkg/contexts" +import "git.haelnorr.com/h/oslstats/internal/db" type ProfileItem struct { name string // Label to display @@ -23,7 +23,7 @@ func getProfileItems() []ProfileItem { // Returns the right portion of the navbar templ navRight() { - {{ user := contexts.CurrentUser(ctx) }} + {{ user := db.CurrentUser(ctx) }} {{ items := getProfileItems() }}
diff --git a/internal/view/component/nav/sidenav.templ b/internal/view/component/nav/sidenav.templ index 23fb9f0..5b50e70 100644 --- a/internal/view/component/nav/sidenav.templ +++ b/internal/view/component/nav/sidenav.templ @@ -1,10 +1,10 @@ package nav -import "git.haelnorr.com/h/oslstats/pkg/contexts" +import "git.haelnorr.com/h/oslstats/internal/db" // Returns the mobile version of the navbar thats only visible when activated templ sideNav(navItems []NavItem) { - {{ user := contexts.CurrentUser(ctx) }} + {{ user := db.CurrentUser(ctx) }}
-
-

{ strconv.Itoa(code) }

-

{ err }

-

{ message }

- Go to homepage +
+
+

{ strconv.Itoa(code) }

+

{ err }

+ // Always show the message from hws.HWSError.Message +

{ message }

+ // Conditionally show technical details in dropdown + if details != "" { +
+
+ + Details + (click to expand) + +
+
{ details }
+
+ +
+
+ } + + Go to homepage +
+ if details != "" { + + } } } diff --git a/internal/view/page/register.templ b/internal/view/page/register.templ new file mode 100644 index 0000000..45c69f8 --- /dev/null +++ b/internal/view/page/register.templ @@ -0,0 +1,35 @@ +package page + +import "git.haelnorr.com/h/oslstats/internal/view/layout" +import "git.haelnorr.com/h/oslstats/internal/view/component/form" + +// Returns the login page +templ Register(username string, unique bool) { + {{ + err := "" + if !unique { + err = "Username is taken" + } + }} + @layout.Global("Register") { +
+
+
+
+

Set your display name

+

+ Select your display name. This must be unique, and cannot be changed. +

+
+
+ @form.RegisterForm(username, err) +
+
+
+
+ } +} diff --git a/pkg/contexts/currentuser.go b/pkg/contexts/currentuser.go deleted file mode 100644 index f522563..0000000 --- a/pkg/contexts/currentuser.go +++ /dev/null @@ -1,8 +0,0 @@ -package contexts - -import ( - "git.haelnorr.com/h/golib/hwsauth" - "git.haelnorr.com/h/oslstats/internal/db" -) - -var CurrentUser hwsauth.ContextLoader[*db.User] diff --git a/pkg/contexts/keys.go b/pkg/contexts/keys.go index 996cba7..e089934 100644 --- a/pkg/contexts/keys.go +++ b/pkg/contexts/keys.go @@ -1,7 +1,7 @@ package contexts -type contextKey string +type Key string -func (c contextKey) String() string { +func (c Key) String() string { return "oslstats context key " + string(c) } diff --git a/pkg/embedfs/files/css/output.css b/pkg/embedfs/files/css/output.css index 8876f5c..022341c 100644 --- a/pkg/embedfs/files/css/output.css +++ b/pkg/embedfs/files/css/output.css @@ -11,7 +11,10 @@ --spacing: 0.25rem; --breakpoint-xl: 80rem; --container-md: 28rem; + --container-2xl: 42rem; --container-7xl: 80rem; + --text-xs: 0.75rem; + --text-xs--line-height: calc(1 / 0.75); --text-sm: 0.875rem; --text-sm--line-height: calc(1.25 / 0.875); --text-lg: 1.125rem; @@ -29,11 +32,13 @@ --text-9xl: 8rem; --text-9xl--line-height: 1; --font-weight-medium: 500; + --font-weight-semibold: 600; --font-weight-bold: 700; --tracking-tight: -0.025em; --leading-relaxed: 1.625; --radius-sm: 0.25rem; --radius-lg: 0.5rem; + --radius-xl: 0.75rem; --default-transition-duration: 150ms; --default-transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); --default-font-family: var(--font-sans); @@ -189,6 +194,9 @@ } } @layer utilities { + .pointer-events-none { + pointer-events: none; + } .visible { visibility: visible; } @@ -212,6 +220,9 @@ .static { position: static; } + .inset-y-0 { + inset-block: calc(var(--spacing) * 0); + } .end-0 { inset-inline-end: calc(var(--spacing) * 0); } @@ -221,12 +232,18 @@ .top-0 { top: calc(var(--spacing) * 0); } + .top-2 { + top: calc(var(--spacing) * 2); + } .top-4 { top: calc(var(--spacing) * 4); } .right-0 { right: calc(var(--spacing) * 0); } + .right-2 { + right: calc(var(--spacing) * 2); + } .bottom-0 { bottom: calc(var(--spacing) * 0); } @@ -236,9 +253,18 @@ .z-10 { z-index: 10; } + .float-left { + float: left; + } + .m-0 { + margin: calc(var(--spacing) * 0); + } .mx-auto { margin-inline: auto; } + .mt-1 { + margin-top: calc(var(--spacing) * 1); + } .mt-1\.5 { margin-top: calc(var(--spacing) * 1.5); } @@ -248,9 +274,18 @@ .mt-4 { margin-top: calc(var(--spacing) * 4); } + .mt-5 { + margin-top: calc(var(--spacing) * 5); + } .mt-6 { margin-top: calc(var(--spacing) * 6); } + .mt-7 { + margin-top: calc(var(--spacing) * 7); + } + .mt-8 { + margin-top: calc(var(--spacing) * 8); + } .mt-10 { margin-top: calc(var(--spacing) * 10); } @@ -263,12 +298,24 @@ .mt-24 { margin-top: calc(var(--spacing) * 24); } + .mr-0 { + margin-right: calc(var(--spacing) * 0); + } + .mr-2 { + margin-right: calc(var(--spacing) * 2); + } .mr-5 { margin-right: calc(var(--spacing) * 5); } .mb-auto { margin-bottom: auto; } + .ml-0 { + margin-left: calc(var(--spacing) * 0); + } + .ml-2 { + margin-left: calc(var(--spacing) * 2); + } .ml-auto { margin-left: auto; } @@ -331,9 +378,15 @@ .w-full { width: 100%; } + .max-w-2xl { + max-width: var(--container-2xl); + } .max-w-7xl { max-width: var(--container-7xl); } + .max-w-100 { + max-width: calc(var(--spacing) * 100); + } .max-w-md { max-width: var(--container-md); } @@ -343,6 +396,9 @@ .flex-1 { flex: 1; } + .border-collapse { + border-collapse: collapse; + } .translate-x-0 { --tw-translate-x: calc(var(--spacing) * 0); translate: var(--tw-translate-x) var(--tw-translate-y); @@ -354,6 +410,12 @@ .transform { transform: var(--tw-rotate-x,) var(--tw-rotate-y,) var(--tw-rotate-z,) var(--tw-skew-x,) var(--tw-skew-y,); } + .cursor-pointer { + cursor: pointer; + } + .resize { + resize: both; + } .flex-col { flex-direction: column; } @@ -391,6 +453,12 @@ margin-block-end: calc(calc(var(--spacing) * 1) * calc(1 - var(--tw-space-y-reverse))); } } + .gap-x-2 { + column-gap: calc(var(--spacing) * 2); + } + .gap-y-4 { + row-gap: calc(var(--spacing) * 4); + } .divide-y { :where(& > :not(:last-child)) { --tw-divide-y-reverse: 0; @@ -408,9 +476,15 @@ .overflow-hidden { overflow: hidden; } + .overflow-x-auto { + overflow-x: auto; + } .overflow-x-hidden { overflow-x: hidden; } + .rounded { + border-radius: 0.25rem; + } .rounded-full { border-radius: calc(infinity * 1px); } @@ -420,6 +494,9 @@ .rounded-sm { border-radius: var(--radius-sm); } + .rounded-xl { + border-radius: var(--radius-xl); + } .border { border-style: var(--tw-border-style); border-width: 1px; @@ -427,6 +504,9 @@ .border-surface1 { border-color: var(--surface1); } + .border-transparent { + border-color: transparent; + } .bg-base { background-color: var(--base); } @@ -463,12 +543,21 @@ .p-4 { padding: calc(var(--spacing) * 4); } + .px-2 { + padding-inline: calc(var(--spacing) * 2); + } + .px-3 { + padding-inline: calc(var(--spacing) * 3); + } .px-4 { padding-inline: calc(var(--spacing) * 4); } .px-5 { padding-inline: calc(var(--spacing) * 5); } + .py-1 { + padding-block: calc(var(--spacing) * 1); + } .py-2 { padding-block: calc(var(--spacing) * 2); } @@ -481,12 +570,27 @@ .py-8 { padding-block: calc(var(--spacing) * 8); } + .pe-3 { + padding-inline-end: calc(var(--spacing) * 3); + } + .pt-3 { + padding-top: calc(var(--spacing) * 3); + } .pb-6 { padding-bottom: calc(var(--spacing) * 6); } .text-center { text-align: center; } + .text-left { + text-align: left; + } + .text-right { + text-align: right; + } + .font-mono { + font-family: var(--font-mono); + } .text-2xl { font-size: var(--text-2xl); line-height: var(--tw-leading, var(--text-2xl--line-height)); @@ -515,6 +619,10 @@ font-size: var(--text-xl); line-height: var(--tw-leading, var(--text-xl--line-height)); } + .text-xs { + font-size: var(--text-xs); + line-height: var(--tw-leading, var(--text-xs--line-height)); + } .leading-relaxed { --tw-leading: var(--leading-relaxed); line-height: var(--leading-relaxed); @@ -527,10 +635,20 @@ --tw-font-weight: var(--font-weight-medium); font-weight: var(--font-weight-medium); } + .font-semibold { + --tw-font-weight: var(--font-weight-semibold); + font-weight: var(--font-weight-semibold); + } .tracking-tight { --tw-tracking: var(--tracking-tight); letter-spacing: var(--tracking-tight); } + .break-all { + word-break: break-all; + } + .whitespace-pre-wrap { + white-space: pre-wrap; + } .text-crust { color: var(--crust); } @@ -552,6 +670,9 @@ .text-text { color: var(--text); } + .underline { + text-decoration-line: underline; + } .opacity-0 { opacity: 0%; } @@ -579,6 +700,10 @@ --tw-duration: 200ms; transition-duration: 200ms; } + .select-none { + -webkit-user-select: none; + user-select: none; + } .hover\:cursor-pointer { &:hover { @media (hover: hover) { @@ -674,6 +799,46 @@ } } } + .hover\:text-text { + &:hover { + @media (hover: hover) { + color: var(--text); + } + } + } + .focus\:border-blue { + &:focus { + border-color: var(--blue); + } + } + .focus\:ring-blue { + &:focus { + --tw-ring-color: var(--blue); + } + } + .disabled\:pointer-events-none { + &:disabled { + pointer-events: none; + } + } + .disabled\:cursor-default { + &:disabled { + cursor: default; + } + } + .disabled\:bg-green\/60 { + &:disabled { + background-color: var(--green); + @supports (color: color-mix(in lab, red, red)) { + background-color: color-mix(in oklab, var(--green) 60%, transparent); + } + } + } + .disabled\:opacity-50 { + &:disabled { + opacity: 50%; + } + } .sm\:end-6 { @media (width >= 40rem) { inset-inline-end: calc(var(--spacing) * 6); @@ -704,6 +869,11 @@ gap: calc(var(--spacing) * 2); } } + .sm\:p-7 { + @media (width >= 40rem) { + padding: calc(var(--spacing) * 7); + } + } .sm\:px-6 { @media (width >= 40rem) { padding-inline: calc(var(--spacing) * 6); diff --git a/pkg/oauth/state_test.go b/pkg/oauth/state_test.go index fb0672d..c647e66 100644 --- a/pkg/oauth/state_test.go +++ b/pkg/oauth/state_test.go @@ -465,7 +465,7 @@ func TestConcurrentGeneration(t *testing.T) { errors := make(chan error, numGoroutines) // Generate states concurrently - for i := 0; i < numGoroutines; i++ { + for range numGoroutines { go func() { state, userAgentKey, err := GenerateState(cfg, data) if err != nil { @@ -486,7 +486,7 @@ func TestConcurrentGeneration(t *testing.T) { // Collect results states := make(map[string]bool) - for i := 0; i < numGoroutines; i++ { + for range numGoroutines { select { case state := <-results: if states[state] {