Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1eedbc5220 | |||
| 6e03c98ae8 | |||
| a0cd269466 | |||
| 28b7ba34f0 | |||
| 4a21ba3821 | |||
| 1bcdf0e813 | |||
| 6dd80ee7b6 | |||
| 8f6b4b0026 | |||
| 03095448d6 | |||
| 1d9af44d0a | |||
| 5c1089e0ce | |||
| e3d2eb1af8 | |||
| 9e12f946b3 | |||
| 141b541e98 | |||
| 540782e2d5 | |||
| 1d5c662bf0 | |||
| b6e0a977c0 | |||
| 3db77eca71 | |||
| 8fcec675e6 | |||
| aa47802f46 | |||
| 05849d028d | |||
| f7f610d7ef | |||
| e2d66fc26d | |||
| 8fa20e05c0 | |||
| a3e9ffb012 | |||
| 838d6264c9 | |||
| e794024786 | |||
| 725038009a | |||
| d8d2307859 |
@@ -5,7 +5,7 @@ tmp_dir = "tmp"
|
|||||||
[build]
|
[build]
|
||||||
args_bin = []
|
args_bin = []
|
||||||
bin = "./tmp/main"
|
bin = "./tmp/main"
|
||||||
cmd = "go build -o ./tmp/main ."
|
cmd = "go build -o ./tmp/main ./cmd/projectreshoot"
|
||||||
delay = 1000
|
delay = 1000
|
||||||
exclude_dir = []
|
exclude_dir = []
|
||||||
exclude_file = []
|
exclude_file = []
|
||||||
|
|||||||
4
.github/workflows/deploy_production.yaml
vendored
4
.github/workflows/deploy_production.yaml
vendored
@@ -53,10 +53,10 @@ jobs:
|
|||||||
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
||||||
scp -i ~/.ssh/id_ed25519 projectreshoot-production-${GITHUB_SHA} $USER@$HOST:$DIR
|
scp -i ~/.ssh/id_ed25519 ./bin/projectreshoot-production-${GITHUB_SHA} $USER@$HOST:$DIR
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
||||||
scp -i ~/.ssh/id_ed25519 prmigrate-production-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 .bin/migrate-production-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
||||||
|
|
||||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
||||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
||||||
|
|||||||
4
.github/workflows/deploy_staging.yaml
vendored
4
.github/workflows/deploy_staging.yaml
vendored
@@ -53,10 +53,10 @@ jobs:
|
|||||||
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
||||||
scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
|
scp -i ~/.ssh/id_ed25519 ./bin/projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
||||||
scp -i ~/.ssh/id_ed25519 prmigrate-staging-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./bin/migrate-staging-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
||||||
|
|
||||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
||||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
||||||
|
|||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,11 +1,9 @@
|
|||||||
.env
|
.env
|
||||||
query.sql
|
*.db*
|
||||||
*.db
|
|
||||||
.logs/
|
.logs/
|
||||||
server.log
|
server.log
|
||||||
|
bin/
|
||||||
tmp/
|
tmp/
|
||||||
prmigrate
|
|
||||||
projectreshoot
|
|
||||||
static/css/output.css
|
static/css/output.css
|
||||||
view/**/*_templ.go
|
internal/view/**/*_templ.go
|
||||||
view/**/*_templ.txt
|
internal/view/**/*_templ.txt
|
||||||
|
|||||||
28
Makefile
28
Makefile
@@ -5,33 +5,25 @@
|
|||||||
BINARY_NAME=projectreshoot
|
BINARY_NAME=projectreshoot
|
||||||
|
|
||||||
build:
|
build:
|
||||||
tailwindcss -i ./static/css/input.css -o ./static/css/output.css && \
|
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css && \
|
||||||
go mod tidy && \
|
go mod tidy && \
|
||||||
templ generate && \
|
templ generate && \
|
||||||
go generate && \
|
go generate ./cmd/${BINARY_NAME} && \
|
||||||
go build -ldflags="-w -s" -o ${BINARY_NAME}${SUFFIX}
|
go build -ldflags="-w -s" -o ./bin/${BINARY_NAME}${SUFFIX} ./cmd/${BINARY_NAME}
|
||||||
|
|
||||||
|
run:
|
||||||
|
make build
|
||||||
|
./bin/${BINARY_NAME}${SUFFIX}
|
||||||
|
|
||||||
dev:
|
dev:
|
||||||
templ generate --watch &\
|
templ generate --watch &\
|
||||||
air &\
|
air &\
|
||||||
tailwindcss -i ./static/css/input.css -o ./static/css/output.css --watch
|
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css --watch
|
||||||
|
|
||||||
tester:
|
|
||||||
go mod tidy && \
|
|
||||||
go run . --port 3232 --tester --loglevel trace
|
|
||||||
|
|
||||||
test:
|
|
||||||
go mod tidy && \
|
|
||||||
templ generate && \
|
|
||||||
go generate && \
|
|
||||||
go test .
|
|
||||||
go test ./db
|
|
||||||
go test ./middleware
|
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
go clean
|
go clean
|
||||||
|
|
||||||
migrate:
|
migrate:
|
||||||
go mod tidy && \
|
go mod tidy && \
|
||||||
go generate && \
|
go generate ./cmd/migrate && \
|
||||||
go build -ldflags="-w -s" -o prmigrate${SUFFIX} ./migrate
|
go build -ldflags="-w -s" -o ./bin/migrate${SUFFIX} ./cmd/migrate
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ var migrationsFS embed.FS
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) != 4 {
|
if len(os.Args) != 4 {
|
||||||
fmt.Println("Usage: prmigrate <file_path> up-to|down-to <version>")
|
fmt.Println("Usage: migrate <file_path> up-to|down-to <version>")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
45
cmd/projectreshoot/auth.go
Normal file
45
cmd/projectreshoot/auth.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/handler"
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
"projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupAuth(
|
||||||
|
config *config.Config,
|
||||||
|
logger *hlog.Logger,
|
||||||
|
db *bun.DB,
|
||||||
|
server *hws.Server,
|
||||||
|
ignoredPaths []string,
|
||||||
|
) (*hwsauth.Authenticator[*models.UserBun, bun.Tx], error) {
|
||||||
|
beginTx := func(ctx context.Context) (hwsauth.DBTransaction, error) {
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
return tx, err
|
||||||
|
}
|
||||||
|
auth, err := hwsauth.NewAuthenticator(
|
||||||
|
config.HWSAuth,
|
||||||
|
models.GetUserByID,
|
||||||
|
server,
|
||||||
|
beginTx,
|
||||||
|
logger,
|
||||||
|
handler.ErrorPage,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.IgnorePaths(ignoredPaths...)
|
||||||
|
|
||||||
|
contexts.CurrentUser = auth.CurrentModel
|
||||||
|
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
52
cmd/projectreshoot/db.go
Normal file
52
cmd/projectreshoot/db.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
"github.com/uptrace/bun/dialect/pgdialect"
|
||||||
|
"github.com/uptrace/bun/driver/pgdriver"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupBun(ctx context.Context, cfg *config.DBConfig, resetDB bool) (db *bun.DB, close func() error, err error) {
|
||||||
|
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DB, cfg.SSL)
|
||||||
|
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
|
||||||
|
db = bun.NewDB(sqldb, pgdialect.New())
|
||||||
|
close = sqldb.Close
|
||||||
|
|
||||||
|
err = loadModels(ctx, db, resetDB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "loadModels")
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, close, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadModels(ctx context.Context, db *bun.DB, resetDB bool) error {
|
||||||
|
models := []any{
|
||||||
|
(*models.UserBun)(nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
_, err := db.NewCreateTable().
|
||||||
|
Model(model).
|
||||||
|
IfNotExists().
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "db.NewCreateTable")
|
||||||
|
}
|
||||||
|
if resetDB {
|
||||||
|
err = db.ResetModel(ctx, model)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "db.ResetModel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
24
cmd/projectreshoot/flags.go
Normal file
24
cmd/projectreshoot/flags.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupFlags() map[string]string {
|
||||||
|
// Parse commandline args
|
||||||
|
resetDB := flag.Bool("resetdb", false, "Reset all the database tables with the updated models")
|
||||||
|
printEnv := flag.Bool("printenv", false, "Print all environment variables and their documentation")
|
||||||
|
genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)")
|
||||||
|
envfile := flag.String("envfile", ".env", "Specify a .env file to use for the configuration")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Map the args for easy access
|
||||||
|
args := map[string]string{
|
||||||
|
"resetdb": strconv.FormatBool(*resetDB),
|
||||||
|
"printenv": strconv.FormatBool(*printEnv),
|
||||||
|
"genenv": *genEnv,
|
||||||
|
"envfile": *envfile,
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
67
cmd/projectreshoot/httpserver.go
Normal file
67
cmd/projectreshoot/httpserver.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/handler"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupHttpServer(
|
||||||
|
staticFS *fs.FS,
|
||||||
|
config *config.Config,
|
||||||
|
logger *hlog.Logger,
|
||||||
|
bun *bun.DB,
|
||||||
|
) (server *hws.Server, err error) {
|
||||||
|
if staticFS == nil {
|
||||||
|
return nil, errors.New("No filesystem provided")
|
||||||
|
}
|
||||||
|
fs := http.FS(*staticFS)
|
||||||
|
httpServer, err := hws.NewServer(config.HWS)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hws.NewServer")
|
||||||
|
}
|
||||||
|
|
||||||
|
ignoredPaths := []string{
|
||||||
|
"/static/css/output.css",
|
||||||
|
"/static/favicon.ico",
|
||||||
|
}
|
||||||
|
|
||||||
|
auth, err := setupAuth(config, logger, bun, httpServer, ignoredPaths)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "setupAuth")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = httpServer.AddErrorPage(handler.ErrorPage)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "httpServer.AddErrorPage")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = httpServer.AddLogger(logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "httpServer.AddLogger")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = httpServer.LoggerIgnorePaths(ignoredPaths...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addRoutes(httpServer, &fs, config, logger, bun, auth)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "addRoutes")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addMiddleware(httpServer, auth)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "httpServer.AddMiddleware")
|
||||||
|
}
|
||||||
|
|
||||||
|
return httpServer, nil
|
||||||
|
}
|
||||||
47
cmd/projectreshoot/logger.go
Normal file
47
cmd/projectreshoot/logger.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Take in the desired logOutput and a console writer to use
|
||||||
|
func setupLogger(cfg *config.HLOGConfig, w *io.Writer) (*hlog.Logger, error) {
|
||||||
|
// Setup the logfile
|
||||||
|
var logfile *os.File = nil
|
||||||
|
if cfg.LogOutput == "both" || cfg.LogOutput == "file" {
|
||||||
|
logfile, err := hlog.NewLogFile(cfg.LogDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hlog")
|
||||||
|
}
|
||||||
|
defer logfile.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the console writer
|
||||||
|
var consoleWriter io.Writer
|
||||||
|
if cfg.LogOutput == "both" || cfg.LogOutput == "console" {
|
||||||
|
if w != nil {
|
||||||
|
consoleWriter = *w
|
||||||
|
} else {
|
||||||
|
if cfg.LogOutput == "console" {
|
||||||
|
return nil, errors.New("Console logging specified as sole method but no writer provided")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the logger
|
||||||
|
logger, err := hlog.NewLogger(
|
||||||
|
cfg.LogLevel,
|
||||||
|
consoleWriter,
|
||||||
|
logfile,
|
||||||
|
cfg.LogDir,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hlog")
|
||||||
|
}
|
||||||
|
return logger, nil
|
||||||
|
}
|
||||||
45
cmd/projectreshoot/main.go
Normal file
45
cmd/projectreshoot/main.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
args := setupFlags()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Handle printenv flag
|
||||||
|
if args["printenv"] == "true" {
|
||||||
|
if err := config.PrintEnvVars(os.Stdout); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to print environment variables: %s\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle genenv flag
|
||||||
|
if args["genenv"] != "" {
|
||||||
|
if err := config.GenerateDotEnv(args["genenv"]); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to generate .env file: %s\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("Successfully generated .env file: %s\n", args["genenv"])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := config.GetConfig(args["envfile"])
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config"))
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := run(ctx, os.Stdout, args, cfg); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
25
cmd/projectreshoot/middleware.go
Normal file
25
cmd/projectreshoot/middleware.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func addMiddleware(
|
||||||
|
server *hws.Server,
|
||||||
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
|
) error {
|
||||||
|
|
||||||
|
err := server.AddMiddleware(
|
||||||
|
auth.Authenticate(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "server.AddMiddleware")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
126
cmd/projectreshoot/routes.go
Normal file
126
cmd/projectreshoot/routes.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/handler"
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
"projectreshoot/internal/view/page"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func addRoutes(
|
||||||
|
server *hws.Server,
|
||||||
|
staticFS *http.FileSystem,
|
||||||
|
config *config.Config,
|
||||||
|
logger *hlog.Logger,
|
||||||
|
db *bun.DB,
|
||||||
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
|
) error {
|
||||||
|
// Create the routes
|
||||||
|
routes := []hws.Route{
|
||||||
|
{
|
||||||
|
Path: "/static/",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: http.StripPrefix("/static/", handler.StaticFS(staticFS, logger)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: handler.Root(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/about",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: handler.HandlePage(page.About()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/login",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LogoutReq(handler.LoginPage(config.HWSAuth.TrustedHost)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/login",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LogoutReq(handler.LoginRequest(server, auth, db)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/register",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LogoutReq(handler.RegisterPage(config.HWSAuth.TrustedHost)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/register",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LogoutReq(handler.RegisterRequest(server, auth, db)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/logout",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: handler.Logout(server, auth, db),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/reauthenticate",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(handler.Reauthenticate(server, auth, db)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/profile",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LoginReq(handler.ProfilePage()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/account",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LoginReq(handler.AccountPage()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/account-select-page",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(handler.AccountSubpage()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/change-username",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(auth.FreshReq(handler.ChangeUsername(server, auth, db))),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/change-password",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(auth.FreshReq(handler.ChangePassword(server, auth, db))),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/change-bio",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(handler.ChangeBio(server, auth, db)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/movies",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: handler.MoviesPage(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/search-movies",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: handler.SearchMovies(config.TMDB, logger),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/movie/{movie_id}",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: handler.Movie(server, config.TMDB),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the routes with the server
|
||||||
|
err := server.AddRoutes(routes...)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "server.AddRoutes")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
75
cmd/projectreshoot/run.go
Normal file
75
cmd/projectreshoot/run.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/pkg/embedfs"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initializes and runs the server
|
||||||
|
func run(ctx context.Context, w io.Writer, args map[string]string, config *config.Config) error {
|
||||||
|
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
logger, err := setupLogger(config.HLOG, &w)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "setupLogger")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the database connection
|
||||||
|
logger.Debug().Msg("Config loaded and logger started")
|
||||||
|
logger.Debug().Msg("Connecting to database")
|
||||||
|
resetdb, err := strconv.ParseBool(args["resetdb"])
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "strconv.ParseBool")
|
||||||
|
}
|
||||||
|
bun, closedb, err := setupBun(ctx, config.DB, resetdb)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "setupDBConn")
|
||||||
|
}
|
||||||
|
defer closedb()
|
||||||
|
|
||||||
|
// Setup embedded files
|
||||||
|
logger.Debug().Msg("Getting embedded files")
|
||||||
|
staticFS, err := embedfs.GetEmbeddedFS()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "getStaticFiles")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug().Msg("Setting up HTTP server")
|
||||||
|
httpServer, err := setupHttpServer(&staticFS, config, logger, bun)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "setupHttpServer")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runs the http server
|
||||||
|
logger.Debug().Msg("Starting up the HTTP server")
|
||||||
|
err = httpServer.Start(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "httpServer.Start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles graceful shutdown
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Go(func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
shutdownCtx := context.Background()
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
err := httpServer.Shutdown(shutdownCtx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error().Err(err).Msg("Graceful shutdown failed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
logger.Info().Msg("Shutting down")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
106
config/config.go
106
config/config.go
@@ -1,106 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/logging"
|
|
||||||
|
|
||||||
"github.com/joho/godotenv"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
Host string // Host to listen on
|
|
||||||
Port string // Port to listen on
|
|
||||||
TrustedHost string // Domain/Hostname to accept as trusted
|
|
||||||
SSL bool // Flag for SSL Mode
|
|
||||||
GZIP bool // Flag for GZIP compression on requests
|
|
||||||
ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
|
|
||||||
WriteTimeout time.Duration // Timeout for writing requests in seconds
|
|
||||||
IdleTimeout time.Duration // Timeout for idle connections in seconds
|
|
||||||
DBName string // Filename of the db - hardcoded and doubles as DB version
|
|
||||||
DBLockTimeout time.Duration // Timeout for acquiring database lock
|
|
||||||
SecretKey string // Secret key for signing tokens
|
|
||||||
AccessTokenExpiry int64 // Access token expiry in minutes
|
|
||||||
RefreshTokenExpiry int64 // Refresh token expiry in minutes
|
|
||||||
TokenFreshTime int64 // Time for tokens to stay fresh in minutes
|
|
||||||
LogLevel zerolog.Level // Log level for global logging. Defaults to info
|
|
||||||
LogOutput string // "file", "console", or "both". Defaults to console
|
|
||||||
LogDir string // Path to create log files
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the application configuration and get a pointer to the Config object
|
|
||||||
func GetConfig(args map[string]string) (*Config, error) {
|
|
||||||
godotenv.Load(".env")
|
|
||||||
var (
|
|
||||||
host string
|
|
||||||
port string
|
|
||||||
logLevel zerolog.Level
|
|
||||||
logOutput string
|
|
||||||
valid bool
|
|
||||||
)
|
|
||||||
|
|
||||||
if args["host"] != "" {
|
|
||||||
host = args["host"]
|
|
||||||
} else {
|
|
||||||
host = GetEnvDefault("HOST", "127.0.0.1")
|
|
||||||
}
|
|
||||||
if args["port"] != "" {
|
|
||||||
port = args["port"]
|
|
||||||
} else {
|
|
||||||
port = GetEnvDefault("PORT", "3010")
|
|
||||||
}
|
|
||||||
if args["loglevel"] != "" {
|
|
||||||
logLevel = logging.GetLogLevel(args["loglevel"])
|
|
||||||
} else {
|
|
||||||
logLevel = logging.GetLogLevel(GetEnvDefault("LOG_LEVEL", "info"))
|
|
||||||
}
|
|
||||||
if args["logoutput"] != "" {
|
|
||||||
opts := map[string]string{
|
|
||||||
"both": "both",
|
|
||||||
"file": "file",
|
|
||||||
"console": "console",
|
|
||||||
}
|
|
||||||
logOutput, valid = opts[args["logoutput"]]
|
|
||||||
if !valid {
|
|
||||||
logOutput = "console"
|
|
||||||
fmt.Println(
|
|
||||||
"Log output type was not parsed correctly. Defaulting to console only",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logOutput = GetEnvDefault("LOG_OUTPUT", "console")
|
|
||||||
}
|
|
||||||
if logOutput != "both" && logOutput != "console" && logOutput != "file" {
|
|
||||||
logOutput = "console"
|
|
||||||
}
|
|
||||||
|
|
||||||
config := &Config{
|
|
||||||
Host: host,
|
|
||||||
Port: port,
|
|
||||||
TrustedHost: GetEnvDefault("TRUSTED_HOST", "127.0.0.1"),
|
|
||||||
SSL: GetEnvBool("SSL_MODE", false),
|
|
||||||
GZIP: GetEnvBool("GZIP", false),
|
|
||||||
ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
|
|
||||||
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
|
|
||||||
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
|
|
||||||
DBName: "00001",
|
|
||||||
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
|
|
||||||
SecretKey: os.Getenv("SECRET_KEY"),
|
|
||||||
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
|
|
||||||
RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day
|
|
||||||
TokenFreshTime: GetEnvInt64("TOKEN_FRESH_TIME", 5),
|
|
||||||
LogLevel: logLevel,
|
|
||||||
LogOutput: logOutput,
|
|
||||||
LogDir: GetEnvDefault("LOG_DIR", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.SecretKey == "" && args["dbver"] != "true" {
|
|
||||||
return nil, errors.New("Envar not set: SECRET_KEY")
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Get an environment variable, specifying a default value if its not set
|
|
||||||
func GetEnvDefault(key string, defaultValue string) string {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get an environment variable as a time.Duration, specifying a default value if its
|
|
||||||
// not set or can't be parsed properly
|
|
||||||
func GetEnvDur(key string, defaultValue time.Duration) time.Duration {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return time.Duration(defaultValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
intVal, err := strconv.Atoi(val)
|
|
||||||
if err != nil {
|
|
||||||
return time.Duration(defaultValue)
|
|
||||||
}
|
|
||||||
return time.Duration(intVal)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get an environment variable as an int, specifying a default value if its
|
|
||||||
// not set or can't be parsed properly into an int
|
|
||||||
func GetEnvInt(key string, defaultValue int) int {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
|
|
||||||
intVal, err := strconv.Atoi(val)
|
|
||||||
if err != nil {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return intVal
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get an environment variable as an int64, specifying a default value if its
|
|
||||||
// not set or can't be parsed properly into an int64
|
|
||||||
func GetEnvInt64(key string, defaultValue int64) int64 {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
|
|
||||||
intVal, err := strconv.ParseInt(val, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return intVal
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get an environment variable as a boolean, specifying a default value if its
|
|
||||||
// not set or can't be parsed properly into a bool
|
|
||||||
func GetEnvBool(key string, defaultValue bool) bool {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
truthy := map[string]bool{
|
|
||||||
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
|
|
||||||
"enable": true, "enabled": true, "active": true, "affirmative": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
falsy := map[string]bool{
|
|
||||||
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
|
|
||||||
"disable": false, "disabled": false, "inactive": false, "negative": false,
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized := strings.TrimSpace(strings.ToLower(val))
|
|
||||||
|
|
||||||
if val, ok := truthy[normalized]; ok {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
if val, ok := falsy[normalized]; ok {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package contexts
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Set the start time of the request
|
|
||||||
func SetStart(ctx context.Context, time time.Time) context.Context {
|
|
||||||
return context.WithValue(ctx, contextKeyRequestTime, time)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the start time of the request
|
|
||||||
func GetStartTime(ctx context.Context) (time.Time, error) {
|
|
||||||
start, ok := ctx.Value(contextKeyRequestTime).(time.Time)
|
|
||||||
if !ok {
|
|
||||||
return time.Time{}, errors.New("Failed to get start time of request")
|
|
||||||
}
|
|
||||||
return start, nil
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
package contexts
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"projectreshoot/db"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthenticatedUser struct {
|
|
||||||
*db.User
|
|
||||||
Fresh int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a new context with the user added in
|
|
||||||
func SetUser(ctx context.Context, u *AuthenticatedUser) context.Context {
|
|
||||||
return context.WithValue(ctx, contextKeyAuthorizedUser, u)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve a user from the given context. Returns nil if not set
|
|
||||||
func GetUser(ctx context.Context) *AuthenticatedUser {
|
|
||||||
user, ok := ctx.Value(contextKeyAuthorizedUser).(*AuthenticatedUser)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return user
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package cookies
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Tell the browser to delete the cookie matching the name provided
|
|
||||||
// Path must match the original set cookie for it to delete
|
|
||||||
func DeleteCookie(w http.ResponseWriter, name string, path string) {
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: "",
|
|
||||||
Path: path,
|
|
||||||
Expires: time.Unix(0, 0), // Expire in the past
|
|
||||||
MaxAge: -1, // Immediately expire
|
|
||||||
HttpOnly: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set a cookie with the given name, path and value. maxAge directly relates
|
|
||||||
// to cookie MaxAge (0 for no max age, >0 for TTL in seconds)
|
|
||||||
func SetCookie(
|
|
||||||
w http.ResponseWriter,
|
|
||||||
name string,
|
|
||||||
path string,
|
|
||||||
value string,
|
|
||||||
maxAge int,
|
|
||||||
) {
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: value,
|
|
||||||
Path: path,
|
|
||||||
HttpOnly: true,
|
|
||||||
MaxAge: maxAge,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
package cookies
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
|
|
||||||
func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
|
|
||||||
pageFromCookie, err := r.Cookie("pagefrom")
|
|
||||||
if err != nil {
|
|
||||||
return "/"
|
|
||||||
}
|
|
||||||
pageFrom := pageFromCookie.Value
|
|
||||||
DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
|
|
||||||
return pageFrom
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the referer of the request, and if it matches the trustedHost, set
|
|
||||||
// the "pagefrom" cookie as the Path of the referer
|
|
||||||
func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
|
|
||||||
referer := r.Referer()
|
|
||||||
parsedURL, err := url.Parse(referer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var pageFrom string
|
|
||||||
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
|
|
||||||
pageFrom = "/"
|
|
||||||
} else if parsedURL.Path == "/login" || parsedURL.Path == "/register" {
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
pageFrom = parsedURL.Path
|
|
||||||
}
|
|
||||||
SetCookie(w, "pagefrom", "/", pageFrom, 0)
|
|
||||||
}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
package cookies
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/jwt"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Get the value of the access and refresh tokens
|
|
||||||
func GetTokenStrings(
|
|
||||||
r *http.Request,
|
|
||||||
) (acc string, ref string) {
|
|
||||||
accCookie, accErr := r.Cookie("access")
|
|
||||||
refCookie, refErr := r.Cookie("refresh")
|
|
||||||
var (
|
|
||||||
accStr string = ""
|
|
||||||
refStr string = ""
|
|
||||||
)
|
|
||||||
if accErr == nil {
|
|
||||||
accStr = accCookie.Value
|
|
||||||
}
|
|
||||||
if refErr == nil {
|
|
||||||
refStr = refCookie.Value
|
|
||||||
}
|
|
||||||
return accStr, refStr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set a token with the provided details
|
|
||||||
func setToken(
|
|
||||||
w http.ResponseWriter,
|
|
||||||
config *config.Config,
|
|
||||||
token string,
|
|
||||||
scope string,
|
|
||||||
exp int64,
|
|
||||||
rememberme bool,
|
|
||||||
) {
|
|
||||||
tokenCookie := &http.Cookie{
|
|
||||||
Name: scope,
|
|
||||||
Value: token,
|
|
||||||
Path: "/",
|
|
||||||
HttpOnly: true,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
Secure: config.SSL,
|
|
||||||
}
|
|
||||||
if rememberme {
|
|
||||||
tokenCookie.Expires = time.Unix(exp, 0)
|
|
||||||
}
|
|
||||||
http.SetCookie(w, tokenCookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate new tokens for the user and set them as cookies
|
|
||||||
func SetTokenCookies(
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
config *config.Config,
|
|
||||||
user *db.User,
|
|
||||||
fresh bool,
|
|
||||||
rememberMe bool,
|
|
||||||
) error {
|
|
||||||
at, atexp, err := jwt.GenerateAccessToken(config, user, fresh, rememberMe)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.GenerateAccessToken")
|
|
||||||
}
|
|
||||||
rt, rtexp, err := jwt.GenerateRefreshToken(config, user, rememberMe)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.GenerateRefreshToken")
|
|
||||||
}
|
|
||||||
// Don't set the cookies until we know no errors occured
|
|
||||||
setToken(w, config, at, "access", atexp, rememberMe)
|
|
||||||
setToken(w, config, rt, "refresh", rtexp, rememberMe)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Returns a database connection handle for the DB
|
|
||||||
func ConnectToDatabase(
|
|
||||||
dbName string,
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
) (*SafeConn, error) {
|
|
||||||
file := fmt.Sprintf("file:%s.db", dbName)
|
|
||||||
db, err := sql.Open("sqlite", file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "sql.Open")
|
|
||||||
}
|
|
||||||
version, err := strconv.Atoi(dbName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "strconv.Atoi")
|
|
||||||
}
|
|
||||||
err = checkDBVersion(db, version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "checkDBVersion")
|
|
||||||
}
|
|
||||||
conn := MakeSafe(db, logger)
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the database version
|
|
||||||
func checkDBVersion(db *sql.DB, expectVer int) error {
|
|
||||||
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
|
|
||||||
ORDER BY version_id DESC LIMIT 1`
|
|
||||||
rows, err := db.Query(query)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "checkDBVersion")
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if rows.Next() {
|
|
||||||
var version int
|
|
||||||
err = rows.Scan(&version)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "rows.Scan")
|
|
||||||
}
|
|
||||||
if version != expectVer {
|
|
||||||
return errors.New("Version mismatch")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return errors.New("No version found")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
129
db/safeconn.go
129
db/safeconn.go
@@ -1,129 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SafeConn struct {
|
|
||||||
db *sql.DB
|
|
||||||
readLockCount uint32
|
|
||||||
globalLockStatus uint32
|
|
||||||
globalLockRequested uint32
|
|
||||||
logger *zerolog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the provided db handle safe and attach a logger to it
|
|
||||||
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
|
||||||
return &SafeConn{db: db, logger: logger}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempts to acquire a global lock on the database connection
|
|
||||||
func (conn *SafeConn) acquireGlobalLock() bool {
|
|
||||||
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
conn.globalLockStatus = 1
|
|
||||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
|
||||||
Msg("Global lock acquired")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Releases a global lock on the database connection
|
|
||||||
func (conn *SafeConn) releaseGlobalLock() {
|
|
||||||
conn.globalLockStatus = 0
|
|
||||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
|
||||||
Msg("Global lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Acquire a read lock on the connection. Multiple read locks can be acquired
|
|
||||||
// at the same time
|
|
||||||
func (conn *SafeConn) acquireReadLock() bool {
|
|
||||||
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
conn.readLockCount += 1
|
|
||||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
|
||||||
Msg("Read lock acquired")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release a read lock. Decrements read lock count by 1
|
|
||||||
func (conn *SafeConn) releaseReadLock() {
|
|
||||||
conn.readLockCount -= 1
|
|
||||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
|
||||||
Msg("Read lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Starts a new transaction based on the current context. Will cancel if
|
|
||||||
// the context is closed/cancelled/done
|
|
||||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
|
||||||
lockAcquired := make(chan struct{})
|
|
||||||
lockCtx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-lockCtx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
if conn.acquireReadLock() {
|
|
||||||
close(lockAcquired)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-lockAcquired:
|
|
||||||
tx, err := conn.db.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
conn.releaseReadLock()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &SafeTX{tx: tx, sc: conn}, nil
|
|
||||||
case <-ctx.Done():
|
|
||||||
cancel()
|
|
||||||
return nil, errors.New("Transaction time out due to database lock")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Acquire a global lock, preventing all transactions
|
|
||||||
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
|
||||||
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
|
||||||
conn.globalLockRequested = 1
|
|
||||||
defer func() { conn.globalLockRequested = 0 }()
|
|
||||||
timeout := time.After(timeoutAfter)
|
|
||||||
attempt := 0
|
|
||||||
for {
|
|
||||||
if conn.acquireGlobalLock() {
|
|
||||||
conn.logger.Info().Msg("Global database lock acquired")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
|
||||||
return
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
attempt++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release the global lock
|
|
||||||
func (conn *SafeConn) Resume() {
|
|
||||||
conn.releaseGlobalLock()
|
|
||||||
conn.logger.Info().Msg("Global database lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the database connection
|
|
||||||
func (conn *SafeConn) Close() error {
|
|
||||||
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
|
||||||
conn.acquireGlobalLock()
|
|
||||||
defer conn.releaseGlobalLock()
|
|
||||||
conn.logger.Debug().Msg("Closing database connection")
|
|
||||||
return conn.db.Close()
|
|
||||||
}
|
|
||||||
@@ -1,143 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"projectreshoot/tests"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSafeConn(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
var requested sync.WaitGroup
|
|
||||||
var engaged sync.WaitGroup
|
|
||||||
requested.Add(1)
|
|
||||||
engaged.Add(1)
|
|
||||||
go func() {
|
|
||||||
requested.Done()
|
|
||||||
sconn.Pause(5 * time.Second)
|
|
||||||
engaged.Done()
|
|
||||||
}()
|
|
||||||
requested.Wait()
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
|
||||||
tx.Commit()
|
|
||||||
engaged.Wait()
|
|
||||||
assert.Equal(t, uint32(1), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
|
||||||
sconn.Resume()
|
|
||||||
})
|
|
||||||
t.Run("Lock abandons after timeout", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn.Pause(250 * time.Millisecond)
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
|
||||||
tx.Commit()
|
|
||||||
})
|
|
||||||
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
var requested sync.WaitGroup
|
|
||||||
var engaged sync.WaitGroup
|
|
||||||
requested.Add(1)
|
|
||||||
engaged.Add(1)
|
|
||||||
go func() {
|
|
||||||
requested.Done()
|
|
||||||
sconn.Pause(5 * time.Second)
|
|
||||||
engaged.Done()
|
|
||||||
}()
|
|
||||||
requested.Wait()
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
_, err = sconn.Begin(ctx)
|
|
||||||
require.Error(t, err)
|
|
||||||
tx.Commit()
|
|
||||||
engaged.Wait()
|
|
||||||
_, err = sconn.Begin(ctx)
|
|
||||||
require.Error(t, err)
|
|
||||||
sconn.Resume()
|
|
||||||
tx, err = sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx.Commit()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
func TestSafeTX(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
t.Run("Commit releases lock", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
|
||||||
tx.Commit()
|
|
||||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
|
||||||
})
|
|
||||||
t.Run("Rollback releases lock", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
|
||||||
tx.Rollback()
|
|
||||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
|
||||||
})
|
|
||||||
t.Run("Multiple TX can gain read lock", func(t *testing.T) {
|
|
||||||
tx1, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx2, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx3, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx1.Commit()
|
|
||||||
tx2.Commit()
|
|
||||||
tx3.Commit()
|
|
||||||
})
|
|
||||||
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
sconn.acquireGlobalLock()
|
|
||||||
defer sconn.releaseGlobalLock()
|
|
||||||
_, err := sconn.Begin(ctx)
|
|
||||||
require.Error(t, err)
|
|
||||||
})
|
|
||||||
t.Run("Lock acquires if lock released", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
sconn.acquireGlobalLock()
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
tx, err := sconn.Begin(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx.Commit()
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
sconn.releaseGlobalLock()
|
|
||||||
wg.Wait()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
61
db/safetx.go
61
db/safetx.go
@@ -1,61 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Extends sql.Tx for use with SafeConn
|
|
||||||
type SafeTX struct {
|
|
||||||
tx *sql.Tx
|
|
||||||
sc *SafeConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query the database inside the transaction
|
|
||||||
func (stx *SafeTX) Query(
|
|
||||||
ctx context.Context,
|
|
||||||
query string,
|
|
||||||
args ...interface{},
|
|
||||||
) (*sql.Rows, error) {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return nil, errors.New("Cannot query without a transaction")
|
|
||||||
}
|
|
||||||
return stx.tx.QueryContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec a statement on the database inside the transaction
|
|
||||||
func (stx *SafeTX) Exec(
|
|
||||||
ctx context.Context,
|
|
||||||
query string,
|
|
||||||
args ...interface{},
|
|
||||||
) (sql.Result, error) {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return nil, errors.New("Cannot exec without a transaction")
|
|
||||||
}
|
|
||||||
return stx.tx.ExecContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commit the current transaction and release the read lock
|
|
||||||
func (stx *SafeTX) Commit() error {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return errors.New("Cannot commit without a transaction")
|
|
||||||
}
|
|
||||||
err := stx.tx.Commit()
|
|
||||||
stx.tx = nil
|
|
||||||
|
|
||||||
stx.sc.releaseReadLock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Abort the current transaction, releasing the read lock
|
|
||||||
func (stx *SafeTX) Rollback() error {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return errors.New("Cannot rollback without a transaction")
|
|
||||||
}
|
|
||||||
err := stx.tx.Rollback()
|
|
||||||
stx.tx = nil
|
|
||||||
stx.sc.releaseReadLock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
60
db/user.go
60
db/user.go
@@ -1,60 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int // Integer ID (index primary key)
|
|
||||||
Username string // Username (unique)
|
|
||||||
Password_hash string // Bcrypt password hash
|
|
||||||
Created_at int64 // Epoch timestamp when the user was added to the database
|
|
||||||
Bio string // Short byline set by the user
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uses bcrypt to set the users Password_hash from the given password
|
|
||||||
func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error {
|
|
||||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
|
||||||
}
|
|
||||||
user.Password_hash = string(hashedPassword)
|
|
||||||
query := `UPDATE users SET password_hash = ? WHERE id = ?`
|
|
||||||
_, err = tx.Exec(ctx, query, user.Password_hash, user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tx.Exec")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uses bcrypt to check if the given password matches the users Password_hash
|
|
||||||
func (user *User) CheckPassword(password string) error {
|
|
||||||
err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password))
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "bcrypt.CompareHashAndPassword")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change the user's username
|
|
||||||
func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error {
|
|
||||||
query := `UPDATE users SET username = ? WHERE id = ?`
|
|
||||||
_, err := tx.Exec(ctx, query, newUsername, user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tx.Exec")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change the user's bio
|
|
||||||
func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error {
|
|
||||||
query := `UPDATE users SET bio = ? WHERE id = ?`
|
|
||||||
_, err := tx.Exec(ctx, query, newBio, user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tx.Exec")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -64,7 +64,7 @@ failed_cleanup() {
|
|||||||
trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT
|
trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT
|
||||||
|
|
||||||
echo "Migration in progress from $CUR_VER to $TGT_VER"
|
echo "Migration in progress from $CUR_VER to $TGT_VER"
|
||||||
${MIGRATION_BIN}/prmigrate-${ENVR}-${COMMIT_HASH} $UPDATED_BACKUP $CMD $TGT_VER
|
${MIGRATION_BIN}/migrate-${ENVR}-${COMMIT_HASH} $UPDATED_BACKUP $CMD $TGT_VER
|
||||||
if [ $? -ne 0 ]; then
|
if [ $? -ne 0 ]; then
|
||||||
echo "Migration failed"
|
echo "Migration failed"
|
||||||
exit 1
|
exit 1
|
||||||
|
|||||||
43
go.mod
43
go.mod
@@ -1,36 +1,53 @@
|
|||||||
module projectreshoot
|
module projectreshoot
|
||||||
|
|
||||||
go 1.24.0
|
go 1.25.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/a-h/templ v0.3.833
|
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
git.haelnorr.com/h/golib/env v0.9.1
|
||||||
github.com/google/uuid v1.6.0
|
git.haelnorr.com/h/golib/hlog v0.9.1
|
||||||
|
git.haelnorr.com/h/golib/hws v0.2.0
|
||||||
|
git.haelnorr.com/h/golib/hwsauth v0.3.1
|
||||||
|
git.haelnorr.com/h/golib/tmdb v0.8.0
|
||||||
|
github.com/a-h/templ v0.3.977
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/pressly/goose/v3 v3.24.1
|
github.com/pressly/goose/v3 v3.24.1
|
||||||
github.com/rs/zerolog v1.33.0
|
github.com/uptrace/bun v1.2.16
|
||||||
github.com/stretchr/testify v1.10.0
|
golang.org/x/crypto v0.45.0
|
||||||
golang.org/x/crypto v0.33.0
|
|
||||||
modernc.org/sqlite v1.35.0
|
modernc.org/sqlite v1.35.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
git.haelnorr.com/h/golib/jwt v0.10.0 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/rs/zerolog v1.34.0 // indirect
|
||||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||||
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16
|
||||||
|
github.com/uptrace/bun/driver/pgdriver v1.2.16
|
||||||
|
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||||
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
|
go.opentelemetry.io/otel v1.38.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
|
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
|
||||||
golang.org/x/sync v0.11.0 // indirect
|
golang.org/x/sync v0.16.0 // indirect
|
||||||
golang.org/x/sys v0.30.0 // indirect
|
golang.org/x/sys v0.40.0 // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
k8s.io/apimachinery v0.35.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
k8s.io/klog/v2 v2.130.1 // indirect
|
||||||
|
k8s.io/utils v0.0.0-20260108192941-914a6e750570 // indirect
|
||||||
|
mellium.im/sasl v0.3.2 // indirect
|
||||||
modernc.org/libc v1.61.13 // indirect
|
modernc.org/libc v1.61.13 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.8.2 // indirect
|
modernc.org/memory v1.8.2 // indirect
|
||||||
|
|||||||
92
go.sum
92
go.sum
@@ -1,26 +1,41 @@
|
|||||||
github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU=
|
git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDOV/AuWs=
|
||||||
github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk=
|
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
|
||||||
|
git.haelnorr.com/h/golib/env v0.9.1 h1:2Vsj+mJKnO5f1Md1GO5v9ggLN5zWa0baCewcSHTjoNY=
|
||||||
|
git.haelnorr.com/h/golib/env v0.9.1/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
||||||
|
git.haelnorr.com/h/golib/hlog v0.9.1 h1:9VmE/IQTfD8LAEyTbUCZLy/+8PbcHA1Kob/WQHRHKzc=
|
||||||
|
git.haelnorr.com/h/golib/hlog v0.9.1/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk=
|
||||||
|
git.haelnorr.com/h/golib/hws v0.2.0 h1:MR2Tu2qPaW+/oK8aXFJLRFaYZIHgKiex3t3zE41cu1U=
|
||||||
|
git.haelnorr.com/h/golib/hws v0.2.0/go.mod h1:6ZlRKnt8YMpv5XcMXmyBGmD1/euvBo3d1azEvHJjOLo=
|
||||||
|
git.haelnorr.com/h/golib/hwsauth v0.3.1 h1:+vVkVj/5DTPXSp7em2DqF3QuovhHKSCRTRFbwRQ7g8E=
|
||||||
|
git.haelnorr.com/h/golib/hwsauth v0.3.1/go.mod h1:WHHMy1EVQWrHtyJx+gQQkB+5otJ4E6ZyEtKBjqZqKhQ=
|
||||||
|
git.haelnorr.com/h/golib/jwt v0.10.0 h1:8cI8mSnb8X+EmJtrBO/5UZwuBMtib0IE9dv85gkm94E=
|
||||||
|
git.haelnorr.com/h/golib/jwt v0.10.0/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||||
|
git.haelnorr.com/h/golib/tmdb v0.8.0 h1:OQ6M2TB8FHm8fJD7/ebfWm63Duzfp0kmFX9genEig34=
|
||||||
|
git.haelnorr.com/h/golib/tmdb v0.8.0/go.mod h1:mGKYa3o3z0IsQ5EO3MPmnL2Bwl2sSMsUHXVgaIGR7Z0=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
|
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=
|
||||||
|
github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo=
|
||||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
||||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
|
||||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
|
||||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|
||||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
|
||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|
||||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||||
@@ -38,37 +53,60 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4LY=
|
github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4LY=
|
||||||
github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug=
|
github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug=
|
||||||
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||||
|
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
||||||
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||||
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||||
|
github.com/uptrace/bun v1.2.16 h1:QlObi6ZIK5Ao7kAALnh91HWYNZUBbVwye52fmlQM9kc=
|
||||||
|
github.com/uptrace/bun v1.2.16/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16 h1:KFNZ0LxAyczKNfK/IJWMyaleO6eI9/Z5tUv3DE1NVL4=
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16/go.mod h1:IJdMeV4sLfh0LDUZl7TIxLI0LipF1vwTK3hBC7p5qLo=
|
||||||
|
github.com/uptrace/bun/driver/pgdriver v1.2.16 h1:b1kpXKUxtTSGYow5Vlsb+dKV3z0R7aSAJNfMfKp61ZU=
|
||||||
|
github.com/uptrace/bun/driver/pgdriver v1.2.16/go.mod h1:H6lUZ9CBfp1X5Vq62YGSV7q96/v94ja9AYFjKvdoTk0=
|
||||||
|
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||||
|
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||||
|
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||||
|
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||||
|
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||||
|
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||||
|
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||||
|
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
|
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
|
||||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
||||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8=
|
||||||
|
k8s.io/apimachinery v0.35.0/go.mod h1:jQCgFZFR1F4Ik7hvr2g84RTJSZegBc8yHgFWKn//hns=
|
||||||
|
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
|
||||||
|
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
|
||||||
|
k8s.io/utils v0.0.0-20260108192941-914a6e750570 h1:JT4W8lsdrGENg9W+YwwdLJxklIuKWdRm+BC+xt33FOY=
|
||||||
|
k8s.io/utils v0.0.0-20260108192941-914a6e750570/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
|
||||||
|
mellium.im/sasl v0.3.2 h1:PT6Xp7ccn9XaXAnJ03FcEjmAn7kK1x7aoXV6F+Vmrl0=
|
||||||
|
mellium.im/sasl v0.3.2/go.mod h1:NKXDi1zkr+BlMHLQjY3ofYuU4KSPFxknb8mfEu6SveY=
|
||||||
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
||||||
modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||||
modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo=
|
modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo=
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/cookies"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/jwt"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
func revokeAccess(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
atStr string,
|
|
||||||
) error {
|
|
||||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "Token is expired") ||
|
|
||||||
strings.Contains(err.Error(), "Token has been revoked") {
|
|
||||||
return nil // Token is expired, dont need to revoke it
|
|
||||||
}
|
|
||||||
return errors.Wrap(err, "jwt.ParseAccessToken")
|
|
||||||
}
|
|
||||||
err = jwt.RevokeToken(ctx, tx, aT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func revokeRefresh(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
rtStr string,
|
|
||||||
) error {
|
|
||||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "Token is expired") ||
|
|
||||||
strings.Contains(err.Error(), "Token has been revoked") {
|
|
||||||
return nil // Token is expired, dont need to revoke it
|
|
||||||
}
|
|
||||||
return errors.Wrap(err, "jwt.ParseRefreshToken")
|
|
||||||
}
|
|
||||||
err = jwt.RevokeToken(ctx, tx, rT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve and revoke the user's tokens
|
|
||||||
func revokeTokens(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
r *http.Request,
|
|
||||||
) error {
|
|
||||||
// get the tokens from the cookies
|
|
||||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
|
||||||
// revoke the refresh token first as the access token expires quicker
|
|
||||||
// only matters if there is an error revoking the tokens
|
|
||||||
err := revokeRefresh(config, ctx, tx, rtStr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "revokeRefresh")
|
|
||||||
}
|
|
||||||
err = revokeAccess(config, ctx, tx, atStr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "revokeAccess")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle a logout request
|
|
||||||
func Logout(
|
|
||||||
config *config.Config,
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
conn *db.SafeConn,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Start the transaction
|
|
||||||
tx, err := conn.Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn().Err(err).Msg("Error occured on user logout")
|
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = revokeTokens(config, ctx, tx, r)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
logger.Error().Err(err).Msg("Error occured on user logout")
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
cookies.DeleteCookie(w, "access", "/")
|
|
||||||
cookies.DeleteCookie(w, "refresh", "/")
|
|
||||||
w.Header().Set("HX-Redirect", "/login")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/cookies"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/jwt"
|
|
||||||
"projectreshoot/view/component/form"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Get the tokens from the request
|
|
||||||
func getTokens(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
r *http.Request,
|
|
||||||
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
|
||||||
// get the existing tokens from the cookies
|
|
||||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
|
||||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken")
|
|
||||||
}
|
|
||||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
|
||||||
}
|
|
||||||
return aT, rT, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Revoke the given token pair
|
|
||||||
func revokeTokenPair(
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
aT *jwt.AccessToken,
|
|
||||||
rT *jwt.RefreshToken,
|
|
||||||
) error {
|
|
||||||
err := jwt.RevokeToken(ctx, tx, aT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
err = jwt.RevokeToken(ctx, tx, rT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Issue new tokens for the user, invalidating the old ones
|
|
||||||
func refreshTokens(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
) error {
|
|
||||||
aT, rT, err := getTokens(config, ctx, tx, r)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "getTokens")
|
|
||||||
}
|
|
||||||
rememberMe := map[string]bool{
|
|
||||||
"session": false,
|
|
||||||
"exp": true,
|
|
||||||
}[aT.TTL]
|
|
||||||
// issue new tokens for the user
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
err = cookies.SetTokenCookies(w, r, config, user.User, true, rememberMe)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "cookies.SetTokenCookies")
|
|
||||||
}
|
|
||||||
err = revokeTokenPair(ctx, tx, aT, rT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "revokeTokenPair")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate the provided password
|
|
||||||
func validatePassword(
|
|
||||||
r *http.Request,
|
|
||||||
) error {
|
|
||||||
r.ParseForm()
|
|
||||||
password := r.FormValue("password")
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
err := user.CheckPassword(password)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "user.CheckPassword")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle request to reauthenticate (i.e. make token fresh again)
|
|
||||||
func Reauthenticate(
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
config *config.Config,
|
|
||||||
conn *db.SafeConn,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Start the transaction
|
|
||||||
tx, err := conn.Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn().Err(err).Msg("Failed to refresh user tokens")
|
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = validatePassword(r)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
w.WriteHeader(445)
|
|
||||||
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = refreshTokens(config, ctx, tx, w, r)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
logger.Error().Err(err).Msg("Failed to refresh user tokens")
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Wrapper for default FileSystem
|
|
||||||
type justFilesFilesystem struct {
|
|
||||||
fs http.FileSystem
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrapper for default File
|
|
||||||
type neuteredReaddirFile struct {
|
|
||||||
http.File
|
|
||||||
}
|
|
||||||
|
|
||||||
// Modifies the behavior of FileSystem.Open to return the neutered version of File
|
|
||||||
func (fs justFilesFilesystem) Open(name string) (http.File, error) {
|
|
||||||
f, err := fs.fs.Open(name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the requested path is a directory
|
|
||||||
// and explicitly return an error to trigger a 404
|
|
||||||
fileInfo, err := f.Stat()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if fileInfo.IsDir() {
|
|
||||||
return nil, os.ErrNotExist
|
|
||||||
}
|
|
||||||
|
|
||||||
return neuteredReaddirFile{f}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Overrides the Readdir method of File to always return nil
|
|
||||||
func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handles requests for static files, without allowing access to the
|
|
||||||
// directory viewer and returning 404 if an exact file is not found
|
|
||||||
func StaticFS(staticFS *http.FileSystem) http.Handler {
|
|
||||||
return http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
nfs := justFilesFilesystem{*staticFS}
|
|
||||||
fs := http.FileServer(nfs)
|
|
||||||
fs.ServeHTTP(w, r)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
func removeme(
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
conn *db.SafeConn,
|
|
||||||
handler func(
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
),
|
|
||||||
onfail func(err error),
|
|
||||||
) {
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Start the transaction
|
|
||||||
tx, err := conn.Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
onfail(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
handler(ctx, tx, w, r)
|
|
||||||
}
|
|
||||||
37
internal/config/auth.go
Normal file
37
internal/config/auth.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.haelnorr.com/h/golib/env"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HWSAUTHConfig struct {
|
||||||
|
SSL bool // ENV HWSAUTH_SSL: Flag for SSL Mode (default: false)
|
||||||
|
TrustedHost string // ENV HWSAUTH_TRUSTED_HOST: Full server address to accept as trusted SSL host (required if SSL is true)
|
||||||
|
SecretKey string // ENV HWSAUTH_SECRET_KEY: Secret key for signing tokens (required)
|
||||||
|
AccessTokenExpiry int64 // ENV HWSAUTH_ACCESS_TOKEN_EXPIRY: Access token expiry in minutes (default: 5)
|
||||||
|
RefreshTokenExpiry int64 // ENV HWSAUTH_REFRESH_TOKEN_EXPIRY: Refresh token expiry in minutes (default: 1440)
|
||||||
|
TokenFreshTime int64 // ENV HWSAUTH_TOKEN_FRESH_TIME: Time for tokens to stay fresh in minutes (default: 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupHWSAuth() (*HWSAUTHConfig, error) {
|
||||||
|
ssl := env.Bool("HWSAUTH_SSL", false)
|
||||||
|
trustedHost := env.String("HWS_TRUSTED_HOST", "")
|
||||||
|
if ssl && trustedHost == "" {
|
||||||
|
return nil, errors.New("SSL is enabled and no HWS_TRUSTED_HOST set")
|
||||||
|
}
|
||||||
|
cfg := &HWSAUTHConfig{
|
||||||
|
SSL: ssl,
|
||||||
|
TrustedHost: trustedHost,
|
||||||
|
SecretKey: env.String("HWSAUTH_SECRET_KEY", ""),
|
||||||
|
AccessTokenExpiry: env.Int64("HWSAUTH_ACCESS_TOKEN_EXPIRY", 5),
|
||||||
|
RefreshTokenExpiry: env.Int64("HWSAUTH_REFRESH_TOKEN_EXPIRY", 1440),
|
||||||
|
TokenFreshTime: env.Int64("HWSAUTH_TOKEN_FRESH_TIME", 5),
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.SecretKey == "" {
|
||||||
|
return nil, errors.New("Envar not set: HWSAUTH_SECRET_KEY")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
56
internal/config/config.go
Normal file
56
internal/config/config.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
"github.com/joho/godotenv"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
DB *DBConfig
|
||||||
|
HWS *hws.Config
|
||||||
|
HWSAuth *hwsauth.Config
|
||||||
|
TMDB *TMDBConfig
|
||||||
|
HLOG *HLOGConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the application configuration and get a pointer to the Config object
|
||||||
|
func GetConfig(envfile string) (*Config, error) {
|
||||||
|
godotenv.Load(envfile)
|
||||||
|
|
||||||
|
db, err := setupDB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "setupDB")
|
||||||
|
}
|
||||||
|
|
||||||
|
hws, err := hws.ConfigFromEnv()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hws.ConfigFromEnv")
|
||||||
|
}
|
||||||
|
|
||||||
|
hwsAuth, err := hwsauth.ConfigFromEnv()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hwsauth.ConfigFromEnv")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmdb, err := setupTMDB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "setupTMDB")
|
||||||
|
}
|
||||||
|
|
||||||
|
hlog, err := setupHLOG()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "setupHLOG")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
DB: db,
|
||||||
|
HWS: hws,
|
||||||
|
HWSAuth: hwsAuth,
|
||||||
|
TMDB: tmdb,
|
||||||
|
HLOG: hlog,
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
55
internal/config/db.go
Normal file
55
internal/config/db.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.haelnorr.com/h/golib/env"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DBConfig struct {
|
||||||
|
User string // ENV DB_USER: Database user for authentication (required)
|
||||||
|
Password string // ENV DB_PASSWORD: Database password for authentication (required)
|
||||||
|
Host string // ENV DB_HOST: Database host address (required)
|
||||||
|
Port uint16 // ENV DB_PORT: Database port (default: 5432)
|
||||||
|
DB string // ENV DB_NAME: Database name to connect to (required)
|
||||||
|
SSL string // ENV DB_SSL: SSL mode for connection (default: disable)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDB() (*DBConfig, error) {
|
||||||
|
cfg := &DBConfig{
|
||||||
|
User: env.String("DB_USER", ""),
|
||||||
|
Password: env.String("DB_PASSWORD", ""),
|
||||||
|
Host: env.String("DB_HOST", ""),
|
||||||
|
Port: env.UInt16("DB_PORT", 5432),
|
||||||
|
DB: env.String("DB_NAME", ""),
|
||||||
|
SSL: env.String("DB_SSL", "disable"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate SSL mode
|
||||||
|
validSSLModes := map[string]bool{
|
||||||
|
"disable": true,
|
||||||
|
"require": true,
|
||||||
|
"verify-ca": true,
|
||||||
|
"verify-full": true,
|
||||||
|
"allow": true,
|
||||||
|
"prefer": true,
|
||||||
|
}
|
||||||
|
if !validSSLModes[cfg.SSL] {
|
||||||
|
return nil, errors.Errorf("Invalid DB_SSL value: %s. Must be one of: disable, allow, prefer, require, verify-ca, verify-full", cfg.SSL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check required fields
|
||||||
|
if cfg.User == "" {
|
||||||
|
return nil, errors.New("Envar not set: DB_USER")
|
||||||
|
}
|
||||||
|
if cfg.Password == "" {
|
||||||
|
return nil, errors.New("Envar not set: DB_PASSWORD")
|
||||||
|
}
|
||||||
|
if cfg.Host == "" {
|
||||||
|
return nil, errors.New("Envar not set: DB_HOST")
|
||||||
|
}
|
||||||
|
if cfg.DB == "" {
|
||||||
|
return nil, errors.New("Envar not set: DB_NAME")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
114
internal/config/envdoc.go
Normal file
114
internal/config/envdoc.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnvVar represents an environment variable with its documentation
|
||||||
|
type EnvVar struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Default string
|
||||||
|
HasDefault bool
|
||||||
|
Required bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractEnvVars parses a struct's field comments to extract environment variable documentation
|
||||||
|
func extractEnvVars(structType reflect.Type, fieldIndex int) *EnvVar {
|
||||||
|
field := structType.Field(fieldIndex)
|
||||||
|
tag := field.Tag.Get("comment")
|
||||||
|
if tag == "" {
|
||||||
|
// Try to get the comment from the struct field's tag or use reflection
|
||||||
|
// For now, we'll parse it manually from the comment string
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
comment := tag
|
||||||
|
if !strings.HasPrefix(comment, "ENV ") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove "ENV " prefix
|
||||||
|
comment = strings.TrimPrefix(comment, "ENV ")
|
||||||
|
|
||||||
|
// Extract name and description
|
||||||
|
parts := strings.SplitN(comment, ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.TrimSpace(parts[0])
|
||||||
|
desc := strings.TrimSpace(parts[1])
|
||||||
|
|
||||||
|
// Check for default value in description
|
||||||
|
defaultRegex := regexp.MustCompile(`\(default:\s*([^)]+)\)`)
|
||||||
|
matches := defaultRegex.FindStringSubmatch(desc)
|
||||||
|
|
||||||
|
envVar := &EnvVar{
|
||||||
|
Name: name,
|
||||||
|
Description: desc,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matches) > 1 {
|
||||||
|
envVar.Default = matches[1]
|
||||||
|
envVar.HasDefault = true
|
||||||
|
// Remove the default notation from description
|
||||||
|
envVar.Description = strings.TrimSpace(defaultRegex.ReplaceAllString(desc, ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
return envVar
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllEnvVars returns a list of all environment variables used in the config
|
||||||
|
func GetAllEnvVars() []EnvVar {
|
||||||
|
var envVars []EnvVar
|
||||||
|
|
||||||
|
// Manually define all env vars based on the config structs
|
||||||
|
// This is more reliable than reflection for extracting comments
|
||||||
|
|
||||||
|
// DBConfig
|
||||||
|
envVars = append(envVars, []EnvVar{
|
||||||
|
{Name: "DB_USER", Description: "Database user for authentication", HasDefault: false, Required: true},
|
||||||
|
{Name: "DB_PASSWORD", Description: "Database password for authentication", HasDefault: false, Required: true},
|
||||||
|
{Name: "DB_HOST", Description: "Database host address", HasDefault: false, Required: true},
|
||||||
|
{Name: "DB_PORT", Description: "Database port", Default: "5432", HasDefault: true, Required: false},
|
||||||
|
{Name: "DB_NAME", Description: "Database name to connect to", HasDefault: false, Required: true},
|
||||||
|
{Name: "DB_SSL", Description: "SSL mode for connection", Default: "disable", HasDefault: true, Required: false},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
// HWSConfig
|
||||||
|
envVars = append(envVars, []EnvVar{
|
||||||
|
{Name: "HWS_HOST", Description: "Host to listen on", Default: "127.0.0.1", HasDefault: true, Required: false},
|
||||||
|
{Name: "HWS_PORT", Description: "Port to listen on", Default: "3000", HasDefault: true, Required: false},
|
||||||
|
{Name: "HWS_TRUSTED_HOST", Description: "Domain/Hostname to accept as trusted", Default: "same as Host", HasDefault: true, Required: false},
|
||||||
|
{Name: "HWS_SSL", Description: "Flag for SSL Mode", Default: "false", HasDefault: true, Required: false},
|
||||||
|
{Name: "HWS_GZIP", Description: "Flag for GZIP compression on requests", Default: "false", HasDefault: true, Required: false},
|
||||||
|
{Name: "HWS_READ_HEADER_TIMEOUT", Description: "Timeout for reading request headers in seconds", Default: "2", HasDefault: true, Required: false},
|
||||||
|
{Name: "HWS_WRITE_TIMEOUT", Description: "Timeout for writing requests in seconds", Default: "10", HasDefault: true, Required: false},
|
||||||
|
{Name: "HWS_IDLE_TIMEOUT", Description: "Timeout for idle connections in seconds", Default: "120", HasDefault: true, Required: false},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
// HWSAUTHConfig
|
||||||
|
envVars = append(envVars, []EnvVar{
|
||||||
|
{Name: "HWSAUTH_SECRET_KEY", Description: "Secret key for signing tokens", HasDefault: false, Required: true},
|
||||||
|
{Name: "HWSAUTH_ACCESS_TOKEN_EXPIRY", Description: "Access token expiry in minutes", Default: "5", HasDefault: true, Required: false},
|
||||||
|
{Name: "HWSAUTH_REFRESH_TOKEN_EXPIRY", Description: "Refresh token expiry in minutes", Default: "1440", HasDefault: true, Required: false},
|
||||||
|
{Name: "HWSAUTH_TOKEN_FRESH_TIME", Description: "Time for tokens to stay fresh in minutes", Default: "5", HasDefault: true, Required: false},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
// TMDBConfig
|
||||||
|
envVars = append(envVars, []EnvVar{
|
||||||
|
{Name: "TMDB_TOKEN", Description: "API token for TMDB", HasDefault: false, Required: true},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
// HLOGConfig
|
||||||
|
envVars = append(envVars, []EnvVar{
|
||||||
|
{Name: "LOG_LEVEL", Description: "Log level for global logging", Default: "info", HasDefault: true, Required: false},
|
||||||
|
{Name: "LOG_OUTPUT", Description: "Output method for the logger (file, console, or both)", Default: "console", HasDefault: true, Required: false},
|
||||||
|
{Name: "LOG_DIR", Description: "Path to create log files", HasDefault: false, Required: false},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
return envVars
|
||||||
|
}
|
||||||
95
internal/config/envgen.go
Normal file
95
internal/config/envgen.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateDotEnv creates a new .env file with all environment variables and their defaults
|
||||||
|
func GenerateDotEnv(filename string) error {
|
||||||
|
envVars := GetAllEnvVars()
|
||||||
|
|
||||||
|
file, err := os.Create(filename)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create file: %w", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Write header
|
||||||
|
fmt.Fprintln(file, "# Environment Configuration")
|
||||||
|
fmt.Fprintln(file, "# Generated by Project Reshoot")
|
||||||
|
fmt.Fprintln(file, "#")
|
||||||
|
fmt.Fprintln(file, "# Variables marked as (required) must be set")
|
||||||
|
fmt.Fprintln(file, "# Variables with defaults can be left commented out to use the default value")
|
||||||
|
fmt.Fprintln(file)
|
||||||
|
|
||||||
|
// Group by prefix
|
||||||
|
groups := map[string][]EnvVar{
|
||||||
|
"DB_": {},
|
||||||
|
"HWS_": {},
|
||||||
|
"HWSAUTH_": {},
|
||||||
|
"TMDB_": {},
|
||||||
|
"LOG_": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range envVars {
|
||||||
|
assigned := false
|
||||||
|
for prefix := range groups {
|
||||||
|
if strings.HasPrefix(ev.Name, prefix) {
|
||||||
|
groups[prefix] = append(groups[prefix], ev)
|
||||||
|
assigned = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !assigned {
|
||||||
|
// Handle ungrouped vars
|
||||||
|
if _, ok := groups["OTHER"]; !ok {
|
||||||
|
groups["OTHER"] = []EnvVar{}
|
||||||
|
}
|
||||||
|
groups["OTHER"] = append(groups["OTHER"], ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print each group
|
||||||
|
groupOrder := []string{"DB_", "HWS_", "HWSAUTH_", "TMDB_", "LOG_", "OTHER"}
|
||||||
|
groupTitles := map[string]string{
|
||||||
|
"DB_": "Database Configuration",
|
||||||
|
"HWS_": "HTTP Web Server Configuration",
|
||||||
|
"HWSAUTH_": "Authentication Configuration",
|
||||||
|
"TMDB_": "TMDB API Configuration",
|
||||||
|
"LOG_": "Logging Configuration",
|
||||||
|
"OTHER": "Other Configuration",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range groupOrder {
|
||||||
|
vars, ok := groups[prefix]
|
||||||
|
if !ok || len(vars) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(file, "# %s\n", groupTitles[prefix])
|
||||||
|
fmt.Fprintln(file, strings.Repeat("#", len(groupTitles[prefix])+2))
|
||||||
|
|
||||||
|
for _, ev := range vars {
|
||||||
|
// Write description as comment
|
||||||
|
if ev.Required {
|
||||||
|
fmt.Fprintf(file, "# %s (required)\n", ev.Description)
|
||||||
|
// Leave required variables uncommented but empty
|
||||||
|
fmt.Fprintf(file, "%s=\n", ev.Name)
|
||||||
|
} else if ev.HasDefault {
|
||||||
|
fmt.Fprintf(file, "# %s\n", ev.Description)
|
||||||
|
// Comment out variables with defaults
|
||||||
|
fmt.Fprintf(file, "# %s=%s\n", ev.Name, ev.Default)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(file, "# %s\n", ev.Description)
|
||||||
|
// Optional variables without defaults are commented out
|
||||||
|
fmt.Fprintf(file, "# %s=\n", ev.Name)
|
||||||
|
}
|
||||||
|
fmt.Fprintln(file)
|
||||||
|
}
|
||||||
|
fmt.Fprintln(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
87
internal/config/envprint.go
Normal file
87
internal/config/envprint.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrintEnvVars writes all environment variables and their documentation to the provided writer
|
||||||
|
func PrintEnvVars(w io.Writer) error {
|
||||||
|
envVars := GetAllEnvVars()
|
||||||
|
|
||||||
|
// Find the longest name for alignment
|
||||||
|
maxNameLen := 0
|
||||||
|
for _, ev := range envVars {
|
||||||
|
if len(ev.Name) > maxNameLen {
|
||||||
|
maxNameLen = len(ev.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print header
|
||||||
|
fmt.Fprintln(w, "Environment Variables")
|
||||||
|
fmt.Fprintln(w, strings.Repeat("=", 80))
|
||||||
|
fmt.Fprintln(w)
|
||||||
|
|
||||||
|
// Group by prefix
|
||||||
|
groups := map[string][]EnvVar{
|
||||||
|
"DB_": {},
|
||||||
|
"HWS_": {},
|
||||||
|
"HWSAUTH_": {},
|
||||||
|
"TMDB_": {},
|
||||||
|
"LOG_": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range envVars {
|
||||||
|
assigned := false
|
||||||
|
for prefix := range groups {
|
||||||
|
if strings.HasPrefix(ev.Name, prefix) {
|
||||||
|
groups[prefix] = append(groups[prefix], ev)
|
||||||
|
assigned = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !assigned {
|
||||||
|
// Handle ungrouped vars
|
||||||
|
if _, ok := groups["OTHER"]; !ok {
|
||||||
|
groups["OTHER"] = []EnvVar{}
|
||||||
|
}
|
||||||
|
groups["OTHER"] = append(groups["OTHER"], ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print each group
|
||||||
|
groupOrder := []string{"DB_", "HWS_", "HWSAUTH_", "TMDB_", "LOG_", "OTHER"}
|
||||||
|
groupTitles := map[string]string{
|
||||||
|
"DB_": "Database Configuration",
|
||||||
|
"HWS_": "HTTP Web Server Configuration",
|
||||||
|
"HWSAUTH_": "Authentication Configuration",
|
||||||
|
"TMDB_": "TMDB API Configuration",
|
||||||
|
"LOG_": "Logging Configuration",
|
||||||
|
"OTHER": "Other Configuration",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range groupOrder {
|
||||||
|
vars, ok := groups[prefix]
|
||||||
|
if !ok || len(vars) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "%s\n", groupTitles[prefix])
|
||||||
|
fmt.Fprintln(w, strings.Repeat("-", len(groupTitles[prefix])))
|
||||||
|
|
||||||
|
for _, ev := range vars {
|
||||||
|
padding := strings.Repeat(" ", maxNameLen-len(ev.Name))
|
||||||
|
if ev.Required {
|
||||||
|
fmt.Fprintf(w, " %s%s : %s (required)\n", ev.Name, padding, ev.Description)
|
||||||
|
} else if ev.HasDefault {
|
||||||
|
fmt.Fprintf(w, " %s%s : %s (default: %s)\n", ev.Name, padding, ev.Description, ev.Default)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(w, " %s%s : %s\n", ev.Name, padding, ev.Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintln(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
29
internal/config/httpserver.go
Normal file
29
internal/config/httpserver.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/env"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HWSConfig struct {
|
||||||
|
Host string // ENV HWS_HOST: Host to listen on (default: 127.0.0.1)
|
||||||
|
Port uint64 // ENV HWS_PORT: Port to listen on (default: 3000)
|
||||||
|
GZIP bool // ENV HWS_GZIP: Flag for GZIP compression on requests (default: false)
|
||||||
|
ReadHeaderTimeout time.Duration // ENV HWS_READ_HEADER_TIMEOUT: Timeout for reading request headers in seconds (default: 2)
|
||||||
|
WriteTimeout time.Duration // ENV HWS_WRITE_TIMEOUT: Timeout for writing requests in seconds (default: 10)
|
||||||
|
IdleTimeout time.Duration // ENV HWS_IDLE_TIMEOUT: Timeout for idle connections in seconds (default: 120)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupHWS() (*HWSConfig, error) {
|
||||||
|
cfg := &HWSConfig{
|
||||||
|
Host: env.String("HWS_HOST", "127.0.0.1"),
|
||||||
|
Port: env.UInt64("HWS_PORT", 3000),
|
||||||
|
GZIP: env.Bool("HWS_GZIP", false),
|
||||||
|
ReadHeaderTimeout: time.Duration(env.Int("HWS_READ_HEADER_TIMEOUT", 2)) * time.Second,
|
||||||
|
WriteTimeout: time.Duration(env.Int("HWS_WRITE_TIMEOUT", 10)) * time.Second,
|
||||||
|
IdleTimeout: time.Duration(env.Int("HWS_IDLE_TIMEOUT", 120)) * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
36
internal/config/logger.go
Normal file
36
internal/config/logger.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.haelnorr.com/h/golib/env"
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HLOGConfig struct {
|
||||||
|
// ENV LOG_LEVEL: Log level for global logging. (default: info)
|
||||||
|
LogLevel hlog.Level
|
||||||
|
|
||||||
|
// ENV LOG_OUTPUT: Output method for the logger. (default: console)
|
||||||
|
// Valid options: "file", "console", "both"
|
||||||
|
LogOutput string
|
||||||
|
|
||||||
|
// ENV LOG_DIR: Path to create log files
|
||||||
|
LogDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupHLOG() (*HLOGConfig, error) {
|
||||||
|
logLevel, err := hlog.LogLevel(env.String("LOG_LEVEL", "info"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hlog.LogLevel")
|
||||||
|
}
|
||||||
|
logOutput := env.String("LOG_OUTPUT", "console")
|
||||||
|
if logOutput != "both" && logOutput != "console" && logOutput != "file" {
|
||||||
|
return nil, errors.Errorf("Invalid LOG_OUTPUT: %s", logOutput)
|
||||||
|
}
|
||||||
|
cfg := &HLOGConfig{
|
||||||
|
LogLevel: logLevel,
|
||||||
|
LogOutput: logOutput,
|
||||||
|
LogDir: env.String("LOG_DIR", ""),
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
28
internal/config/tmdb.go
Normal file
28
internal/config/tmdb.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.haelnorr.com/h/golib/env"
|
||||||
|
"git.haelnorr.com/h/golib/tmdb"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TMDBConfig struct {
|
||||||
|
Token string // ENV TMDB_TOKEN: API token for TMDB (required)
|
||||||
|
Config *tmdb.Config // Config data for interfacing with TMDB
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTMDB() (*TMDBConfig, error) {
|
||||||
|
token := env.String("TMDB_TOKEN", "")
|
||||||
|
if token == "" {
|
||||||
|
return nil, errors.New("No TMDB API Token provided")
|
||||||
|
}
|
||||||
|
tmdbcfg, err := tmdb.GetConfig(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdb.GetConfig")
|
||||||
|
}
|
||||||
|
cfg := &TMDBConfig{
|
||||||
|
Token: token,
|
||||||
|
Config: tmdbcfg,
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
@@ -5,14 +5,17 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/contexts"
|
"projectreshoot/internal/models"
|
||||||
"projectreshoot/cookies"
|
"projectreshoot/internal/view/component/account"
|
||||||
"projectreshoot/db"
|
"projectreshoot/internal/view/page"
|
||||||
"projectreshoot/view/component/account"
|
|
||||||
"projectreshoot/view/page"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Renders the account page on the 'General' subpage
|
// Renders the account page on the 'General' subpage
|
||||||
@@ -43,8 +46,9 @@ func AccountSubpage() http.Handler {
|
|||||||
|
|
||||||
// Handles a request to change the users username
|
// Handles a request to change the users username
|
||||||
func ChangeUsername(
|
func ChangeUsername(
|
||||||
logger *zerolog.Logger,
|
server *hws.Server,
|
||||||
conn *db.SafeConn,
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
|
db *bun.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -52,19 +56,31 @@ func ChangeUsername(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Error updating username")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Message: "Error updating username",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
newUsername := r.FormValue("username")
|
newUsername := r.FormValue("username")
|
||||||
unique, err := db.CheckUsernameUnique(ctx, tx, newUsername)
|
unique, err := models.IsUsernameUnique(ctx, tx, newUsername)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
logger.Error().Err(err).Msg("Error updating username")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Error updating username",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !unique {
|
if !unique {
|
||||||
@@ -73,12 +89,18 @@ func ChangeUsername(
|
|||||||
Render(r.Context(), w)
|
Render(r.Context(), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := contexts.GetUser(r.Context())
|
user := auth.CurrentModel(r.Context())
|
||||||
err = user.ChangeUsername(ctx, tx, newUsername)
|
err = user.ChangeUsername(ctx, tx, newUsername)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
logger.Error().Err(err).Msg("Error updating username")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Error updating username",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
@@ -89,8 +111,9 @@ func ChangeUsername(
|
|||||||
|
|
||||||
// Handles a request to change the users bio
|
// Handles a request to change the users bio
|
||||||
func ChangeBio(
|
func ChangeBio(
|
||||||
logger *zerolog.Logger,
|
server *hws.Server,
|
||||||
conn *db.SafeConn,
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
|
db *bun.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -98,10 +121,16 @@ func ChangeBio(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Error updating bio")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Message: "Error updating bio",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
@@ -113,12 +142,18 @@ func ChangeBio(
|
|||||||
Render(r.Context(), w)
|
Render(r.Context(), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := contexts.GetUser(r.Context())
|
user := auth.CurrentModel(r.Context())
|
||||||
err = user.ChangeBio(ctx, tx, newBio)
|
err = user.ChangeBio(ctx, tx, newBio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
logger.Error().Err(err).Msg("Error updating bio")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Error updating bio",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
@@ -127,8 +162,6 @@ func ChangeBio(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
func validateChangePassword(
|
func validateChangePassword(
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
@@ -145,8 +178,9 @@ func validateChangePassword(
|
|||||||
|
|
||||||
// Handles a request to change the users password
|
// Handles a request to change the users password
|
||||||
func ChangePassword(
|
func ChangePassword(
|
||||||
logger *zerolog.Logger,
|
server *hws.Server,
|
||||||
conn *db.SafeConn,
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
|
db *bun.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -154,24 +188,36 @@ func ChangePassword(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Error updating password")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Message: "Error updating password",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
newPass, err := validateChangePassword(ctx, tx, r)
|
newPass, err := validateChangePassword(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
account.ChangePassword(err.Error()).Render(r.Context(), w)
|
account.ChangePassword(err.Error()).Render(r.Context(), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := contexts.GetUser(r.Context())
|
user := auth.CurrentModel(r.Context())
|
||||||
err = user.SetPassword(ctx, tx, newPass)
|
err = user.SetPassword(ctx, tx, newPass)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
logger.Error().Err(err).Msg("Error updating password")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Error updating password",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
@@ -2,15 +2,17 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"projectreshoot/view/page"
|
"projectreshoot/internal/view/page"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ErrorPage(
|
func ErrorPage(
|
||||||
errorCode int,
|
errorCode int,
|
||||||
w http.ResponseWriter,
|
) (hws.ErrorPage, error) {
|
||||||
r *http.Request,
|
messages := map[int]string{
|
||||||
) {
|
400: "The request you made was malformed or unexpected.",
|
||||||
message := map[int]string{
|
|
||||||
401: "You need to login to view this page.",
|
401: "You need to login to view this page.",
|
||||||
403: "You do not have permission to view this page.",
|
403: "You do not have permission to view this page.",
|
||||||
404: "The page or resource you have requested does not exist.",
|
404: "The page or resource you have requested does not exist.",
|
||||||
@@ -18,7 +20,9 @@ func ErrorPage(
|
|||||||
continues to happen contact an administrator.`,
|
continues to happen contact an administrator.`,
|
||||||
503: "The server is currently down for maintenance and should be back soon. =)",
|
503: "The server is currently down for maintenance and should be back soon. =)",
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
msg, exists := messages[errorCode]
|
||||||
page.Error(errorCode, http.StatusText(errorCode), message[errorCode]).
|
if !exists {
|
||||||
Render(r.Context(), w)
|
return nil, errors.New("No valid message for the given code")
|
||||||
|
}
|
||||||
|
return page.Error(errorCode, http.StatusText(errorCode), msg), nil
|
||||||
}
|
}
|
||||||
@@ -3,7 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"projectreshoot/view/page"
|
"projectreshoot/internal/view/page"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handles responses to the / path. Also serves a 404 Page for paths that
|
// Handles responses to the / path. Also serves a 404 Page for paths that
|
||||||
@@ -12,7 +12,14 @@ func Root() http.Handler {
|
|||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path != "/" {
|
if r.URL.Path != "/" {
|
||||||
ErrorPage(http.StatusNotFound, w, r)
|
page, err := ErrorPage(http.StatusNotFound)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: add logger for this
|
||||||
|
}
|
||||||
|
err = page.Render(r.Context(), w)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: add logger for this
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
page.Index().Render(r.Context(), w)
|
page.Index().Render(r.Context(), w)
|
||||||
@@ -3,34 +3,44 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/internal/models"
|
||||||
"projectreshoot/cookies"
|
"projectreshoot/internal/view/component/form"
|
||||||
"projectreshoot/db"
|
"projectreshoot/internal/view/page"
|
||||||
"projectreshoot/view/component/form"
|
|
||||||
"projectreshoot/view/page"
|
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Validates the username matches a user in the database and the password
|
// Validates the username matches a user in the database and the password
|
||||||
// is correct. Returns the corresponding user
|
// is correct. Returns the corresponding user
|
||||||
func validateLogin(
|
func validateLogin(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
tx *db.SafeTX,
|
tx bun.Tx,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (*db.User, error) {
|
) (*models.UserBun, error) {
|
||||||
formUsername := r.FormValue("username")
|
formUsername := r.FormValue("username")
|
||||||
formPassword := r.FormValue("password")
|
formPassword := r.FormValue("password")
|
||||||
user, err := db.GetUserFromUsername(ctx, tx, formUsername)
|
user, err := models.GetUserByUsername(ctx, tx, formUsername)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "db.GetUserFromUsername")
|
return nil, errors.Wrap(err, "db.GetUserFromUsername")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = user.CheckPassword(formPassword)
|
if user == nil {
|
||||||
|
return nil, errors.New("Username or password incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = user.CheckPassword(ctx, tx, formPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "Username or password incorrect") {
|
||||||
|
return nil, errors.Wrap(err, "user.CheckPassword")
|
||||||
|
}
|
||||||
return nil, errors.New("Username or password incorrect")
|
return nil, errors.New("Username or password incorrect")
|
||||||
}
|
}
|
||||||
return user, nil
|
return user, nil
|
||||||
@@ -50,9 +60,9 @@ func checkRememberMe(r *http.Request) bool {
|
|||||||
// and on fail will return the login form again, passing the error to the
|
// and on fail will return the login form again, passing the error to the
|
||||||
// template for user feedback
|
// template for user feedback
|
||||||
func LoginRequest(
|
func LoginRequest(
|
||||||
config *config.Config,
|
server *hws.Server,
|
||||||
logger *zerolog.Logger,
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
conn *db.SafeConn,
|
db *bun.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -60,10 +70,16 @@ func LoginRequest(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Failed to set token cookies")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Message: "Login failed",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
@@ -71,20 +87,32 @@ func LoginRequest(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
if err.Error() != "Username or password incorrect" {
|
if err.Error() != "Username or password incorrect" {
|
||||||
logger.Warn().Caller().Err(err).Msg("Login request failed")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Login failed",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
form.LoginForm(err.Error()).Render(r.Context(), w)
|
form.LoginForm("Username or password incorrect").Render(r.Context(), w)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rememberMe := checkRememberMe(r)
|
rememberMe := checkRememberMe(r)
|
||||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
err = auth.Login(w, r, user, rememberMe)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Login failed",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
55
internal/handler/logout.go
Normal file
55
internal/handler/logout.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handle a logout request
|
||||||
|
func Logout(
|
||||||
|
server *hws.Server,
|
||||||
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
|
db *bun.DB,
|
||||||
|
) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Logout failed",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
err = auth.Logout(tx, w, r)
|
||||||
|
if err != nil {
|
||||||
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Logout failed",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
w.Header().Set("HX-Redirect", "/login")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
61
internal/handler/movie.go
Normal file
61
internal/handler/movie.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/view/page"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/tmdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Movie(
|
||||||
|
server *hws.Server,
|
||||||
|
cfg *config.TMDBConfig,
|
||||||
|
) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id := r.PathValue("movie_id")
|
||||||
|
movie_id, err := strconv.ParseInt(id, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Message: "Movie ID provided is not valid",
|
||||||
|
Error: err,
|
||||||
|
Level: hws.ErrorDEBUG,
|
||||||
|
RenderErrorPage: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
movie, err := tmdb.GetMovie(int32(movie_id), cfg.Token)
|
||||||
|
if err != nil {
|
||||||
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "An error occured when trying to retrieve the requested movie",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
credits, err := tmdb.GetCredits(int32(movie_id), cfg.Token)
|
||||||
|
if err != nil {
|
||||||
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "An error occured when trying to retrieve the credits for the requested movie",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
page.Movie(movie, credits, &cfg.Config.Image).Render(r.Context(), w)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
41
internal/handler/movie_search.go
Normal file
41
internal/handler/movie_search.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/view/component/search"
|
||||||
|
"projectreshoot/internal/view/page"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/tmdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SearchMovies(
|
||||||
|
cfg *config.TMDBConfig,
|
||||||
|
logger *hlog.Logger,
|
||||||
|
) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.ParseForm()
|
||||||
|
query := r.FormValue("search")
|
||||||
|
if query == "" {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
movies, err := tmdb.SearchMovies(cfg.Token, query, false, 1)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
search.MovieResults(movies, &cfg.Config.Image).Render(r.Context(), w)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MoviesPage() http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
page.Movies().Render(r.Context(), w)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -2,7 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"projectreshoot/view/page"
|
"projectreshoot/internal/view/page"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ProfilePage() http.Handler {
|
func ProfilePage() http.Handler {
|
||||||
82
internal/handler/reauthenticatate.go
Normal file
82
internal/handler/reauthenticatate.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
"projectreshoot/internal/view/component/form"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Validate the provided password
|
||||||
|
func validatePassword(
|
||||||
|
ctx context.Context,
|
||||||
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
|
tx bun.Tx,
|
||||||
|
r *http.Request,
|
||||||
|
) error {
|
||||||
|
r.ParseForm()
|
||||||
|
password := r.FormValue("password")
|
||||||
|
user := auth.CurrentModel(r.Context())
|
||||||
|
err := user.CheckPassword(ctx, tx, password)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "user.CheckPassword")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle request to reauthenticate (i.e. make token fresh again)
|
||||||
|
func Reauthenticate(
|
||||||
|
server *hws.Server,
|
||||||
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
|
db *bun.DB,
|
||||||
|
) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Start the transaction
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Failed to start transcation",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
err = validatePassword(ctx, auth, tx, r)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(445)
|
||||||
|
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = auth.RefreshAuthTokens(tx, w, r)
|
||||||
|
if err != nil {
|
||||||
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Failed to refresh user tokens",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -5,27 +5,29 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/internal/models"
|
||||||
"projectreshoot/cookies"
|
"projectreshoot/internal/view/component/form"
|
||||||
"projectreshoot/db"
|
"projectreshoot/internal/view/page"
|
||||||
"projectreshoot/view/component/form"
|
|
||||||
"projectreshoot/view/page"
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
func validateRegistration(
|
func validateRegistration(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
tx *db.SafeTX,
|
tx bun.Tx,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (*db.User, error) {
|
) (*models.UserBun, error) {
|
||||||
formUsername := r.FormValue("username")
|
formUsername := r.FormValue("username")
|
||||||
formPassword := r.FormValue("password")
|
formPassword := r.FormValue("password")
|
||||||
formConfirmPassword := r.FormValue("confirm-password")
|
formConfirmPassword := r.FormValue("confirm-password")
|
||||||
unique, err := db.CheckUsernameUnique(ctx, tx, formUsername)
|
unique, err := models.IsUsernameUnique(ctx, tx, formUsername)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "db.CheckUsernameUnique")
|
return nil, errors.Wrap(err, "models.CheckUsernameUnique")
|
||||||
}
|
}
|
||||||
if !unique {
|
if !unique {
|
||||||
return nil, errors.New("Username is taken")
|
return nil, errors.New("Username is taken")
|
||||||
@@ -36,18 +38,18 @@ func validateRegistration(
|
|||||||
if len(formPassword) > 72 {
|
if len(formPassword) > 72 {
|
||||||
return nil, errors.New("Password exceeds maximum length of 72 bytes")
|
return nil, errors.New("Password exceeds maximum length of 72 bytes")
|
||||||
}
|
}
|
||||||
user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword)
|
user, err := models.CreateUser(ctx, tx, formUsername, formPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "db.CreateNewUser")
|
return nil, errors.Wrap(err, "models.CreateNewUser")
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterRequest(
|
func RegisterRequest(
|
||||||
config *config.Config,
|
server *hws.Server,
|
||||||
logger *zerolog.Logger,
|
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||||
conn *db.SafeConn,
|
db *bun.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -55,10 +57,16 @@ func RegisterRequest(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Failed to set token cookies")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Message: "Failed to start transaction",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
@@ -68,8 +76,14 @@ func RegisterRequest(
|
|||||||
if err.Error() != "Username is taken" &&
|
if err.Error() != "Username is taken" &&
|
||||||
err.Error() != "Passwords do not match" &&
|
err.Error() != "Passwords do not match" &&
|
||||||
err.Error() != "Password exceeds maximum length of 72 bytes" {
|
err.Error() != "Password exceeds maximum length of 72 bytes" {
|
||||||
logger.Warn().Caller().Err(err).Msg("Registration request failed")
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Registration failed",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
form.RegisterForm(err.Error()).Render(r.Context(), w)
|
form.RegisterForm(err.Error()).Render(r.Context(), w)
|
||||||
}
|
}
|
||||||
@@ -77,11 +91,17 @@ func RegisterRequest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
rememberMe := checkRememberMe(r)
|
rememberMe := checkRememberMe(r)
|
||||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
err = auth.Login(w, r, user, rememberMe)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
err := server.ThrowError(w, r, hws.HWSError{
|
||||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Message: "Login failed",
|
||||||
|
Error: err,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowFatal(w, err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
24
internal/handler/static.go
Normal file
24
internal/handler/static.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handles requests for static files, without allowing access to the
|
||||||
|
// directory viewer and returning 404 if an exact file is not found
|
||||||
|
func StaticFS(staticFS *http.FileSystem, logger *hlog.Logger) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fs, err := hws.SafeFileServer(staticFS)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
logger.Error().Err(err).Msg("Failed to load file system")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fs.ServeHTTP(w, r)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
73
internal/models/user.go
Normal file
73
internal/models/user.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
id int // Integer ID (index primary key)
|
||||||
|
Username string // Username (unique)
|
||||||
|
Created_at int64 // Epoch timestamp when the user was added to the database
|
||||||
|
Bio string // Short byline set by the user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u User) ID() int {
|
||||||
|
return u.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses bcrypt to set the users Password_hash from the given password
|
||||||
|
func (user *User) SetPassword(
|
||||||
|
tx *sql.Tx,
|
||||||
|
password string,
|
||||||
|
) error {
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||||
|
}
|
||||||
|
newPassword := string(hashedPassword)
|
||||||
|
query := `UPDATE users SET password_hash = ? WHERE id = ?`
|
||||||
|
_, err = tx.Exec(query, newPassword, user.id)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Exec")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses bcrypt to check if the given password matches the users Password_hash
|
||||||
|
func (user *User) CheckPassword(tx *sql.Tx, password string) error {
|
||||||
|
query := `SELECT password_hash FROM users WHERE id = ? LIMIT 1`
|
||||||
|
row := tx.QueryRow(query, user.id)
|
||||||
|
var hashedPassword string
|
||||||
|
err := row.Scan(&hashedPassword)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "row.Scan")
|
||||||
|
}
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Username or password incorrect")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the user's username
|
||||||
|
func (user *User) ChangeUsername(tx *sql.Tx, newUsername string) error {
|
||||||
|
query := `UPDATE users SET username = ? WHERE id = ?`
|
||||||
|
_, err := tx.Exec(query, newUsername, user.id)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Exec")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the user's bio
|
||||||
|
func (user *User) ChangeBio(tx *sql.Tx, newBio string) error {
|
||||||
|
query := `UPDATE users SET bio = ? WHERE id = ?`
|
||||||
|
_, err := tx.Exec(query, newBio, user.id)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Exec")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
163
internal/models/user_bun.go
Normal file
163
internal/models/user_bun.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserBun struct {
|
||||||
|
bun.BaseModel `bun:"table:users,alias:u"`
|
||||||
|
|
||||||
|
ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key)
|
||||||
|
Username string `bun:"username,unique"` // Username (unique)
|
||||||
|
PasswordHash string `bun:"password_hash,nullzero"` // Bcrypt hashed password (not exported in JSON)
|
||||||
|
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
|
||||||
|
Bio string `bun:"bio"` // Short byline set by the user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *UserBun) GetID() int {
|
||||||
|
return user.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses bcrypt to set the users password_hash from the given password
|
||||||
|
func (user *UserBun) SetPassword(ctx context.Context, tx bun.Tx, password string) error {
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||||
|
}
|
||||||
|
newPassword := string(hashedPassword)
|
||||||
|
|
||||||
|
_, err = tx.NewUpdate().
|
||||||
|
Model(user).
|
||||||
|
Set("password_hash = ?", newPassword).
|
||||||
|
Where("id = ?", user.ID).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Update")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses bcrypt to check if the given password matches the users password_hash
|
||||||
|
func (user *UserBun) CheckPassword(ctx context.Context, tx bun.Tx, password string) error {
|
||||||
|
var hashedPassword string
|
||||||
|
err := tx.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Column("password_hash").
|
||||||
|
Where("id = ?", user.ID).
|
||||||
|
Limit(1).
|
||||||
|
Scan(ctx, &hashedPassword)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Select")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Username or password incorrect")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the user's username
|
||||||
|
func (user *UserBun) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername string) error {
|
||||||
|
_, err := tx.NewUpdate().
|
||||||
|
Model(user).
|
||||||
|
Set("username = ?", newUsername).
|
||||||
|
Where("id = ?", user.ID).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Update")
|
||||||
|
}
|
||||||
|
user.Username = newUsername
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the user's bio
|
||||||
|
func (user *UserBun) ChangeBio(ctx context.Context, tx bun.Tx, newBio string) error {
|
||||||
|
_, err := tx.NewUpdate().
|
||||||
|
Model(user).
|
||||||
|
Set("bio = ?", newBio).
|
||||||
|
Where("id = ?", user.ID).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Update")
|
||||||
|
}
|
||||||
|
user.Bio = newBio
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user with the given username and password
|
||||||
|
func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*UserBun, error) {
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &UserBun{
|
||||||
|
Username: username,
|
||||||
|
PasswordHash: string(hashedPassword),
|
||||||
|
CreatedAt: 0, // You may want to set this to time.Now().Unix()
|
||||||
|
Bio: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.NewInsert().
|
||||||
|
Model(user).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tx.Insert")
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByID queries the database for a user matching the given ID
|
||||||
|
// Returns nil, nil if no user is found
|
||||||
|
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*UserBun, error) {
|
||||||
|
user := new(UserBun)
|
||||||
|
err := tx.NewSelect().
|
||||||
|
Model(user).
|
||||||
|
Where("id = ?", id).
|
||||||
|
Limit(1).
|
||||||
|
Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == "sql: no rows in result set" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(err, "tx.Select")
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByUsername queries the database for a user matching the given username
|
||||||
|
// Returns nil, nil if no user is found
|
||||||
|
func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*UserBun, error) {
|
||||||
|
user := new(UserBun)
|
||||||
|
err := tx.NewSelect().
|
||||||
|
Model(user).
|
||||||
|
Where("username = ?", username).
|
||||||
|
Limit(1).
|
||||||
|
Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == "sql: no rows in result set" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(err, "tx.Select")
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUsernameUnique checks if the given username is unique (not already taken)
|
||||||
|
// Returns true if the username is available, false if it's taken
|
||||||
|
func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) {
|
||||||
|
count, err := tx.NewSelect().
|
||||||
|
Model((*UserBun)(nil)).
|
||||||
|
Where("username = ?", username).
|
||||||
|
Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "tx.Count")
|
||||||
|
}
|
||||||
|
return count == 0, nil
|
||||||
|
}
|
||||||
@@ -1,30 +1,29 @@
|
|||||||
package db
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Creates a new user in the database and returns a pointer
|
// Creates a new user in the database and returns a pointer
|
||||||
func CreateNewUser(
|
func CreateNewUser(
|
||||||
ctx context.Context,
|
tx *sql.Tx,
|
||||||
tx *SafeTX,
|
|
||||||
username string,
|
username string,
|
||||||
password string,
|
password string,
|
||||||
) (*User, error) {
|
) (*User, error) {
|
||||||
query := `INSERT INTO users (username) VALUES (?)`
|
query := `INSERT INTO users (username) VALUES (?)`
|
||||||
_, err := tx.Exec(ctx, query, username)
|
_, err := tx.Exec(query, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "tx.Exec")
|
return nil, errors.Wrap(err, "tx.Exec")
|
||||||
}
|
}
|
||||||
user, err := GetUserFromUsername(ctx, tx, username)
|
user, err := GetUserFromUsername(tx, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "GetUserFromUsername")
|
return nil, errors.Wrap(err, "GetUserFromUsername")
|
||||||
}
|
}
|
||||||
err = user.SetPassword(ctx, tx, password)
|
err = user.SetPassword(tx, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "user.SetPassword")
|
return nil, errors.Wrap(err, "user.SetPassword")
|
||||||
}
|
}
|
||||||
@@ -33,23 +32,23 @@ func CreateNewUser(
|
|||||||
|
|
||||||
// Fetches data from the users table using "WHERE column = 'value'"
|
// Fetches data from the users table using "WHERE column = 'value'"
|
||||||
func fetchUserData(
|
func fetchUserData(
|
||||||
ctx context.Context,
|
tx interface {
|
||||||
tx *SafeTX,
|
Query(query string, args ...any) (*sql.Rows, error)
|
||||||
|
},
|
||||||
column string,
|
column string,
|
||||||
value interface{},
|
value any,
|
||||||
) (*sql.Rows, error) {
|
) (*sql.Rows, error) {
|
||||||
query := fmt.Sprintf(
|
query := fmt.Sprintf(
|
||||||
`SELECT
|
`SELECT
|
||||||
id,
|
id,
|
||||||
username,
|
username,
|
||||||
password_hash,
|
|
||||||
created_at,
|
created_at,
|
||||||
bio
|
bio
|
||||||
FROM users
|
FROM users
|
||||||
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
||||||
column,
|
column,
|
||||||
)
|
)
|
||||||
rows, err := tx.Query(ctx, query, value)
|
rows, err := tx.Query(query, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "tx.Query")
|
return nil, errors.Wrap(err, "tx.Query")
|
||||||
}
|
}
|
||||||
@@ -63,9 +62,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
|
|||||||
return errors.New("User not found")
|
return errors.New("User not found")
|
||||||
}
|
}
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&user.ID,
|
&user.id,
|
||||||
&user.Username,
|
&user.Username,
|
||||||
&user.Password_hash,
|
|
||||||
&user.Created_at,
|
&user.Created_at,
|
||||||
&user.Bio,
|
&user.Bio,
|
||||||
)
|
)
|
||||||
@@ -77,8 +75,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
|
|||||||
|
|
||||||
// Queries the database for a user matching the given username.
|
// Queries the database for a user matching the given username.
|
||||||
// Query is case insensitive
|
// Query is case insensitive
|
||||||
func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) {
|
func GetUserFromUsername(tx *sql.Tx, username string) (*User, error) {
|
||||||
rows, err := fetchUserData(ctx, tx, "username", username)
|
rows, err := fetchUserData(tx, "username", username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "fetchUserData")
|
return nil, errors.Wrap(err, "fetchUserData")
|
||||||
}
|
}
|
||||||
@@ -92,8 +90,8 @@ func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*Use
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Queries the database for a user matching the given ID.
|
// Queries the database for a user matching the given ID.
|
||||||
func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
func GetUserFromID(tx hwsauth.DBTransaction, id int) (*User, error) {
|
||||||
rows, err := fetchUserData(ctx, tx, "id", id)
|
rows, err := fetchUserData(tx, "id", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "fetchUserData")
|
return nil, errors.Wrap(err, "fetchUserData")
|
||||||
}
|
}
|
||||||
@@ -107,9 +105,9 @@ func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Checks if the given username is unique. Returns true if not taken
|
// Checks if the given username is unique. Returns true if not taken
|
||||||
func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) {
|
func CheckUsernameUnique(tx *sql.Tx, username string) (bool, error) {
|
||||||
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
|
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
|
||||||
rows, err := tx.Query(ctx, query, username)
|
rows, err := tx.Query(query, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "tx.Query")
|
return false, errors.Wrap(err, "tx.Query")
|
||||||
}
|
}
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
package account
|
package account
|
||||||
|
|
||||||
import "projectreshoot/contexts"
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
templ ChangeBio(err string, bio string) {
|
templ ChangeBio(err string, bio string) {
|
||||||
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
{{
|
{{
|
||||||
user := contexts.GetUser(ctx)
|
if bio == "" {
|
||||||
if bio == "" {
|
bio = user.Bio
|
||||||
bio = user.Bio
|
}
|
||||||
}
|
|
||||||
}}
|
}}
|
||||||
<form
|
<form
|
||||||
hx-post="/change-bio"
|
hx-post="/change-bio"
|
||||||
@@ -16,40 +16,39 @@ templ ChangeBio(err string, bio string) {
|
|||||||
x-data={ templ.JSFuncCall("bioComponent", bio, user.Bio, err).CallInline }
|
x-data={ templ.JSFuncCall("bioComponent", bio, user.Bio, err).CallInline }
|
||||||
>
|
>
|
||||||
<script>
|
<script>
|
||||||
function bioComponent(newBio, oldBio, err) {
|
function bioComponent(newBio, oldBio, err) {
|
||||||
return {
|
return {
|
||||||
bio: newBio,
|
bio: newBio,
|
||||||
initialBio: oldBio,
|
initialBio: oldBio,
|
||||||
err: err,
|
err: err,
|
||||||
bioLenText: '',
|
bioLenText: "",
|
||||||
updateTextArea() {
|
updateTextArea() {
|
||||||
this.$nextTick(() => {
|
this.$nextTick(() => {
|
||||||
if (this.$refs.bio) {
|
if (this.$refs.bio) {
|
||||||
this.$refs.bio.style.height = 'auto';
|
this.$refs.bio.style.height = "auto";
|
||||||
this.$refs.bio.style.height = `
|
this.$refs.bio.style.height = `
|
||||||
${this.$refs.bio.scrollHeight+20}px`;
|
${this.$refs.bio.scrollHeight + 20}px`;
|
||||||
};
|
}
|
||||||
this.bioLenText = `${this.bio.length}/128`;
|
this.bioLenText = `${this.bio.length}/128`;
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
resetBio() {
|
resetBio() {
|
||||||
this.bio = this.initialBio;
|
this.bio = this.initialBio;
|
||||||
this.err = "",
|
((this.err = ""), this.updateTextArea());
|
||||||
this.updateTextArea();
|
},
|
||||||
},
|
init() {
|
||||||
init() {
|
this.$nextTick(() => {
|
||||||
this.$nextTick(() => {
|
// this timeout makes sure the textarea resizes on
|
||||||
// this timeout makes sure the textarea resizes on
|
// page render correctly. seems 20ms is the sweet
|
||||||
// page render correctly. seems 20ms is the sweet
|
// spot between a noticable delay and not working
|
||||||
// spot between a noticable delay and not working
|
setTimeout(() => {
|
||||||
setTimeout(() => {
|
this.updateTextArea();
|
||||||
this.updateTextArea();
|
}, 20);
|
||||||
}, 20);
|
});
|
||||||
});
|
},
|
||||||
}
|
};
|
||||||
};
|
}
|
||||||
}
|
</script>
|
||||||
</script>
|
|
||||||
<div
|
<div
|
||||||
class="flex flex-col"
|
class="flex flex-col"
|
||||||
>
|
>
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
package account
|
package account
|
||||||
|
|
||||||
import "projectreshoot/contexts"
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
templ ChangeUsername(err string, username string) {
|
templ ChangeUsername(err string, username string) {
|
||||||
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
{{
|
{{
|
||||||
user := contexts.GetUser(ctx)
|
if username == "" {
|
||||||
if username == "" {
|
username = user.Username
|
||||||
username = user.Username
|
}
|
||||||
}
|
|
||||||
}}
|
}}
|
||||||
<form
|
<form
|
||||||
hx-post="/change-username"
|
hx-post="/change-username"
|
||||||
@@ -18,18 +18,18 @@ templ ChangeUsername(err string, username string) {
|
|||||||
).CallInline }
|
).CallInline }
|
||||||
>
|
>
|
||||||
<script>
|
<script>
|
||||||
function usernameComponent(newUsername, oldUsername, err) {
|
function usernameComponent(newUsername, oldUsername, err) {
|
||||||
return {
|
return {
|
||||||
username: newUsername,
|
username: newUsername,
|
||||||
initialUsername: oldUsername,
|
initialUsername: oldUsername,
|
||||||
err: err,
|
err: err,
|
||||||
resetUsername() {
|
resetUsername() {
|
||||||
this.username = this.initialUsername;
|
this.username = this.initialUsername;
|
||||||
this.err = "";
|
this.err = "";
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
<div
|
<div
|
||||||
class="flex flex-col sm:flex-row"
|
class="flex flex-col sm:flex-row"
|
||||||
>
|
>
|
||||||
@@ -3,12 +3,12 @@ package account
|
|||||||
templ AccountContainer(subpage string) {
|
templ AccountContainer(subpage string) {
|
||||||
<div
|
<div
|
||||||
id="account-container"
|
id="account-container"
|
||||||
class="flex max-w-200 min-h-100 mx-auto bg-mantle mt-10 rounded-xl"
|
class="flex max-w-200 min-h-100 mx-5 md:mx-auto bg-mantle mt-5 rounded-xl"
|
||||||
x-data="{big:window.innerWidth >=768, open:false}"
|
x-data="{big:window.innerWidth >=768, open:false}"
|
||||||
@resize.window="big = window.innerWidth >= 768"
|
@resize.window="big = window.innerWidth >= 768"
|
||||||
>
|
>
|
||||||
@SelectMenu(subpage)
|
@SelectMenu(subpage)
|
||||||
<div class="mt-5 w-full md:ml-[200px] ml-[40px] transition-all duration-300">
|
<div class="mt-5 w-full md:ml-[200px] ml-10 transition-all duration-300">
|
||||||
<div
|
<div
|
||||||
class="pl-5 text-2xl text-subtext1 border-b
|
class="pl-5 text-2xl text-subtext1 border-b
|
||||||
border-overlay0 w-[90%] mx-auto"
|
border-overlay0 w-[90%] mx-auto"
|
||||||
@@ -75,8 +75,13 @@ templ Footer() {
|
|||||||
</div>
|
</div>
|
||||||
<div class="lg:flex lg:items-end lg:justify-between">
|
<div class="lg:flex lg:items-end lg:justify-between">
|
||||||
<div>
|
<div>
|
||||||
<p class="mt-4 text-center text-sm text-subtext0">
|
<p class="mt-4 text-center text-sm text-overlay0">
|
||||||
by Haelnorr
|
by Haelnorr |
|
||||||
|
<a href="#">Film data</a> from
|
||||||
|
<a
|
||||||
|
href="https://www.themoviedb.org/"
|
||||||
|
class="underline hover:text-subtext0 transition"
|
||||||
|
>TMDB</a>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
@@ -44,7 +44,7 @@ templ RegisterForm(registerError string) {
|
|||||||
<div class="relative">
|
<div class="relative">
|
||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
idnutanix="username"
|
id="username"
|
||||||
name="username"
|
name="username"
|
||||||
class="py-3 px-4 block w-full rounded-lg text-sm
|
class="py-3 px-4 block w-full rounded-lg text-sm
|
||||||
focus:border-blue focus:ring-blue bg-base
|
focus:border-blue focus:ring-blue bg-base
|
||||||
@@ -21,7 +21,7 @@ templ Navbar() {
|
|||||||
<div x-data="{ open: false }">
|
<div x-data="{ open: false }">
|
||||||
<header class="bg-crust">
|
<header class="bg-crust">
|
||||||
<div
|
<div
|
||||||
class="mx-auto flex h-16 max-w-screen-xl items-center gap-8
|
class="mx-auto flex h-16 max-w-7xl items-center gap-8
|
||||||
px-4 sm:px-6 lg:px-8"
|
px-4 sm:px-6 lg:px-8"
|
||||||
>
|
>
|
||||||
<a class="block" href="/">
|
<a class="block" href="/">
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package nav
|
package nav
|
||||||
|
|
||||||
import "projectreshoot/contexts"
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
type ProfileItem struct {
|
type ProfileItem struct {
|
||||||
name string // Label to display
|
name string // Label to display
|
||||||
@@ -23,7 +23,7 @@ func getProfileItems() []ProfileItem {
|
|||||||
|
|
||||||
// Returns the right portion of the navbar
|
// Returns the right portion of the navbar
|
||||||
templ navRight() {
|
templ navRight() {
|
||||||
{{ user := contexts.GetUser(ctx) }}
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
{{ items := getProfileItems() }}
|
{{ items := getProfileItems() }}
|
||||||
<div class="flex items-center gap-2">
|
<div class="flex items-center gap-2">
|
||||||
<div class="sm:flex sm:gap-2">
|
<div class="sm:flex sm:gap-2">
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package nav
|
package nav
|
||||||
|
|
||||||
import "projectreshoot/contexts"
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
// Returns the mobile version of the navbar thats only visible when activated
|
// Returns the mobile version of the navbar thats only visible when activated
|
||||||
templ sideNav(navItems []NavItem) {
|
templ sideNav(navItems []NavItem) {
|
||||||
{{ user := contexts.GetUser(ctx) }}
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
<div
|
<div
|
||||||
x-show="open"
|
x-show="open"
|
||||||
x-transition
|
x-transition
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
package popup
|
package popup
|
||||||
|
|
||||||
import "projectreshoot/view/component/form"
|
import "projectreshoot/internal/view/component/form"
|
||||||
|
|
||||||
templ ConfirmPasswordModal() {
|
templ ConfirmPasswordModal() {
|
||||||
<div
|
<div
|
||||||
44
internal/view/component/search/movies_results.templ
Normal file
44
internal/view/component/search/movies_results.templ
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package search
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
import "git.haelnorr.com/h/golib/tmdb"
|
||||||
|
|
||||||
|
templ MovieResults(movies *tmdb.ResultMovies, image *tmdb.Image) {
|
||||||
|
for _, movie := range movies.Results {
|
||||||
|
<a
|
||||||
|
href={ templ.SafeURL(fmt.Sprintf("/movie/%v", movie.ID)) }
|
||||||
|
class="bg-surface0 p-4 rounded-lg shadow-lg flex
|
||||||
|
items-start space-x-4 cursor-pointer transition-all
|
||||||
|
hover:outline hover:outline-green hover:shadow-xl hover:scale-105"
|
||||||
|
>
|
||||||
|
<img
|
||||||
|
src={ movie.GetPoster(image, "w92") }
|
||||||
|
alt="Movie Poster"
|
||||||
|
class="rounded-lg object-cover"
|
||||||
|
width="96"
|
||||||
|
height="144"
|
||||||
|
onerror="this.onerror=null; setFallbackColor(this);"
|
||||||
|
/>
|
||||||
|
<script>
|
||||||
|
function setFallbackColor(img) {
|
||||||
|
const baseColor = getComputedStyle(document.documentElement)
|
||||||
|
.getPropertyValue("--base")
|
||||||
|
.trim();
|
||||||
|
img.src = `data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='96' height='144'%3E%3Crect width='100%' height='100%' fill='${baseColor}'/%3E%3C/svg%3E`;
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
<div>
|
||||||
|
<h3 class="text-xl font-semibold">{ movie.Title } { movie.ReleaseYear() }</h3>
|
||||||
|
<p class="text-subtext0">
|
||||||
|
Released:
|
||||||
|
<span class="font-medium">{ movie.ReleaseDate }</span>
|
||||||
|
</p>
|
||||||
|
<p class="text-subtext0">
|
||||||
|
Original Title:
|
||||||
|
<span class="font-medium">{ movie.OriginalTitle }</span>
|
||||||
|
</p>
|
||||||
|
<p class="text-subtext0">{ movie.Overview }</p>
|
||||||
|
</div>
|
||||||
|
</a>
|
||||||
|
}
|
||||||
|
}
|
||||||
109
internal/view/layout/global.templ
Normal file
109
internal/view/layout/global.templ
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package layout
|
||||||
|
|
||||||
|
import "projectreshoot/internal/view/component/nav"
|
||||||
|
import "projectreshoot/internal/view/component/footer"
|
||||||
|
import "projectreshoot/internal/view/component/popup"
|
||||||
|
|
||||||
|
// Global page layout. Includes HTML document settings, header tags
|
||||||
|
// navbar and footer
|
||||||
|
templ Global(title string) {
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html
|
||||||
|
lang="en"
|
||||||
|
x-data="{
|
||||||
|
theme: localStorage.getItem('theme')
|
||||||
|
|| 'system'}"
|
||||||
|
x-init="$watch('theme', (val) => localStorage.setItem('theme', val))"
|
||||||
|
x-bind:class="{'dark': theme === 'dark' || (theme === 'system' &&
|
||||||
|
window.matchMedia('(prefers-color-scheme: dark)').matches)}"
|
||||||
|
>
|
||||||
|
<head>
|
||||||
|
<script>
|
||||||
|
(function () {
|
||||||
|
let theme = localStorage.getItem("theme") || "system";
|
||||||
|
if (theme === "system") {
|
||||||
|
theme = window.matchMedia("(prefers-color-scheme: dark)").matches
|
||||||
|
? "dark"
|
||||||
|
: "light";
|
||||||
|
}
|
||||||
|
if (theme === "dark") {
|
||||||
|
document.documentElement.classList.add("dark");
|
||||||
|
} else {
|
||||||
|
document.documentElement.classList.remove("dark");
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
</script>
|
||||||
|
<meta charset="UTF-8"/>
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||||
|
<title>{ title }</title>
|
||||||
|
<link rel="icon" type="image/x-icon" href="/static/favicon.ico"/>
|
||||||
|
<link href="/static/css/output.css" rel="stylesheet"/>
|
||||||
|
<script src="https://unpkg.com/htmx.org@2.0.4" integrity="sha384-HGfztofotfshcF7+8n44JQL2oJmowVChPTg48S+jvZoztPfvwD79OC/LTtG6dMp+" crossorigin="anonymous"></script>
|
||||||
|
<script defer src="https://cdn.jsdelivr.net/npm/@alpinejs/persist@3.x.x/dist/cdn.min.js"></script>
|
||||||
|
<script src="https://unpkg.com/alpinejs" defer></script>
|
||||||
|
<script>
|
||||||
|
// uncomment this line to enable logging of htmx events
|
||||||
|
// htmx.logAll();
|
||||||
|
</script>
|
||||||
|
<script>
|
||||||
|
const bodyData = {
|
||||||
|
showError500: false,
|
||||||
|
showError503: false,
|
||||||
|
showConfirmPasswordModal: false,
|
||||||
|
handleHtmxBeforeOnLoad(event) {
|
||||||
|
const requestPath = event.detail.pathInfo.requestPath;
|
||||||
|
if (requestPath === "/reauthenticate") {
|
||||||
|
// handle password incorrect on refresh attempt
|
||||||
|
if (event.detail.xhr.status === 445) {
|
||||||
|
event.detail.shouldSwap = true;
|
||||||
|
event.detail.isError = false;
|
||||||
|
} else if (event.detail.xhr.status === 200) {
|
||||||
|
this.showConfirmPasswordModal = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// handle errors from the server on HTMX requests
|
||||||
|
handleHtmxError(event) {
|
||||||
|
const errorCode = event.detail.errorInfo.error;
|
||||||
|
|
||||||
|
// internal server error
|
||||||
|
if (errorCode.includes("Code 500")) {
|
||||||
|
this.showError500 = true;
|
||||||
|
setTimeout(() => (this.showError500 = false), 6000);
|
||||||
|
}
|
||||||
|
// service not available error
|
||||||
|
if (errorCode.includes("Code 503")) {
|
||||||
|
this.showError503 = true;
|
||||||
|
setTimeout(() => (this.showError503 = false), 6000);
|
||||||
|
}
|
||||||
|
|
||||||
|
// user is authorized but needs to refresh their login
|
||||||
|
if (errorCode.includes("Code 444")) {
|
||||||
|
this.showConfirmPasswordModal = true;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body
|
||||||
|
class="bg-base text-text ubuntu-mono-regular overflow-x-hidden"
|
||||||
|
x-data="bodyData"
|
||||||
|
x-on:htmx:error="handleHtmxError($event)"
|
||||||
|
x-on:htmx:before-on-load="handleHtmxBeforeOnLoad($event)"
|
||||||
|
>
|
||||||
|
@popup.Error500Popup()
|
||||||
|
@popup.Error503Popup()
|
||||||
|
@popup.ConfirmPasswordModal()
|
||||||
|
<div
|
||||||
|
id="main-content"
|
||||||
|
class="flex flex-col h-screen justify-between"
|
||||||
|
>
|
||||||
|
@nav.Navbar()
|
||||||
|
<div id="page-content" class="mb-auto md:px-5 md:pt-5">
|
||||||
|
{ children... }
|
||||||
|
</div>
|
||||||
|
@footer.Footer()
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
}
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
|
|
||||||
// Returns the about page content
|
// Returns the about page content
|
||||||
templ About() {
|
templ About() {
|
||||||
@layout.Global() {
|
@layout.Global("About") {
|
||||||
<div class="text-center max-w-150 m-auto">
|
<div class="text-center max-w-150 m-auto">
|
||||||
<div class="text-4xl mt-8">About</div>
|
<div class="text-4xl mt-8">About</div>
|
||||||
<div class="text-xl font-bold mt-4">What is Project Reshoot?</div>
|
<div class="text-xl font-bold mt-4">What is Project Reshoot?</div>
|
||||||
10
internal/view/page/account.templ
Normal file
10
internal/view/page/account.templ
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package page
|
||||||
|
|
||||||
|
import "projectreshoot/internal/view/layout"
|
||||||
|
import "projectreshoot/internal/view/component/account"
|
||||||
|
|
||||||
|
templ Account(subpage string) {
|
||||||
|
@layout.Global("Account - " + subpage) {
|
||||||
|
@account.AccountContainer(subpage)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
import "strconv"
|
import "strconv"
|
||||||
|
|
||||||
// Page template for Error pages. Error code should be a HTTP status code as
|
// Page template for Error pages. Error code should be a HTTP status code as
|
||||||
// a string, and err should be the corresponding response title.
|
// a string, and err should be the corresponding response title.
|
||||||
// Message is a custom error message displayed below the code and error.
|
// Message is a custom error message displayed below the code and error.
|
||||||
templ Error(code int, err string, message string) {
|
templ Error(code int, err string, message string) {
|
||||||
@layout.Global() {
|
@layout.Global(err) {
|
||||||
<div
|
<div
|
||||||
class="grid mt-24 left-0 right-0 top-0 bottom-0
|
class="grid mt-24 left-0 right-0 top-0 bottom-0
|
||||||
place-content-center bg-base px-4"
|
place-content-center bg-base px-4"
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
|
|
||||||
// Page content for the index page
|
// Page content for the index page
|
||||||
templ Index() {
|
templ Index() {
|
||||||
@layout.Global() {
|
@layout.Global("Project Reshoot") {
|
||||||
<div class="text-center mt-24">
|
<div class="text-center mt-24">
|
||||||
<div class="text-4xl lg:text-6xl">Project Reshoot</div>
|
<div class="text-4xl lg:text-6xl">Project Reshoot</div>
|
||||||
<div>A better way to discover and rate films</div>
|
<div>A better way to discover and rate films</div>
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
import "projectreshoot/view/component/form"
|
import "projectreshoot/internal/view/component/form"
|
||||||
|
|
||||||
// Returns the login page
|
// Returns the login page
|
||||||
templ Login() {
|
templ Login() {
|
||||||
@layout.Global() {
|
@layout.Global("Login") {
|
||||||
<div class="max-w-100 mx-auto px-2">
|
<div class="max-w-100 mx-auto px-2">
|
||||||
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
|
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
|
||||||
<div class="p-4 sm:p-7">
|
<div class="p-4 sm:p-7">
|
||||||
90
internal/view/page/movie.templ
Normal file
90
internal/view/page/movie.templ
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package page
|
||||||
|
|
||||||
|
import "git.haelnorr.com/h/golib/tmdb"
|
||||||
|
import "projectreshoot/internal/view/layout"
|
||||||
|
|
||||||
|
templ Movie(movie *tmdb.Movie, credits *tmdb.Credits, image *tmdb.Image) {
|
||||||
|
@layout.Global(movie.Title) {
|
||||||
|
<div class="md:bg-surface0 md:p-2 md:rounded-lg transition-all">
|
||||||
|
<div id="billedcrew" class="hidden">
|
||||||
|
for _, billedcrew := range credits.BilledCrew() {
|
||||||
|
<span class="flex flex-col text-left w-[130px] md:w-[180px]">
|
||||||
|
<span class="font-bold">{ billedcrew.Name }</span>
|
||||||
|
<span class="text-subtext1">{ billedcrew.FRoles() }</span>
|
||||||
|
</span>
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
<div class="flex items-start">
|
||||||
|
<div class="w-[154px] md:w-[300px] flex-col">
|
||||||
|
<img
|
||||||
|
class="object-cover aspect-2/3 w-[154px] md:w-[300px]
|
||||||
|
transition-all md:rounded-md shadow-black shadow-2xl"
|
||||||
|
src={ movie.GetPoster(image, "w300") }
|
||||||
|
alt="Poster"
|
||||||
|
/>
|
||||||
|
<div
|
||||||
|
id="billedcrew-sm"
|
||||||
|
class="text-sm md:text-lg text-subtext1 flex gap-6
|
||||||
|
mt-5 flex-wrap justify-around flex-col px-5 md:hidden"
|
||||||
|
></div>
|
||||||
|
<script>
|
||||||
|
function moveBilledCrew() {
|
||||||
|
const billedCrewMd = document.getElementById('billedcrew-md');
|
||||||
|
const billedCrewSm = document.getElementById('billedcrew-sm');
|
||||||
|
const billedCrew = document.getElementById('billedcrew');
|
||||||
|
|
||||||
|
if (window.innerWidth < 768) {
|
||||||
|
billedCrewSm.innerHTML = billedCrew.innerHTML;
|
||||||
|
billedCrewMd.innerHTML = "";
|
||||||
|
} else {
|
||||||
|
billedCrewMd.innerHTML = billedCrew.innerHTML;
|
||||||
|
billedCrewSm.innerHTML = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
window.addEventListener('load', moveBilledCrew);
|
||||||
|
|
||||||
|
const resizeObs = new ResizeObserver(() => {
|
||||||
|
moveBilledCrew();
|
||||||
|
});
|
||||||
|
resizeObs.observe(document.body);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
<div class="flex flex-col flex-1 text-center px-4">
|
||||||
|
<span class="text-xl md:text-3xl font-semibold">
|
||||||
|
{ movie.Title }
|
||||||
|
</span>
|
||||||
|
<span class="text-sm md:text-lg text-subtext1">
|
||||||
|
{ movie.FGenres() }
|
||||||
|
• { movie.FRuntime() }
|
||||||
|
• { movie.ReleaseYear() }
|
||||||
|
</span>
|
||||||
|
<div class="flex justify-center gap-2 mt-2">
|
||||||
|
<div
|
||||||
|
class="w-20 h-20 md:w-30 md:h-30 bg-overlay2
|
||||||
|
transition-all rounded-sm"
|
||||||
|
></div>
|
||||||
|
<div
|
||||||
|
class="w-20 h-20 md:w-30 md:h-30 bg-overlay2
|
||||||
|
transition-all rounded-sm"
|
||||||
|
></div>
|
||||||
|
</div>
|
||||||
|
<div class="flex flex-col mt-4">
|
||||||
|
<span class="text-sm md:text-lg text-overlay2 italic">
|
||||||
|
{ movie.Tagline }
|
||||||
|
</span>
|
||||||
|
<div
|
||||||
|
id="billedcrew-md"
|
||||||
|
class="hidden text-sm md:text-lg text-subtext1 md:flex gap-6
|
||||||
|
mt-5 flex-wrap justify-around"
|
||||||
|
></div>
|
||||||
|
<span class="text-lg mt-5 font-semibold">Overview</span>
|
||||||
|
<span class="text-sm md:text-lg text-subtext1">
|
||||||
|
{ movie.Overview }
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
}
|
||||||
31
internal/view/page/movie_search.templ
Normal file
31
internal/view/page/movie_search.templ
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package page
|
||||||
|
|
||||||
|
import "projectreshoot/internal/view/layout"
|
||||||
|
|
||||||
|
templ Movies() {
|
||||||
|
@layout.Global("Search movies") {
|
||||||
|
<div class="max-w-4xl mx-auto md:mt-0 mt-2 px-2 md:px-0">
|
||||||
|
<form hx-post="/search-movies" hx-target="#search-movies-results">
|
||||||
|
<div
|
||||||
|
class="max-w-100 flex items-center space-x-2 mb-2"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
id="search"
|
||||||
|
name="search"
|
||||||
|
type="text"
|
||||||
|
placeholder="Search movies..."
|
||||||
|
class="grow p-2 border rounded-lg
|
||||||
|
bg-mantle border-surface2 shadow-sm
|
||||||
|
focus:outline-none focus:ring-2 focus:ring-blue"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="submit"
|
||||||
|
class="py-2 px-4 bg-green text-mantle rounded-lg transition
|
||||||
|
hover:cursor-pointer hover:bg-green/75"
|
||||||
|
>Search</button>
|
||||||
|
</div>
|
||||||
|
<div id="search-movies-results" class="space-y-4 pt-4"></div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
}
|
||||||
13
internal/view/page/profile.templ
Normal file
13
internal/view/page/profile.templ
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package page
|
||||||
|
|
||||||
|
import "projectreshoot/internal/view/layout"
|
||||||
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
|
templ Profile() {
|
||||||
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
|
@layout.Global("Profile - " + user.Username) {
|
||||||
|
<div class="">
|
||||||
|
Hello, { user.Username }
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
import "projectreshoot/view/component/form"
|
import "projectreshoot/internal/view/component/form"
|
||||||
|
|
||||||
// Returns the login page
|
// Returns the login page
|
||||||
templ Register() {
|
templ Register() {
|
||||||
@layout.Global() {
|
@layout.Global("Register") {
|
||||||
<div class="max-w-100 mx-auto px-2">
|
<div class="max-w-100 mx-auto px-2">
|
||||||
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
|
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
|
||||||
<div class="p-4 sm:p-7">
|
<div class="p-4 sm:p-7">
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package jwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Generates an access token for the provided user
|
|
||||||
func GenerateAccessToken(
|
|
||||||
config *config.Config,
|
|
||||||
user *db.User,
|
|
||||||
fresh bool,
|
|
||||||
rememberMe bool,
|
|
||||||
) (tokenStr string, exp int64, err error) {
|
|
||||||
issuedAt := time.Now().Unix()
|
|
||||||
expiresAt := issuedAt + (config.AccessTokenExpiry * 60)
|
|
||||||
var freshExpiresAt int64
|
|
||||||
if fresh {
|
|
||||||
freshExpiresAt = issuedAt + (config.TokenFreshTime * 60)
|
|
||||||
} else {
|
|
||||||
freshExpiresAt = issuedAt
|
|
||||||
}
|
|
||||||
var ttl string
|
|
||||||
if rememberMe {
|
|
||||||
ttl = "exp"
|
|
||||||
} else {
|
|
||||||
ttl = "session"
|
|
||||||
}
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
|
||||||
jwt.MapClaims{
|
|
||||||
"iss": config.TrustedHost,
|
|
||||||
"scope": "access",
|
|
||||||
"ttl": ttl,
|
|
||||||
"jti": uuid.New(),
|
|
||||||
"iat": issuedAt,
|
|
||||||
"exp": expiresAt,
|
|
||||||
"fresh": freshExpiresAt,
|
|
||||||
"sub": user.ID,
|
|
||||||
})
|
|
||||||
|
|
||||||
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
|
||||||
}
|
|
||||||
return signedToken, expiresAt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generates a refresh token for the provided user
|
|
||||||
func GenerateRefreshToken(
|
|
||||||
config *config.Config,
|
|
||||||
user *db.User,
|
|
||||||
rememberMe bool,
|
|
||||||
) (tokenStr string, exp int64, err error) {
|
|
||||||
issuedAt := time.Now().Unix()
|
|
||||||
expiresAt := issuedAt + (config.RefreshTokenExpiry * 60)
|
|
||||||
var ttl string
|
|
||||||
if rememberMe {
|
|
||||||
ttl = "exp"
|
|
||||||
} else {
|
|
||||||
ttl = "session"
|
|
||||||
}
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
|
||||||
jwt.MapClaims{
|
|
||||||
"iss": config.TrustedHost,
|
|
||||||
"scope": "refresh",
|
|
||||||
"ttl": ttl,
|
|
||||||
"jti": uuid.New(),
|
|
||||||
"iat": issuedAt,
|
|
||||||
"exp": expiresAt,
|
|
||||||
"sub": user.ID,
|
|
||||||
})
|
|
||||||
|
|
||||||
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
|
||||||
}
|
|
||||||
return signedToken, expiresAt, nil
|
|
||||||
}
|
|
||||||
268
jwt/parse.go
268
jwt/parse.go
@@ -1,268 +0,0 @@
|
|||||||
package jwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Parse an access token and return a struct with all the claims. Does validation on
|
|
||||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
|
||||||
// has the correct scope.
|
|
||||||
func ParseAccessToken(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
tokenString string,
|
|
||||||
) (*AccessToken, error) {
|
|
||||||
if tokenString == "" {
|
|
||||||
return nil, errors.New("Access token string not provided")
|
|
||||||
}
|
|
||||||
claims, err := parseToken(config.SecretKey, tokenString)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "parseToken")
|
|
||||||
}
|
|
||||||
expiry, err := checkTokenExpired(claims["exp"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
|
||||||
}
|
|
||||||
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
|
||||||
}
|
|
||||||
ttl, err := getTokenTTL(claims["ttl"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenTTL")
|
|
||||||
}
|
|
||||||
scope, err := getTokenScope(claims["scope"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenScope")
|
|
||||||
}
|
|
||||||
if scope != "access" {
|
|
||||||
return nil, errors.New("Token is not an Access token")
|
|
||||||
}
|
|
||||||
issuedAt, err := getIssuedTime(claims["iat"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getIssuedTime")
|
|
||||||
}
|
|
||||||
subject, err := getTokenSubject(claims["sub"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenSubject")
|
|
||||||
}
|
|
||||||
fresh, err := getFreshTime(claims["fresh"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getFreshTime")
|
|
||||||
}
|
|
||||||
jti, err := getTokenJTI(claims["jti"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenJTI")
|
|
||||||
}
|
|
||||||
|
|
||||||
token := &AccessToken{
|
|
||||||
ISS: issuer,
|
|
||||||
TTL: ttl,
|
|
||||||
EXP: expiry,
|
|
||||||
IAT: issuedAt,
|
|
||||||
SUB: subject,
|
|
||||||
Fresh: fresh,
|
|
||||||
JTI: jti,
|
|
||||||
Scope: scope,
|
|
||||||
}
|
|
||||||
|
|
||||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
|
||||||
}
|
|
||||||
if !valid {
|
|
||||||
return nil, errors.New("Token has been revoked")
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse a refresh token and return a struct with all the claims. Does validation on
|
|
||||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
|
||||||
// has the correct scope.
|
|
||||||
func ParseRefreshToken(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
tokenString string,
|
|
||||||
) (*RefreshToken, error) {
|
|
||||||
if tokenString == "" {
|
|
||||||
return nil, errors.New("Refresh token string not provided")
|
|
||||||
}
|
|
||||||
claims, err := parseToken(config.SecretKey, tokenString)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "parseToken")
|
|
||||||
}
|
|
||||||
expiry, err := checkTokenExpired(claims["exp"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
|
||||||
}
|
|
||||||
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
|
||||||
}
|
|
||||||
ttl, err := getTokenTTL(claims["ttl"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenTTL")
|
|
||||||
}
|
|
||||||
scope, err := getTokenScope(claims["scope"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenScope")
|
|
||||||
}
|
|
||||||
if scope != "refresh" {
|
|
||||||
return nil, errors.New("Token is not an Refresh token")
|
|
||||||
}
|
|
||||||
issuedAt, err := getIssuedTime(claims["iat"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getIssuedTime")
|
|
||||||
}
|
|
||||||
subject, err := getTokenSubject(claims["sub"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenSubject")
|
|
||||||
}
|
|
||||||
jti, err := getTokenJTI(claims["jti"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenJTI")
|
|
||||||
}
|
|
||||||
|
|
||||||
token := &RefreshToken{
|
|
||||||
ISS: issuer,
|
|
||||||
TTL: ttl,
|
|
||||||
EXP: expiry,
|
|
||||||
IAT: issuedAt,
|
|
||||||
SUB: subject,
|
|
||||||
JTI: jti,
|
|
||||||
Scope: scope,
|
|
||||||
}
|
|
||||||
|
|
||||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
|
||||||
}
|
|
||||||
if !valid {
|
|
||||||
return nil, errors.New("Token has been revoked")
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse a token, validating its signing sigature and returning the claims
|
|
||||||
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
|
|
||||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
||||||
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
|
||||||
}
|
|
||||||
|
|
||||||
return []byte(secretKey), nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "jwt.Parse")
|
|
||||||
}
|
|
||||||
// Token decoded, parse the claims
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("Failed to parse claims")
|
|
||||||
}
|
|
||||||
return claims, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if a token is expired. Returns the expiry if not expired
|
|
||||||
func checkTokenExpired(expiry interface{}) (int64, error) {
|
|
||||||
// Coerce the expiry to a float64 to avoid scientific notation
|
|
||||||
expFloat, ok := expiry.(float64)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("Missing or invalid 'exp' claim")
|
|
||||||
}
|
|
||||||
// Convert to the int64 time we expect :)
|
|
||||||
expiryTime := int64(expFloat)
|
|
||||||
|
|
||||||
// Check if its expired
|
|
||||||
isExpired := time.Now().After(time.Unix(expiryTime, 0))
|
|
||||||
if isExpired {
|
|
||||||
return 0, errors.New("Token has expired")
|
|
||||||
}
|
|
||||||
return expiryTime, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if a token has a valid issuer. Returns the issuer if valid
|
|
||||||
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
|
|
||||||
issuerVal, ok := issuer.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("Missing or invalid 'iss' claim")
|
|
||||||
}
|
|
||||||
if issuer != trustedHost {
|
|
||||||
return "", errors.New("Issuer does not matched trusted host")
|
|
||||||
}
|
|
||||||
return issuerVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the scope matches the expected scope. Returns scope if true
|
|
||||||
func getTokenScope(scope interface{}) (string, error) {
|
|
||||||
scopeStr, ok := scope.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("Missing or invalid 'scope' claim")
|
|
||||||
}
|
|
||||||
return scopeStr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the TTL of the token, either "session" or "exp"
|
|
||||||
func getTokenTTL(ttl interface{}) (string, error) {
|
|
||||||
ttlStr, ok := ttl.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("Missing or invalid 'ttl' claim")
|
|
||||||
}
|
|
||||||
if ttlStr != "exp" && ttlStr != "session" {
|
|
||||||
return "", errors.New("TTL value is not recognised")
|
|
||||||
}
|
|
||||||
return ttlStr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the time the token was issued at
|
|
||||||
func getIssuedTime(issued interface{}) (int64, error) {
|
|
||||||
// Same float64 -> int64 trick as expiry
|
|
||||||
issuedFloat, ok := issued.(float64)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("Missing or invalid 'iat' claim")
|
|
||||||
}
|
|
||||||
issuedAt := int64(issuedFloat)
|
|
||||||
return issuedAt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the freshness expiry timestamp
|
|
||||||
func getFreshTime(fresh interface{}) (int64, error) {
|
|
||||||
freshUntil, ok := fresh.(float64)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("Missing or invalid 'fresh' claim")
|
|
||||||
}
|
|
||||||
return int64(freshUntil), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the subject of the token
|
|
||||||
func getTokenSubject(sub interface{}) (int, error) {
|
|
||||||
subject, ok := sub.(float64)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("Missing or invalid 'sub' claim")
|
|
||||||
}
|
|
||||||
return int(subject), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the JTI of the token
|
|
||||||
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
|
|
||||||
jtiStr, ok := jti.(string)
|
|
||||||
if !ok {
|
|
||||||
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
|
|
||||||
}
|
|
||||||
jtiUUID, err := uuid.Parse(jtiStr)
|
|
||||||
if err != nil {
|
|
||||||
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
|
|
||||||
}
|
|
||||||
return jtiUUID, nil
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
package jwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Revoke a token by adding it to the database
|
|
||||||
func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error {
|
|
||||||
jti := t.GetJTI()
|
|
||||||
exp := t.GetEXP()
|
|
||||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
|
||||||
_, err := tx.Exec(ctx, query, jti, exp)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tx.Exec")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if a token has been revoked. Returns true if not revoked.
|
|
||||||
func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) {
|
|
||||||
jti := t.GetJTI()
|
|
||||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
|
||||||
rows, err := tx.Query(ctx, query, jti)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "tx.Query")
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
revoked := rows.Next()
|
|
||||||
return !revoked, nil
|
|
||||||
}
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
package jwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Token interface {
|
|
||||||
GetJTI() uuid.UUID
|
|
||||||
GetEXP() int64
|
|
||||||
GetScope() string
|
|
||||||
GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Access token
|
|
||||||
type AccessToken struct {
|
|
||||||
ISS string // Issuer, generally TrustedHost
|
|
||||||
IAT int64 // Time issued at
|
|
||||||
EXP int64 // Time expiring at
|
|
||||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
|
||||||
SUB int // Subject (user) ID
|
|
||||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
|
||||||
Fresh int64 // Time freshness expiring at
|
|
||||||
Scope string // Should be "access"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refresh token
|
|
||||||
type RefreshToken struct {
|
|
||||||
ISS string // Issuer, generally TrustedHost
|
|
||||||
IAT int64 // Time issued at
|
|
||||||
EXP int64 // Time expiring at
|
|
||||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
|
||||||
SUB int // Subject (user) ID
|
|
||||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
|
||||||
Scope string // Should be "refresh"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
|
||||||
user, err := db.GetUserFromID(ctx, tx, a.SUB)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
|
||||||
user, err := db.GetUserFromID(ctx, tx, r.SUB)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a AccessToken) GetJTI() uuid.UUID {
|
|
||||||
return a.JTI
|
|
||||||
}
|
|
||||||
func (r RefreshToken) GetJTI() uuid.UUID {
|
|
||||||
return r.JTI
|
|
||||||
}
|
|
||||||
func (a AccessToken) GetEXP() int64 {
|
|
||||||
return a.EXP
|
|
||||||
}
|
|
||||||
func (r RefreshToken) GetEXP() int64 {
|
|
||||||
return r.EXP
|
|
||||||
}
|
|
||||||
func (a AccessToken) GetScope() string {
|
|
||||||
return a.Scope
|
|
||||||
}
|
|
||||||
func (r RefreshToken) GetScope() string {
|
|
||||||
return r.Scope
|
|
||||||
}
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package logging
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/pkgerrors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Takes a log level as string and converts it to a zerolog.Level interface.
|
|
||||||
// If the string is not a valid input it will return zerolog.InfoLevel
|
|
||||||
func GetLogLevel(level string) zerolog.Level {
|
|
||||||
levels := map[string]zerolog.Level{
|
|
||||||
"trace": zerolog.TraceLevel,
|
|
||||||
"debug": zerolog.DebugLevel,
|
|
||||||
"info": zerolog.InfoLevel,
|
|
||||||
"warn": zerolog.WarnLevel,
|
|
||||||
"error": zerolog.ErrorLevel,
|
|
||||||
"fatal": zerolog.FatalLevel,
|
|
||||||
"panic": zerolog.PanicLevel,
|
|
||||||
}
|
|
||||||
logLevel, valid := levels[level]
|
|
||||||
if !valid {
|
|
||||||
return zerolog.InfoLevel
|
|
||||||
}
|
|
||||||
return logLevel
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a pointer to a new log file with the specified path.
|
|
||||||
// Remember to call file.Close() when finished writing to the log file
|
|
||||||
func GetLogFile(path string) (*os.File, error) {
|
|
||||||
logPath := filepath.Join(path, "server.log")
|
|
||||||
file, err := os.OpenFile(
|
|
||||||
logPath,
|
|
||||||
os.O_APPEND|os.O_CREATE|os.O_WRONLY,
|
|
||||||
0663,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "os.OpenFile")
|
|
||||||
}
|
|
||||||
return file, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get a pointer to a new zerolog.Logger with the specified level and output
|
|
||||||
// Can provide a file, writer or both. Must provide at least one of the two
|
|
||||||
func GetLogger(
|
|
||||||
logLevel zerolog.Level,
|
|
||||||
w io.Writer,
|
|
||||||
logFile *os.File,
|
|
||||||
logDir string,
|
|
||||||
) (*zerolog.Logger, error) {
|
|
||||||
if w == nil && logFile == nil {
|
|
||||||
return nil, errors.New("No Writer provided for log output.")
|
|
||||||
}
|
|
||||||
|
|
||||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
|
|
||||||
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
|
|
||||||
|
|
||||||
var consoleWriter zerolog.ConsoleWriter
|
|
||||||
if w != nil {
|
|
||||||
consoleWriter = zerolog.ConsoleWriter{Out: w}
|
|
||||||
}
|
|
||||||
|
|
||||||
var output io.Writer
|
|
||||||
if logFile != nil {
|
|
||||||
if w != nil {
|
|
||||||
output = zerolog.MultiLevelWriter(logFile, consoleWriter)
|
|
||||||
} else {
|
|
||||||
output = logFile
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
output = consoleWriter
|
|
||||||
}
|
|
||||||
logger := zerolog.New(output).
|
|
||||||
With().
|
|
||||||
Timestamp().
|
|
||||||
Logger().
|
|
||||||
Level(logLevel)
|
|
||||||
|
|
||||||
return &logger, nil
|
|
||||||
}
|
|
||||||
233
main.go
233
main.go
@@ -1,233 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"embed"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/fs"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/logging"
|
|
||||||
"projectreshoot/server"
|
|
||||||
"projectreshoot/tests"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed static/*
|
|
||||||
var embeddedStatic embed.FS
|
|
||||||
|
|
||||||
// Gets the static files
|
|
||||||
func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) {
|
|
||||||
if _, err := os.Stat("static"); err == nil {
|
|
||||||
// Use actual filesystem in development
|
|
||||||
logger.Debug().Msg("Using filesystem for static files")
|
|
||||||
return http.Dir("static"), nil
|
|
||||||
} else {
|
|
||||||
// Use embedded filesystem in production
|
|
||||||
logger.Debug().Msg("Using embedded static files")
|
|
||||||
subFS, err := fs.Sub(embeddedStatic, "static")
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "fs.Sub")
|
|
||||||
}
|
|
||||||
return http.FS(subFS), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var maint uint32 // atomic: 1 if in maintenance mode
|
|
||||||
|
|
||||||
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
|
|
||||||
func handleMaintSignals(
|
|
||||||
conn *db.SafeConn,
|
|
||||||
srv *http.Server,
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
config *config.Config,
|
|
||||||
) {
|
|
||||||
logger.Debug().Msg("Starting signal listener")
|
|
||||||
ch := make(chan os.Signal, 1)
|
|
||||||
srv.RegisterOnShutdown(func() {
|
|
||||||
logger.Debug().Msg("Shutting down signal listener")
|
|
||||||
close(ch)
|
|
||||||
})
|
|
||||||
go func() {
|
|
||||||
for sig := range ch {
|
|
||||||
switch sig {
|
|
||||||
case syscall.SIGUSR1:
|
|
||||||
if atomic.LoadUint32(&maint) != 1 {
|
|
||||||
atomic.StoreUint32(&maint, 1)
|
|
||||||
logger.Info().Msg("Signal received: Starting maintenance")
|
|
||||||
logger.Info().Msg("Attempting to acquire database lock")
|
|
||||||
conn.Pause(config.DBLockTimeout * time.Second)
|
|
||||||
}
|
|
||||||
case syscall.SIGUSR2:
|
|
||||||
if atomic.LoadUint32(&maint) != 0 {
|
|
||||||
logger.Info().Msg("Signal received: Maintenance over")
|
|
||||||
logger.Info().Msg("Releasing database lock")
|
|
||||||
conn.Resume()
|
|
||||||
atomic.StoreUint32(&maint, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initializes and runs the server
|
|
||||||
func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|
||||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
config, err := config.GetConfig(args)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "server.GetConfig")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the version of the database required
|
|
||||||
if args["dbver"] == "true" {
|
|
||||||
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var logfile *os.File = nil
|
|
||||||
if config.LogOutput == "both" || config.LogOutput == "file" {
|
|
||||||
logfile, err = logging.GetLogFile(config.LogDir)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "logging.GetLogFile")
|
|
||||||
}
|
|
||||||
defer logfile.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
var consoleWriter io.Writer
|
|
||||||
if config.LogOutput == "both" || config.LogOutput == "console" {
|
|
||||||
consoleWriter = w
|
|
||||||
}
|
|
||||||
|
|
||||||
logger, err := logging.GetLogger(
|
|
||||||
config.LogLevel,
|
|
||||||
consoleWriter,
|
|
||||||
logfile,
|
|
||||||
config.LogDir,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "logging.GetLogger")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug().Msg("Config loaded and logger started")
|
|
||||||
logger.Debug().Msg("Connecting to database")
|
|
||||||
var conn *db.SafeConn
|
|
||||||
if args["test"] == "true" {
|
|
||||||
logger.Debug().Msg("Server in test mode, using test database")
|
|
||||||
ver, err := strconv.ParseInt(config.DBName, 10, 0)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "strconv.ParseInt")
|
|
||||||
}
|
|
||||||
testconn, err := tests.SetupTestDB(ver)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tests.SetupTestDB")
|
|
||||||
}
|
|
||||||
conn = db.MakeSafe(testconn, logger)
|
|
||||||
} else {
|
|
||||||
conn, err = db.ConnectToDatabase(config.DBName, logger)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
logger.Debug().Msg("Getting static files")
|
|
||||||
staticFS, err := getStaticFiles(logger)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "getStaticFiles")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug().Msg("Setting up HTTP server")
|
|
||||||
srv := server.NewServer(config, logger, conn, &staticFS, &maint)
|
|
||||||
httpServer := &http.Server{
|
|
||||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
|
||||||
Handler: srv,
|
|
||||||
ReadHeaderTimeout: config.ReadHeaderTimeout * time.Second,
|
|
||||||
WriteTimeout: config.WriteTimeout * time.Second,
|
|
||||||
IdleTimeout: config.IdleTimeout * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Runs function for testing in dev if --test flag true
|
|
||||||
if args["tester"] == "true" {
|
|
||||||
logger.Debug().Msg("Running tester function")
|
|
||||||
test(config, logger, conn, httpServer)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setups a channel to listen for os.Signal
|
|
||||||
handleMaintSignals(conn, httpServer, logger, config)
|
|
||||||
|
|
||||||
// Runs the http server
|
|
||||||
logger.Debug().Msg("Starting up the HTTP server")
|
|
||||||
go func() {
|
|
||||||
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
|
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
||||||
logger.Error().Err(err).Msg("Error listening and serving")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Handles graceful shutdown
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
<-ctx.Done()
|
|
||||||
shutdownCtx := context.Background()
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
|
||||||
logger.Error().Err(err).Msg("Error shutting down server")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
wg.Wait()
|
|
||||||
logger.Info().Msg("Shutting down")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start of runtime. Parse commandline arguments & flags, Initializes context
|
|
||||||
// and starts the server
|
|
||||||
func main() {
|
|
||||||
// Parse commandline args
|
|
||||||
host := flag.String("host", "", "Override host to listen on")
|
|
||||||
port := flag.String("port", "", "Override port to listen on")
|
|
||||||
test := flag.Bool("test", false, "Run server in test mode")
|
|
||||||
tester := flag.Bool("tester", false, "Run tester function instead of main program")
|
|
||||||
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
|
||||||
loglevel := flag.String("loglevel", "", "Set log level")
|
|
||||||
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
// Map the args for easy access
|
|
||||||
args := map[string]string{
|
|
||||||
"host": *host,
|
|
||||||
"port": *port,
|
|
||||||
"test": strconv.FormatBool(*test),
|
|
||||||
"tester": strconv.FormatBool(*tester),
|
|
||||||
"dbver": strconv.FormatBool(*dbver),
|
|
||||||
"loglevel": *loglevel,
|
|
||||||
"logoutput": *logoutput,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the server
|
|
||||||
ctx := context.Background()
|
|
||||||
if err := run(ctx, os.Stdout, args); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
144
main_test.go
144
main_test.go
@@ -1,144 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_main(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
t.Cleanup(cancel)
|
|
||||||
args := map[string]string{"test": "true"}
|
|
||||||
var stdout bytes.Buffer
|
|
||||||
os.Setenv("SECRET_KEY", ".")
|
|
||||||
os.Setenv("HOST", "127.0.0.1")
|
|
||||||
os.Setenv("PORT", "3232")
|
|
||||||
runSrvErr := make(chan error)
|
|
||||||
go func() {
|
|
||||||
if err := run(ctx, &stdout, args); err != nil {
|
|
||||||
runSrvErr <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz")
|
|
||||||
if err != nil {
|
|
||||||
runSrvErr <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
runSrvErr <- nil
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case err := <-runSrvErr:
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error starting test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Log("Test server started")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
|
|
||||||
done := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
expected := "Global database lock acquired"
|
|
||||||
for {
|
|
||||||
if strings.Contains(stdout.String(), expected) {
|
|
||||||
done <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proc, err := os.FindProcess(os.Getpid())
|
|
||||||
require.NoError(t, err)
|
|
||||||
proc.Signal(syscall.SIGUSR1)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
t.Log("found")
|
|
||||||
case <-time.After(250 * time.Millisecond):
|
|
||||||
t.Errorf("Not found")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("SIGUSR2 releases database global lock", func(t *testing.T) {
|
|
||||||
done := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
expected := "Global database lock released"
|
|
||||||
for {
|
|
||||||
if strings.Contains(stdout.String(), expected) {
|
|
||||||
done <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proc, err := os.FindProcess(os.Getpid())
|
|
||||||
require.NoError(t, err)
|
|
||||||
proc.Signal(syscall.SIGUSR2)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
t.Log("found")
|
|
||||||
case <-time.After(250 * time.Millisecond):
|
|
||||||
t.Errorf("Not found")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForReady(
|
|
||||||
ctx context.Context,
|
|
||||||
timeout time.Duration,
|
|
||||||
endpoint string,
|
|
||||||
) error {
|
|
||||||
client := http.Client{}
|
|
||||||
startTime := time.Now()
|
|
||||||
for {
|
|
||||||
req, err := http.NewRequestWithContext(
|
|
||||||
ctx,
|
|
||||||
http.MethodGet,
|
|
||||||
endpoint,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Error making request: %s\n", err.Error())
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
fmt.Println("Endpoint is ready!")
|
|
||||||
resp.Body.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
default:
|
|
||||||
if time.Since(startTime) >= timeout {
|
|
||||||
return fmt.Errorf("timeout reached while waiting for endpoint")
|
|
||||||
}
|
|
||||||
// wait a little while between checks
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,143 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/cookies"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/handler"
|
|
||||||
"projectreshoot/jwt"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Attempt to use a valid refresh token to generate a new token pair
|
|
||||||
func refreshAuthTokens(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
req *http.Request,
|
|
||||||
ref *jwt.RefreshToken,
|
|
||||||
) (*db.User, error) {
|
|
||||||
user, err := ref.GetUser(ctx, tx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "ref.GetUser")
|
|
||||||
}
|
|
||||||
|
|
||||||
rememberMe := map[string]bool{
|
|
||||||
"session": false,
|
|
||||||
"exp": true,
|
|
||||||
}[ref.TTL]
|
|
||||||
|
|
||||||
// Set fresh to true because new tokens coming from refresh request
|
|
||||||
err = cookies.SetTokenCookies(w, req, config, user, false, rememberMe)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
|
|
||||||
}
|
|
||||||
// New tokens sent, revoke the used refresh token
|
|
||||||
err = jwt.RevokeToken(ctx, tx, ref)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
// Return the authorized user
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the cookies for token strings and attempt to authenticate them
|
|
||||||
func getAuthenticatedUser(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
) (*contexts.AuthenticatedUser, error) {
|
|
||||||
// Get token strings from cookies
|
|
||||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
|
||||||
// Attempt to parse the access token
|
|
||||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
|
||||||
if err != nil {
|
|
||||||
// Access token invalid, attempt to parse refresh token
|
|
||||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
|
||||||
}
|
|
||||||
// Refresh token valid, attempt to get a new token pair
|
|
||||||
user, err := refreshAuthTokens(config, ctx, tx, w, r, rT)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "refreshAuthTokens")
|
|
||||||
}
|
|
||||||
// New token pair sent, return the authorized user
|
|
||||||
authUser := contexts.AuthenticatedUser{
|
|
||||||
User: user,
|
|
||||||
Fresh: time.Now().Unix(),
|
|
||||||
}
|
|
||||||
return &authUser, nil
|
|
||||||
}
|
|
||||||
// Access token valid
|
|
||||||
user, err := aT.GetUser(ctx, tx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "aT.GetUser")
|
|
||||||
}
|
|
||||||
authUser := contexts.AuthenticatedUser{
|
|
||||||
User: user,
|
|
||||||
Fresh: aT.Fresh,
|
|
||||||
}
|
|
||||||
return &authUser, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to authenticate the user and add their account details
|
|
||||||
// to the request context
|
|
||||||
func Authentication(
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
config *config.Config,
|
|
||||||
conn *db.SafeConn,
|
|
||||||
next http.Handler,
|
|
||||||
maint *uint32,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/static/css/output.css" ||
|
|
||||||
r.URL.Path == "/static/favicon.ico" {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if atomic.LoadUint32(maint) == 1 {
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the transaction
|
|
||||||
tx, err := conn.Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
// Failed to start transaction, skip auth
|
|
||||||
logger.Warn().Err(err).
|
|
||||||
Msg("Skipping Auth - unable to start a transaction")
|
|
||||||
handler.ErrorPage(http.StatusServiceUnavailable, w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
// User auth failed, delete the cookies to avoid repeat requests
|
|
||||||
cookies.DeleteCookie(w, "access", "/")
|
|
||||||
cookies.DeleteCookie(w, "refresh", "/")
|
|
||||||
logger.Debug().
|
|
||||||
Str("remote_addr", r.RemoteAddr).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to authenticate user")
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
uctx := contexts.SetUser(r.Context(), user)
|
|
||||||
newReq := r.WithContext(uctx)
|
|
||||||
next.ServeHTTP(w, newReq)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/tests"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAuthenticationMiddleware(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := db.MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
// Handler to check outcome of Authentication middleware
|
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
if user == nil {
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
w.Write([]byte(strconv.Itoa(0)))
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte(strconv.Itoa(user.ID)))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
var maint uint32
|
|
||||||
atomic.StoreUint32(&maint, 0)
|
|
||||||
// Add the middleware and create the server
|
|
||||||
authHandler := Authentication(logger, cfg, sconn, testHandler, &maint)
|
|
||||||
require.NoError(t, err)
|
|
||||||
server := httptest.NewServer(authHandler)
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
tokens := getTokens()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
id int
|
|
||||||
accessToken string
|
|
||||||
refreshToken string
|
|
||||||
expectedCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Valid Access Token (Fresh)",
|
|
||||||
id: 1,
|
|
||||||
accessToken: tokens["accessFresh"],
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid Access Token (Unfresh)",
|
|
||||||
id: 1,
|
|
||||||
accessToken: tokens["accessUnfresh"],
|
|
||||||
refreshToken: tokens["refreshExpired"],
|
|
||||||
expectedCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid Refresh Token (Triggers Refresh)",
|
|
||||||
id: 1,
|
|
||||||
accessToken: tokens["accessExpired"],
|
|
||||||
refreshToken: tokens["refreshValid"],
|
|
||||||
expectedCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Both tokens expired",
|
|
||||||
accessToken: tokens["accessExpired"],
|
|
||||||
refreshToken: tokens["refreshExpired"],
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Access token revoked",
|
|
||||||
accessToken: tokens["accessRevoked"],
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Refresh token revoked",
|
|
||||||
accessToken: "",
|
|
||||||
refreshToken: tokens["refreshRevoked"],
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid Tokens",
|
|
||||||
accessToken: tokens["invalid"],
|
|
||||||
refreshToken: tokens["invalid"],
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No Tokens",
|
|
||||||
accessToken: "",
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
client := &http.Client{}
|
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
|
|
||||||
|
|
||||||
// Add cookies if provided
|
|
||||||
if tt.accessToken != "" {
|
|
||||||
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
|
|
||||||
}
|
|
||||||
if tt.refreshToken != "" {
|
|
||||||
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.expectedCode, resp.StatusCode)
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, strconv.Itoa(tt.id), string(body))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the tokens to test with
|
|
||||||
func getTokens() map[string]string {
|
|
||||||
tokens := map[string]string{
|
|
||||||
"accessFresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzIyMTAsImZyZXNoIjo0ODk1NjcyMjEwLCJpYXQiOjE3Mzk2NzIyMTAsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImE4Njk2YWM4LTg3OWMtNDdkNC1iZWM2LTRlY2Y4MTRiZThiZiIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.6nAquDY0JBLPdaJ9q_sMpKj1ISG4Vt2U05J57aoPue8",
|
|
||||||
"accessUnfresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJmcmVzaCI6MTczOTY3NTY3MSwiaWF0IjoxNzM5Njc1NjcxLCJpc3MiOiIxMjcuMC4wLjEiLCJqdGkiOiJjOGNhZmFjNy0yODkzLTQzNzMtOTI4ZS03MGUwODJkYmM2MGIiLCJzY29wZSI6ImFjY2VzcyIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.plWQVFwHlhXUYI5utS7ny1JfXjJSFrigkq-PnTHD5VY",
|
|
||||||
"accessExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImZyZXNoIjoxNzM5NjcyMjQ4LCJpYXQiOjE3Mzk2NzIyNDgsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjgxYzA1YzBjLTJhOGItNGQ2MC04Yzc4LWY2ZTQxODYxZDFmNCIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.iI1f17kKTuFDEMEYltJRIwRYgYQ-_nF9Wsn0KR6x77Q",
|
|
||||||
"refreshValid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImlhdCI6MTczOTY3MTkyMiwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTUxMTY3ZWEtNDA3OS00ZTczLTkzZDQtNTgwZDMzODRjZDU4Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.tvtqQ8Z4WrYWHHb0MaEPdsU2FT2KLRE1zHOv3ipoFyc",
|
|
||||||
"refreshExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImlhdCI6MTczOTY3MjI0OCwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTg5YTc5MTYtZGEzYi00YmJhLWI3ZDMtOWI1N2ViNjRhMmU0Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.rH_fytC7Duxo598xacu820pQKF9ELbG8674h_bK_c4I",
|
|
||||||
"accessRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImZyZXNoIjoxNzM5NjcxOTIyLCJpYXQiOjE3Mzk2NzE5MjIsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjBhNmIzMzhlLTkzMGEtNDNmZS04ZjcwLTFhNmRhZWQyNTZmYSIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.mZLuCp9amcm2_CqYvbHPlk86nfiuy_Or8TlntUCw4Qs",
|
|
||||||
"refreshRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJpYXQiOjE3Mzk2NzU2NzEsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImI3ZmE1MWRjLTg1MzItNDJlMS04NzU2LTVkMjViZmIyMDAzYSIsInNjb3BlIjoicmVmcmVzaCIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.5Q9yDZN5FubfCWHclUUZEkJPOUHcOEpVpgcUK-ameHo",
|
|
||||||
"invalid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0ODUxNDA5ODQsImlhdCI6MTQ4NTEzNzM4NCwiaXNzIjoiYWNtZS5jb20iLCJzdWIiOiIyOWFjMGMxOC0wYjRhLTQyY2YtODJmYy0wM2Q1NzAzMThhMWQiLCJhcHBsaWNhdGlvbklkIjoiNzkxMDM3MzQtOTdhYi00ZDFhLWFmMzctZTAwNmQwNWQyOTUyIiwicm9sZXMiOltdfQ.Mp0Pcwsz5VECK11Kf2ZZNF_SMKu5CgBeLN9ZOP04kZo",
|
|
||||||
}
|
|
||||||
return tokens
|
|
||||||
}
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"compress/gzip"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Gzip(next http.Handler, useGzip bool) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") ||
|
|
||||||
!useGzip {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Encoding", "gzip")
|
|
||||||
gz := gzip.NewWriter(w)
|
|
||||||
defer gz.Close()
|
|
||||||
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
|
|
||||||
next.ServeHTTP(gzw, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type gzipResponseWriter struct {
|
|
||||||
io.Writer
|
|
||||||
http.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w gzipResponseWriter) Write(b []byte) (int, error) {
|
|
||||||
return w.Writer.Write(b)
|
|
||||||
}
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/handler"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Wraps the http.ResponseWriter, adding a statusCode field
|
|
||||||
type wrappedWriter struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
statusCode int
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extends WriteHeader to the ResponseWriter to add the status code
|
|
||||||
func (w *wrappedWriter) WriteHeader(statusCode int) {
|
|
||||||
w.ResponseWriter.WriteHeader(statusCode)
|
|
||||||
w.statusCode = statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
// Middleware to add logs to console with details of the request
|
|
||||||
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/static/css/output.css" ||
|
|
||||||
r.URL.Path == "/static/favicon.ico" {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
start, err := contexts.GetStartTime(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
handler.ErrorPage(http.StatusInternalServerError, w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wrapped := &wrappedWriter{
|
|
||||||
ResponseWriter: w,
|
|
||||||
statusCode: http.StatusOK,
|
|
||||||
}
|
|
||||||
next.ServeHTTP(wrapped, r)
|
|
||||||
logger.Info().
|
|
||||||
Int("status", wrapped.statusCode).
|
|
||||||
Str("method", r.Method).
|
|
||||||
Str("resource", r.URL.Path).
|
|
||||||
Dur("time_elapsed", time.Since(start)).
|
|
||||||
Str("remote_addr", r.Header.Get("X-Forwarded-For")).
|
|
||||||
Msg("Served")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/handler"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Checks if the user is set in the context and shows 401 page if not logged in
|
|
||||||
func LoginReq(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
if user == nil {
|
|
||||||
handler.ErrorPage(http.StatusUnauthorized, w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if the user is set in the context and redirects them to profile if
|
|
||||||
// they are logged in
|
|
||||||
func LogoutReq(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
if user != nil {
|
|
||||||
http.Redirect(w, r, "/profile", http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user