Compare commits
26 Commits
master
...
28b7ba34f0
| Author | SHA1 | Date | |
|---|---|---|---|
| 28b7ba34f0 | |||
| 4a21ba3821 | |||
| 1bcdf0e813 | |||
| 6dd80ee7b6 | |||
| 8f6b4b0026 | |||
| 03095448d6 | |||
| 1d9af44d0a | |||
| 5c1089e0ce | |||
| e3d2eb1af8 | |||
| 9e12f946b3 | |||
| 141b541e98 | |||
| 540782e2d5 | |||
| 1d5c662bf0 | |||
| b6e0a977c0 | |||
| 3db77eca71 | |||
| 8fcec675e6 | |||
| aa47802f46 | |||
| 05849d028d | |||
| f7f610d7ef | |||
| e2d66fc26d | |||
| 8fa20e05c0 | |||
| a3e9ffb012 | |||
| 838d6264c9 | |||
| e794024786 | |||
| 725038009a | |||
| d8d2307859 |
@@ -5,7 +5,7 @@ tmp_dir = "tmp"
|
|||||||
[build]
|
[build]
|
||||||
args_bin = []
|
args_bin = []
|
||||||
bin = "./tmp/main"
|
bin = "./tmp/main"
|
||||||
cmd = "go build -o ./tmp/main ."
|
cmd = "go build -o ./tmp/main ./cmd/projectreshoot"
|
||||||
delay = 1000
|
delay = 1000
|
||||||
exclude_dir = []
|
exclude_dir = []
|
||||||
exclude_file = []
|
exclude_file = []
|
||||||
|
|||||||
4
.github/workflows/deploy_production.yaml
vendored
4
.github/workflows/deploy_production.yaml
vendored
@@ -53,10 +53,10 @@ jobs:
|
|||||||
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
||||||
scp -i ~/.ssh/id_ed25519 projectreshoot-production-${GITHUB_SHA} $USER@$HOST:$DIR
|
scp -i ~/.ssh/id_ed25519 ./bin/projectreshoot-production-${GITHUB_SHA} $USER@$HOST:$DIR
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
||||||
scp -i ~/.ssh/id_ed25519 prmigrate-production-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 .bin/migrate-production-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
||||||
|
|
||||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
||||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
||||||
|
|||||||
4
.github/workflows/deploy_staging.yaml
vendored
4
.github/workflows/deploy_staging.yaml
vendored
@@ -53,10 +53,10 @@ jobs:
|
|||||||
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
||||||
scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
|
scp -i ~/.ssh/id_ed25519 ./bin/projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
|
||||||
|
|
||||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
||||||
scp -i ~/.ssh/id_ed25519 prmigrate-staging-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./bin/migrate-staging-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
||||||
|
|
||||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
||||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
||||||
|
|||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,11 +1,9 @@
|
|||||||
.env
|
.env
|
||||||
query.sql
|
*.db*
|
||||||
*.db
|
|
||||||
.logs/
|
.logs/
|
||||||
server.log
|
server.log
|
||||||
|
bin/
|
||||||
tmp/
|
tmp/
|
||||||
prmigrate
|
|
||||||
projectreshoot
|
|
||||||
static/css/output.css
|
static/css/output.css
|
||||||
view/**/*_templ.go
|
internal/view/**/*_templ.go
|
||||||
view/**/*_templ.txt
|
internal/view/**/*_templ.txt
|
||||||
|
|||||||
28
Makefile
28
Makefile
@@ -5,33 +5,25 @@
|
|||||||
BINARY_NAME=projectreshoot
|
BINARY_NAME=projectreshoot
|
||||||
|
|
||||||
build:
|
build:
|
||||||
tailwindcss -i ./static/css/input.css -o ./static/css/output.css && \
|
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css && \
|
||||||
go mod tidy && \
|
go mod tidy && \
|
||||||
templ generate && \
|
templ generate && \
|
||||||
go generate && \
|
go generate ./cmd/${BINARY_NAME} && \
|
||||||
go build -ldflags="-w -s" -o ${BINARY_NAME}${SUFFIX}
|
go build -ldflags="-w -s" -o ./bin/${BINARY_NAME}${SUFFIX} ./cmd/${BINARY_NAME}
|
||||||
|
|
||||||
|
run:
|
||||||
|
make build
|
||||||
|
./bin/${BINARY_NAME}${SUFFIX}
|
||||||
|
|
||||||
dev:
|
dev:
|
||||||
templ generate --watch &\
|
templ generate --watch &\
|
||||||
air &\
|
air &\
|
||||||
tailwindcss -i ./static/css/input.css -o ./static/css/output.css --watch
|
tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css --watch
|
||||||
|
|
||||||
tester:
|
|
||||||
go mod tidy && \
|
|
||||||
go run . --port 3232 --tester --loglevel trace
|
|
||||||
|
|
||||||
test:
|
|
||||||
go mod tidy && \
|
|
||||||
templ generate && \
|
|
||||||
go generate && \
|
|
||||||
go test .
|
|
||||||
go test ./db
|
|
||||||
go test ./middleware
|
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
go clean
|
go clean
|
||||||
|
|
||||||
migrate:
|
migrate:
|
||||||
go mod tidy && \
|
go mod tidy && \
|
||||||
go generate && \
|
go generate ./cmd/migrate && \
|
||||||
go build -ldflags="-w -s" -o prmigrate${SUFFIX} ./migrate
|
go build -ldflags="-w -s" -o ./bin/migrate${SUFFIX} ./cmd/migrate
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ var migrationsFS embed.FS
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) != 4 {
|
if len(os.Args) != 4 {
|
||||||
fmt.Println("Usage: prmigrate <file_path> up-to|down-to <version>")
|
fmt.Println("Usage: migrate <file_path> up-to|down-to <version>")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
52
cmd/projectreshoot/auth.go
Normal file
52
cmd/projectreshoot/auth.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/handler"
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
"projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupAuth(
|
||||||
|
config *config.Config,
|
||||||
|
logger *hlog.Logger,
|
||||||
|
conn *sql.DB,
|
||||||
|
server *hws.Server,
|
||||||
|
ignoredPaths []string,
|
||||||
|
) (*hwsauth.Authenticator[*models.User], error) {
|
||||||
|
auth, err := hwsauth.NewAuthenticator(
|
||||||
|
models.GetUserFromID,
|
||||||
|
server,
|
||||||
|
conn,
|
||||||
|
logger,
|
||||||
|
handler.NewErrorPage,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.SSL = config.SSL
|
||||||
|
auth.AccessTokenExpiry = config.AccessTokenExpiry
|
||||||
|
auth.RefreshTokenExpiry = config.RefreshTokenExpiry
|
||||||
|
auth.TokenFreshTime = config.TokenFreshTime
|
||||||
|
auth.TrustedHost = config.TrustedHost
|
||||||
|
auth.SecretKey = config.SecretKey
|
||||||
|
auth.LandingPage = "/profile"
|
||||||
|
|
||||||
|
auth.IgnorePaths(ignoredPaths...)
|
||||||
|
|
||||||
|
err = auth.Initialise()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "auth.Initialise")
|
||||||
|
}
|
||||||
|
|
||||||
|
contexts.CurrentUser = auth.CurrentModel
|
||||||
|
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package db
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
@@ -6,40 +6,35 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
|
||||||
|
|
||||||
_ "modernc.org/sqlite"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Returns a database connection handle for the DB
|
func setupDBConn(dbName string) (*sql.DB, error) {
|
||||||
func ConnectToDatabase(
|
opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
|
||||||
dbName string,
|
file := fmt.Sprintf("file:%s.db?%s", dbName, opts)
|
||||||
logger *zerolog.Logger,
|
conn, err := sql.Open("sqlite3", file)
|
||||||
) (*SafeConn, error) {
|
|
||||||
file := fmt.Sprintf("file:%s.db", dbName)
|
|
||||||
db, err := sql.Open("sqlite", file)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "sql.Open")
|
return nil, errors.Wrap(err, "sql.Open")
|
||||||
}
|
}
|
||||||
version, err := strconv.Atoi(dbName)
|
err = checkDBVersion(conn, dbName)
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "strconv.Atoi")
|
|
||||||
}
|
|
||||||
err = checkDBVersion(db, version)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "checkDBVersion")
|
return nil, errors.Wrap(err, "checkDBVersion")
|
||||||
}
|
}
|
||||||
conn := MakeSafe(db, logger)
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the database version
|
// Check the database version
|
||||||
func checkDBVersion(db *sql.DB, expectVer int) error {
|
func checkDBVersion(db *sql.DB, dbName string) error {
|
||||||
|
expectVer, err := strconv.Atoi(dbName)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "strconv.Atoi")
|
||||||
|
}
|
||||||
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
|
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
|
||||||
ORDER BY version_id DESC LIMIT 1`
|
ORDER BY version_id DESC LIMIT 1`
|
||||||
rows, err := db.Query(query)
|
rows, err := db.Query(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "checkDBVersion")
|
return errors.Wrap(err, "db.Query")
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
28
cmd/projectreshoot/flags.go
Normal file
28
cmd/projectreshoot/flags.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupFlags() map[string]string {
|
||||||
|
// Parse commandline args
|
||||||
|
host := flag.String("host", "", "Override host to listen on")
|
||||||
|
port := flag.String("port", "", "Override port to listen on")
|
||||||
|
test := flag.Bool("test", false, "Run server in test mode")
|
||||||
|
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
||||||
|
loglevel := flag.String("loglevel", "", "Set log level")
|
||||||
|
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Map the args for easy access
|
||||||
|
args := map[string]string{
|
||||||
|
"host": *host,
|
||||||
|
"port": *port,
|
||||||
|
"test": strconv.FormatBool(*test),
|
||||||
|
"dbver": strconv.FormatBool(*dbver),
|
||||||
|
"loglevel": *loglevel,
|
||||||
|
"logoutput": *logoutput,
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
69
cmd/projectreshoot/httpserver.go
Normal file
69
cmd/projectreshoot/httpserver.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupHttpServer(
|
||||||
|
staticFS *fs.FS,
|
||||||
|
config *config.Config,
|
||||||
|
logger *hlog.Logger,
|
||||||
|
conn *sql.DB,
|
||||||
|
tokenGen *jwt.TokenGenerator,
|
||||||
|
) (server *hws.Server, err error) {
|
||||||
|
if staticFS == nil {
|
||||||
|
return nil, errors.New("No filesystem provided")
|
||||||
|
}
|
||||||
|
fs := http.FS(*staticFS)
|
||||||
|
httpServer, err := hws.NewServer(
|
||||||
|
config.Host,
|
||||||
|
config.Port,
|
||||||
|
config.ReadHeaderTimeout,
|
||||||
|
config.WriteTimeout,
|
||||||
|
config.IdleTimeout,
|
||||||
|
config.GZIP,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hws.NewServer")
|
||||||
|
}
|
||||||
|
|
||||||
|
ignoredPaths := []string{
|
||||||
|
"/static/css/output.css",
|
||||||
|
"/static/favicon.ico",
|
||||||
|
}
|
||||||
|
|
||||||
|
auth, err := setupAuth(config, logger, conn, httpServer, ignoredPaths)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "setupAuth")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = httpServer.AddLogger(logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "httpServer.AddLogger")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = httpServer.LoggerIgnorePaths(ignoredPaths...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addRoutes(httpServer, &fs, config, logger, conn, tokenGen, auth)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "addRoutes")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addMiddleware(httpServer, auth)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "httpServer.AddMiddleware")
|
||||||
|
}
|
||||||
|
|
||||||
|
return httpServer, nil
|
||||||
|
}
|
||||||
46
cmd/projectreshoot/logger.go
Normal file
46
cmd/projectreshoot/logger.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Take in the desired logOutput and a console writer to use
|
||||||
|
func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirectory string) (*hlog.Logger, error) {
|
||||||
|
// Setup the logfile
|
||||||
|
var logfile *os.File = nil
|
||||||
|
if logOutput == "both" || logOutput == "file" {
|
||||||
|
logfile, err := hlog.NewLogFile(logDirectory)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hlog")
|
||||||
|
}
|
||||||
|
defer logfile.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the console writer
|
||||||
|
var consoleWriter io.Writer
|
||||||
|
if logOutput == "both" || logOutput == "console" {
|
||||||
|
if w != nil {
|
||||||
|
consoleWriter = *w
|
||||||
|
} else {
|
||||||
|
if logOutput == "console" {
|
||||||
|
return nil, errors.New("Console logging specified as sole method but no writer provided")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the logger
|
||||||
|
logger, err := hlog.NewLogger(
|
||||||
|
logLevel,
|
||||||
|
consoleWriter,
|
||||||
|
logfile,
|
||||||
|
logDirectory,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hlog")
|
||||||
|
}
|
||||||
|
return logger, nil
|
||||||
|
}
|
||||||
26
cmd/projectreshoot/main.go
Normal file
26
cmd/projectreshoot/main.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
args := setupFlags()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
config, err := config.GetConfig(args)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config"))
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := run(ctx, os.Stdout, args, config); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
23
cmd/projectreshoot/middleware.go
Normal file
23
cmd/projectreshoot/middleware.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func addMiddleware(
|
||||||
|
server *hws.Server,
|
||||||
|
auth *hwsauth.Authenticator[*models.User],
|
||||||
|
) error {
|
||||||
|
|
||||||
|
err := server.AddMiddleware(
|
||||||
|
auth.Authenticate(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "server.AddMiddleware")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
127
cmd/projectreshoot/routes.go
Normal file
127
cmd/projectreshoot/routes.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/handler"
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
"projectreshoot/internal/view/page"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func addRoutes(
|
||||||
|
server *hws.Server,
|
||||||
|
staticFS *http.FileSystem,
|
||||||
|
config *config.Config,
|
||||||
|
logger *hlog.Logger,
|
||||||
|
conn *sql.DB,
|
||||||
|
tokenGen *jwt.TokenGenerator,
|
||||||
|
auth *hwsauth.Authenticator[*models.User],
|
||||||
|
) error {
|
||||||
|
// Create the routes
|
||||||
|
routes := []hws.Route{
|
||||||
|
{
|
||||||
|
Path: "/static/",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: http.StripPrefix("/static/", handler.StaticFS(staticFS, logger)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: handler.Root(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/about",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: handler.HandlePage(page.About()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/login",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LogoutReq(handler.LoginPage(config.TrustedHost)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/login",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LogoutReq(handler.LoginRequest(server, auth, conn)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/register",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LogoutReq(handler.RegisterPage(config.TrustedHost)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/register",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LogoutReq(handler.RegisterRequest(config, logger, conn, tokenGen)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/logout",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: handler.Logout(server, auth, conn),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/reauthenticate",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(handler.Reauthenticate(server, auth, conn)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/profile",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LoginReq(handler.ProfilePage()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/account",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: auth.LoginReq(handler.AccountPage()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/account-select-page",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(handler.AccountSubpage()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/change-username",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(auth.FreshReq(handler.ChangeUsername(server, auth, conn))),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/change-password",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(auth.FreshReq(handler.ChangePassword(server, auth, conn))),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/change-bio",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: auth.LoginReq(handler.ChangeBio(server, auth, conn)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/movies",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: handler.MoviesPage(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/search-movies",
|
||||||
|
Method: hws.MethodPOST,
|
||||||
|
Handler: handler.SearchMovies(config, logger),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/movie/{movie_id}",
|
||||||
|
Method: hws.MethodGET,
|
||||||
|
Handler: handler.Movie(config, logger),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the routes with the server
|
||||||
|
err := server.AddRoutes(routes...)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "server.AddRoutes")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
87
cmd/projectreshoot/run.go
Normal file
87
cmd/projectreshoot/run.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/pkg/embedfs"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initializes and runs the server
|
||||||
|
func run(ctx context.Context, w io.Writer, args map[string]string, config *config.Config) error {
|
||||||
|
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Return the version of the database required
|
||||||
|
if args["dbver"] == "true" {
|
||||||
|
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger, err := setupLogger(config.LogLevel, config.LogOutput, &w, config.LogDir)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "setupLogger")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the database connection
|
||||||
|
logger.Debug().Msg("Config loaded and logger started")
|
||||||
|
logger.Debug().Msg("Connecting to database")
|
||||||
|
conn, err := setupDBConn(config.DBName)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "setupDBConn")
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Setup embedded files
|
||||||
|
logger.Debug().Msg("Getting embedded files")
|
||||||
|
staticFS, err := embedfs.GetEmbeddedFS()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "getStaticFiles")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup TokenGenerator
|
||||||
|
logger.Debug().Msg("Creating TokenGenerator")
|
||||||
|
tokenGen, err := jwt.CreateGenerator(
|
||||||
|
config.AccessTokenExpiry,
|
||||||
|
config.RefreshTokenExpiry,
|
||||||
|
config.TokenFreshTime,
|
||||||
|
config.TrustedHost,
|
||||||
|
config.SecretKey,
|
||||||
|
conn,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.Debug().Msg("Setting up HTTP server")
|
||||||
|
httpServer, err := setupHttpServer(&staticFS, config, logger, conn, tokenGen)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "setupHttpServer")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runs the http server
|
||||||
|
logger.Debug().Msg("Starting up the HTTP server")
|
||||||
|
err = httpServer.Start()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "httpServer.Start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles graceful shutdown
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Go(func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
shutdownCtx := context.Background()
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
httpServer.Shutdown(shutdownCtx)
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
logger.Info().Msg("Shutting down")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Get an environment variable, specifying a default value if its not set
|
|
||||||
func GetEnvDefault(key string, defaultValue string) string {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get an environment variable as a time.Duration, specifying a default value if its
|
|
||||||
// not set or can't be parsed properly
|
|
||||||
func GetEnvDur(key string, defaultValue time.Duration) time.Duration {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return time.Duration(defaultValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
intVal, err := strconv.Atoi(val)
|
|
||||||
if err != nil {
|
|
||||||
return time.Duration(defaultValue)
|
|
||||||
}
|
|
||||||
return time.Duration(intVal)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get an environment variable as an int, specifying a default value if its
|
|
||||||
// not set or can't be parsed properly into an int
|
|
||||||
func GetEnvInt(key string, defaultValue int) int {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
|
|
||||||
intVal, err := strconv.Atoi(val)
|
|
||||||
if err != nil {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return intVal
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get an environment variable as an int64, specifying a default value if its
|
|
||||||
// not set or can't be parsed properly into an int64
|
|
||||||
func GetEnvInt64(key string, defaultValue int64) int64 {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
|
|
||||||
intVal, err := strconv.ParseInt(val, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return intVal
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get an environment variable as a boolean, specifying a default value if its
|
|
||||||
// not set or can't be parsed properly into a bool
|
|
||||||
func GetEnvBool(key string, defaultValue bool) bool {
|
|
||||||
val, exists := os.LookupEnv(key)
|
|
||||||
if !exists {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
truthy := map[string]bool{
|
|
||||||
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
|
|
||||||
"enable": true, "enabled": true, "active": true, "affirmative": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
falsy := map[string]bool{
|
|
||||||
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
|
|
||||||
"disable": false, "disabled": false, "inactive": false, "negative": false,
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized := strings.TrimSpace(strings.ToLower(val))
|
|
||||||
|
|
||||||
if val, ok := truthy[normalized]; ok {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
if val, ok := falsy[normalized]; ok {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
package contexts
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"projectreshoot/db"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthenticatedUser struct {
|
|
||||||
*db.User
|
|
||||||
Fresh int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a new context with the user added in
|
|
||||||
func SetUser(ctx context.Context, u *AuthenticatedUser) context.Context {
|
|
||||||
return context.WithValue(ctx, contextKeyAuthorizedUser, u)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve a user from the given context. Returns nil if not set
|
|
||||||
func GetUser(ctx context.Context) *AuthenticatedUser {
|
|
||||||
user, ok := ctx.Value(contextKeyAuthorizedUser).(*AuthenticatedUser)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return user
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package cookies
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Tell the browser to delete the cookie matching the name provided
|
|
||||||
// Path must match the original set cookie for it to delete
|
|
||||||
func DeleteCookie(w http.ResponseWriter, name string, path string) {
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: "",
|
|
||||||
Path: path,
|
|
||||||
Expires: time.Unix(0, 0), // Expire in the past
|
|
||||||
MaxAge: -1, // Immediately expire
|
|
||||||
HttpOnly: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set a cookie with the given name, path and value. maxAge directly relates
|
|
||||||
// to cookie MaxAge (0 for no max age, >0 for TTL in seconds)
|
|
||||||
func SetCookie(
|
|
||||||
w http.ResponseWriter,
|
|
||||||
name string,
|
|
||||||
path string,
|
|
||||||
value string,
|
|
||||||
maxAge int,
|
|
||||||
) {
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: value,
|
|
||||||
Path: path,
|
|
||||||
HttpOnly: true,
|
|
||||||
MaxAge: maxAge,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
package cookies
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
|
|
||||||
func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
|
|
||||||
pageFromCookie, err := r.Cookie("pagefrom")
|
|
||||||
if err != nil {
|
|
||||||
return "/"
|
|
||||||
}
|
|
||||||
pageFrom := pageFromCookie.Value
|
|
||||||
DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
|
|
||||||
return pageFrom
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the referer of the request, and if it matches the trustedHost, set
|
|
||||||
// the "pagefrom" cookie as the Path of the referer
|
|
||||||
func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
|
|
||||||
referer := r.Referer()
|
|
||||||
parsedURL, err := url.Parse(referer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var pageFrom string
|
|
||||||
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
|
|
||||||
pageFrom = "/"
|
|
||||||
} else if parsedURL.Path == "/login" || parsedURL.Path == "/register" {
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
pageFrom = parsedURL.Path
|
|
||||||
}
|
|
||||||
SetCookie(w, "pagefrom", "/", pageFrom, 0)
|
|
||||||
}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
package cookies
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/jwt"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Get the value of the access and refresh tokens
|
|
||||||
func GetTokenStrings(
|
|
||||||
r *http.Request,
|
|
||||||
) (acc string, ref string) {
|
|
||||||
accCookie, accErr := r.Cookie("access")
|
|
||||||
refCookie, refErr := r.Cookie("refresh")
|
|
||||||
var (
|
|
||||||
accStr string = ""
|
|
||||||
refStr string = ""
|
|
||||||
)
|
|
||||||
if accErr == nil {
|
|
||||||
accStr = accCookie.Value
|
|
||||||
}
|
|
||||||
if refErr == nil {
|
|
||||||
refStr = refCookie.Value
|
|
||||||
}
|
|
||||||
return accStr, refStr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set a token with the provided details
|
|
||||||
func setToken(
|
|
||||||
w http.ResponseWriter,
|
|
||||||
config *config.Config,
|
|
||||||
token string,
|
|
||||||
scope string,
|
|
||||||
exp int64,
|
|
||||||
rememberme bool,
|
|
||||||
) {
|
|
||||||
tokenCookie := &http.Cookie{
|
|
||||||
Name: scope,
|
|
||||||
Value: token,
|
|
||||||
Path: "/",
|
|
||||||
HttpOnly: true,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
Secure: config.SSL,
|
|
||||||
}
|
|
||||||
if rememberme {
|
|
||||||
tokenCookie.Expires = time.Unix(exp, 0)
|
|
||||||
}
|
|
||||||
http.SetCookie(w, tokenCookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate new tokens for the user and set them as cookies
|
|
||||||
func SetTokenCookies(
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
config *config.Config,
|
|
||||||
user *db.User,
|
|
||||||
fresh bool,
|
|
||||||
rememberMe bool,
|
|
||||||
) error {
|
|
||||||
at, atexp, err := jwt.GenerateAccessToken(config, user, fresh, rememberMe)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.GenerateAccessToken")
|
|
||||||
}
|
|
||||||
rt, rtexp, err := jwt.GenerateRefreshToken(config, user, rememberMe)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.GenerateRefreshToken")
|
|
||||||
}
|
|
||||||
// Don't set the cookies until we know no errors occured
|
|
||||||
setToken(w, config, at, "access", atexp, rememberMe)
|
|
||||||
setToken(w, config, rt, "refresh", rtexp, rememberMe)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
129
db/safeconn.go
129
db/safeconn.go
@@ -1,129 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SafeConn struct {
|
|
||||||
db *sql.DB
|
|
||||||
readLockCount uint32
|
|
||||||
globalLockStatus uint32
|
|
||||||
globalLockRequested uint32
|
|
||||||
logger *zerolog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the provided db handle safe and attach a logger to it
|
|
||||||
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
|
||||||
return &SafeConn{db: db, logger: logger}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempts to acquire a global lock on the database connection
|
|
||||||
func (conn *SafeConn) acquireGlobalLock() bool {
|
|
||||||
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
conn.globalLockStatus = 1
|
|
||||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
|
||||||
Msg("Global lock acquired")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Releases a global lock on the database connection
|
|
||||||
func (conn *SafeConn) releaseGlobalLock() {
|
|
||||||
conn.globalLockStatus = 0
|
|
||||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
|
||||||
Msg("Global lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Acquire a read lock on the connection. Multiple read locks can be acquired
|
|
||||||
// at the same time
|
|
||||||
func (conn *SafeConn) acquireReadLock() bool {
|
|
||||||
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
conn.readLockCount += 1
|
|
||||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
|
||||||
Msg("Read lock acquired")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release a read lock. Decrements read lock count by 1
|
|
||||||
func (conn *SafeConn) releaseReadLock() {
|
|
||||||
conn.readLockCount -= 1
|
|
||||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
|
||||||
Msg("Read lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Starts a new transaction based on the current context. Will cancel if
|
|
||||||
// the context is closed/cancelled/done
|
|
||||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
|
||||||
lockAcquired := make(chan struct{})
|
|
||||||
lockCtx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-lockCtx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
if conn.acquireReadLock() {
|
|
||||||
close(lockAcquired)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-lockAcquired:
|
|
||||||
tx, err := conn.db.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
conn.releaseReadLock()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &SafeTX{tx: tx, sc: conn}, nil
|
|
||||||
case <-ctx.Done():
|
|
||||||
cancel()
|
|
||||||
return nil, errors.New("Transaction time out due to database lock")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Acquire a global lock, preventing all transactions
|
|
||||||
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
|
||||||
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
|
||||||
conn.globalLockRequested = 1
|
|
||||||
defer func() { conn.globalLockRequested = 0 }()
|
|
||||||
timeout := time.After(timeoutAfter)
|
|
||||||
attempt := 0
|
|
||||||
for {
|
|
||||||
if conn.acquireGlobalLock() {
|
|
||||||
conn.logger.Info().Msg("Global database lock acquired")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
|
||||||
return
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
attempt++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release the global lock
|
|
||||||
func (conn *SafeConn) Resume() {
|
|
||||||
conn.releaseGlobalLock()
|
|
||||||
conn.logger.Info().Msg("Global database lock released")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the database connection
|
|
||||||
func (conn *SafeConn) Close() error {
|
|
||||||
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
|
||||||
conn.acquireGlobalLock()
|
|
||||||
defer conn.releaseGlobalLock()
|
|
||||||
conn.logger.Debug().Msg("Closing database connection")
|
|
||||||
return conn.db.Close()
|
|
||||||
}
|
|
||||||
@@ -1,143 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"projectreshoot/tests"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSafeConn(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
var requested sync.WaitGroup
|
|
||||||
var engaged sync.WaitGroup
|
|
||||||
requested.Add(1)
|
|
||||||
engaged.Add(1)
|
|
||||||
go func() {
|
|
||||||
requested.Done()
|
|
||||||
sconn.Pause(5 * time.Second)
|
|
||||||
engaged.Done()
|
|
||||||
}()
|
|
||||||
requested.Wait()
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
|
||||||
tx.Commit()
|
|
||||||
engaged.Wait()
|
|
||||||
assert.Equal(t, uint32(1), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
|
||||||
sconn.Resume()
|
|
||||||
})
|
|
||||||
t.Run("Lock abandons after timeout", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn.Pause(250 * time.Millisecond)
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
|
||||||
tx.Commit()
|
|
||||||
})
|
|
||||||
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
var requested sync.WaitGroup
|
|
||||||
var engaged sync.WaitGroup
|
|
||||||
requested.Add(1)
|
|
||||||
engaged.Add(1)
|
|
||||||
go func() {
|
|
||||||
requested.Done()
|
|
||||||
sconn.Pause(5 * time.Second)
|
|
||||||
engaged.Done()
|
|
||||||
}()
|
|
||||||
requested.Wait()
|
|
||||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
|
||||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
_, err = sconn.Begin(ctx)
|
|
||||||
require.Error(t, err)
|
|
||||||
tx.Commit()
|
|
||||||
engaged.Wait()
|
|
||||||
_, err = sconn.Begin(ctx)
|
|
||||||
require.Error(t, err)
|
|
||||||
sconn.Resume()
|
|
||||||
tx, err = sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx.Commit()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
func TestSafeTX(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
t.Run("Commit releases lock", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
|
||||||
tx.Commit()
|
|
||||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
|
||||||
})
|
|
||||||
t.Run("Rollback releases lock", func(t *testing.T) {
|
|
||||||
tx, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
|
||||||
tx.Rollback()
|
|
||||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
|
||||||
})
|
|
||||||
t.Run("Multiple TX can gain read lock", func(t *testing.T) {
|
|
||||||
tx1, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx2, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx3, err := sconn.Begin(t.Context())
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx1.Commit()
|
|
||||||
tx2.Commit()
|
|
||||||
tx3.Commit()
|
|
||||||
})
|
|
||||||
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
sconn.acquireGlobalLock()
|
|
||||||
defer sconn.releaseGlobalLock()
|
|
||||||
_, err := sconn.Begin(ctx)
|
|
||||||
require.Error(t, err)
|
|
||||||
})
|
|
||||||
t.Run("Lock acquires if lock released", func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
sconn.acquireGlobalLock()
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
tx, err := sconn.Begin(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
tx.Commit()
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
sconn.releaseGlobalLock()
|
|
||||||
wg.Wait()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
61
db/safetx.go
61
db/safetx.go
@@ -1,61 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Extends sql.Tx for use with SafeConn
|
|
||||||
type SafeTX struct {
|
|
||||||
tx *sql.Tx
|
|
||||||
sc *SafeConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query the database inside the transaction
|
|
||||||
func (stx *SafeTX) Query(
|
|
||||||
ctx context.Context,
|
|
||||||
query string,
|
|
||||||
args ...interface{},
|
|
||||||
) (*sql.Rows, error) {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return nil, errors.New("Cannot query without a transaction")
|
|
||||||
}
|
|
||||||
return stx.tx.QueryContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec a statement on the database inside the transaction
|
|
||||||
func (stx *SafeTX) Exec(
|
|
||||||
ctx context.Context,
|
|
||||||
query string,
|
|
||||||
args ...interface{},
|
|
||||||
) (sql.Result, error) {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return nil, errors.New("Cannot exec without a transaction")
|
|
||||||
}
|
|
||||||
return stx.tx.ExecContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commit the current transaction and release the read lock
|
|
||||||
func (stx *SafeTX) Commit() error {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return errors.New("Cannot commit without a transaction")
|
|
||||||
}
|
|
||||||
err := stx.tx.Commit()
|
|
||||||
stx.tx = nil
|
|
||||||
|
|
||||||
stx.sc.releaseReadLock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Abort the current transaction, releasing the read lock
|
|
||||||
func (stx *SafeTX) Rollback() error {
|
|
||||||
if stx.tx == nil {
|
|
||||||
return errors.New("Cannot rollback without a transaction")
|
|
||||||
}
|
|
||||||
err := stx.tx.Rollback()
|
|
||||||
stx.tx = nil
|
|
||||||
stx.sc.releaseReadLock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
60
db/user.go
60
db/user.go
@@ -1,60 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int // Integer ID (index primary key)
|
|
||||||
Username string // Username (unique)
|
|
||||||
Password_hash string // Bcrypt password hash
|
|
||||||
Created_at int64 // Epoch timestamp when the user was added to the database
|
|
||||||
Bio string // Short byline set by the user
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uses bcrypt to set the users Password_hash from the given password
|
|
||||||
func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error {
|
|
||||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
|
||||||
}
|
|
||||||
user.Password_hash = string(hashedPassword)
|
|
||||||
query := `UPDATE users SET password_hash = ? WHERE id = ?`
|
|
||||||
_, err = tx.Exec(ctx, query, user.Password_hash, user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tx.Exec")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uses bcrypt to check if the given password matches the users Password_hash
|
|
||||||
func (user *User) CheckPassword(password string) error {
|
|
||||||
err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password))
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "bcrypt.CompareHashAndPassword")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change the user's username
|
|
||||||
func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error {
|
|
||||||
query := `UPDATE users SET username = ? WHERE id = ?`
|
|
||||||
_, err := tx.Exec(ctx, query, newUsername, user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tx.Exec")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change the user's bio
|
|
||||||
func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error {
|
|
||||||
query := `UPDATE users SET bio = ? WHERE id = ?`
|
|
||||||
_, err := tx.Exec(ctx, query, newBio, user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tx.Exec")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -64,7 +64,7 @@ failed_cleanup() {
|
|||||||
trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT
|
trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT
|
||||||
|
|
||||||
echo "Migration in progress from $CUR_VER to $TGT_VER"
|
echo "Migration in progress from $CUR_VER to $TGT_VER"
|
||||||
${MIGRATION_BIN}/prmigrate-${ENVR}-${COMMIT_HASH} $UPDATED_BACKUP $CMD $TGT_VER
|
${MIGRATION_BIN}/migrate-${ENVR}-${COMMIT_HASH} $UPDATED_BACKUP $CMD $TGT_VER
|
||||||
if [ $? -ne 0 ]; then
|
if [ $? -ne 0 ]; then
|
||||||
echo "Migration failed"
|
echo "Migration failed"
|
||||||
exit 1
|
exit 1
|
||||||
|
|||||||
27
go.mod
27
go.mod
@@ -1,36 +1,39 @@
|
|||||||
module projectreshoot
|
module projectreshoot
|
||||||
|
|
||||||
go 1.24.0
|
go 1.25.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/a-h/templ v0.3.833
|
git.haelnorr.com/h/golib/cookies v0.9.0
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
git.haelnorr.com/h/golib/env v0.9.0
|
||||||
github.com/google/uuid v1.6.0
|
git.haelnorr.com/h/golib/hlog v0.9.0
|
||||||
|
git.haelnorr.com/h/golib/hws v0.1.0
|
||||||
|
git.haelnorr.com/h/golib/hwsauth v0.2.0
|
||||||
|
git.haelnorr.com/h/golib/jwt v0.9.2
|
||||||
|
git.haelnorr.com/h/golib/tmdb v0.8.0
|
||||||
|
github.com/a-h/templ v0.3.977
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.24
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/pressly/goose/v3 v3.24.1
|
github.com/pressly/goose/v3 v3.24.1
|
||||||
github.com/rs/zerolog v1.33.0
|
|
||||||
github.com/stretchr/testify v1.10.0
|
|
||||||
golang.org/x/crypto v0.33.0
|
golang.org/x/crypto v0.33.0
|
||||||
modernc.org/sqlite v1.35.0
|
modernc.org/sqlite v1.35.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/rs/zerolog v1.34.0 // indirect
|
||||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
|
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
|
||||||
golang.org/x/sync v0.11.0 // indirect
|
golang.org/x/sync v0.16.0 // indirect
|
||||||
golang.org/x/sys v0.30.0 // indirect
|
golang.org/x/sys v0.34.0 // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
|
||||||
modernc.org/libc v1.61.13 // indirect
|
modernc.org/libc v1.61.13 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.8.2 // indirect
|
modernc.org/memory v1.8.2 // indirect
|
||||||
|
|||||||
56
go.sum
56
go.sum
@@ -1,5 +1,21 @@
|
|||||||
github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU=
|
git.haelnorr.com/h/golib/cookies v0.9.0 h1:Vf+eX1prHkKuGrQon1BHY87yaPc1H+HJFRXDOV/AuWs=
|
||||||
github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk=
|
git.haelnorr.com/h/golib/cookies v0.9.0/go.mod h1:y1385YExI9gLwckCVDCYVcsFXr6N7T3brJjnJD2QIuo=
|
||||||
|
git.haelnorr.com/h/golib/env v0.9.0 h1:Ahqr3PbHy7HdWEHUhylzIZy6Gg8mST5UdgKlU2RAhls=
|
||||||
|
git.haelnorr.com/h/golib/env v0.9.0/go.mod h1:glUQVdA1HMKX1avTDyTyuhcr36SSxZtlJxKDT5KTztg=
|
||||||
|
git.haelnorr.com/h/golib/hlog v0.9.0 h1:ib8n2MdmiRK2TF067p220kXmhDe9aAnlcsgpuv+QpvE=
|
||||||
|
git.haelnorr.com/h/golib/hlog v0.9.0/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk=
|
||||||
|
git.haelnorr.com/h/golib/hws v0.1.0 h1:+0eNq1uGWrGfbS5AgHeGoGDjVfCWuaVu+1wBxgPqyOY=
|
||||||
|
git.haelnorr.com/h/golib/hws v0.1.0/go.mod h1:b2pbkMaebzmck9TxqGBGzTJPEcB5TWcEHwFknLE7dqM=
|
||||||
|
git.haelnorr.com/h/golib/hwsauth v0.2.0 h1:rLfTtxo0lBUMuWzEdoS1Y4i8/UiCzDZ5DS+6WC/C974=
|
||||||
|
git.haelnorr.com/h/golib/hwsauth v0.2.0/go.mod h1:d1oXUstDHqKwCXzcEMdHGC8yoT2S2gwpJkrEo8daCMs=
|
||||||
|
git.haelnorr.com/h/golib/jwt v0.9.2 h1:l1Ow7DPGACAU54CnMP/NlZjdc4nRD1wr3xZ8a7taRvU=
|
||||||
|
git.haelnorr.com/h/golib/jwt v0.9.2/go.mod h1:fbuPrfucT9lL0faV5+Q5Gk9WFJxPlwzRPpbMQKYZok4=
|
||||||
|
git.haelnorr.com/h/golib/tmdb v0.8.0 h1:OQ6M2TB8FHm8fJD7/ebfWm63Duzfp0kmFX9genEig34=
|
||||||
|
git.haelnorr.com/h/golib/tmdb v0.8.0/go.mod h1:mGKYa3o3z0IsQ5EO3MPmnL2Bwl2sSMsUHXVgaIGR7Z0=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
|
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=
|
||||||
|
github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo=
|
||||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
@@ -16,11 +32,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
|
||||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
|
||||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|
||||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
|
||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|
||||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||||
@@ -28,6 +39,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
|
|||||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
|
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
|
||||||
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||||
@@ -40,33 +53,30 @@ github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4
|
|||||||
github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug=
|
github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
||||||
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
|
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
|
||||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
||||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"projectreshoot/view/page"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ErrorPage(
|
|
||||||
errorCode int,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
) {
|
|
||||||
message := map[int]string{
|
|
||||||
401: "You need to login to view this page.",
|
|
||||||
403: "You do not have permission to view this page.",
|
|
||||||
404: "The page or resource you have requested does not exist.",
|
|
||||||
500: `An error occured on the server. Please try again, and if this
|
|
||||||
continues to happen contact an administrator.`,
|
|
||||||
503: "The server is currently down for maintenance and should be back soon. =)",
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
page.Error(errorCode, http.StatusText(errorCode), message[errorCode]).
|
|
||||||
Render(r.Context(), w)
|
|
||||||
}
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/cookies"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/jwt"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
func revokeAccess(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
atStr string,
|
|
||||||
) error {
|
|
||||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "Token is expired") ||
|
|
||||||
strings.Contains(err.Error(), "Token has been revoked") {
|
|
||||||
return nil // Token is expired, dont need to revoke it
|
|
||||||
}
|
|
||||||
return errors.Wrap(err, "jwt.ParseAccessToken")
|
|
||||||
}
|
|
||||||
err = jwt.RevokeToken(ctx, tx, aT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func revokeRefresh(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
rtStr string,
|
|
||||||
) error {
|
|
||||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "Token is expired") ||
|
|
||||||
strings.Contains(err.Error(), "Token has been revoked") {
|
|
||||||
return nil // Token is expired, dont need to revoke it
|
|
||||||
}
|
|
||||||
return errors.Wrap(err, "jwt.ParseRefreshToken")
|
|
||||||
}
|
|
||||||
err = jwt.RevokeToken(ctx, tx, rT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve and revoke the user's tokens
|
|
||||||
func revokeTokens(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
r *http.Request,
|
|
||||||
) error {
|
|
||||||
// get the tokens from the cookies
|
|
||||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
|
||||||
// revoke the refresh token first as the access token expires quicker
|
|
||||||
// only matters if there is an error revoking the tokens
|
|
||||||
err := revokeRefresh(config, ctx, tx, rtStr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "revokeRefresh")
|
|
||||||
}
|
|
||||||
err = revokeAccess(config, ctx, tx, atStr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "revokeAccess")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle a logout request
|
|
||||||
func Logout(
|
|
||||||
config *config.Config,
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
conn *db.SafeConn,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Start the transaction
|
|
||||||
tx, err := conn.Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn().Err(err).Msg("Error occured on user logout")
|
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = revokeTokens(config, ctx, tx, r)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
logger.Error().Err(err).Msg("Error occured on user logout")
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
cookies.DeleteCookie(w, "access", "/")
|
|
||||||
cookies.DeleteCookie(w, "refresh", "/")
|
|
||||||
w.Header().Set("HX-Redirect", "/login")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/cookies"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/jwt"
|
|
||||||
"projectreshoot/view/component/form"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Get the tokens from the request
|
|
||||||
func getTokens(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
r *http.Request,
|
|
||||||
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
|
||||||
// get the existing tokens from the cookies
|
|
||||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
|
||||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken")
|
|
||||||
}
|
|
||||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
|
||||||
}
|
|
||||||
return aT, rT, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Revoke the given token pair
|
|
||||||
func revokeTokenPair(
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
aT *jwt.AccessToken,
|
|
||||||
rT *jwt.RefreshToken,
|
|
||||||
) error {
|
|
||||||
err := jwt.RevokeToken(ctx, tx, aT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
err = jwt.RevokeToken(ctx, tx, rT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Issue new tokens for the user, invalidating the old ones
|
|
||||||
func refreshTokens(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
) error {
|
|
||||||
aT, rT, err := getTokens(config, ctx, tx, r)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "getTokens")
|
|
||||||
}
|
|
||||||
rememberMe := map[string]bool{
|
|
||||||
"session": false,
|
|
||||||
"exp": true,
|
|
||||||
}[aT.TTL]
|
|
||||||
// issue new tokens for the user
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
err = cookies.SetTokenCookies(w, r, config, user.User, true, rememberMe)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "cookies.SetTokenCookies")
|
|
||||||
}
|
|
||||||
err = revokeTokenPair(ctx, tx, aT, rT)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "revokeTokenPair")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate the provided password
|
|
||||||
func validatePassword(
|
|
||||||
r *http.Request,
|
|
||||||
) error {
|
|
||||||
r.ParseForm()
|
|
||||||
password := r.FormValue("password")
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
err := user.CheckPassword(password)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "user.CheckPassword")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle request to reauthenticate (i.e. make token fresh again)
|
|
||||||
func Reauthenticate(
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
config *config.Config,
|
|
||||||
conn *db.SafeConn,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Start the transaction
|
|
||||||
tx, err := conn.Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn().Err(err).Msg("Failed to refresh user tokens")
|
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = validatePassword(r)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
w.WriteHeader(445)
|
|
||||||
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = refreshTokens(config, ctx, tx, w, r)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
logger.Error().Err(err).Msg("Failed to refresh user tokens")
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Wrapper for default FileSystem
|
|
||||||
type justFilesFilesystem struct {
|
|
||||||
fs http.FileSystem
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrapper for default File
|
|
||||||
type neuteredReaddirFile struct {
|
|
||||||
http.File
|
|
||||||
}
|
|
||||||
|
|
||||||
// Modifies the behavior of FileSystem.Open to return the neutered version of File
|
|
||||||
func (fs justFilesFilesystem) Open(name string) (http.File, error) {
|
|
||||||
f, err := fs.fs.Open(name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the requested path is a directory
|
|
||||||
// and explicitly return an error to trigger a 404
|
|
||||||
fileInfo, err := f.Stat()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if fileInfo.IsDir() {
|
|
||||||
return nil, os.ErrNotExist
|
|
||||||
}
|
|
||||||
|
|
||||||
return neuteredReaddirFile{f}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Overrides the Readdir method of File to always return nil
|
|
||||||
func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handles requests for static files, without allowing access to the
|
|
||||||
// directory viewer and returning 404 if an exact file is not found
|
|
||||||
func StaticFS(staticFS *http.FileSystem) http.Handler {
|
|
||||||
return http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
nfs := justFilesFilesystem{*staticFS}
|
|
||||||
fs := http.FileServer(nfs)
|
|
||||||
fs.ServeHTTP(w, r)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
func removeme(
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
conn *db.SafeConn,
|
|
||||||
handler func(
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
),
|
|
||||||
onfail func(err error),
|
|
||||||
) {
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Start the transaction
|
|
||||||
tx, err := conn.Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
onfail(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
handler(ctx, tx, w, r)
|
|
||||||
}
|
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/logging"
|
"git.haelnorr.com/h/golib/env"
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/tmdb"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
"github.com/rs/zerolog"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -27,9 +27,11 @@ type Config struct {
|
|||||||
AccessTokenExpiry int64 // Access token expiry in minutes
|
AccessTokenExpiry int64 // Access token expiry in minutes
|
||||||
RefreshTokenExpiry int64 // Refresh token expiry in minutes
|
RefreshTokenExpiry int64 // Refresh token expiry in minutes
|
||||||
TokenFreshTime int64 // Time for tokens to stay fresh in minutes
|
TokenFreshTime int64 // Time for tokens to stay fresh in minutes
|
||||||
LogLevel zerolog.Level // Log level for global logging. Defaults to info
|
LogLevel hlog.Level // Log level for global logging. Defaults to info
|
||||||
LogOutput string // "file", "console", or "both". Defaults to console
|
LogOutput string // "file", "console", or "both". Defaults to console
|
||||||
LogDir string // Path to create log files
|
LogDir string // Path to create log files
|
||||||
|
TMDBToken string // Read access token for TMDB API
|
||||||
|
TMDBConfig *tmdb.Config // Config data for interfacing with TMDB
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the application configuration and get a pointer to the Config object
|
// Load the application configuration and get a pointer to the Config object
|
||||||
@@ -38,7 +40,7 @@ func GetConfig(args map[string]string) (*Config, error) {
|
|||||||
var (
|
var (
|
||||||
host string
|
host string
|
||||||
port string
|
port string
|
||||||
logLevel zerolog.Level
|
logLevel hlog.Level
|
||||||
logOutput string
|
logOutput string
|
||||||
valid bool
|
valid bool
|
||||||
)
|
)
|
||||||
@@ -46,17 +48,17 @@ func GetConfig(args map[string]string) (*Config, error) {
|
|||||||
if args["host"] != "" {
|
if args["host"] != "" {
|
||||||
host = args["host"]
|
host = args["host"]
|
||||||
} else {
|
} else {
|
||||||
host = GetEnvDefault("HOST", "127.0.0.1")
|
host = env.String("HOST", "127.0.0.1")
|
||||||
}
|
}
|
||||||
if args["port"] != "" {
|
if args["port"] != "" {
|
||||||
port = args["port"]
|
port = args["port"]
|
||||||
} else {
|
} else {
|
||||||
port = GetEnvDefault("PORT", "3010")
|
port = env.String("PORT", "3010")
|
||||||
}
|
}
|
||||||
if args["loglevel"] != "" {
|
if args["loglevel"] != "" {
|
||||||
logLevel = logging.GetLogLevel(args["loglevel"])
|
logLevel = hlog.LogLevel(args["loglevel"])
|
||||||
} else {
|
} else {
|
||||||
logLevel = logging.GetLogLevel(GetEnvDefault("LOG_LEVEL", "info"))
|
logLevel = hlog.LogLevel(env.String("LOG_LEVEL", "info"))
|
||||||
}
|
}
|
||||||
if args["logoutput"] != "" {
|
if args["logoutput"] != "" {
|
||||||
opts := map[string]string{
|
opts := map[string]string{
|
||||||
@@ -72,35 +74,44 @@ func GetConfig(args map[string]string) (*Config, error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logOutput = GetEnvDefault("LOG_OUTPUT", "console")
|
logOutput = env.String("LOG_OUTPUT", "console")
|
||||||
}
|
}
|
||||||
if logOutput != "both" && logOutput != "console" && logOutput != "file" {
|
if logOutput != "both" && logOutput != "console" && logOutput != "file" {
|
||||||
logOutput = "console"
|
logOutput = "console"
|
||||||
}
|
}
|
||||||
|
tmdbcfg, err := tmdb.GetConfig(os.Getenv("TMDB_API_TOKEN"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdb.GetConfig")
|
||||||
|
}
|
||||||
|
|
||||||
config := &Config{
|
config := &Config{
|
||||||
Host: host,
|
Host: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
TrustedHost: GetEnvDefault("TRUSTED_HOST", "127.0.0.1"),
|
TrustedHost: env.String("TRUSTED_HOST", "127.0.0.1"),
|
||||||
SSL: GetEnvBool("SSL_MODE", false),
|
SSL: env.Bool("SSL_MODE", false),
|
||||||
GZIP: GetEnvBool("GZIP", false),
|
GZIP: env.Bool("GZIP", false),
|
||||||
ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
|
ReadHeaderTimeout: env.Duration("READ_HEADER_TIMEOUT", 2),
|
||||||
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
|
WriteTimeout: env.Duration("WRITE_TIMEOUT", 10),
|
||||||
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
|
IdleTimeout: env.Duration("IDLE_TIMEOUT", 120),
|
||||||
DBName: "00001",
|
DBName: "00001",
|
||||||
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
|
DBLockTimeout: env.Duration("DB_LOCK_TIMEOUT", 60),
|
||||||
SecretKey: os.Getenv("SECRET_KEY"),
|
SecretKey: env.String("SECRET_KEY", ""),
|
||||||
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
|
AccessTokenExpiry: env.Int64("ACCESS_TOKEN_EXPIRY", 5),
|
||||||
RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day
|
RefreshTokenExpiry: env.Int64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day
|
||||||
TokenFreshTime: GetEnvInt64("TOKEN_FRESH_TIME", 5),
|
TokenFreshTime: env.Int64("TOKEN_FRESH_TIME", 5),
|
||||||
LogLevel: logLevel,
|
LogLevel: logLevel,
|
||||||
LogOutput: logOutput,
|
LogOutput: logOutput,
|
||||||
LogDir: GetEnvDefault("LOG_DIR", ""),
|
LogDir: env.String("LOG_DIR", ""),
|
||||||
|
TMDBToken: env.String("TMDB_API_TOKEN", ""),
|
||||||
|
TMDBConfig: tmdbcfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.SecretKey == "" && args["dbver"] != "true" {
|
if config.SecretKey == "" && args["dbver"] != "true" {
|
||||||
return nil, errors.New("Envar not set: SECRET_KEY")
|
return nil, errors.New("Envar not set: SECRET_KEY")
|
||||||
}
|
}
|
||||||
|
if config.TMDBToken == "" && args["dbver"] != "true" {
|
||||||
|
return nil, errors.New("Envar not set: TMDB_API_TOKEN")
|
||||||
|
}
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
@@ -2,17 +2,19 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/contexts"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"projectreshoot/cookies"
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
"projectreshoot/db"
|
"projectreshoot/internal/models"
|
||||||
"projectreshoot/view/component/account"
|
"projectreshoot/internal/view/component/account"
|
||||||
"projectreshoot/view/page"
|
"projectreshoot/internal/view/page"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Renders the account page on the 'General' subpage
|
// Renders the account page on the 'General' subpage
|
||||||
@@ -43,8 +45,9 @@ func AccountSubpage() http.Handler {
|
|||||||
|
|
||||||
// Handles a request to change the users username
|
// Handles a request to change the users username
|
||||||
func ChangeUsername(
|
func ChangeUsername(
|
||||||
logger *zerolog.Logger,
|
server *hws.Server,
|
||||||
conn *db.SafeConn,
|
auth *hwsauth.Authenticator[*models.User],
|
||||||
|
conn *sql.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -52,19 +55,17 @@ func ChangeUsername(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Error updating username")
|
server.ThrowWarn(w, hws.NewError(http.StatusServiceUnavailable, "Error updating username", err))
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
newUsername := r.FormValue("username")
|
newUsername := r.FormValue("username")
|
||||||
unique, err := db.CheckUsernameUnique(ctx, tx, newUsername)
|
unique, err := models.CheckUsernameUnique(tx, newUsername)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
logger.Error().Err(err).Msg("Error updating username")
|
server.ThrowWarn(w, hws.NewError(http.StatusInternalServerError, "Error updating username", err))
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !unique {
|
if !unique {
|
||||||
@@ -73,12 +74,11 @@ func ChangeUsername(
|
|||||||
Render(r.Context(), w)
|
Render(r.Context(), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := contexts.GetUser(r.Context())
|
user := auth.CurrentModel(r.Context())
|
||||||
err = user.ChangeUsername(ctx, tx, newUsername)
|
err = user.ChangeUsername(tx, newUsername)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
logger.Error().Err(err).Msg("Error updating username")
|
server.ThrowWarn(w, hws.NewError(http.StatusInternalServerError, "Error updating username", err))
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
@@ -89,8 +89,9 @@ func ChangeUsername(
|
|||||||
|
|
||||||
// Handles a request to change the users bio
|
// Handles a request to change the users bio
|
||||||
func ChangeBio(
|
func ChangeBio(
|
||||||
logger *zerolog.Logger,
|
server *hws.Server,
|
||||||
conn *db.SafeConn,
|
auth *hwsauth.Authenticator[*models.User],
|
||||||
|
conn *sql.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -98,10 +99,9 @@ func ChangeBio(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Error updating bio")
|
server.ThrowWarn(w, hws.NewError(http.StatusServiceUnavailable, "Error updating bio", err))
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
@@ -113,12 +113,11 @@ func ChangeBio(
|
|||||||
Render(r.Context(), w)
|
Render(r.Context(), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := contexts.GetUser(r.Context())
|
user := auth.CurrentModel(r.Context())
|
||||||
err = user.ChangeBio(ctx, tx, newBio)
|
err = user.ChangeBio(tx, newBio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
logger.Error().Err(err).Msg("Error updating bio")
|
server.ThrowWarn(w, hws.NewError(http.StatusInternalServerError, "Error updating bio", err))
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
@@ -127,8 +126,6 @@ func ChangeBio(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
func validateChangePassword(
|
func validateChangePassword(
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
@@ -145,8 +142,9 @@ func validateChangePassword(
|
|||||||
|
|
||||||
// Handles a request to change the users password
|
// Handles a request to change the users password
|
||||||
func ChangePassword(
|
func ChangePassword(
|
||||||
logger *zerolog.Logger,
|
server *hws.Server,
|
||||||
conn *db.SafeConn,
|
auth *hwsauth.Authenticator[*models.User],
|
||||||
|
conn *sql.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -154,24 +152,22 @@ func ChangePassword(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Error updating password")
|
server.ThrowWarn(w, hws.NewError(http.StatusServiceUnavailable, "Error updating password", err))
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
newPass, err := validateChangePassword(ctx, tx, r)
|
newPass, err := validateChangePassword(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
account.ChangePassword(err.Error()).Render(r.Context(), w)
|
account.ChangePassword(err.Error()).Render(r.Context(), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := contexts.GetUser(r.Context())
|
user := auth.CurrentModel(r.Context())
|
||||||
err = user.SetPassword(ctx, tx, newPass)
|
err = user.SetPassword(tx, newPass)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
logger.Error().Err(err).Msg("Error updating password")
|
server.ThrowWarn(w, hws.NewError(http.StatusInternalServerError, "Error updating password", err))
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tx.Commit()
|
tx.Commit()
|
||||||
42
internal/handler/errorpage.go
Normal file
42
internal/handler/errorpage.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/view/page"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ErrorPage(
|
||||||
|
errorCode int,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) {
|
||||||
|
message := map[int]string{
|
||||||
|
401: "You need to login to view this page.",
|
||||||
|
403: "You do not have permission to view this page.",
|
||||||
|
404: "The page or resource you have requested does not exist.",
|
||||||
|
500: `An error occured on the server. Please try again, and if this
|
||||||
|
continues to happen contact an administrator.`,
|
||||||
|
503: "The server is currently down for maintenance and should be back soon. =)",
|
||||||
|
}
|
||||||
|
w.WriteHeader(errorCode)
|
||||||
|
page.Error(errorCode, http.StatusText(errorCode), message[errorCode]).
|
||||||
|
Render(r.Context(), w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewErrorPage(
|
||||||
|
errorCode int,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
) error {
|
||||||
|
message := map[int]string{
|
||||||
|
401: "You need to login to view this page.",
|
||||||
|
403: "You do not have permission to view this page.",
|
||||||
|
404: "The page or resource you have requested does not exist.",
|
||||||
|
500: `An error occured on the server. Please try again, and if this
|
||||||
|
continues to happen contact an administrator.`,
|
||||||
|
503: "The server is currently down for maintenance and should be back soon. =)",
|
||||||
|
}
|
||||||
|
w.WriteHeader(errorCode)
|
||||||
|
return page.Error(errorCode, http.StatusText(errorCode), message[errorCode]).
|
||||||
|
Render(r.Context(), w)
|
||||||
|
}
|
||||||
@@ -3,7 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"projectreshoot/view/page"
|
"projectreshoot/internal/view/page"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handles responses to the / path. Also serves a 404 Page for paths that
|
// Handles responses to the / path. Also serves a 404 Page for paths that
|
||||||
@@ -2,35 +2,39 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"git.haelnorr.com/h/golib/hws"
|
||||||
"projectreshoot/cookies"
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
"projectreshoot/db"
|
"projectreshoot/internal/models"
|
||||||
"projectreshoot/view/component/form"
|
"projectreshoot/internal/view/component/form"
|
||||||
"projectreshoot/view/page"
|
"projectreshoot/internal/view/page"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Validates the username matches a user in the database and the password
|
// Validates the username matches a user in the database and the password
|
||||||
// is correct. Returns the corresponding user
|
// is correct. Returns the corresponding user
|
||||||
func validateLogin(
|
func validateLogin(
|
||||||
ctx context.Context,
|
tx *sql.Tx,
|
||||||
tx *db.SafeTX,
|
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (*db.User, error) {
|
) (*models.User, error) {
|
||||||
formUsername := r.FormValue("username")
|
formUsername := r.FormValue("username")
|
||||||
formPassword := r.FormValue("password")
|
formPassword := r.FormValue("password")
|
||||||
user, err := db.GetUserFromUsername(ctx, tx, formUsername)
|
user, err := models.GetUserFromUsername(tx, formUsername)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "db.GetUserFromUsername")
|
return nil, errors.Wrap(err, "db.GetUserFromUsername")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = user.CheckPassword(formPassword)
|
err = user.CheckPassword(tx, formPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "Username or password incorrect") {
|
||||||
|
return nil, errors.Wrap(err, "user.CheckPassword")
|
||||||
|
}
|
||||||
return nil, errors.New("Username or password incorrect")
|
return nil, errors.New("Username or password incorrect")
|
||||||
}
|
}
|
||||||
return user, nil
|
return user, nil
|
||||||
@@ -50,9 +54,9 @@ func checkRememberMe(r *http.Request) bool {
|
|||||||
// and on fail will return the login form again, passing the error to the
|
// and on fail will return the login form again, passing the error to the
|
||||||
// template for user feedback
|
// template for user feedback
|
||||||
func LoginRequest(
|
func LoginRequest(
|
||||||
config *config.Config,
|
server *hws.Server,
|
||||||
logger *zerolog.Logger,
|
auth *hwsauth.Authenticator[*models.User],
|
||||||
conn *db.SafeConn,
|
conn *sql.DB,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -60,19 +64,17 @@ func LoginRequest(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Failed to set token cookies")
|
server.ThrowWarn(w, hws.NewError(http.StatusServiceUnavailable, "Login failed", err))
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
user, err := validateLogin(ctx, tx, r)
|
user, err := validateLogin(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
if err.Error() != "Username or password incorrect" {
|
if err.Error() != "Username or password incorrect" {
|
||||||
logger.Warn().Caller().Err(err).Msg("Login request failed")
|
server.ThrowWarn(w, hws.NewError(http.StatusInternalServerError, "Login failed", err))
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
} else {
|
} else {
|
||||||
form.LoginForm(err.Error()).Render(r.Context(), w)
|
form.LoginForm(err.Error()).Render(r.Context(), w)
|
||||||
}
|
}
|
||||||
@@ -80,11 +82,10 @@ func LoginRequest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
rememberMe := checkRememberMe(r)
|
rememberMe := checkRememberMe(r)
|
||||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
err = auth.Login(w, r, user, rememberMe)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
server.ThrowWarn(w, hws.NewError(http.StatusInternalServerError, "Login failed", err))
|
||||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
40
internal/handler/logout.go
Normal file
40
internal/handler/logout.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handle a logout request
|
||||||
|
func Logout(
|
||||||
|
server *hws.Server,
|
||||||
|
auth *hwsauth.Authenticator[*models.User],
|
||||||
|
conn *sql.DB,
|
||||||
|
) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowError(w, r, hws.NewError(http.StatusInternalServerError, "Logout failed", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
err = auth.Logout(tx, w, r)
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowError(w, r, hws.NewError(http.StatusInternalServerError, "Logout failed", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
w.Header().Set("HX-Redirect", "/login")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
44
internal/handler/movie.go
Normal file
44
internal/handler/movie.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/view/page"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/tmdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Movie(
|
||||||
|
config *config.Config,
|
||||||
|
logger *hlog.Logger,
|
||||||
|
) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id := r.PathValue("movie_id")
|
||||||
|
movie_id, err := strconv.ParseInt(id, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
ErrorPage(http.StatusNotFound, w, r)
|
||||||
|
logger.Error().Err(err).Str("movie_id", id).
|
||||||
|
Msg("Error occured getting the movie")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
movie, err := tmdb.GetMovie(int32(movie_id), config.TMDBToken)
|
||||||
|
if err != nil {
|
||||||
|
ErrorPage(http.StatusInternalServerError, w, r)
|
||||||
|
logger.Error().Err(err).Int32("movie_id", int32(movie_id)).
|
||||||
|
Msg("Error occured getting the movie")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
credits, err := tmdb.GetCredits(int32(movie_id), config.TMDBToken)
|
||||||
|
if err != nil {
|
||||||
|
ErrorPage(http.StatusInternalServerError, w, r)
|
||||||
|
logger.Error().Err(err).Int32("movie_id", int32(movie_id)).
|
||||||
|
Msg("Error occured getting the movie credits")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
page.Movie(movie, credits, &config.TMDBConfig.Image).Render(r.Context(), w)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
41
internal/handler/movie_search.go
Normal file
41
internal/handler/movie_search.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
"projectreshoot/internal/view/component/search"
|
||||||
|
"projectreshoot/internal/view/page"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/tmdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SearchMovies(
|
||||||
|
config *config.Config,
|
||||||
|
logger *hlog.Logger,
|
||||||
|
) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.ParseForm()
|
||||||
|
query := r.FormValue("search")
|
||||||
|
if query == "" {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
movies, err := tmdb.SearchMovies(config.TMDBToken, query, false, 1)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
search.MovieResults(movies, &config.TMDBConfig.Image).Render(r.Context(), w)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MoviesPage() http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
page.Movies().Render(r.Context(), w)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -2,7 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"projectreshoot/view/page"
|
"projectreshoot/internal/view/page"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ProfilePage() http.Handler {
|
func ProfilePage() http.Handler {
|
||||||
66
internal/handler/reauthenticatate.go
Normal file
66
internal/handler/reauthenticatate.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
"projectreshoot/internal/view/component/form"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Validate the provided password
|
||||||
|
func validatePassword(
|
||||||
|
auth *hwsauth.Authenticator[*models.User],
|
||||||
|
tx *sql.Tx,
|
||||||
|
r *http.Request,
|
||||||
|
) error {
|
||||||
|
r.ParseForm()
|
||||||
|
password := r.FormValue("password")
|
||||||
|
user := auth.CurrentModel(r.Context())
|
||||||
|
err := user.CheckPassword(tx, password)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "user.CheckPassword")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle request to reauthenticate (i.e. make token fresh again)
|
||||||
|
func Reauthenticate(
|
||||||
|
server *hws.Server,
|
||||||
|
auth *hwsauth.Authenticator[*models.User],
|
||||||
|
conn *sql.DB,
|
||||||
|
) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Start the transaction
|
||||||
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowError(w, r, hws.NewError(http.StatusInternalServerError, "Failed to start transaction", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
err = validatePassword(auth, tx, r)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(445)
|
||||||
|
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = auth.RefreshAuthTokens(tx, w, r)
|
||||||
|
if err != nil {
|
||||||
|
server.ThrowError(w, r, hws.NewError(http.StatusInternalServerError, "Failed to refresh user tokens", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -2,30 +2,32 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/config"
|
"projectreshoot/internal/config"
|
||||||
"projectreshoot/cookies"
|
"projectreshoot/internal/models"
|
||||||
"projectreshoot/db"
|
"projectreshoot/internal/view/component/form"
|
||||||
"projectreshoot/view/component/form"
|
"projectreshoot/internal/view/page"
|
||||||
"projectreshoot/view/page"
|
|
||||||
|
"git.haelnorr.com/h/golib/cookies"
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func validateRegistration(
|
func validateRegistration(
|
||||||
ctx context.Context,
|
tx *sql.Tx,
|
||||||
tx *db.SafeTX,
|
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (*db.User, error) {
|
) (*models.User, error) {
|
||||||
formUsername := r.FormValue("username")
|
formUsername := r.FormValue("username")
|
||||||
formPassword := r.FormValue("password")
|
formPassword := r.FormValue("password")
|
||||||
formConfirmPassword := r.FormValue("confirm-password")
|
formConfirmPassword := r.FormValue("confirm-password")
|
||||||
unique, err := db.CheckUsernameUnique(ctx, tx, formUsername)
|
unique, err := models.CheckUsernameUnique(tx, formUsername)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "db.CheckUsernameUnique")
|
return nil, errors.Wrap(err, "models.CheckUsernameUnique")
|
||||||
}
|
}
|
||||||
if !unique {
|
if !unique {
|
||||||
return nil, errors.New("Username is taken")
|
return nil, errors.New("Username is taken")
|
||||||
@@ -36,9 +38,9 @@ func validateRegistration(
|
|||||||
if len(formPassword) > 72 {
|
if len(formPassword) > 72 {
|
||||||
return nil, errors.New("Password exceeds maximum length of 72 bytes")
|
return nil, errors.New("Password exceeds maximum length of 72 bytes")
|
||||||
}
|
}
|
||||||
user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword)
|
user, err := models.CreateNewUser(tx, formUsername, formPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "db.CreateNewUser")
|
return nil, errors.Wrap(err, "models.CreateNewUser")
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
@@ -46,8 +48,9 @@ func validateRegistration(
|
|||||||
|
|
||||||
func RegisterRequest(
|
func RegisterRequest(
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
logger *zerolog.Logger,
|
logger *hlog.Logger,
|
||||||
conn *db.SafeConn,
|
conn *sql.DB,
|
||||||
|
tokenGen *jwt.TokenGenerator,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(
|
return http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -55,14 +58,14 @@ func RegisterRequest(
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.Begin(ctx)
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Failed to set token cookies")
|
logger.Warn().Err(err).Msg("Failed to set token cookies")
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
user, err := validateRegistration(ctx, tx, r)
|
user, err := validateRegistration(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
if err.Error() != "Username is taken" &&
|
if err.Error() != "Username is taken" &&
|
||||||
@@ -77,7 +80,7 @@ func RegisterRequest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
rememberMe := checkRememberMe(r)
|
rememberMe := checkRememberMe(r)
|
||||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
err = jwt.SetTokenCookies(w, r, tokenGen, user.ID(), true, rememberMe, config.SSL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
24
internal/handler/static.go
Normal file
24
internal/handler/static.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.haelnorr.com/h/golib/hws"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handles requests for static files, without allowing access to the
|
||||||
|
// directory viewer and returning 404 if an exact file is not found
|
||||||
|
func StaticFS(staticFS *http.FileSystem, logger *hlog.Logger) http.Handler {
|
||||||
|
return http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fs, err := hws.SafeFileServer(staticFS)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
logger.Error().Err(err).Msg("Failed to load file system")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fs.ServeHTTP(w, r)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
73
internal/models/user.go
Normal file
73
internal/models/user.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
id int // Integer ID (index primary key)
|
||||||
|
Username string // Username (unique)
|
||||||
|
Created_at int64 // Epoch timestamp when the user was added to the database
|
||||||
|
Bio string // Short byline set by the user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u User) ID() int {
|
||||||
|
return u.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses bcrypt to set the users Password_hash from the given password
|
||||||
|
func (user *User) SetPassword(
|
||||||
|
tx *sql.Tx,
|
||||||
|
password string,
|
||||||
|
) error {
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||||
|
}
|
||||||
|
newPassword := string(hashedPassword)
|
||||||
|
query := `UPDATE users SET password_hash = ? WHERE id = ?`
|
||||||
|
_, err = tx.Exec(query, newPassword, user.id)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Exec")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses bcrypt to check if the given password matches the users Password_hash
|
||||||
|
func (user *User) CheckPassword(tx *sql.Tx, password string) error {
|
||||||
|
query := `SELECT password_hash FROM users WHERE id = ? LIMIT 1`
|
||||||
|
row := tx.QueryRow(query, user.id)
|
||||||
|
var hashedPassword string
|
||||||
|
err := row.Scan(&hashedPassword)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "row.Scan")
|
||||||
|
}
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "bcrypt.CompareHashAndPassword")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the user's username
|
||||||
|
func (user *User) ChangeUsername(tx *sql.Tx, newUsername string) error {
|
||||||
|
query := `UPDATE users SET username = ? WHERE id = ?`
|
||||||
|
_, err := tx.Exec(query, newUsername, user.id)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Exec")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change the user's bio
|
||||||
|
func (user *User) ChangeBio(tx *sql.Tx, newBio string) error {
|
||||||
|
query := `UPDATE users SET bio = ? WHERE id = ?`
|
||||||
|
_, err := tx.Exec(query, newBio, user.id)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Exec")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package db
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
@@ -10,21 +9,20 @@ import (
|
|||||||
|
|
||||||
// Creates a new user in the database and returns a pointer
|
// Creates a new user in the database and returns a pointer
|
||||||
func CreateNewUser(
|
func CreateNewUser(
|
||||||
ctx context.Context,
|
tx *sql.Tx,
|
||||||
tx *SafeTX,
|
|
||||||
username string,
|
username string,
|
||||||
password string,
|
password string,
|
||||||
) (*User, error) {
|
) (*User, error) {
|
||||||
query := `INSERT INTO users (username) VALUES (?)`
|
query := `INSERT INTO users (username) VALUES (?)`
|
||||||
_, err := tx.Exec(ctx, query, username)
|
_, err := tx.Exec(query, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "tx.Exec")
|
return nil, errors.Wrap(err, "tx.Exec")
|
||||||
}
|
}
|
||||||
user, err := GetUserFromUsername(ctx, tx, username)
|
user, err := GetUserFromUsername(tx, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "GetUserFromUsername")
|
return nil, errors.Wrap(err, "GetUserFromUsername")
|
||||||
}
|
}
|
||||||
err = user.SetPassword(ctx, tx, password)
|
err = user.SetPassword(tx, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "user.SetPassword")
|
return nil, errors.Wrap(err, "user.SetPassword")
|
||||||
}
|
}
|
||||||
@@ -33,23 +31,21 @@ func CreateNewUser(
|
|||||||
|
|
||||||
// Fetches data from the users table using "WHERE column = 'value'"
|
// Fetches data from the users table using "WHERE column = 'value'"
|
||||||
func fetchUserData(
|
func fetchUserData(
|
||||||
ctx context.Context,
|
tx *sql.Tx,
|
||||||
tx *SafeTX,
|
|
||||||
column string,
|
column string,
|
||||||
value interface{},
|
value any,
|
||||||
) (*sql.Rows, error) {
|
) (*sql.Rows, error) {
|
||||||
query := fmt.Sprintf(
|
query := fmt.Sprintf(
|
||||||
`SELECT
|
`SELECT
|
||||||
id,
|
id,
|
||||||
username,
|
username,
|
||||||
password_hash,
|
|
||||||
created_at,
|
created_at,
|
||||||
bio
|
bio
|
||||||
FROM users
|
FROM users
|
||||||
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
||||||
column,
|
column,
|
||||||
)
|
)
|
||||||
rows, err := tx.Query(ctx, query, value)
|
rows, err := tx.Query(query, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "tx.Query")
|
return nil, errors.Wrap(err, "tx.Query")
|
||||||
}
|
}
|
||||||
@@ -63,9 +59,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
|
|||||||
return errors.New("User not found")
|
return errors.New("User not found")
|
||||||
}
|
}
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&user.ID,
|
&user.id,
|
||||||
&user.Username,
|
&user.Username,
|
||||||
&user.Password_hash,
|
|
||||||
&user.Created_at,
|
&user.Created_at,
|
||||||
&user.Bio,
|
&user.Bio,
|
||||||
)
|
)
|
||||||
@@ -77,8 +72,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
|
|||||||
|
|
||||||
// Queries the database for a user matching the given username.
|
// Queries the database for a user matching the given username.
|
||||||
// Query is case insensitive
|
// Query is case insensitive
|
||||||
func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) {
|
func GetUserFromUsername(tx *sql.Tx, username string) (*User, error) {
|
||||||
rows, err := fetchUserData(ctx, tx, "username", username)
|
rows, err := fetchUserData(tx, "username", username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "fetchUserData")
|
return nil, errors.Wrap(err, "fetchUserData")
|
||||||
}
|
}
|
||||||
@@ -92,8 +87,8 @@ func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*Use
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Queries the database for a user matching the given ID.
|
// Queries the database for a user matching the given ID.
|
||||||
func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
func GetUserFromID(tx *sql.Tx, id int) (*User, error) {
|
||||||
rows, err := fetchUserData(ctx, tx, "id", id)
|
rows, err := fetchUserData(tx, "id", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "fetchUserData")
|
return nil, errors.Wrap(err, "fetchUserData")
|
||||||
}
|
}
|
||||||
@@ -107,9 +102,9 @@ func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Checks if the given username is unique. Returns true if not taken
|
// Checks if the given username is unique. Returns true if not taken
|
||||||
func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) {
|
func CheckUsernameUnique(tx *sql.Tx, username string) (bool, error) {
|
||||||
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
|
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
|
||||||
rows, err := tx.Query(ctx, query, username)
|
rows, err := tx.Query(query, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, errors.Wrap(err, "tx.Query")
|
return false, errors.Wrap(err, "tx.Query")
|
||||||
}
|
}
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package account
|
package account
|
||||||
|
|
||||||
import "projectreshoot/contexts"
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
templ ChangeBio(err string, bio string) {
|
templ ChangeBio(err string, bio string) {
|
||||||
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
{{
|
{{
|
||||||
user := contexts.GetUser(ctx)
|
|
||||||
if bio == "" {
|
if bio == "" {
|
||||||
bio = user.Bio
|
bio = user.Bio
|
||||||
}
|
}
|
||||||
@@ -21,21 +21,20 @@ templ ChangeBio(err string, bio string) {
|
|||||||
bio: newBio,
|
bio: newBio,
|
||||||
initialBio: oldBio,
|
initialBio: oldBio,
|
||||||
err: err,
|
err: err,
|
||||||
bioLenText: '',
|
bioLenText: "",
|
||||||
updateTextArea() {
|
updateTextArea() {
|
||||||
this.$nextTick(() => {
|
this.$nextTick(() => {
|
||||||
if (this.$refs.bio) {
|
if (this.$refs.bio) {
|
||||||
this.$refs.bio.style.height = 'auto';
|
this.$refs.bio.style.height = "auto";
|
||||||
this.$refs.bio.style.height = `
|
this.$refs.bio.style.height = `
|
||||||
${this.$refs.bio.scrollHeight + 20}px`;
|
${this.$refs.bio.scrollHeight + 20}px`;
|
||||||
};
|
}
|
||||||
this.bioLenText = `${this.bio.length}/128`;
|
this.bioLenText = `${this.bio.length}/128`;
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
resetBio() {
|
resetBio() {
|
||||||
this.bio = this.initialBio;
|
this.bio = this.initialBio;
|
||||||
this.err = "",
|
((this.err = ""), this.updateTextArea());
|
||||||
this.updateTextArea();
|
|
||||||
},
|
},
|
||||||
init() {
|
init() {
|
||||||
this.$nextTick(() => {
|
this.$nextTick(() => {
|
||||||
@@ -46,7 +45,7 @@ templ ChangeBio(err string, bio string) {
|
|||||||
this.updateTextArea();
|
this.updateTextArea();
|
||||||
}, 20);
|
}, 20);
|
||||||
});
|
});
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package account
|
package account
|
||||||
|
|
||||||
import "projectreshoot/contexts"
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
templ ChangeUsername(err string, username string) {
|
templ ChangeUsername(err string, username string) {
|
||||||
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
{{
|
{{
|
||||||
user := contexts.GetUser(ctx)
|
|
||||||
if username == "" {
|
if username == "" {
|
||||||
username = user.Username
|
username = user.Username
|
||||||
}
|
}
|
||||||
@@ -3,12 +3,12 @@ package account
|
|||||||
templ AccountContainer(subpage string) {
|
templ AccountContainer(subpage string) {
|
||||||
<div
|
<div
|
||||||
id="account-container"
|
id="account-container"
|
||||||
class="flex max-w-200 min-h-100 mx-auto bg-mantle mt-10 rounded-xl"
|
class="flex max-w-200 min-h-100 mx-5 md:mx-auto bg-mantle mt-5 rounded-xl"
|
||||||
x-data="{big:window.innerWidth >=768, open:false}"
|
x-data="{big:window.innerWidth >=768, open:false}"
|
||||||
@resize.window="big = window.innerWidth >= 768"
|
@resize.window="big = window.innerWidth >= 768"
|
||||||
>
|
>
|
||||||
@SelectMenu(subpage)
|
@SelectMenu(subpage)
|
||||||
<div class="mt-5 w-full md:ml-[200px] ml-[40px] transition-all duration-300">
|
<div class="mt-5 w-full md:ml-[200px] ml-10 transition-all duration-300">
|
||||||
<div
|
<div
|
||||||
class="pl-5 text-2xl text-subtext1 border-b
|
class="pl-5 text-2xl text-subtext1 border-b
|
||||||
border-overlay0 w-[90%] mx-auto"
|
border-overlay0 w-[90%] mx-auto"
|
||||||
@@ -75,8 +75,13 @@ templ Footer() {
|
|||||||
</div>
|
</div>
|
||||||
<div class="lg:flex lg:items-end lg:justify-between">
|
<div class="lg:flex lg:items-end lg:justify-between">
|
||||||
<div>
|
<div>
|
||||||
<p class="mt-4 text-center text-sm text-subtext0">
|
<p class="mt-4 text-center text-sm text-overlay0">
|
||||||
by Haelnorr
|
by Haelnorr |
|
||||||
|
<a href="#">Film data</a> from
|
||||||
|
<a
|
||||||
|
href="https://www.themoviedb.org/"
|
||||||
|
class="underline hover:text-subtext0 transition"
|
||||||
|
>TMDB</a>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
@@ -44,7 +44,7 @@ templ RegisterForm(registerError string) {
|
|||||||
<div class="relative">
|
<div class="relative">
|
||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
idnutanix="username"
|
id="username"
|
||||||
name="username"
|
name="username"
|
||||||
class="py-3 px-4 block w-full rounded-lg text-sm
|
class="py-3 px-4 block w-full rounded-lg text-sm
|
||||||
focus:border-blue focus:ring-blue bg-base
|
focus:border-blue focus:ring-blue bg-base
|
||||||
@@ -21,7 +21,7 @@ templ Navbar() {
|
|||||||
<div x-data="{ open: false }">
|
<div x-data="{ open: false }">
|
||||||
<header class="bg-crust">
|
<header class="bg-crust">
|
||||||
<div
|
<div
|
||||||
class="mx-auto flex h-16 max-w-screen-xl items-center gap-8
|
class="mx-auto flex h-16 max-w-7xl items-center gap-8
|
||||||
px-4 sm:px-6 lg:px-8"
|
px-4 sm:px-6 lg:px-8"
|
||||||
>
|
>
|
||||||
<a class="block" href="/">
|
<a class="block" href="/">
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package nav
|
package nav
|
||||||
|
|
||||||
import "projectreshoot/contexts"
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
type ProfileItem struct {
|
type ProfileItem struct {
|
||||||
name string // Label to display
|
name string // Label to display
|
||||||
@@ -23,7 +23,7 @@ func getProfileItems() []ProfileItem {
|
|||||||
|
|
||||||
// Returns the right portion of the navbar
|
// Returns the right portion of the navbar
|
||||||
templ navRight() {
|
templ navRight() {
|
||||||
{{ user := contexts.GetUser(ctx) }}
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
{{ items := getProfileItems() }}
|
{{ items := getProfileItems() }}
|
||||||
<div class="flex items-center gap-2">
|
<div class="flex items-center gap-2">
|
||||||
<div class="sm:flex sm:gap-2">
|
<div class="sm:flex sm:gap-2">
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package nav
|
package nav
|
||||||
|
|
||||||
import "projectreshoot/contexts"
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
// Returns the mobile version of the navbar thats only visible when activated
|
// Returns the mobile version of the navbar thats only visible when activated
|
||||||
templ sideNav(navItems []NavItem) {
|
templ sideNav(navItems []NavItem) {
|
||||||
{{ user := contexts.GetUser(ctx) }}
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
<div
|
<div
|
||||||
x-show="open"
|
x-show="open"
|
||||||
x-transition
|
x-transition
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
package popup
|
package popup
|
||||||
|
|
||||||
import "projectreshoot/view/component/form"
|
import "projectreshoot/internal/view/component/form"
|
||||||
|
|
||||||
templ ConfirmPasswordModal() {
|
templ ConfirmPasswordModal() {
|
||||||
<div
|
<div
|
||||||
44
internal/view/component/search/movies_results.templ
Normal file
44
internal/view/component/search/movies_results.templ
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package search
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
import "git.haelnorr.com/h/golib/tmdb"
|
||||||
|
|
||||||
|
templ MovieResults(movies *tmdb.ResultMovies, image *tmdb.Image) {
|
||||||
|
for _, movie := range movies.Results {
|
||||||
|
<div
|
||||||
|
class="bg-surface0 p-4 rounded-lg shadow-lg flex
|
||||||
|
items-start space-x-4"
|
||||||
|
>
|
||||||
|
<img
|
||||||
|
src={ movie.GetPoster(image, "w92") }
|
||||||
|
alt="Movie Poster"
|
||||||
|
class="rounded-lg object-cover"
|
||||||
|
width="96"
|
||||||
|
height="144"
|
||||||
|
onerror="this.onerror=null; setFallbackColor(this);"
|
||||||
|
/>
|
||||||
|
<script>
|
||||||
|
function setFallbackColor(img) {
|
||||||
|
const baseColor = getComputedStyle(document.documentElement).
|
||||||
|
getPropertyValue('--base').trim();
|
||||||
|
img.src = `data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='96' height='144'%3E%3Crect width='100%' height='100%' fill='${baseColor}'/%3E%3C/svg%3E`;
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
<div>
|
||||||
|
<a
|
||||||
|
href={ templ.SafeURL(fmt.Sprintf("/movie/%v", movie.ID)) }
|
||||||
|
class="text-xl font-semibold transition hover:text-green"
|
||||||
|
>{ movie.Title } { movie.ReleaseYear() }</a>
|
||||||
|
<p class="text-subtext0">
|
||||||
|
Released:
|
||||||
|
<span class="font-medium">{ movie.ReleaseDate }</span>
|
||||||
|
</p>
|
||||||
|
<p class="text-subtext0">
|
||||||
|
Original Title:
|
||||||
|
<span class="font-medium">{ movie.OriginalTitle }</span>
|
||||||
|
</p>
|
||||||
|
<p class="text-subtext0">{ movie.Overview }</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
}
|
||||||
109
internal/view/layout/global.templ
Normal file
109
internal/view/layout/global.templ
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package layout
|
||||||
|
|
||||||
|
import "projectreshoot/internal/view/component/nav"
|
||||||
|
import "projectreshoot/internal/view/component/footer"
|
||||||
|
import "projectreshoot/internal/view/component/popup"
|
||||||
|
|
||||||
|
// Global page layout. Includes HTML document settings, header tags
|
||||||
|
// navbar and footer
|
||||||
|
templ Global(title string) {
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html
|
||||||
|
lang="en"
|
||||||
|
x-data="{
|
||||||
|
theme: localStorage.getItem('theme')
|
||||||
|
|| 'system'}"
|
||||||
|
x-init="$watch('theme', (val) => localStorage.setItem('theme', val))"
|
||||||
|
x-bind:class="{'dark': theme === 'dark' || (theme === 'system' &&
|
||||||
|
window.matchMedia('(prefers-color-scheme: dark)').matches)}"
|
||||||
|
>
|
||||||
|
<head>
|
||||||
|
<script>
|
||||||
|
(function () {
|
||||||
|
let theme = localStorage.getItem("theme") || "system";
|
||||||
|
if (theme === "system") {
|
||||||
|
theme = window.matchMedia("(prefers-color-scheme: dark)").matches
|
||||||
|
? "dark"
|
||||||
|
: "light";
|
||||||
|
}
|
||||||
|
if (theme === "dark") {
|
||||||
|
document.documentElement.classList.add("dark");
|
||||||
|
} else {
|
||||||
|
document.documentElement.classList.remove("dark");
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
</script>
|
||||||
|
<meta charset="UTF-8"/>
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||||
|
<title>{ title }</title>
|
||||||
|
<link rel="icon" type="image/x-icon" href="/static/favicon.ico"/>
|
||||||
|
<link href="/static/css/output.css" rel="stylesheet"/>
|
||||||
|
<script src="https://unpkg.com/htmx.org@2.0.4" integrity="sha384-HGfztofotfshcF7+8n44JQL2oJmowVChPTg48S+jvZoztPfvwD79OC/LTtG6dMp+" crossorigin="anonymous"></script>
|
||||||
|
<script defer src="https://cdn.jsdelivr.net/npm/@alpinejs/persist@3.x.x/dist/cdn.min.js"></script>
|
||||||
|
<script src="https://unpkg.com/alpinejs" defer></script>
|
||||||
|
<script>
|
||||||
|
// uncomment this line to enable logging of htmx events
|
||||||
|
// htmx.logAll();
|
||||||
|
</script>
|
||||||
|
<script>
|
||||||
|
const bodyData = {
|
||||||
|
showError500: false,
|
||||||
|
showError503: false,
|
||||||
|
showConfirmPasswordModal: false,
|
||||||
|
handleHtmxBeforeOnLoad(event) {
|
||||||
|
const requestPath = event.detail.pathInfo.requestPath;
|
||||||
|
if (requestPath === "/reauthenticate") {
|
||||||
|
// handle password incorrect on refresh attempt
|
||||||
|
if (event.detail.xhr.status === 445) {
|
||||||
|
event.detail.shouldSwap = true;
|
||||||
|
event.detail.isError = false;
|
||||||
|
} else if (event.detail.xhr.status === 200) {
|
||||||
|
this.showConfirmPasswordModal = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// handle errors from the server on HTMX requests
|
||||||
|
handleHtmxError(event) {
|
||||||
|
const errorCode = event.detail.errorInfo.error;
|
||||||
|
|
||||||
|
// internal server error
|
||||||
|
if (errorCode.includes("Code 500")) {
|
||||||
|
this.showError500 = true;
|
||||||
|
setTimeout(() => (this.showError500 = false), 6000);
|
||||||
|
}
|
||||||
|
// service not available error
|
||||||
|
if (errorCode.includes("Code 503")) {
|
||||||
|
this.showError503 = true;
|
||||||
|
setTimeout(() => (this.showError503 = false), 6000);
|
||||||
|
}
|
||||||
|
|
||||||
|
// user is authorized but needs to refresh their login
|
||||||
|
if (errorCode.includes("Code 444")) {
|
||||||
|
this.showConfirmPasswordModal = true;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body
|
||||||
|
class="bg-base text-text ubuntu-mono-regular overflow-x-hidden"
|
||||||
|
x-data="bodyData"
|
||||||
|
x-on:htmx:error="handleHtmxError($event)"
|
||||||
|
x-on:htmx:before-on-load="handleHtmxBeforeOnLoad($event)"
|
||||||
|
>
|
||||||
|
@popup.Error500Popup()
|
||||||
|
@popup.Error503Popup()
|
||||||
|
@popup.ConfirmPasswordModal()
|
||||||
|
<div
|
||||||
|
id="main-content"
|
||||||
|
class="flex flex-col h-screen justify-between"
|
||||||
|
>
|
||||||
|
@nav.Navbar()
|
||||||
|
<div id="page-content" class="mb-auto md:px-5 md:pt-5">
|
||||||
|
{ children... }
|
||||||
|
</div>
|
||||||
|
@footer.Footer()
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
}
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
|
|
||||||
// Returns the about page content
|
// Returns the about page content
|
||||||
templ About() {
|
templ About() {
|
||||||
@layout.Global() {
|
@layout.Global("About") {
|
||||||
<div class="text-center max-w-150 m-auto">
|
<div class="text-center max-w-150 m-auto">
|
||||||
<div class="text-4xl mt-8">About</div>
|
<div class="text-4xl mt-8">About</div>
|
||||||
<div class="text-xl font-bold mt-4">What is Project Reshoot?</div>
|
<div class="text-xl font-bold mt-4">What is Project Reshoot?</div>
|
||||||
10
internal/view/page/account.templ
Normal file
10
internal/view/page/account.templ
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package page
|
||||||
|
|
||||||
|
import "projectreshoot/internal/view/layout"
|
||||||
|
import "projectreshoot/internal/view/component/account"
|
||||||
|
|
||||||
|
templ Account(subpage string) {
|
||||||
|
@layout.Global("Account - " + subpage) {
|
||||||
|
@account.AccountContainer(subpage)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
import "strconv"
|
import "strconv"
|
||||||
|
|
||||||
// Page template for Error pages. Error code should be a HTTP status code as
|
// Page template for Error pages. Error code should be a HTTP status code as
|
||||||
// a string, and err should be the corresponding response title.
|
// a string, and err should be the corresponding response title.
|
||||||
// Message is a custom error message displayed below the code and error.
|
// Message is a custom error message displayed below the code and error.
|
||||||
templ Error(code int, err string, message string) {
|
templ Error(code int, err string, message string) {
|
||||||
@layout.Global() {
|
@layout.Global(err) {
|
||||||
<div
|
<div
|
||||||
class="grid mt-24 left-0 right-0 top-0 bottom-0
|
class="grid mt-24 left-0 right-0 top-0 bottom-0
|
||||||
place-content-center bg-base px-4"
|
place-content-center bg-base px-4"
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
|
|
||||||
// Page content for the index page
|
// Page content for the index page
|
||||||
templ Index() {
|
templ Index() {
|
||||||
@layout.Global() {
|
@layout.Global("Project Reshoot") {
|
||||||
<div class="text-center mt-24">
|
<div class="text-center mt-24">
|
||||||
<div class="text-4xl lg:text-6xl">Project Reshoot</div>
|
<div class="text-4xl lg:text-6xl">Project Reshoot</div>
|
||||||
<div>A better way to discover and rate films</div>
|
<div>A better way to discover and rate films</div>
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
import "projectreshoot/view/component/form"
|
import "projectreshoot/internal/view/component/form"
|
||||||
|
|
||||||
// Returns the login page
|
// Returns the login page
|
||||||
templ Login() {
|
templ Login() {
|
||||||
@layout.Global() {
|
@layout.Global("Login") {
|
||||||
<div class="max-w-100 mx-auto px-2">
|
<div class="max-w-100 mx-auto px-2">
|
||||||
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
|
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
|
||||||
<div class="p-4 sm:p-7">
|
<div class="p-4 sm:p-7">
|
||||||
90
internal/view/page/movie.templ
Normal file
90
internal/view/page/movie.templ
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package page
|
||||||
|
|
||||||
|
import "git.haelnorr.com/h/golib/tmdb"
|
||||||
|
import "projectreshoot/internal/view/layout"
|
||||||
|
|
||||||
|
templ Movie(movie *tmdb.Movie, credits *tmdb.Credits, image *tmdb.Image) {
|
||||||
|
@layout.Global(movie.Title) {
|
||||||
|
<div class="md:bg-surface0 md:p-2 md:rounded-lg transition-all">
|
||||||
|
<div id="billedcrew" class="hidden">
|
||||||
|
for _, billedcrew := range credits.BilledCrew() {
|
||||||
|
<span class="flex flex-col text-left w-[130px] md:w-[180px]">
|
||||||
|
<span class="font-bold">{ billedcrew.Name }</span>
|
||||||
|
<span class="text-subtext1">{ billedcrew.FRoles() }</span>
|
||||||
|
</span>
|
||||||
|
}
|
||||||
|
</div>
|
||||||
|
<div class="flex items-start">
|
||||||
|
<div class="w-[154px] md:w-[300px] flex-col">
|
||||||
|
<img
|
||||||
|
class="object-cover aspect-2/3 w-[154px] md:w-[300px]
|
||||||
|
transition-all md:rounded-md shadow-black shadow-2xl"
|
||||||
|
src={ movie.GetPoster(image, "w300") }
|
||||||
|
alt="Poster"
|
||||||
|
/>
|
||||||
|
<div
|
||||||
|
id="billedcrew-sm"
|
||||||
|
class="text-sm md:text-lg text-subtext1 flex gap-6
|
||||||
|
mt-5 flex-wrap justify-around flex-col px-5 md:hidden"
|
||||||
|
></div>
|
||||||
|
<script>
|
||||||
|
function moveBilledCrew() {
|
||||||
|
const billedCrewMd = document.getElementById('billedcrew-md');
|
||||||
|
const billedCrewSm = document.getElementById('billedcrew-sm');
|
||||||
|
const billedCrew = document.getElementById('billedcrew');
|
||||||
|
|
||||||
|
if (window.innerWidth < 768) {
|
||||||
|
billedCrewSm.innerHTML = billedCrew.innerHTML;
|
||||||
|
billedCrewMd.innerHTML = "";
|
||||||
|
} else {
|
||||||
|
billedCrewMd.innerHTML = billedCrew.innerHTML;
|
||||||
|
billedCrewSm.innerHTML = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
window.addEventListener('load', moveBilledCrew);
|
||||||
|
|
||||||
|
const resizeObs = new ResizeObserver(() => {
|
||||||
|
moveBilledCrew();
|
||||||
|
});
|
||||||
|
resizeObs.observe(document.body);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
<div class="flex flex-col flex-1 text-center px-4">
|
||||||
|
<span class="text-xl md:text-3xl font-semibold">
|
||||||
|
{ movie.Title }
|
||||||
|
</span>
|
||||||
|
<span class="text-sm md:text-lg text-subtext1">
|
||||||
|
{ movie.FGenres() }
|
||||||
|
• { movie.FRuntime() }
|
||||||
|
• { movie.ReleaseYear() }
|
||||||
|
</span>
|
||||||
|
<div class="flex justify-center gap-2 mt-2">
|
||||||
|
<div
|
||||||
|
class="w-20 h-20 md:w-30 md:h-30 bg-overlay2
|
||||||
|
transition-all rounded-sm"
|
||||||
|
></div>
|
||||||
|
<div
|
||||||
|
class="w-20 h-20 md:w-30 md:h-30 bg-overlay2
|
||||||
|
transition-all rounded-sm"
|
||||||
|
></div>
|
||||||
|
</div>
|
||||||
|
<div class="flex flex-col mt-4">
|
||||||
|
<span class="text-sm md:text-lg text-overlay2 italic">
|
||||||
|
{ movie.Tagline }
|
||||||
|
</span>
|
||||||
|
<div
|
||||||
|
id="billedcrew-md"
|
||||||
|
class="hidden text-sm md:text-lg text-subtext1 md:flex gap-6
|
||||||
|
mt-5 flex-wrap justify-around"
|
||||||
|
></div>
|
||||||
|
<span class="text-lg mt-5 font-semibold">Overview</span>
|
||||||
|
<span class="text-sm md:text-lg text-subtext1">
|
||||||
|
{ movie.Overview }
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
}
|
||||||
31
internal/view/page/movie_search.templ
Normal file
31
internal/view/page/movie_search.templ
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package page
|
||||||
|
|
||||||
|
import "projectreshoot/internal/view/layout"
|
||||||
|
|
||||||
|
templ Movies() {
|
||||||
|
@layout.Global("Search movies") {
|
||||||
|
<div class="max-w-4xl mx-auto md:mt-0 mt-2 px-2 md:px-0">
|
||||||
|
<form hx-post="/search-movies" hx-target="#search-movies-results">
|
||||||
|
<div
|
||||||
|
class="max-w-100 flex items-center space-x-2 mb-2"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
id="search"
|
||||||
|
name="search"
|
||||||
|
type="text"
|
||||||
|
placeholder="Search movies..."
|
||||||
|
class="flex-grow p-2 border rounded-lg
|
||||||
|
bg-mantle border-surface2 shadow-sm
|
||||||
|
focus:outline-none focus:ring-2 focus:ring-blue"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="submit"
|
||||||
|
class="py-2 px-4 bg-green text-mantle rounded-lg transition
|
||||||
|
hover:cursor-pointer hover:bg-green/75"
|
||||||
|
>Search</button>
|
||||||
|
</div>
|
||||||
|
<div id="search-movies-results" class="space-y-4"></div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
}
|
||||||
13
internal/view/page/profile.templ
Normal file
13
internal/view/page/profile.templ
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package page
|
||||||
|
|
||||||
|
import "projectreshoot/internal/view/layout"
|
||||||
|
import "projectreshoot/pkg/contexts"
|
||||||
|
|
||||||
|
templ Profile() {
|
||||||
|
{{ user := contexts.CurrentUser(ctx) }}
|
||||||
|
@layout.Global("Profile - " + user.Username) {
|
||||||
|
<div class="">
|
||||||
|
Hello, { user.Username }
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
package page
|
package page
|
||||||
|
|
||||||
import "projectreshoot/view/layout"
|
import "projectreshoot/internal/view/layout"
|
||||||
import "projectreshoot/view/component/form"
|
import "projectreshoot/internal/view/component/form"
|
||||||
|
|
||||||
// Returns the login page
|
// Returns the login page
|
||||||
templ Register() {
|
templ Register() {
|
||||||
@layout.Global() {
|
@layout.Global("Register") {
|
||||||
<div class="max-w-100 mx-auto px-2">
|
<div class="max-w-100 mx-auto px-2">
|
||||||
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
|
<div class="mt-7 bg-mantle border border-surface1 rounded-xl">
|
||||||
<div class="p-4 sm:p-7">
|
<div class="p-4 sm:p-7">
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package jwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Generates an access token for the provided user
|
|
||||||
func GenerateAccessToken(
|
|
||||||
config *config.Config,
|
|
||||||
user *db.User,
|
|
||||||
fresh bool,
|
|
||||||
rememberMe bool,
|
|
||||||
) (tokenStr string, exp int64, err error) {
|
|
||||||
issuedAt := time.Now().Unix()
|
|
||||||
expiresAt := issuedAt + (config.AccessTokenExpiry * 60)
|
|
||||||
var freshExpiresAt int64
|
|
||||||
if fresh {
|
|
||||||
freshExpiresAt = issuedAt + (config.TokenFreshTime * 60)
|
|
||||||
} else {
|
|
||||||
freshExpiresAt = issuedAt
|
|
||||||
}
|
|
||||||
var ttl string
|
|
||||||
if rememberMe {
|
|
||||||
ttl = "exp"
|
|
||||||
} else {
|
|
||||||
ttl = "session"
|
|
||||||
}
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
|
||||||
jwt.MapClaims{
|
|
||||||
"iss": config.TrustedHost,
|
|
||||||
"scope": "access",
|
|
||||||
"ttl": ttl,
|
|
||||||
"jti": uuid.New(),
|
|
||||||
"iat": issuedAt,
|
|
||||||
"exp": expiresAt,
|
|
||||||
"fresh": freshExpiresAt,
|
|
||||||
"sub": user.ID,
|
|
||||||
})
|
|
||||||
|
|
||||||
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
|
||||||
}
|
|
||||||
return signedToken, expiresAt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generates a refresh token for the provided user
|
|
||||||
func GenerateRefreshToken(
|
|
||||||
config *config.Config,
|
|
||||||
user *db.User,
|
|
||||||
rememberMe bool,
|
|
||||||
) (tokenStr string, exp int64, err error) {
|
|
||||||
issuedAt := time.Now().Unix()
|
|
||||||
expiresAt := issuedAt + (config.RefreshTokenExpiry * 60)
|
|
||||||
var ttl string
|
|
||||||
if rememberMe {
|
|
||||||
ttl = "exp"
|
|
||||||
} else {
|
|
||||||
ttl = "session"
|
|
||||||
}
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
|
||||||
jwt.MapClaims{
|
|
||||||
"iss": config.TrustedHost,
|
|
||||||
"scope": "refresh",
|
|
||||||
"ttl": ttl,
|
|
||||||
"jti": uuid.New(),
|
|
||||||
"iat": issuedAt,
|
|
||||||
"exp": expiresAt,
|
|
||||||
"sub": user.ID,
|
|
||||||
})
|
|
||||||
|
|
||||||
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
|
||||||
}
|
|
||||||
return signedToken, expiresAt, nil
|
|
||||||
}
|
|
||||||
268
jwt/parse.go
268
jwt/parse.go
@@ -1,268 +0,0 @@
|
|||||||
package jwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Parse an access token and return a struct with all the claims. Does validation on
|
|
||||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
|
||||||
// has the correct scope.
|
|
||||||
func ParseAccessToken(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
tokenString string,
|
|
||||||
) (*AccessToken, error) {
|
|
||||||
if tokenString == "" {
|
|
||||||
return nil, errors.New("Access token string not provided")
|
|
||||||
}
|
|
||||||
claims, err := parseToken(config.SecretKey, tokenString)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "parseToken")
|
|
||||||
}
|
|
||||||
expiry, err := checkTokenExpired(claims["exp"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
|
||||||
}
|
|
||||||
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
|
||||||
}
|
|
||||||
ttl, err := getTokenTTL(claims["ttl"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenTTL")
|
|
||||||
}
|
|
||||||
scope, err := getTokenScope(claims["scope"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenScope")
|
|
||||||
}
|
|
||||||
if scope != "access" {
|
|
||||||
return nil, errors.New("Token is not an Access token")
|
|
||||||
}
|
|
||||||
issuedAt, err := getIssuedTime(claims["iat"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getIssuedTime")
|
|
||||||
}
|
|
||||||
subject, err := getTokenSubject(claims["sub"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenSubject")
|
|
||||||
}
|
|
||||||
fresh, err := getFreshTime(claims["fresh"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getFreshTime")
|
|
||||||
}
|
|
||||||
jti, err := getTokenJTI(claims["jti"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenJTI")
|
|
||||||
}
|
|
||||||
|
|
||||||
token := &AccessToken{
|
|
||||||
ISS: issuer,
|
|
||||||
TTL: ttl,
|
|
||||||
EXP: expiry,
|
|
||||||
IAT: issuedAt,
|
|
||||||
SUB: subject,
|
|
||||||
Fresh: fresh,
|
|
||||||
JTI: jti,
|
|
||||||
Scope: scope,
|
|
||||||
}
|
|
||||||
|
|
||||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
|
||||||
}
|
|
||||||
if !valid {
|
|
||||||
return nil, errors.New("Token has been revoked")
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse a refresh token and return a struct with all the claims. Does validation on
|
|
||||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
|
||||||
// has the correct scope.
|
|
||||||
func ParseRefreshToken(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
tokenString string,
|
|
||||||
) (*RefreshToken, error) {
|
|
||||||
if tokenString == "" {
|
|
||||||
return nil, errors.New("Refresh token string not provided")
|
|
||||||
}
|
|
||||||
claims, err := parseToken(config.SecretKey, tokenString)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "parseToken")
|
|
||||||
}
|
|
||||||
expiry, err := checkTokenExpired(claims["exp"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
|
||||||
}
|
|
||||||
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
|
||||||
}
|
|
||||||
ttl, err := getTokenTTL(claims["ttl"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenTTL")
|
|
||||||
}
|
|
||||||
scope, err := getTokenScope(claims["scope"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenScope")
|
|
||||||
}
|
|
||||||
if scope != "refresh" {
|
|
||||||
return nil, errors.New("Token is not an Refresh token")
|
|
||||||
}
|
|
||||||
issuedAt, err := getIssuedTime(claims["iat"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getIssuedTime")
|
|
||||||
}
|
|
||||||
subject, err := getTokenSubject(claims["sub"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenSubject")
|
|
||||||
}
|
|
||||||
jti, err := getTokenJTI(claims["jti"])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "getTokenJTI")
|
|
||||||
}
|
|
||||||
|
|
||||||
token := &RefreshToken{
|
|
||||||
ISS: issuer,
|
|
||||||
TTL: ttl,
|
|
||||||
EXP: expiry,
|
|
||||||
IAT: issuedAt,
|
|
||||||
SUB: subject,
|
|
||||||
JTI: jti,
|
|
||||||
Scope: scope,
|
|
||||||
}
|
|
||||||
|
|
||||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
|
||||||
}
|
|
||||||
if !valid {
|
|
||||||
return nil, errors.New("Token has been revoked")
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse a token, validating its signing sigature and returning the claims
|
|
||||||
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
|
|
||||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
||||||
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
|
||||||
}
|
|
||||||
|
|
||||||
return []byte(secretKey), nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "jwt.Parse")
|
|
||||||
}
|
|
||||||
// Token decoded, parse the claims
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("Failed to parse claims")
|
|
||||||
}
|
|
||||||
return claims, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if a token is expired. Returns the expiry if not expired
|
|
||||||
func checkTokenExpired(expiry interface{}) (int64, error) {
|
|
||||||
// Coerce the expiry to a float64 to avoid scientific notation
|
|
||||||
expFloat, ok := expiry.(float64)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("Missing or invalid 'exp' claim")
|
|
||||||
}
|
|
||||||
// Convert to the int64 time we expect :)
|
|
||||||
expiryTime := int64(expFloat)
|
|
||||||
|
|
||||||
// Check if its expired
|
|
||||||
isExpired := time.Now().After(time.Unix(expiryTime, 0))
|
|
||||||
if isExpired {
|
|
||||||
return 0, errors.New("Token has expired")
|
|
||||||
}
|
|
||||||
return expiryTime, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if a token has a valid issuer. Returns the issuer if valid
|
|
||||||
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
|
|
||||||
issuerVal, ok := issuer.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("Missing or invalid 'iss' claim")
|
|
||||||
}
|
|
||||||
if issuer != trustedHost {
|
|
||||||
return "", errors.New("Issuer does not matched trusted host")
|
|
||||||
}
|
|
||||||
return issuerVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the scope matches the expected scope. Returns scope if true
|
|
||||||
func getTokenScope(scope interface{}) (string, error) {
|
|
||||||
scopeStr, ok := scope.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("Missing or invalid 'scope' claim")
|
|
||||||
}
|
|
||||||
return scopeStr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the TTL of the token, either "session" or "exp"
|
|
||||||
func getTokenTTL(ttl interface{}) (string, error) {
|
|
||||||
ttlStr, ok := ttl.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("Missing or invalid 'ttl' claim")
|
|
||||||
}
|
|
||||||
if ttlStr != "exp" && ttlStr != "session" {
|
|
||||||
return "", errors.New("TTL value is not recognised")
|
|
||||||
}
|
|
||||||
return ttlStr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the time the token was issued at
|
|
||||||
func getIssuedTime(issued interface{}) (int64, error) {
|
|
||||||
// Same float64 -> int64 trick as expiry
|
|
||||||
issuedFloat, ok := issued.(float64)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("Missing or invalid 'iat' claim")
|
|
||||||
}
|
|
||||||
issuedAt := int64(issuedFloat)
|
|
||||||
return issuedAt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the freshness expiry timestamp
|
|
||||||
func getFreshTime(fresh interface{}) (int64, error) {
|
|
||||||
freshUntil, ok := fresh.(float64)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("Missing or invalid 'fresh' claim")
|
|
||||||
}
|
|
||||||
return int64(freshUntil), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the subject of the token
|
|
||||||
func getTokenSubject(sub interface{}) (int, error) {
|
|
||||||
subject, ok := sub.(float64)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("Missing or invalid 'sub' claim")
|
|
||||||
}
|
|
||||||
return int(subject), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the JTI of the token
|
|
||||||
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
|
|
||||||
jtiStr, ok := jti.(string)
|
|
||||||
if !ok {
|
|
||||||
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
|
|
||||||
}
|
|
||||||
jtiUUID, err := uuid.Parse(jtiStr)
|
|
||||||
if err != nil {
|
|
||||||
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
|
|
||||||
}
|
|
||||||
return jtiUUID, nil
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
package jwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Revoke a token by adding it to the database
|
|
||||||
func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error {
|
|
||||||
jti := t.GetJTI()
|
|
||||||
exp := t.GetEXP()
|
|
||||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
|
||||||
_, err := tx.Exec(ctx, query, jti, exp)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tx.Exec")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if a token has been revoked. Returns true if not revoked.
|
|
||||||
func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) {
|
|
||||||
jti := t.GetJTI()
|
|
||||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
|
||||||
rows, err := tx.Query(ctx, query, jti)
|
|
||||||
if err != nil {
|
|
||||||
return false, errors.Wrap(err, "tx.Query")
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
revoked := rows.Next()
|
|
||||||
return !revoked, nil
|
|
||||||
}
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
package jwt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"projectreshoot/db"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Token interface {
|
|
||||||
GetJTI() uuid.UUID
|
|
||||||
GetEXP() int64
|
|
||||||
GetScope() string
|
|
||||||
GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Access token
|
|
||||||
type AccessToken struct {
|
|
||||||
ISS string // Issuer, generally TrustedHost
|
|
||||||
IAT int64 // Time issued at
|
|
||||||
EXP int64 // Time expiring at
|
|
||||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
|
||||||
SUB int // Subject (user) ID
|
|
||||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
|
||||||
Fresh int64 // Time freshness expiring at
|
|
||||||
Scope string // Should be "access"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refresh token
|
|
||||||
type RefreshToken struct {
|
|
||||||
ISS string // Issuer, generally TrustedHost
|
|
||||||
IAT int64 // Time issued at
|
|
||||||
EXP int64 // Time expiring at
|
|
||||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
|
||||||
SUB int // Subject (user) ID
|
|
||||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
|
||||||
Scope string // Should be "refresh"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
|
||||||
user, err := db.GetUserFromID(ctx, tx, a.SUB)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
|
||||||
user, err := db.GetUserFromID(ctx, tx, r.SUB)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a AccessToken) GetJTI() uuid.UUID {
|
|
||||||
return a.JTI
|
|
||||||
}
|
|
||||||
func (r RefreshToken) GetJTI() uuid.UUID {
|
|
||||||
return r.JTI
|
|
||||||
}
|
|
||||||
func (a AccessToken) GetEXP() int64 {
|
|
||||||
return a.EXP
|
|
||||||
}
|
|
||||||
func (r RefreshToken) GetEXP() int64 {
|
|
||||||
return r.EXP
|
|
||||||
}
|
|
||||||
func (a AccessToken) GetScope() string {
|
|
||||||
return a.Scope
|
|
||||||
}
|
|
||||||
func (r RefreshToken) GetScope() string {
|
|
||||||
return r.Scope
|
|
||||||
}
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package logging
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/pkgerrors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Takes a log level as string and converts it to a zerolog.Level interface.
|
|
||||||
// If the string is not a valid input it will return zerolog.InfoLevel
|
|
||||||
func GetLogLevel(level string) zerolog.Level {
|
|
||||||
levels := map[string]zerolog.Level{
|
|
||||||
"trace": zerolog.TraceLevel,
|
|
||||||
"debug": zerolog.DebugLevel,
|
|
||||||
"info": zerolog.InfoLevel,
|
|
||||||
"warn": zerolog.WarnLevel,
|
|
||||||
"error": zerolog.ErrorLevel,
|
|
||||||
"fatal": zerolog.FatalLevel,
|
|
||||||
"panic": zerolog.PanicLevel,
|
|
||||||
}
|
|
||||||
logLevel, valid := levels[level]
|
|
||||||
if !valid {
|
|
||||||
return zerolog.InfoLevel
|
|
||||||
}
|
|
||||||
return logLevel
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a pointer to a new log file with the specified path.
|
|
||||||
// Remember to call file.Close() when finished writing to the log file
|
|
||||||
func GetLogFile(path string) (*os.File, error) {
|
|
||||||
logPath := filepath.Join(path, "server.log")
|
|
||||||
file, err := os.OpenFile(
|
|
||||||
logPath,
|
|
||||||
os.O_APPEND|os.O_CREATE|os.O_WRONLY,
|
|
||||||
0663,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "os.OpenFile")
|
|
||||||
}
|
|
||||||
return file, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get a pointer to a new zerolog.Logger with the specified level and output
|
|
||||||
// Can provide a file, writer or both. Must provide at least one of the two
|
|
||||||
func GetLogger(
|
|
||||||
logLevel zerolog.Level,
|
|
||||||
w io.Writer,
|
|
||||||
logFile *os.File,
|
|
||||||
logDir string,
|
|
||||||
) (*zerolog.Logger, error) {
|
|
||||||
if w == nil && logFile == nil {
|
|
||||||
return nil, errors.New("No Writer provided for log output.")
|
|
||||||
}
|
|
||||||
|
|
||||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
|
|
||||||
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
|
|
||||||
|
|
||||||
var consoleWriter zerolog.ConsoleWriter
|
|
||||||
if w != nil {
|
|
||||||
consoleWriter = zerolog.ConsoleWriter{Out: w}
|
|
||||||
}
|
|
||||||
|
|
||||||
var output io.Writer
|
|
||||||
if logFile != nil {
|
|
||||||
if w != nil {
|
|
||||||
output = zerolog.MultiLevelWriter(logFile, consoleWriter)
|
|
||||||
} else {
|
|
||||||
output = logFile
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
output = consoleWriter
|
|
||||||
}
|
|
||||||
logger := zerolog.New(output).
|
|
||||||
With().
|
|
||||||
Timestamp().
|
|
||||||
Logger().
|
|
||||||
Level(logLevel)
|
|
||||||
|
|
||||||
return &logger, nil
|
|
||||||
}
|
|
||||||
233
main.go
233
main.go
@@ -1,233 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"embed"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/fs"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/logging"
|
|
||||||
"projectreshoot/server"
|
|
||||||
"projectreshoot/tests"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed static/*
|
|
||||||
var embeddedStatic embed.FS
|
|
||||||
|
|
||||||
// Gets the static files
|
|
||||||
func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) {
|
|
||||||
if _, err := os.Stat("static"); err == nil {
|
|
||||||
// Use actual filesystem in development
|
|
||||||
logger.Debug().Msg("Using filesystem for static files")
|
|
||||||
return http.Dir("static"), nil
|
|
||||||
} else {
|
|
||||||
// Use embedded filesystem in production
|
|
||||||
logger.Debug().Msg("Using embedded static files")
|
|
||||||
subFS, err := fs.Sub(embeddedStatic, "static")
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "fs.Sub")
|
|
||||||
}
|
|
||||||
return http.FS(subFS), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var maint uint32 // atomic: 1 if in maintenance mode
|
|
||||||
|
|
||||||
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
|
|
||||||
func handleMaintSignals(
|
|
||||||
conn *db.SafeConn,
|
|
||||||
srv *http.Server,
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
config *config.Config,
|
|
||||||
) {
|
|
||||||
logger.Debug().Msg("Starting signal listener")
|
|
||||||
ch := make(chan os.Signal, 1)
|
|
||||||
srv.RegisterOnShutdown(func() {
|
|
||||||
logger.Debug().Msg("Shutting down signal listener")
|
|
||||||
close(ch)
|
|
||||||
})
|
|
||||||
go func() {
|
|
||||||
for sig := range ch {
|
|
||||||
switch sig {
|
|
||||||
case syscall.SIGUSR1:
|
|
||||||
if atomic.LoadUint32(&maint) != 1 {
|
|
||||||
atomic.StoreUint32(&maint, 1)
|
|
||||||
logger.Info().Msg("Signal received: Starting maintenance")
|
|
||||||
logger.Info().Msg("Attempting to acquire database lock")
|
|
||||||
conn.Pause(config.DBLockTimeout * time.Second)
|
|
||||||
}
|
|
||||||
case syscall.SIGUSR2:
|
|
||||||
if atomic.LoadUint32(&maint) != 0 {
|
|
||||||
logger.Info().Msg("Signal received: Maintenance over")
|
|
||||||
logger.Info().Msg("Releasing database lock")
|
|
||||||
conn.Resume()
|
|
||||||
atomic.StoreUint32(&maint, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initializes and runs the server
|
|
||||||
func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|
||||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
config, err := config.GetConfig(args)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "server.GetConfig")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the version of the database required
|
|
||||||
if args["dbver"] == "true" {
|
|
||||||
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var logfile *os.File = nil
|
|
||||||
if config.LogOutput == "both" || config.LogOutput == "file" {
|
|
||||||
logfile, err = logging.GetLogFile(config.LogDir)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "logging.GetLogFile")
|
|
||||||
}
|
|
||||||
defer logfile.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
var consoleWriter io.Writer
|
|
||||||
if config.LogOutput == "both" || config.LogOutput == "console" {
|
|
||||||
consoleWriter = w
|
|
||||||
}
|
|
||||||
|
|
||||||
logger, err := logging.GetLogger(
|
|
||||||
config.LogLevel,
|
|
||||||
consoleWriter,
|
|
||||||
logfile,
|
|
||||||
config.LogDir,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "logging.GetLogger")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug().Msg("Config loaded and logger started")
|
|
||||||
logger.Debug().Msg("Connecting to database")
|
|
||||||
var conn *db.SafeConn
|
|
||||||
if args["test"] == "true" {
|
|
||||||
logger.Debug().Msg("Server in test mode, using test database")
|
|
||||||
ver, err := strconv.ParseInt(config.DBName, 10, 0)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "strconv.ParseInt")
|
|
||||||
}
|
|
||||||
testconn, err := tests.SetupTestDB(ver)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "tests.SetupTestDB")
|
|
||||||
}
|
|
||||||
conn = db.MakeSafe(testconn, logger)
|
|
||||||
} else {
|
|
||||||
conn, err = db.ConnectToDatabase(config.DBName, logger)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
logger.Debug().Msg("Getting static files")
|
|
||||||
staticFS, err := getStaticFiles(logger)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "getStaticFiles")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug().Msg("Setting up HTTP server")
|
|
||||||
srv := server.NewServer(config, logger, conn, &staticFS, &maint)
|
|
||||||
httpServer := &http.Server{
|
|
||||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
|
||||||
Handler: srv,
|
|
||||||
ReadHeaderTimeout: config.ReadHeaderTimeout * time.Second,
|
|
||||||
WriteTimeout: config.WriteTimeout * time.Second,
|
|
||||||
IdleTimeout: config.IdleTimeout * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Runs function for testing in dev if --test flag true
|
|
||||||
if args["tester"] == "true" {
|
|
||||||
logger.Debug().Msg("Running tester function")
|
|
||||||
test(config, logger, conn, httpServer)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setups a channel to listen for os.Signal
|
|
||||||
handleMaintSignals(conn, httpServer, logger, config)
|
|
||||||
|
|
||||||
// Runs the http server
|
|
||||||
logger.Debug().Msg("Starting up the HTTP server")
|
|
||||||
go func() {
|
|
||||||
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
|
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
||||||
logger.Error().Err(err).Msg("Error listening and serving")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Handles graceful shutdown
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
<-ctx.Done()
|
|
||||||
shutdownCtx := context.Background()
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
|
||||||
logger.Error().Err(err).Msg("Error shutting down server")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
wg.Wait()
|
|
||||||
logger.Info().Msg("Shutting down")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start of runtime. Parse commandline arguments & flags, Initializes context
|
|
||||||
// and starts the server
|
|
||||||
func main() {
|
|
||||||
// Parse commandline args
|
|
||||||
host := flag.String("host", "", "Override host to listen on")
|
|
||||||
port := flag.String("port", "", "Override port to listen on")
|
|
||||||
test := flag.Bool("test", false, "Run server in test mode")
|
|
||||||
tester := flag.Bool("tester", false, "Run tester function instead of main program")
|
|
||||||
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
|
||||||
loglevel := flag.String("loglevel", "", "Set log level")
|
|
||||||
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
// Map the args for easy access
|
|
||||||
args := map[string]string{
|
|
||||||
"host": *host,
|
|
||||||
"port": *port,
|
|
||||||
"test": strconv.FormatBool(*test),
|
|
||||||
"tester": strconv.FormatBool(*tester),
|
|
||||||
"dbver": strconv.FormatBool(*dbver),
|
|
||||||
"loglevel": *loglevel,
|
|
||||||
"logoutput": *logoutput,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the server
|
|
||||||
ctx := context.Background()
|
|
||||||
if err := run(ctx, os.Stdout, args); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
144
main_test.go
144
main_test.go
@@ -1,144 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_main(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
t.Cleanup(cancel)
|
|
||||||
args := map[string]string{"test": "true"}
|
|
||||||
var stdout bytes.Buffer
|
|
||||||
os.Setenv("SECRET_KEY", ".")
|
|
||||||
os.Setenv("HOST", "127.0.0.1")
|
|
||||||
os.Setenv("PORT", "3232")
|
|
||||||
runSrvErr := make(chan error)
|
|
||||||
go func() {
|
|
||||||
if err := run(ctx, &stdout, args); err != nil {
|
|
||||||
runSrvErr <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz")
|
|
||||||
if err != nil {
|
|
||||||
runSrvErr <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
runSrvErr <- nil
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case err := <-runSrvErr:
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error starting test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Log("Test server started")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
|
|
||||||
done := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
expected := "Global database lock acquired"
|
|
||||||
for {
|
|
||||||
if strings.Contains(stdout.String(), expected) {
|
|
||||||
done <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proc, err := os.FindProcess(os.Getpid())
|
|
||||||
require.NoError(t, err)
|
|
||||||
proc.Signal(syscall.SIGUSR1)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
t.Log("found")
|
|
||||||
case <-time.After(250 * time.Millisecond):
|
|
||||||
t.Errorf("Not found")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("SIGUSR2 releases database global lock", func(t *testing.T) {
|
|
||||||
done := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
expected := "Global database lock released"
|
|
||||||
for {
|
|
||||||
if strings.Contains(stdout.String(), expected) {
|
|
||||||
done <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
proc, err := os.FindProcess(os.Getpid())
|
|
||||||
require.NoError(t, err)
|
|
||||||
proc.Signal(syscall.SIGUSR2)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
t.Log("found")
|
|
||||||
case <-time.After(250 * time.Millisecond):
|
|
||||||
t.Errorf("Not found")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForReady(
|
|
||||||
ctx context.Context,
|
|
||||||
timeout time.Duration,
|
|
||||||
endpoint string,
|
|
||||||
) error {
|
|
||||||
client := http.Client{}
|
|
||||||
startTime := time.Now()
|
|
||||||
for {
|
|
||||||
req, err := http.NewRequestWithContext(
|
|
||||||
ctx,
|
|
||||||
http.MethodGet,
|
|
||||||
endpoint,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Error making request: %s\n", err.Error())
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
fmt.Println("Endpoint is ready!")
|
|
||||||
resp.Body.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
default:
|
|
||||||
if time.Since(startTime) >= timeout {
|
|
||||||
return fmt.Errorf("timeout reached while waiting for endpoint")
|
|
||||||
}
|
|
||||||
// wait a little while between checks
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,143 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"projectreshoot/config"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/cookies"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/handler"
|
|
||||||
"projectreshoot/jwt"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Attempt to use a valid refresh token to generate a new token pair
|
|
||||||
func refreshAuthTokens(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
req *http.Request,
|
|
||||||
ref *jwt.RefreshToken,
|
|
||||||
) (*db.User, error) {
|
|
||||||
user, err := ref.GetUser(ctx, tx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "ref.GetUser")
|
|
||||||
}
|
|
||||||
|
|
||||||
rememberMe := map[string]bool{
|
|
||||||
"session": false,
|
|
||||||
"exp": true,
|
|
||||||
}[ref.TTL]
|
|
||||||
|
|
||||||
// Set fresh to true because new tokens coming from refresh request
|
|
||||||
err = cookies.SetTokenCookies(w, req, config, user, false, rememberMe)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
|
|
||||||
}
|
|
||||||
// New tokens sent, revoke the used refresh token
|
|
||||||
err = jwt.RevokeToken(ctx, tx, ref)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "jwt.RevokeToken")
|
|
||||||
}
|
|
||||||
// Return the authorized user
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the cookies for token strings and attempt to authenticate them
|
|
||||||
func getAuthenticatedUser(
|
|
||||||
config *config.Config,
|
|
||||||
ctx context.Context,
|
|
||||||
tx *db.SafeTX,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
) (*contexts.AuthenticatedUser, error) {
|
|
||||||
// Get token strings from cookies
|
|
||||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
|
||||||
// Attempt to parse the access token
|
|
||||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
|
||||||
if err != nil {
|
|
||||||
// Access token invalid, attempt to parse refresh token
|
|
||||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
|
||||||
}
|
|
||||||
// Refresh token valid, attempt to get a new token pair
|
|
||||||
user, err := refreshAuthTokens(config, ctx, tx, w, r, rT)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "refreshAuthTokens")
|
|
||||||
}
|
|
||||||
// New token pair sent, return the authorized user
|
|
||||||
authUser := contexts.AuthenticatedUser{
|
|
||||||
User: user,
|
|
||||||
Fresh: time.Now().Unix(),
|
|
||||||
}
|
|
||||||
return &authUser, nil
|
|
||||||
}
|
|
||||||
// Access token valid
|
|
||||||
user, err := aT.GetUser(ctx, tx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "aT.GetUser")
|
|
||||||
}
|
|
||||||
authUser := contexts.AuthenticatedUser{
|
|
||||||
User: user,
|
|
||||||
Fresh: aT.Fresh,
|
|
||||||
}
|
|
||||||
return &authUser, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to authenticate the user and add their account details
|
|
||||||
// to the request context
|
|
||||||
func Authentication(
|
|
||||||
logger *zerolog.Logger,
|
|
||||||
config *config.Config,
|
|
||||||
conn *db.SafeConn,
|
|
||||||
next http.Handler,
|
|
||||||
maint *uint32,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/static/css/output.css" ||
|
|
||||||
r.URL.Path == "/static/favicon.ico" {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if atomic.LoadUint32(maint) == 1 {
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the transaction
|
|
||||||
tx, err := conn.Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
// Failed to start transaction, skip auth
|
|
||||||
logger.Warn().Err(err).
|
|
||||||
Msg("Skipping Auth - unable to start a transaction")
|
|
||||||
handler.ErrorPage(http.StatusServiceUnavailable, w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
// User auth failed, delete the cookies to avoid repeat requests
|
|
||||||
cookies.DeleteCookie(w, "access", "/")
|
|
||||||
cookies.DeleteCookie(w, "refresh", "/")
|
|
||||||
logger.Debug().
|
|
||||||
Str("remote_addr", r.RemoteAddr).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to authenticate user")
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
uctx := contexts.SetUser(r.Context(), user)
|
|
||||||
newReq := r.WithContext(uctx)
|
|
||||||
next.ServeHTTP(w, newReq)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/tests"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAuthenticationMiddleware(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := db.MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
// Handler to check outcome of Authentication middleware
|
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
if user == nil {
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
w.Write([]byte(strconv.Itoa(0)))
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte(strconv.Itoa(user.ID)))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
var maint uint32
|
|
||||||
atomic.StoreUint32(&maint, 0)
|
|
||||||
// Add the middleware and create the server
|
|
||||||
authHandler := Authentication(logger, cfg, sconn, testHandler, &maint)
|
|
||||||
require.NoError(t, err)
|
|
||||||
server := httptest.NewServer(authHandler)
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
tokens := getTokens()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
id int
|
|
||||||
accessToken string
|
|
||||||
refreshToken string
|
|
||||||
expectedCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Valid Access Token (Fresh)",
|
|
||||||
id: 1,
|
|
||||||
accessToken: tokens["accessFresh"],
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid Access Token (Unfresh)",
|
|
||||||
id: 1,
|
|
||||||
accessToken: tokens["accessUnfresh"],
|
|
||||||
refreshToken: tokens["refreshExpired"],
|
|
||||||
expectedCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid Refresh Token (Triggers Refresh)",
|
|
||||||
id: 1,
|
|
||||||
accessToken: tokens["accessExpired"],
|
|
||||||
refreshToken: tokens["refreshValid"],
|
|
||||||
expectedCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Both tokens expired",
|
|
||||||
accessToken: tokens["accessExpired"],
|
|
||||||
refreshToken: tokens["refreshExpired"],
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Access token revoked",
|
|
||||||
accessToken: tokens["accessRevoked"],
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Refresh token revoked",
|
|
||||||
accessToken: "",
|
|
||||||
refreshToken: tokens["refreshRevoked"],
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid Tokens",
|
|
||||||
accessToken: tokens["invalid"],
|
|
||||||
refreshToken: tokens["invalid"],
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No Tokens",
|
|
||||||
accessToken: "",
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
client := &http.Client{}
|
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
|
|
||||||
|
|
||||||
// Add cookies if provided
|
|
||||||
if tt.accessToken != "" {
|
|
||||||
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
|
|
||||||
}
|
|
||||||
if tt.refreshToken != "" {
|
|
||||||
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.expectedCode, resp.StatusCode)
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, strconv.Itoa(tt.id), string(body))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the tokens to test with
|
|
||||||
func getTokens() map[string]string {
|
|
||||||
tokens := map[string]string{
|
|
||||||
"accessFresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzIyMTAsImZyZXNoIjo0ODk1NjcyMjEwLCJpYXQiOjE3Mzk2NzIyMTAsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImE4Njk2YWM4LTg3OWMtNDdkNC1iZWM2LTRlY2Y4MTRiZThiZiIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.6nAquDY0JBLPdaJ9q_sMpKj1ISG4Vt2U05J57aoPue8",
|
|
||||||
"accessUnfresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJmcmVzaCI6MTczOTY3NTY3MSwiaWF0IjoxNzM5Njc1NjcxLCJpc3MiOiIxMjcuMC4wLjEiLCJqdGkiOiJjOGNhZmFjNy0yODkzLTQzNzMtOTI4ZS03MGUwODJkYmM2MGIiLCJzY29wZSI6ImFjY2VzcyIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.plWQVFwHlhXUYI5utS7ny1JfXjJSFrigkq-PnTHD5VY",
|
|
||||||
"accessExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImZyZXNoIjoxNzM5NjcyMjQ4LCJpYXQiOjE3Mzk2NzIyNDgsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjgxYzA1YzBjLTJhOGItNGQ2MC04Yzc4LWY2ZTQxODYxZDFmNCIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.iI1f17kKTuFDEMEYltJRIwRYgYQ-_nF9Wsn0KR6x77Q",
|
|
||||||
"refreshValid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImlhdCI6MTczOTY3MTkyMiwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTUxMTY3ZWEtNDA3OS00ZTczLTkzZDQtNTgwZDMzODRjZDU4Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.tvtqQ8Z4WrYWHHb0MaEPdsU2FT2KLRE1zHOv3ipoFyc",
|
|
||||||
"refreshExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImlhdCI6MTczOTY3MjI0OCwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTg5YTc5MTYtZGEzYi00YmJhLWI3ZDMtOWI1N2ViNjRhMmU0Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.rH_fytC7Duxo598xacu820pQKF9ELbG8674h_bK_c4I",
|
|
||||||
"accessRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImZyZXNoIjoxNzM5NjcxOTIyLCJpYXQiOjE3Mzk2NzE5MjIsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjBhNmIzMzhlLTkzMGEtNDNmZS04ZjcwLTFhNmRhZWQyNTZmYSIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.mZLuCp9amcm2_CqYvbHPlk86nfiuy_Or8TlntUCw4Qs",
|
|
||||||
"refreshRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJpYXQiOjE3Mzk2NzU2NzEsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImI3ZmE1MWRjLTg1MzItNDJlMS04NzU2LTVkMjViZmIyMDAzYSIsInNjb3BlIjoicmVmcmVzaCIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.5Q9yDZN5FubfCWHclUUZEkJPOUHcOEpVpgcUK-ameHo",
|
|
||||||
"invalid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0ODUxNDA5ODQsImlhdCI6MTQ4NTEzNzM4NCwiaXNzIjoiYWNtZS5jb20iLCJzdWIiOiIyOWFjMGMxOC0wYjRhLTQyY2YtODJmYy0wM2Q1NzAzMThhMWQiLCJhcHBsaWNhdGlvbklkIjoiNzkxMDM3MzQtOTdhYi00ZDFhLWFmMzctZTAwNmQwNWQyOTUyIiwicm9sZXMiOltdfQ.Mp0Pcwsz5VECK11Kf2ZZNF_SMKu5CgBeLN9ZOP04kZo",
|
|
||||||
}
|
|
||||||
return tokens
|
|
||||||
}
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"compress/gzip"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Gzip(next http.Handler, useGzip bool) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") ||
|
|
||||||
!useGzip {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Encoding", "gzip")
|
|
||||||
gz := gzip.NewWriter(w)
|
|
||||||
defer gz.Close()
|
|
||||||
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
|
|
||||||
next.ServeHTTP(gzw, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type gzipResponseWriter struct {
|
|
||||||
io.Writer
|
|
||||||
http.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w gzipResponseWriter) Write(b []byte) (int, error) {
|
|
||||||
return w.Writer.Write(b)
|
|
||||||
}
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/handler"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Wraps the http.ResponseWriter, adding a statusCode field
|
|
||||||
type wrappedWriter struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
statusCode int
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extends WriteHeader to the ResponseWriter to add the status code
|
|
||||||
func (w *wrappedWriter) WriteHeader(statusCode int) {
|
|
||||||
w.ResponseWriter.WriteHeader(statusCode)
|
|
||||||
w.statusCode = statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
// Middleware to add logs to console with details of the request
|
|
||||||
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/static/css/output.css" ||
|
|
||||||
r.URL.Path == "/static/favicon.ico" {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
start, err := contexts.GetStartTime(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
handler.ErrorPage(http.StatusInternalServerError, w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wrapped := &wrappedWriter{
|
|
||||||
ResponseWriter: w,
|
|
||||||
statusCode: http.StatusOK,
|
|
||||||
}
|
|
||||||
next.ServeHTTP(wrapped, r)
|
|
||||||
logger.Info().
|
|
||||||
Int("status", wrapped.statusCode).
|
|
||||||
Str("method", r.Method).
|
|
||||||
Str("resource", r.URL.Path).
|
|
||||||
Dur("time_elapsed", time.Since(start)).
|
|
||||||
Str("remote_addr", r.Header.Get("X-Forwarded-For")).
|
|
||||||
Msg("Served")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"projectreshoot/handler"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Checks if the user is set in the context and shows 401 page if not logged in
|
|
||||||
func LoginReq(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
if user == nil {
|
|
||||||
handler.ErrorPage(http.StatusUnauthorized, w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if the user is set in the context and redirects them to profile if
|
|
||||||
// they are logged in
|
|
||||||
func LogoutReq(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
if user != nil {
|
|
||||||
http.Redirect(w, r, "/profile", http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/tests"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPageLoginRequired(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := db.MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
// Handler to check outcome of Authentication middleware
|
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
var maint uint32
|
|
||||||
atomic.StoreUint32(&maint, 0)
|
|
||||||
// Add the middleware and create the server
|
|
||||||
loginRequiredHandler := LoginReq(testHandler)
|
|
||||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
|
|
||||||
server := httptest.NewServer(authHandler)
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
tokens := getTokens()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
accessToken string
|
|
||||||
refreshToken string
|
|
||||||
expectedCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Valid Login",
|
|
||||||
accessToken: tokens["accessFresh"],
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Expired login",
|
|
||||||
accessToken: tokens["accessExpired"],
|
|
||||||
refreshToken: tokens["refreshExpired"],
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No login",
|
|
||||||
accessToken: "",
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
client := &http.Client{}
|
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
|
|
||||||
|
|
||||||
// Add cookies if provided
|
|
||||||
if tt.accessToken != "" {
|
|
||||||
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
|
|
||||||
}
|
|
||||||
if tt.refreshToken != "" {
|
|
||||||
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.expectedCode, resp.StatusCode)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func FreshReq(
|
|
||||||
next http.Handler,
|
|
||||||
) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := contexts.GetUser(r.Context())
|
|
||||||
isFresh := time.Now().Before(time.Unix(user.Fresh, 0))
|
|
||||||
if !isFresh {
|
|
||||||
w.WriteHeader(444)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"projectreshoot/db"
|
|
||||||
"projectreshoot/tests"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestReauthRequired(t *testing.T) {
|
|
||||||
cfg, err := tests.TestConfig()
|
|
||||||
require.NoError(t, err)
|
|
||||||
logger := tests.NilLogger()
|
|
||||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn, err := tests.SetupTestDB(ver)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sconn := db.MakeSafe(conn, logger)
|
|
||||||
defer sconn.Close()
|
|
||||||
|
|
||||||
// Handler to check outcome of Authentication middleware
|
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
var maint uint32
|
|
||||||
atomic.StoreUint32(&maint, 0)
|
|
||||||
// Add the middleware and create the server
|
|
||||||
reauthRequiredHandler := FreshReq(testHandler)
|
|
||||||
loginRequiredHandler := LoginReq(reauthRequiredHandler)
|
|
||||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
|
|
||||||
server := httptest.NewServer(authHandler)
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
tokens := getTokens()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
accessToken string
|
|
||||||
refreshToken string
|
|
||||||
expectedCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Fresh Login",
|
|
||||||
accessToken: tokens["accessFresh"],
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Unfresh Login",
|
|
||||||
accessToken: tokens["accessUnfresh"],
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: 444,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Expired login",
|
|
||||||
accessToken: tokens["accessExpired"],
|
|
||||||
refreshToken: tokens["refreshExpired"],
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No login",
|
|
||||||
accessToken: "",
|
|
||||||
refreshToken: "",
|
|
||||||
expectedCode: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
client := &http.Client{}
|
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
|
|
||||||
|
|
||||||
// Add cookies if provided
|
|
||||||
if tt.accessToken != "" {
|
|
||||||
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
|
|
||||||
}
|
|
||||||
if tt.refreshToken != "" {
|
|
||||||
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.expectedCode, resp.StatusCode)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"projectreshoot/contexts"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func StartTimer(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
start := time.Now()
|
|
||||||
ctx := contexts.SetStart(r.Context(), start)
|
|
||||||
newReq := r.WithContext(ctx)
|
|
||||||
next.ServeHTTP(w, newReq)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
9
pkg/contexts/currentuser.go
Normal file
9
pkg/contexts/currentuser.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package contexts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"projectreshoot/internal/models"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hwsauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
var CurrentUser hwsauth.ContextLoader[*models.User]
|
||||||
20
pkg/embedfs/embedfs.go
Normal file
20
pkg/embedfs/embedfs.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package embedfs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"io/fs"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed files/*
|
||||||
|
var embeddedFiles embed.FS
|
||||||
|
|
||||||
|
// Gets the embedded files
|
||||||
|
func GetEmbeddedFS() (fs.FS, error) {
|
||||||
|
subFS, err := fs.Sub(embeddedFiles, "files")
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "fs.Sub")
|
||||||
|
}
|
||||||
|
return subFS, nil
|
||||||
|
}
|
||||||
BIN
pkg/embedfs/files/assets/error.png
Normal file
BIN
pkg/embedfs/files/assets/error.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
1625
pkg/embedfs/files/css/output.css
Normal file
1625
pkg/embedfs/files/css/output.css
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user