Compare commits
2 Commits
4c31c24069
...
7125683e6a
| Author | SHA1 | Date | |
|---|---|---|---|
| 7125683e6a | |||
| 697bef80e9 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,6 +8,7 @@ tmp/
|
|||||||
static/css/output.css
|
static/css/output.css
|
||||||
internal/view/**/*_templ.go
|
internal/view/**/*_templ.go
|
||||||
internal/view/**/*_templ.txt
|
internal/view/**/*_templ.txt
|
||||||
|
cmd/test/*
|
||||||
|
|
||||||
# Database backups (compressed)
|
# Database backups (compressed)
|
||||||
backups/*.sql.gz
|
backups/*.sql.gz
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func setupAuth(
|
func setupAuth(
|
||||||
config *hwsauth.Config,
|
cfg *hwsauth.Config,
|
||||||
logger *hlog.Logger,
|
logger *hlog.Logger,
|
||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
server *hws.Server,
|
server *hws.Server,
|
||||||
@@ -24,7 +24,7 @@ func setupAuth(
|
|||||||
return tx, err
|
return tx, err
|
||||||
}
|
}
|
||||||
auth, err := hwsauth.NewAuthenticator(
|
auth, err := hwsauth.NewAuthenticator(
|
||||||
config,
|
cfg,
|
||||||
db.GetUserByID,
|
db.GetUserByID,
|
||||||
server,
|
server,
|
||||||
beginTx,
|
beginTx,
|
||||||
@@ -36,11 +36,12 @@ func setupAuth(
|
|||||||
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
|
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.IgnorePaths(ignoredPaths...)
|
err = auth.IgnorePaths(ignoredPaths...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "auth.IgnorePaths")
|
||||||
|
}
|
||||||
|
|
||||||
db.CurrentUser = auth.CurrentModel
|
db.CurrentUser = auth.CurrentModel
|
||||||
|
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: make a new getuser function that wraps db.GetUserByID and does OAuth refresh
|
|
||||||
|
|||||||
@@ -40,7 +40,8 @@ func setupHTTPServer(
|
|||||||
"/ws/notifications",
|
"/ws/notifications",
|
||||||
}
|
}
|
||||||
|
|
||||||
auth, err := setupAuth(cfg.HWSAuth, logger, bun, httpServer, ignoredPaths)
|
auth, err := setupAuth(
|
||||||
|
cfg.HWSAuth, logger, bun, httpServer, ignoredPaths)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "setupAuth")
|
return nil, errors.Wrap(err, "setupAuth")
|
||||||
}
|
}
|
||||||
@@ -74,7 +75,7 @@ func setupHTTPServer(
|
|||||||
return nil, errors.Wrap(err, "addRoutes")
|
return nil, errors.Wrap(err, "addRoutes")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = addMiddleware(httpServer, auth, cfg, perms)
|
err = addMiddleware(httpServer, auth, cfg, perms, discordAPI, store)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "addMiddleware")
|
return nil, errors.Wrap(err, "addMiddleware")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
//
|
|
||||||
// Setup the logger
|
// Setup the logger
|
||||||
logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout)
|
logger, err := hlog.NewLogger(cfg.HLOG, os.Stdout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,13 +4,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hws"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"git.haelnorr.com/h/golib/hwsauth"
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
"git.haelnorr.com/h/oslstats/internal/config"
|
"git.haelnorr.com/h/oslstats/internal/config"
|
||||||
"git.haelnorr.com/h/oslstats/internal/contexts"
|
"git.haelnorr.com/h/oslstats/internal/contexts"
|
||||||
"git.haelnorr.com/h/oslstats/internal/db"
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
"git.haelnorr.com/h/oslstats/internal/rbac"
|
"git.haelnorr.com/h/oslstats/internal/rbac"
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/store"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -21,9 +25,11 @@ func addMiddleware(
|
|||||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
perms *rbac.Checker,
|
perms *rbac.Checker,
|
||||||
|
discordAPI *discord.APIClient,
|
||||||
|
store *store.Store,
|
||||||
) error {
|
) error {
|
||||||
err := server.AddMiddleware(
|
err := server.AddMiddleware(
|
||||||
auth.Authenticate(),
|
auth.Authenticate(tokenRefresh(auth, discordAPI, store)),
|
||||||
perms.LoadPermissionsMiddleware(),
|
perms.LoadPermissionsMiddleware(),
|
||||||
devMode(cfg),
|
devMode(cfg),
|
||||||
)
|
)
|
||||||
@@ -51,3 +57,111 @@ func devMode(cfg *config.Config) hws.Middleware {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tokenRefresh(
|
||||||
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
|
discordAPI *discord.APIClient,
|
||||||
|
store *store.Store,
|
||||||
|
) func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
|
||||||
|
return func(ctx context.Context, user *db.User, tx bun.Tx, w http.ResponseWriter, r *http.Request) (bool, *hws.HWSError) {
|
||||||
|
success, err := refreshToken(ctx, store, discordAPI, user, tx)
|
||||||
|
if err != nil {
|
||||||
|
return false, &hws.HWSError{
|
||||||
|
Error: errors.Wrap(err, "refreshToken"),
|
||||||
|
Message: "Error refreshing discord token",
|
||||||
|
Level: hws.ErrorERROR,
|
||||||
|
RenderErrorPage: true,
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
err = auth.Logout(tx, w, r)
|
||||||
|
if err != nil {
|
||||||
|
return false, &hws.HWSError{
|
||||||
|
Error: errors.Wrap(err, "auth.Logout"),
|
||||||
|
Message: "Logout failed",
|
||||||
|
Level: hws.ErrorERROR,
|
||||||
|
RenderErrorPage: true,
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func refreshToken(
|
||||||
|
ctx context.Context,
|
||||||
|
store *store.Store,
|
||||||
|
discordAPI *discord.APIClient,
|
||||||
|
user *db.User,
|
||||||
|
tx bun.Tx,
|
||||||
|
) (bool, error) {
|
||||||
|
token := store.CheckToken(user)
|
||||||
|
if token != nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
// Get the token
|
||||||
|
token, err := user.GetDiscordToken(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "user.GetDiscordToken")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenstatus, err := tokenStatus(token)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "tokenStatus")
|
||||||
|
}
|
||||||
|
switch tokenstatus {
|
||||||
|
case "revoked":
|
||||||
|
return false, nil
|
||||||
|
case "expired", "expiring":
|
||||||
|
newtoken, err := discordAPI.RefreshToken(token.Convert())
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "discordAPI.RefreshToken")
|
||||||
|
}
|
||||||
|
err = user.UpdateDiscordToken(ctx, tx, newtoken)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "user.UpdateDiscordToken")
|
||||||
|
}
|
||||||
|
err = store.NewTokenCheck(user, token)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "store.NewTokenCheck")
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
case "valid":
|
||||||
|
err = store.NewTokenCheck(user, token)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "store.NewTokenCheck")
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, errors.New("unexpected error occured validating discord token for user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenStatus(token *db.DiscordToken) (string, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
dayfromnow := now + int64(24*time.Hour/time.Second)
|
||||||
|
oauthtoken := token.Convert()
|
||||||
|
session, err := discord.NewOAuthSession(oauthtoken)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "discord.NewOAuthSession")
|
||||||
|
}
|
||||||
|
_, err = session.GetUser()
|
||||||
|
if err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "HTTP 401") {
|
||||||
|
// Error not related to token status
|
||||||
|
return "", errors.Wrap(err, "session.GetUser")
|
||||||
|
}
|
||||||
|
// Token not valid
|
||||||
|
if token.ExpiresAt < now {
|
||||||
|
return "expired", nil
|
||||||
|
}
|
||||||
|
return "revoked", nil
|
||||||
|
}
|
||||||
|
if token.ExpiresAt < dayfromnow {
|
||||||
|
return "expiring", nil
|
||||||
|
}
|
||||||
|
return "valid", nil
|
||||||
|
}
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -7,7 +7,7 @@ require (
|
|||||||
git.haelnorr.com/h/golib/ezconf v0.1.1
|
git.haelnorr.com/h/golib/ezconf v0.1.1
|
||||||
git.haelnorr.com/h/golib/hlog v0.10.4
|
git.haelnorr.com/h/golib/hlog v0.10.4
|
||||||
git.haelnorr.com/h/golib/hws v0.5.0
|
git.haelnorr.com/h/golib/hws v0.5.0
|
||||||
git.haelnorr.com/h/golib/hwsauth v0.5.5
|
git.haelnorr.com/h/golib/hwsauth v0.6.1
|
||||||
git.haelnorr.com/h/golib/notify v0.1.0
|
git.haelnorr.com/h/golib/notify v0.1.0
|
||||||
github.com/a-h/templ v0.3.977
|
github.com/a-h/templ v0.3.977
|
||||||
github.com/coder/websocket v1.8.14
|
github.com/coder/websocket v1.8.14
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -8,8 +8,10 @@ git.haelnorr.com/h/golib/hlog v0.10.4 h1:vpCsV/OddjIYx8F48U66WxojjmhEbeLGQAOBG4V
|
|||||||
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
|
git.haelnorr.com/h/golib/hlog v0.10.4/go.mod h1:+wJ8vecQY/JITTXKmI3JfkHiUGyMs7N6wooj2wuWZbc=
|
||||||
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
|
git.haelnorr.com/h/golib/hws v0.5.0 h1:0CSv2f+dm/KzB/o5o6uXCyvN74iBdMTImhkyAZzU52c=
|
||||||
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
|
git.haelnorr.com/h/golib/hws v0.5.0/go.mod h1:dxAbbGGNzqLXhZXwgt091QsvsPBdrS+1YsNQNldNVoM=
|
||||||
git.haelnorr.com/h/golib/hwsauth v0.5.5 h1:w1qssktq0zYo5cC/xa44h/ZE5G5r7rIsJ4QQWq2Jeoo=
|
git.haelnorr.com/h/golib/hwsauth v0.6.0 h1:qaRsgfcD7M6xCasfXwP6Ww9RM4TwDqYMFK2YtO6nt6c=
|
||||||
git.haelnorr.com/h/golib/hwsauth v0.5.5/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
|
git.haelnorr.com/h/golib/hwsauth v0.6.0/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
|
||||||
|
git.haelnorr.com/h/golib/hwsauth v0.6.1 h1:3BiM6hwuYDjgfu02hshvUtr592DnWi9Epj//3N13ti0=
|
||||||
|
git.haelnorr.com/h/golib/hwsauth v0.6.1/go.mod h1:xPdxqHzr1ZU0MHlG4o8r1zEstBu4FJCdaA0ZHSFxmKA=
|
||||||
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
|
git.haelnorr.com/h/golib/jwt v0.10.1 h1:1Adxt9H3Y4fWFvFjWpvg/vSFhbgCMDMxgiE3m7KvDMI=
|
||||||
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
git.haelnorr.com/h/golib/jwt v0.10.1/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||||
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
git.haelnorr.com/h/golib/notify v0.1.0 h1:xdf6zd21F6n+SuGTeJiuLNMf6zFXMvwpKD0gmNq8N10=
|
||||||
|
|||||||
@@ -59,11 +59,9 @@ type AuditLogFilters struct {
|
|||||||
|
|
||||||
// GetAuditLogs retrieves audit logs with optional filters and pagination
|
// GetAuditLogs retrieves audit logs with optional filters and pagination
|
||||||
func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) {
|
func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *AuditLogFilters) (*AuditLogs, error) {
|
||||||
pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderDesc, "created_at")
|
|
||||||
query := tx.NewSelect().
|
query := tx.NewSelect().
|
||||||
Model((*AuditLog)(nil)).
|
Model((*AuditLog)(nil)).
|
||||||
Relation("User").
|
Relation("User")
|
||||||
OrderBy(pageOpts.OrderBy, pageOpts.Order)
|
|
||||||
|
|
||||||
// Apply filters if provided
|
// Apply filters if provided
|
||||||
if filters != nil {
|
if filters != nil {
|
||||||
@@ -88,11 +86,9 @@ func GetAuditLogs(ctx context.Context, tx bun.Tx, pageOpts *PageOpts, filters *A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get paginated results
|
// Get paginated results
|
||||||
|
query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderDesc, "created_at")
|
||||||
logs := new([]*AuditLog)
|
logs := new([]*AuditLog)
|
||||||
err = query.
|
err = query.Scan(ctx, &logs)
|
||||||
Offset(pageOpts.PerPage*(pageOpts.Page-1)).
|
|
||||||
Limit(pageOpts.PerPage).
|
|
||||||
Scan(ctx, &logs)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return nil, errors.Wrap(err, "query.Scan")
|
return nil, errors.Wrap(err, "query.Scan")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||||
@@ -57,6 +58,9 @@ func (u *User) DeleteDiscordTokens(ctx context.Context, tx bun.Tx) (*DiscordToke
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "user.GetDiscordToken")
|
return nil, errors.Wrap(err, "user.GetDiscordToken")
|
||||||
}
|
}
|
||||||
|
if token == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
_, err = tx.NewDelete().
|
_, err = tx.NewDelete().
|
||||||
Model((*DiscordToken)(nil)).
|
Model((*DiscordToken)(nil)).
|
||||||
Where("discord_id = ?", u.DiscordID).
|
Where("discord_id = ?", u.DiscordID).
|
||||||
@@ -76,6 +80,9 @@ func (u *User) GetDiscordToken(ctx context.Context, tx bun.Tx) (*DiscordToken, e
|
|||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import "github.com/uptrace/bun"
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
type PageOpts struct {
|
type PageOpts struct {
|
||||||
Page int
|
Page int
|
||||||
@@ -15,7 +19,7 @@ type OrderOpts struct {
|
|||||||
Label string
|
Label string
|
||||||
}
|
}
|
||||||
|
|
||||||
func setDefaultPageOpts(p *PageOpts, page, perpage int, order bun.Order, orderby string) *PageOpts {
|
func setPageOpts(q *bun.SelectQuery, p *PageOpts, page, perpage int, order bun.Order, orderby string) (*bun.SelectQuery, *PageOpts) {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
p = new(PageOpts)
|
p = new(PageOpts)
|
||||||
}
|
}
|
||||||
@@ -31,7 +35,46 @@ func setDefaultPageOpts(p *PageOpts, page, perpage int, order bun.Order, orderby
|
|||||||
if p.OrderBy == "" {
|
if p.OrderBy == "" {
|
||||||
p.OrderBy = orderby
|
p.OrderBy = orderby
|
||||||
}
|
}
|
||||||
return p
|
p.OrderBy = sanitiseOrderBy(p.OrderBy)
|
||||||
|
q = q.OrderBy(p.OrderBy, p.Order).
|
||||||
|
Limit(p.PerPage).
|
||||||
|
Offset(p.PerPage * (p.Page - 1))
|
||||||
|
return q, p
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitiseOrderBy(orderby string) string {
|
||||||
|
result := strings.ToLower(orderby)
|
||||||
|
var builder strings.Builder
|
||||||
|
for _, r := range result {
|
||||||
|
if isValidChar(r) {
|
||||||
|
builder.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sanitized := builder.String()
|
||||||
|
|
||||||
|
if sanitized == "" {
|
||||||
|
return "_"
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidFirstChar(rune(sanitized[0])) {
|
||||||
|
sanitized = "_" + sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sanitized) > 63 {
|
||||||
|
sanitized = sanitized[:63]
|
||||||
|
}
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidChar(r rune) bool {
|
||||||
|
return (r >= 'a' && r <= 'z') ||
|
||||||
|
(r >= '0' && r <= '9') ||
|
||||||
|
r == '_'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidFirstChar(r rune) bool {
|
||||||
|
return (r >= 'a' && r <= 'z') || r == '_'
|
||||||
}
|
}
|
||||||
|
|
||||||
// TotalPages calculates the total number of pages
|
// TotalPages calculates the total number of pages
|
||||||
|
|||||||
@@ -37,7 +37,10 @@ func GetPermissionByName(ctx context.Context, tx bun.Tx, name permissions.Permis
|
|||||||
Where("name = ?", name).
|
Where("name = ?", name).
|
||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return perm, nil
|
return perm, nil
|
||||||
@@ -56,7 +59,10 @@ func GetPermissionByID(ctx context.Context, tx bun.Tx, id int) (*Permission, err
|
|||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return perm, nil
|
return perm, nil
|
||||||
@@ -115,9 +121,22 @@ func CreatePermission(ctx context.Context, tx bun.Tx, perm *Permission) error {
|
|||||||
if perm == nil {
|
if perm == nil {
|
||||||
return errors.New("permission cannot be nil")
|
return errors.New("permission cannot be nil")
|
||||||
}
|
}
|
||||||
|
if perm.Name == "" {
|
||||||
|
return errors.New("name cannot be empty")
|
||||||
|
}
|
||||||
|
if perm.DisplayName == "" {
|
||||||
|
return errors.New("display name cannot be empty")
|
||||||
|
}
|
||||||
|
if perm.Resource == "" {
|
||||||
|
return errors.New("resource cannot be empty")
|
||||||
|
}
|
||||||
|
if perm.Action == "" {
|
||||||
|
return errors.New("action cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
_, err := tx.NewInsert().
|
_, err := tx.NewInsert().
|
||||||
Model(perm).
|
Model(perm).
|
||||||
|
Returning("id").
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "tx.NewInsert")
|
return errors.Wrap(err, "tx.NewInsert")
|
||||||
|
|||||||
@@ -46,7 +46,10 @@ func GetRoleByName(ctx context.Context, tx bun.Tx, name roles.Role) (*Role, erro
|
|||||||
Where("name = ?", name).
|
Where("name = ?", name).
|
||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return role, nil
|
return role, nil
|
||||||
@@ -65,7 +68,10 @@ func GetRoleByID(ctx context.Context, tx bun.Tx, id int) (*Role, error) {
|
|||||||
Where("id = ?", id).
|
Where("id = ?", id).
|
||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return role, nil
|
return role, nil
|
||||||
@@ -84,7 +90,10 @@ func GetRoleWithPermissions(ctx context.Context, tx bun.Tx, id int) (*Role, erro
|
|||||||
Relation("Permissions").
|
Relation("Permissions").
|
||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return role, nil
|
return role, nil
|
||||||
@@ -112,6 +121,7 @@ func CreateRole(ctx context.Context, tx bun.Tx, role *Role) error {
|
|||||||
|
|
||||||
_, err := tx.NewInsert().
|
_, err := tx.NewInsert().
|
||||||
Model(role).
|
Model(role).
|
||||||
|
Returning("id").
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "tx.NewInsert")
|
return errors.Wrap(err, "tx.NewInsert")
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim
|
|||||||
}
|
}
|
||||||
_, err := tx.NewInsert().
|
_, err := tx.NewInsert().
|
||||||
Model(season).
|
Model(season).
|
||||||
|
Returning("id").
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "tx.NewInsert")
|
return nil, errors.Wrap(err, "tx.NewInsert")
|
||||||
@@ -50,22 +51,18 @@ func NewSeason(ctx context.Context, tx bun.Tx, name, shortname string, start tim
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) {
|
func ListSeasons(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*SeasonList, error) {
|
||||||
pageOpts = setDefaultPageOpts(pageOpts, 1, 10, bun.OrderDesc, "start_date")
|
|
||||||
seasons := new([]*Season)
|
seasons := new([]*Season)
|
||||||
err := tx.NewSelect().
|
query := tx.NewSelect().
|
||||||
Model(seasons).
|
Model(seasons)
|
||||||
OrderBy(pageOpts.OrderBy, pageOpts.Order).
|
|
||||||
Offset(pageOpts.PerPage * (pageOpts.Page - 1)).
|
total, err := query.Count(ctx)
|
||||||
Limit(pageOpts.PerPage).
|
|
||||||
Scan(ctx)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
|
||||||
}
|
|
||||||
total, err := tx.NewSelect().
|
|
||||||
Model(seasons).
|
|
||||||
Count(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "query.Count")
|
||||||
|
}
|
||||||
|
query, pageOpts = setPageOpts(query, pageOpts, 1, 10, bun.OrderDesc, "start_date")
|
||||||
|
err = query.Scan(ctx)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return nil, errors.Wrap(err, "query.Scan")
|
||||||
}
|
}
|
||||||
sl := &SeasonList{
|
sl := &SeasonList{
|
||||||
Seasons: *seasons,
|
Seasons: *seasons,
|
||||||
@@ -82,7 +79,10 @@ func GetSeason(ctx context.Context, tx bun.Tx, shortname string) (*Season, error
|
|||||||
Where("short_name = ?", strings.ToUpper(shortname)).
|
Where("short_name = ?", strings.ToUpper(shortname)).
|
||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return season, nil
|
return season, nil
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package db
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hwsauth"
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
@@ -50,6 +49,7 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
|
|||||||
|
|
||||||
_, err := tx.NewInsert().
|
_, err := tx.NewInsert().
|
||||||
Model(user).
|
Model(user).
|
||||||
|
Returning("id").
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "tx.NewInsert")
|
return nil, errors.Wrap(err, "tx.NewInsert")
|
||||||
@@ -61,7 +61,6 @@ func CreateUser(ctx context.Context, tx bun.Tx, username string, discorduser *di
|
|||||||
// GetUserByID queries the database for a user matching the given ID
|
// GetUserByID queries the database for a user matching the given ID
|
||||||
// Returns nil, nil if no user is found
|
// Returns nil, nil if no user is found
|
||||||
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
|
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
|
||||||
fmt.Printf("user id requested: %v", id)
|
|
||||||
user := new(User)
|
user := new(User)
|
||||||
err := tx.NewSelect().
|
err := tx.NewSelect().
|
||||||
Model(user).
|
Model(user).
|
||||||
@@ -69,7 +68,7 @@ func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*User, error) {
|
|||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "sql: no rows in result set" {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
@@ -87,10 +86,10 @@ func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*User,
|
|||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "sql: no rows in result set" {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(err, "tx.Select")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
@@ -105,7 +104,7 @@ func GetUserByDiscordID(ctx context.Context, tx bun.Tx, discordID string) (*User
|
|||||||
Limit(1).
|
Limit(1).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "sql: no rows in result set" {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "tx.NewSelect")
|
||||||
@@ -201,22 +200,17 @@ func (u *User) IsAdmin(ctx context.Context, tx bun.Tx) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) {
|
func GetUsers(ctx context.Context, tx bun.Tx, pageOpts *PageOpts) (*Users, error) {
|
||||||
pageOpts = setDefaultPageOpts(pageOpts, 1, 50, bun.OrderAsc, "id")
|
|
||||||
users := new([]*User)
|
users := new([]*User)
|
||||||
err := tx.NewSelect().
|
query := tx.NewSelect().
|
||||||
Model(users).
|
Model(users)
|
||||||
OrderBy(pageOpts.OrderBy, pageOpts.Order).
|
total, err := query.Count(ctx)
|
||||||
Limit(pageOpts.PerPage).
|
|
||||||
Offset(pageOpts.PerPage * (pageOpts.Page - 1)).
|
|
||||||
Scan(ctx)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
|
||||||
}
|
|
||||||
total, err := tx.NewSelect().
|
|
||||||
Model(users).
|
|
||||||
Count(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "tx.NewSelect")
|
return nil, errors.Wrap(err, "query.Count")
|
||||||
|
}
|
||||||
|
query, pageOpts = setPageOpts(query, pageOpts, 1, 50, bun.OrderAsc, "id")
|
||||||
|
err = query.Scan(ctx)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
return nil, errors.Wrap(err, "query.Scan")
|
||||||
}
|
}
|
||||||
list := &Users{
|
list := &Users{
|
||||||
Users: *users,
|
Users: *users,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
"git.haelnorr.com/h/oslstats/internal/roles"
|
"git.haelnorr.com/h/oslstats/internal/roles"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -72,9 +73,12 @@ func HasRole(ctx context.Context, tx bun.Tx, userID int, roleName roles.Role) (b
|
|||||||
err := tx.NewSelect().
|
err := tx.NewSelect().
|
||||||
Model(user).
|
Model(user).
|
||||||
Relation("Roles").
|
Relation("Roles").
|
||||||
Where("u.id = ? ", userID).
|
Where("u.id = ?", userID).
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
return false, errors.Wrap(err, "tx.NewSelect")
|
return false, errors.Wrap(err, "tx.NewSelect")
|
||||||
}
|
}
|
||||||
for _, role := range user.Roles {
|
for _, role := range user.Roles {
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func Logout(
|
func Logout(
|
||||||
server *hws.Server,
|
s *hws.Server,
|
||||||
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
auth *hwsauth.Authenticator[*db.User, bun.Tx],
|
||||||
conn *bun.DB,
|
conn *bun.DB,
|
||||||
discordAPI *discord.APIClient,
|
discordAPI *discord.APIClient,
|
||||||
@@ -26,10 +26,10 @@ func Logout(
|
|||||||
|
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "conn.BeginTx"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
user := db.CurrentUser(r.Context())
|
user := db.CurrentUser(r.Context())
|
||||||
if user == nil {
|
if user == nil {
|
||||||
@@ -39,20 +39,26 @@ func Logout(
|
|||||||
}
|
}
|
||||||
token, err := user.DeleteDiscordTokens(ctx, tx)
|
token, err := user.DeleteDiscordTokens(ctx, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(server, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
|
throwInternalServiceError(s, w, r, "Database error", errors.Wrap(err, "user.DeleteDiscordTokens"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = discordAPI.RevokeToken(token.Convert())
|
if token != nil {
|
||||||
if err != nil {
|
err = discordAPI.RevokeToken(token.Convert())
|
||||||
throwInternalServiceError(server, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
|
if err != nil {
|
||||||
return
|
throwInternalServiceError(s, w, r, "Discord API error", errors.Wrap(err, "discordAPI.RevokeToken"))
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err = auth.Logout(tx, w, r)
|
err = auth.Logout(tx, w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
throwInternalServiceError(server, w, r, "Logout failed", err)
|
throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "auth.Logout"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
throwInternalServiceError(s, w, r, "Logout failed", errors.Wrap(err, "tx.Commit"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
|
||||||
w.Header().Set("HX-Redirect", "/")
|
w.Header().Set("HX-Redirect", "/")
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
// Package store provides a session store for caching data
|
||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -22,6 +23,7 @@ type Store struct {
|
|||||||
sessions sync.Map // key: string, value: *RegistrationSession
|
sessions sync.Map // key: string, value: *RegistrationSession
|
||||||
redirectTracks sync.Map // key: string, value: *RedirectTrack
|
redirectTracks sync.Map // key: string, value: *RedirectTrack
|
||||||
cleanup *time.Ticker
|
cleanup *time.Ticker
|
||||||
|
tokenchecks sync.Map // key: int, value: *TokenCheck
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStore() *Store {
|
func NewStore() *Store {
|
||||||
@@ -42,6 +44,7 @@ func NewStore() *Store {
|
|||||||
func (s *Store) Delete(id string) {
|
func (s *Store) Delete(id string) {
|
||||||
s.sessions.Delete(id)
|
s.sessions.Delete(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) cleanupExpired() {
|
func (s *Store) cleanupExpired() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
@@ -62,10 +65,19 @@ func (s *Store) cleanupExpired() {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
s.tokenchecks.Range(func(key, value any) bool {
|
||||||
|
check := value.(*TokenCheck)
|
||||||
|
if now.After(check.ExpiresAt) {
|
||||||
|
s.tokenchecks.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateID() string {
|
func generateID() string {
|
||||||
b := make([]byte, 32)
|
b := make([]byte, 32)
|
||||||
rand.Read(b)
|
_, _ = rand.Read(b)
|
||||||
return base64.RawURLEncoding.EncodeToString(b)
|
return base64.RawURLEncoding.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
39
internal/store/tokencheck.go
Normal file
39
internal/store/tokencheck.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/oslstats/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenCheck struct {
|
||||||
|
Token *db.DiscordToken
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) NewTokenCheck(user *db.User, token *db.DiscordToken) error {
|
||||||
|
if user == nil {
|
||||||
|
return errors.New("user cannot be nil")
|
||||||
|
}
|
||||||
|
if token == nil {
|
||||||
|
return errors.New("token cannot be nil")
|
||||||
|
}
|
||||||
|
s.tokenchecks.Store(user.ID, &TokenCheck{
|
||||||
|
Token: token,
|
||||||
|
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CheckToken(user *db.User) *db.DiscordToken {
|
||||||
|
res, ok := s.tokenchecks.Load(user.ID)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
check := (res).(*TokenCheck)
|
||||||
|
check.ExpiresAt = time.Now().Add(5 * time.Minute)
|
||||||
|
s.tokenchecks.Delete(user.ID)
|
||||||
|
s.tokenchecks.Store(user.ID, check)
|
||||||
|
return check.Token
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user