Compare commits

3 Commits

106 changed files with 2327 additions and 2530 deletions

View File

@@ -5,7 +5,7 @@ tmp_dir = "tmp"
[build] [build]
args_bin = [] args_bin = []
bin = "./tmp/main" bin = "./tmp/main"
cmd = "go build -o ./tmp/main ." cmd = "go build -o ./tmp/main ./cmd/projectreshoot"
delay = 1000 delay = 1000
exclude_dir = [] exclude_dir = []
exclude_file = [] exclude_file = []

View File

@@ -53,10 +53,10 @@ jobs:
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
scp -i ~/.ssh/id_ed25519 projectreshoot-production-${GITHUB_SHA} $USER@$HOST:$DIR scp -i ~/.ssh/id_ed25519 ./bin/projectreshoot-production-${GITHUB_SHA} $USER@$HOST:$DIR
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
scp -i ~/.ssh/id_ed25519 prmigrate-production-${GITHUB_SHA} $USER@$HOST:$MIG_DIR scp -i ~/.ssh/id_ed25519 .bin/migrate-production-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR

View File

@@ -53,10 +53,10 @@ jobs:
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR scp -i ~/.ssh/id_ed25519 ./bin/projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
scp -i ~/.ssh/id_ed25519 prmigrate-staging-${GITHUB_SHA} $USER@$HOST:$MIG_DIR scp -i ~/.ssh/id_ed25519 ./bin/migrate-staging-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR

10
.gitignore vendored
View File

@@ -1,11 +1,9 @@
.env .env
query.sql *.db*
*.db
.logs/ .logs/
server.log server.log
bin/
tmp/ tmp/
prmigrate
projectreshoot
static/css/output.css static/css/output.css
view/**/*_templ.go internal/view/**/*_templ.go
view/**/*_templ.txt internal/view/**/*_templ.txt

View File

@@ -5,33 +5,25 @@
BINARY_NAME=projectreshoot BINARY_NAME=projectreshoot
build: build:
tailwindcss -i ./static/css/input.css -o ./static/css/output.css && \ tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css && \
go mod tidy && \ go mod tidy && \
templ generate && \ templ generate && \
go generate && \ go generate ./cmd/${BINARY_NAME} && \
go build -ldflags="-w -s" -o ${BINARY_NAME}${SUFFIX} go build -ldflags="-w -s" -o ./bin/${BINARY_NAME}${SUFFIX} ./cmd/${BINARY_NAME}
run:
make build
./bin/${BINARY_NAME}${SUFFIX}
dev: dev:
templ generate --watch &\ templ generate --watch &\
air &\ air &\
tailwindcss -i ./static/css/input.css -o ./static/css/output.css --watch tailwindcss -i ./pkg/embedfs/files/css/input.css -o ./pkg/embedfs/files/css/output.css --watch
tester:
go mod tidy && \
go run . --port 3232 --tester --loglevel trace
test:
go mod tidy && \
templ generate && \
go generate && \
go test .
go test ./db
go test ./middleware
clean: clean:
go clean go clean
migrate: migrate:
go mod tidy && \ go mod tidy && \
go generate && \ go generate ./cmd/migrate && \
go build -ldflags="-w -s" -o prmigrate${SUFFIX} ./migrate go build -ldflags="-w -s" -o ./bin/migrate${SUFFIX} ./cmd/migrate

View File

@@ -20,7 +20,7 @@ var migrationsFS embed.FS
func main() { func main() {
if len(os.Args) != 4 { if len(os.Args) != 4 {
fmt.Println("Usage: prmigrate <file_path> up-to|down-to <version>") fmt.Println("Usage: migrate <file_path> up-to|down-to <version>")
os.Exit(1) os.Exit(1)
} }

View File

@@ -1,4 +1,4 @@
package db package main
import ( import (
"database/sql" "database/sql"
@@ -6,35 +6,30 @@ import (
"strconv" "strconv"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
_ "modernc.org/sqlite" _ "github.com/mattn/go-sqlite3"
) )
// Returns a database connection handle for the DB func setupDBConn(dbName string) (*sql.DB, error) {
func ConnectToDatabase( opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
dbName string, file := fmt.Sprintf("file:%s.db?%s", dbName, opts)
logger *zerolog.Logger, conn, err := sql.Open("sqlite3", file)
) (*SafeConn, error) {
file := fmt.Sprintf("file:%s.db", dbName)
db, err := sql.Open("sqlite", file)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sql.Open") return nil, errors.Wrap(err, "sql.Open")
} }
version, err := strconv.Atoi(dbName) err = checkDBVersion(conn, dbName)
if err != nil {
return nil, errors.Wrap(err, "strconv.Atoi")
}
err = checkDBVersion(db, version)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "checkDBVersion") return nil, errors.Wrap(err, "checkDBVersion")
} }
conn := MakeSafe(db, logger)
return conn, nil return conn, nil
} }
// Check the database version // Check the database version
func checkDBVersion(db *sql.DB, expectVer int) error { func checkDBVersion(db *sql.DB, dbName string) error {
expectVer, err := strconv.Atoi(dbName)
if err != nil {
return errors.Wrap(err, "strconv.Atoi")
}
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1 query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
ORDER BY version_id DESC LIMIT 1` ORDER BY version_id DESC LIMIT 1`
rows, err := db.Query(query) rows, err := db.Query(query)

View File

@@ -0,0 +1,30 @@
package main
import (
"flag"
"strconv"
)
func setupFlags() map[string]string {
// Parse commandline args
host := flag.String("host", "", "Override host to listen on")
port := flag.String("port", "", "Override port to listen on")
test := flag.Bool("test", false, "Run server in test mode")
tester := flag.Bool("tester", false, "Run tester function instead of main program")
dbver := flag.Bool("dbver", false, "Get the version of the database required")
loglevel := flag.String("loglevel", "", "Set log level")
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
flag.Parse()
// Map the args for easy access
args := map[string]string{
"host": *host,
"port": *port,
"test": strconv.FormatBool(*test),
"tester": strconv.FormatBool(*tester),
"dbver": strconv.FormatBool(*dbver),
"loglevel": *loglevel,
"logoutput": *logoutput,
}
return args
}

View File

@@ -0,0 +1,16 @@
package main
import (
"context"
"fmt"
"os"
)
func main() {
args := setupFlags()
ctx := context.Background()
if err := run(ctx, os.Stdout, args); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
}

123
cmd/projectreshoot/run.go Normal file
View File

@@ -0,0 +1,123 @@
package main
import (
"context"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"projectreshoot/internal/httpserver"
"projectreshoot/pkg/config"
"projectreshoot/pkg/embedfs"
"sync"
"time"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
var maint uint32 // atomic: 1 if in maintenance mode
// Initializes and runs the server
func run(ctx context.Context, w io.Writer, args map[string]string) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
config, err := config.GetConfig(args)
if err != nil {
return errors.Wrap(err, "server.GetConfig")
}
// Return the version of the database required
if args["dbver"] == "true" {
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
return nil
}
// Setup the logfile
var logfile *os.File = nil
if config.LogOutput == "both" || config.LogOutput == "file" {
logfile, err = hlog.NewLogFile(config.LogDir)
if err != nil {
return errors.Wrap(err, "logging.GetLogFile")
}
defer logfile.Close()
}
// Setup the console writer
var consoleWriter io.Writer
if config.LogOutput == "both" || config.LogOutput == "console" {
consoleWriter = w
}
// Setup the logger
logger, err := hlog.NewLogger(
config.LogLevel,
consoleWriter,
logfile,
config.LogDir,
)
if err != nil {
return errors.Wrap(err, "logging.GetLogger")
}
// Setup the database connection
logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database")
conn, err := setupDBConn(config.DBName)
if err != nil {
return errors.Wrap(err, "setupDBConn")
}
defer conn.Close()
// Setup embedded files
logger.Debug().Msg("Getting embedded files")
staticFS, err := embedfs.GetEmbeddedFS()
if err != nil {
return errors.Wrap(err, "getStaticFiles")
}
// Setup TokenGenerator
logger.Debug().Msg("Creating TokenGenerator")
tokenGen, err := jwt.CreateGenerator(
config.AccessTokenExpiry,
config.RefreshTokenExpiry,
config.TokenFreshTime,
config.TrustedHost,
config.SecretKey,
conn,
)
logger.Debug().Msg("Setting up HTTP server")
httpServer := httpserver.NewServer(config, logger, conn, tokenGen, &staticFS, &maint)
// Setups a channel to listen for os.Signal
handleMaintSignals(httpServer, logger)
// Runs the http server
logger.Debug().Msg("Starting up the HTTP server")
go func() {
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error().Err(err).Msg("Error listening and serving")
}
}()
// Handles graceful shutdown
var wg sync.WaitGroup
wg.Go(func() {
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
logger.Error().Err(err).Msg("Error shutting down server")
}
})
wg.Wait()
logger.Info().Msg("Shutting down")
return nil
}

View File

@@ -0,0 +1,41 @@
package main
import (
"net/http"
"os"
"os/signal"
"sync/atomic"
"syscall"
"git.haelnorr.com/h/golib/hlog"
)
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
func handleMaintSignals(
srv *http.Server,
logger *hlog.Logger,
) {
logger.Debug().Msg("Starting signal listener")
ch := make(chan os.Signal, 1)
srv.RegisterOnShutdown(func() {
logger.Debug().Msg("Shutting down signal listener")
close(ch)
})
go func() {
for sig := range ch {
switch sig {
case syscall.SIGUSR1:
if atomic.LoadUint32(&maint) != 1 {
atomic.StoreUint32(&maint, 1)
logger.Info().Msg("Signal received: Starting maintenance")
}
case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 {
logger.Info().Msg("Signal received: Maintenance over")
atomic.StoreUint32(&maint, 0)
}
}
}
}()
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
}

View File

@@ -1,129 +0,0 @@
package db
import (
"context"
"database/sql"
"time"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
type SafeConn struct {
db *sql.DB
readLockCount uint32
globalLockStatus uint32
globalLockRequested uint32
logger *zerolog.Logger
}
// Make the provided db handle safe and attach a logger to it
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
return &SafeConn{db: db, logger: logger}
}
// Attempts to acquire a global lock on the database connection
func (conn *SafeConn) acquireGlobalLock() bool {
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
return false
}
conn.globalLockStatus = 1
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock acquired")
return true
}
// Releases a global lock on the database connection
func (conn *SafeConn) releaseGlobalLock() {
conn.globalLockStatus = 0
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock released")
}
// Acquire a read lock on the connection. Multiple read locks can be acquired
// at the same time
func (conn *SafeConn) acquireReadLock() bool {
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
return false
}
conn.readLockCount += 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock acquired")
return true
}
// Release a read lock. Decrements read lock count by 1
func (conn *SafeConn) releaseReadLock() {
conn.readLockCount -= 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock released")
}
// Starts a new transaction based on the current context. Will cancel if
// the context is closed/cancelled/done
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
lockAcquired := make(chan struct{})
lockCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-lockCtx.Done():
return
default:
if conn.acquireReadLock() {
close(lockAcquired)
}
}
}()
select {
case <-lockAcquired:
tx, err := conn.db.BeginTx(ctx, nil)
if err != nil {
conn.releaseReadLock()
return nil, err
}
return &SafeTX{tx: tx, sc: conn}, nil
case <-ctx.Done():
cancel()
return nil, errors.New("Transaction time out due to database lock")
}
}
// Acquire a global lock, preventing all transactions
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
conn.logger.Info().Msg("Attempting to acquire global database lock")
conn.globalLockRequested = 1
defer func() { conn.globalLockRequested = 0 }()
timeout := time.After(timeoutAfter)
attempt := 0
for {
if conn.acquireGlobalLock() {
conn.logger.Info().Msg("Global database lock acquired")
return
}
select {
case <-timeout:
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
return
case <-time.After(100 * time.Millisecond):
attempt++
}
}
}
// Release the global lock
func (conn *SafeConn) Resume() {
conn.releaseGlobalLock()
conn.logger.Info().Msg("Global database lock released")
}
// Close the database connection
func (conn *SafeConn) Close() error {
conn.logger.Debug().Msg("Acquiring global lock for connection close")
conn.acquireGlobalLock()
defer conn.releaseGlobalLock()
conn.logger.Debug().Msg("Closing database connection")
return conn.db.Close()
}

View File

@@ -1,143 +0,0 @@
package db
import (
"context"
"projectreshoot/tests"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSafeConn(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := MakeSafe(conn, logger)
defer sconn.Close()
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
var requested sync.WaitGroup
var engaged sync.WaitGroup
requested.Add(1)
engaged.Add(1)
go func() {
requested.Done()
sconn.Pause(5 * time.Second)
engaged.Done()
}()
requested.Wait()
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(1), sconn.globalLockRequested)
tx.Commit()
engaged.Wait()
assert.Equal(t, uint32(1), sconn.globalLockStatus)
assert.Equal(t, uint32(0), sconn.globalLockRequested)
sconn.Resume()
})
t.Run("Lock abandons after timeout", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
sconn.Pause(250 * time.Millisecond)
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(0), sconn.globalLockRequested)
tx.Commit()
})
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
var requested sync.WaitGroup
var engaged sync.WaitGroup
requested.Add(1)
engaged.Add(1)
go func() {
requested.Done()
sconn.Pause(5 * time.Second)
engaged.Done()
}()
requested.Wait()
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(1), sconn.globalLockRequested)
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
_, err = sconn.Begin(ctx)
require.Error(t, err)
tx.Commit()
engaged.Wait()
_, err = sconn.Begin(ctx)
require.Error(t, err)
sconn.Resume()
tx, err = sconn.Begin(t.Context())
require.NoError(t, err)
tx.Commit()
})
}
func TestSafeTX(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := MakeSafe(conn, logger)
defer sconn.Close()
t.Run("Commit releases lock", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
assert.Equal(t, uint32(1), sconn.readLockCount)
tx.Commit()
assert.Equal(t, uint32(0), sconn.readLockCount)
})
t.Run("Rollback releases lock", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
assert.Equal(t, uint32(1), sconn.readLockCount)
tx.Rollback()
assert.Equal(t, uint32(0), sconn.readLockCount)
})
t.Run("Multiple TX can gain read lock", func(t *testing.T) {
tx1, err := sconn.Begin(t.Context())
require.NoError(t, err)
tx2, err := sconn.Begin(t.Context())
require.NoError(t, err)
tx3, err := sconn.Begin(t.Context())
require.NoError(t, err)
tx1.Commit()
tx2.Commit()
tx3.Commit()
})
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
sconn.acquireGlobalLock()
defer sconn.releaseGlobalLock()
_, err := sconn.Begin(ctx)
require.Error(t, err)
})
t.Run("Lock acquires if lock released", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
sconn.acquireGlobalLock()
var wg sync.WaitGroup
wg.Add(1)
go func() {
tx, err := sconn.Begin(ctx)
require.NoError(t, err)
tx.Commit()
wg.Done()
}()
sconn.releaseGlobalLock()
wg.Wait()
})
}

View File

@@ -1,61 +0,0 @@
package db
import (
"context"
"database/sql"
"github.com/pkg/errors"
)
// Extends sql.Tx for use with SafeConn
type SafeTX struct {
tx *sql.Tx
sc *SafeConn
}
// Query the database inside the transaction
func (stx *SafeTX) Query(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Rows, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
return stx.tx.QueryContext(ctx, query, args...)
}
// Exec a statement on the database inside the transaction
func (stx *SafeTX) Exec(
ctx context.Context,
query string,
args ...interface{},
) (sql.Result, error) {
if stx.tx == nil {
return nil, errors.New("Cannot exec without a transaction")
}
return stx.tx.ExecContext(ctx, query, args...)
}
// Commit the current transaction and release the read lock
func (stx *SafeTX) Commit() error {
if stx.tx == nil {
return errors.New("Cannot commit without a transaction")
}
err := stx.tx.Commit()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Abort the current transaction, releasing the read lock
func (stx *SafeTX) Rollback() error {
if stx.tx == nil {
return errors.New("Cannot rollback without a transaction")
}
err := stx.tx.Rollback()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}

View File

@@ -1,60 +0,0 @@
package db
import (
"context"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
type User struct {
ID int // Integer ID (index primary key)
Username string // Username (unique)
Password_hash string // Bcrypt password hash
Created_at int64 // Epoch timestamp when the user was added to the database
Bio string // Short byline set by the user
}
// Uses bcrypt to set the users Password_hash from the given password
func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
}
user.Password_hash = string(hashedPassword)
query := `UPDATE users SET password_hash = ? WHERE id = ?`
_, err = tx.Exec(ctx, query, user.Password_hash, user.ID)
if err != nil {
return errors.Wrap(err, "tx.Exec")
}
return nil
}
// Uses bcrypt to check if the given password matches the users Password_hash
func (user *User) CheckPassword(password string) error {
err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password))
if err != nil {
return errors.Wrap(err, "bcrypt.CompareHashAndPassword")
}
return nil
}
// Change the user's username
func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error {
query := `UPDATE users SET username = ? WHERE id = ?`
_, err := tx.Exec(ctx, query, newUsername, user.ID)
if err != nil {
return errors.Wrap(err, "tx.Exec")
}
return nil
}
// Change the user's bio
func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error {
query := `UPDATE users SET bio = ? WHERE id = ?`
_, err := tx.Exec(ctx, query, newBio, user.ID)
if err != nil {
return errors.Wrap(err, "tx.Exec")
}
return nil
}

View File

@@ -64,7 +64,7 @@ failed_cleanup() {
trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT
echo "Migration in progress from $CUR_VER to $TGT_VER" echo "Migration in progress from $CUR_VER to $TGT_VER"
${MIGRATION_BIN}/prmigrate-${ENVR}-${COMMIT_HASH} $UPDATED_BACKUP $CMD $TGT_VER ${MIGRATION_BIN}/migrate-${ENVR}-${COMMIT_HASH} $UPDATED_BACKUP $CMD $TGT_VER
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Migration failed" echo "Migration failed"
exit 1 exit 1

25
go.mod
View File

@@ -1,36 +1,37 @@
module projectreshoot module projectreshoot
go 1.24.0 go 1.25.5
require ( require (
github.com/a-h/templ v0.3.833 git.haelnorr.com/h/golib/hlog v0.9.0
github.com/golang-jwt/jwt v3.2.2+incompatible git.haelnorr.com/h/golib/jwt v0.9.0
github.com/google/uuid v1.6.0 git.haelnorr.com/h/golib/tmdb v0.8.0
github.com/a-h/templ v0.3.977
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/mattn/go-sqlite3 v1.14.24
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pressly/goose/v3 v3.24.1 github.com/pressly/goose/v3 v3.24.1
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.10.0
golang.org/x/crypto v0.33.0 golang.org/x/crypto v0.33.0
modernc.org/sqlite v1.35.0 modernc.org/sqlite v1.35.0
) )
replace git.haelnorr.com/h/golib/jwt v0.9.0 => /home/haelnorr/projects/golib/jwt
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect github.com/mfridman/interpolate v0.0.2 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rs/zerolog v1.34.0 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
golang.org/x/sync v0.11.0 // indirect golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.30.0 // indirect golang.org/x/sys v0.34.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.61.13 // indirect modernc.org/libc v1.61.13 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.8.2 // indirect modernc.org/memory v1.8.2 // indirect

46
go.sum
View File

@@ -1,5 +1,11 @@
github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU= git.haelnorr.com/h/golib/hlog v0.9.0 h1:ib8n2MdmiRK2TF067p220kXmhDe9aAnlcsgpuv+QpvE=
github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk= git.haelnorr.com/h/golib/hlog v0.9.0/go.mod h1:oOlzb8UVHUYP1k7dN5PSJXVskAB2z8EYgRN85jAi0Zk=
git.haelnorr.com/h/golib/tmdb v0.8.0 h1:OQ6M2TB8FHm8fJD7/ebfWm63Duzfp0kmFX9genEig34=
git.haelnorr.com/h/golib/tmdb v0.8.0/go.mod h1:mGKYa3o3z0IsQ5EO3MPmnL2Bwl2sSMsUHXVgaIGR7Z0=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/a-h/templ v0.3.977 h1:kiKAPXTZE2Iaf8JbtM21r54A8bCNsncrfnokZZSrSDg=
github.com/a-h/templ v0.3.977/go.mod h1:oCZcnKRf5jjsGpf2yELzQfodLphd2mwecwG4Crk5HBo=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -16,11 +22,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@@ -28,6 +29,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
@@ -40,33 +43,30 @@ github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4
github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug= github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=

View File

@@ -2,17 +2,19 @@ package handler
import ( import (
"context" "context"
"database/sql"
"net/http" "net/http"
"time" "time"
"projectreshoot/contexts" "projectreshoot/internal/models"
"projectreshoot/cookies" "projectreshoot/internal/view/component/account"
"projectreshoot/db" "projectreshoot/internal/view/page"
"projectreshoot/view/component/account" "projectreshoot/pkg/contexts"
"projectreshoot/view/page" "projectreshoot/pkg/cookies"
"git.haelnorr.com/h/golib/hlog"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
) )
// Renders the account page on the 'General' subpage // Renders the account page on the 'General' subpage
@@ -43,8 +45,8 @@ func AccountSubpage() http.Handler {
// Handles a request to change the users username // Handles a request to change the users username
func ChangeUsername( func ChangeUsername(
logger *zerolog.Logger, logger *hlog.Logger,
conn *db.SafeConn, conn *sql.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
@@ -52,7 +54,7 @@ func ChangeUsername(
defer cancel() defer cancel()
// Start the transaction // Start the transaction
tx, err := conn.Begin(ctx) tx, err := conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("Error updating username") logger.Warn().Err(err).Msg("Error updating username")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
@@ -60,7 +62,7 @@ func ChangeUsername(
} }
r.ParseForm() r.ParseForm()
newUsername := r.FormValue("username") newUsername := r.FormValue("username")
unique, err := db.CheckUsernameUnique(ctx, tx, newUsername) unique, err := models.CheckUsernameUnique(tx, newUsername)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
logger.Error().Err(err).Msg("Error updating username") logger.Error().Err(err).Msg("Error updating username")
@@ -74,7 +76,7 @@ func ChangeUsername(
return return
} }
user := contexts.GetUser(r.Context()) user := contexts.GetUser(r.Context())
err = user.ChangeUsername(ctx, tx, newUsername) err = user.ChangeUsername(tx, newUsername)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
logger.Error().Err(err).Msg("Error updating username") logger.Error().Err(err).Msg("Error updating username")
@@ -89,8 +91,8 @@ func ChangeUsername(
// Handles a request to change the users bio // Handles a request to change the users bio
func ChangeBio( func ChangeBio(
logger *zerolog.Logger, logger *hlog.Logger,
conn *db.SafeConn, conn *sql.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
@@ -98,7 +100,7 @@ func ChangeBio(
defer cancel() defer cancel()
// Start the transaction // Start the transaction
tx, err := conn.Begin(ctx) tx, err := conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("Error updating bio") logger.Warn().Err(err).Msg("Error updating bio")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
@@ -114,7 +116,7 @@ func ChangeBio(
return return
} }
user := contexts.GetUser(r.Context()) user := contexts.GetUser(r.Context())
err = user.ChangeBio(ctx, tx, newBio) err = user.ChangeBio(tx, newBio)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
logger.Error().Err(err).Msg("Error updating bio") logger.Error().Err(err).Msg("Error updating bio")
@@ -127,8 +129,6 @@ func ChangeBio(
) )
} }
func validateChangePassword( func validateChangePassword(
ctx context.Context,
tx *db.SafeTX,
r *http.Request, r *http.Request,
) (string, error) { ) (string, error) {
r.ParseForm() r.ParseForm()
@@ -145,8 +145,8 @@ func validateChangePassword(
// Handles a request to change the users password // Handles a request to change the users password
func ChangePassword( func ChangePassword(
logger *zerolog.Logger, logger *hlog.Logger,
conn *db.SafeConn, conn *sql.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
@@ -154,20 +154,20 @@ func ChangePassword(
defer cancel() defer cancel()
// Start the transaction // Start the transaction
tx, err := conn.Begin(ctx) tx, err := conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("Error updating password") logger.Warn().Err(err).Msg("Error updating password")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
return return
} }
newPass, err := validateChangePassword(ctx, tx, r) newPass, err := validateChangePassword(r)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
account.ChangePassword(err.Error()).Render(r.Context(), w) account.ChangePassword(err.Error()).Render(r.Context(), w)
return return
} }
user := contexts.GetUser(r.Context()) user := contexts.GetUser(r.Context())
err = user.SetPassword(ctx, tx, newPass) err = user.SetPassword(tx, newPass)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
logger.Error().Err(err).Msg("Error updating password") logger.Error().Err(err).Msg("Error updating password")

View File

@@ -2,7 +2,7 @@ package handler
import ( import (
"net/http" "net/http"
"projectreshoot/view/page" "projectreshoot/internal/view/page"
) )
func ErrorPage( func ErrorPage(

View File

@@ -3,7 +3,7 @@ package handler
import ( import (
"net/http" "net/http"
"projectreshoot/view/page" "projectreshoot/internal/view/page"
) )
// Handles responses to the / path. Also serves a 404 Page for paths that // Handles responses to the / path. Also serves a 404 Page for paths that

View File

@@ -2,34 +2,35 @@ package handler
import ( import (
"context" "context"
"database/sql"
"net/http" "net/http"
"time" "time"
"projectreshoot/config" "projectreshoot/internal/models"
"projectreshoot/cookies" "projectreshoot/internal/view/component/form"
"projectreshoot/db" "projectreshoot/internal/view/page"
"projectreshoot/view/component/form" "projectreshoot/pkg/config"
"projectreshoot/view/page" "projectreshoot/pkg/cookies"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
) )
// Validates the username matches a user in the database and the password // Validates the username matches a user in the database and the password
// is correct. Returns the corresponding user // is correct. Returns the corresponding user
func validateLogin( func validateLogin(
ctx context.Context, tx *sql.Tx,
tx *db.SafeTX,
r *http.Request, r *http.Request,
) (*db.User, error) { ) (*models.User, error) {
formUsername := r.FormValue("username") formUsername := r.FormValue("username")
formPassword := r.FormValue("password") formPassword := r.FormValue("password")
user, err := db.GetUserFromUsername(ctx, tx, formUsername) user, err := models.GetUserFromUsername(tx, formUsername)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromUsername") return nil, errors.Wrap(err, "db.GetUserFromUsername")
} }
err = user.CheckPassword(formPassword) err = user.CheckPassword(tx, formPassword)
if err != nil { if err != nil {
return nil, errors.New("Username or password incorrect") return nil, errors.New("Username or password incorrect")
} }
@@ -51,8 +52,9 @@ func checkRememberMe(r *http.Request) bool {
// template for user feedback // template for user feedback
func LoginRequest( func LoginRequest(
config *config.Config, config *config.Config,
logger *zerolog.Logger, logger *hlog.Logger,
conn *db.SafeConn, conn *sql.DB,
tokenGen *jwt.TokenGenerator,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
@@ -60,14 +62,14 @@ func LoginRequest(
defer cancel() defer cancel()
// Start the transaction // Start the transaction
tx, err := conn.Begin(ctx) tx, err := conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("Failed to set token cookies") logger.Warn().Err(err).Msg("Failed to set token cookies")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
return return
} }
r.ParseForm() r.ParseForm()
user, err := validateLogin(ctx, tx, r) user, err := validateLogin(tx, r)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
if err.Error() != "Username or password incorrect" { if err.Error() != "Username or password incorrect" {
@@ -80,7 +82,7 @@ func LoginRequest(
} }
rememberMe := checkRememberMe(r) rememberMe := checkRememberMe(r)
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) err = cookies.SetTokenCookies(w, r, config, tokenGen, user, true, rememberMe)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)

View File

@@ -2,26 +2,25 @@ package handler
import ( import (
"context" "context"
"database/sql"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"projectreshoot/config" "projectreshoot/pkg/cookies"
"projectreshoot/cookies"
"projectreshoot/db" "git.haelnorr.com/h/golib/hlog"
"projectreshoot/jwt" "git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
) )
func revokeAccess( func revokeAccess(
config *config.Config, tokenGen *jwt.TokenGenerator,
ctx context.Context, tx *sql.Tx,
tx *db.SafeTX,
atStr string, atStr string,
) error { ) error {
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) aT, err := tokenGen.ValidateAccess(tx, atStr)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "Token is expired") || if strings.Contains(err.Error(), "Token is expired") ||
strings.Contains(err.Error(), "Token has been revoked") { strings.Contains(err.Error(), "Token has been revoked") {
@@ -29,7 +28,7 @@ func revokeAccess(
} }
return errors.Wrap(err, "jwt.ParseAccessToken") return errors.Wrap(err, "jwt.ParseAccessToken")
} }
err = jwt.RevokeToken(ctx, tx, aT) err = aT.Revoke(tx)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.RevokeToken") return errors.Wrap(err, "jwt.RevokeToken")
} }
@@ -37,12 +36,11 @@ func revokeAccess(
} }
func revokeRefresh( func revokeRefresh(
config *config.Config, tokenGen *jwt.TokenGenerator,
ctx context.Context, tx *sql.Tx,
tx *db.SafeTX,
rtStr string, rtStr string,
) error { ) error {
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) rT, err := tokenGen.ValidateRefresh(tx, rtStr)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "Token is expired") || if strings.Contains(err.Error(), "Token is expired") ||
strings.Contains(err.Error(), "Token has been revoked") { strings.Contains(err.Error(), "Token has been revoked") {
@@ -50,7 +48,7 @@ func revokeRefresh(
} }
return errors.Wrap(err, "jwt.ParseRefreshToken") return errors.Wrap(err, "jwt.ParseRefreshToken")
} }
err = jwt.RevokeToken(ctx, tx, rT) err = rT.Revoke(tx)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.RevokeToken") return errors.Wrap(err, "jwt.RevokeToken")
} }
@@ -59,20 +57,19 @@ func revokeRefresh(
// Retrieve and revoke the user's tokens // Retrieve and revoke the user's tokens
func revokeTokens( func revokeTokens(
config *config.Config, tokenGen *jwt.TokenGenerator,
ctx context.Context, tx *sql.Tx,
tx *db.SafeTX,
r *http.Request, r *http.Request,
) error { ) error {
// get the tokens from the cookies // get the tokens from the cookies
atStr, rtStr := cookies.GetTokenStrings(r) atStr, rtStr := cookies.GetTokenStrings(r)
// revoke the refresh token first as the access token expires quicker // revoke the refresh token first as the access token expires quicker
// only matters if there is an error revoking the tokens // only matters if there is an error revoking the tokens
err := revokeRefresh(config, ctx, tx, rtStr) err := revokeRefresh(tokenGen, tx, rtStr)
if err != nil { if err != nil {
return errors.Wrap(err, "revokeRefresh") return errors.Wrap(err, "revokeRefresh")
} }
err = revokeAccess(config, ctx, tx, atStr) err = revokeAccess(tokenGen, tx, atStr)
if err != nil { if err != nil {
return errors.Wrap(err, "revokeAccess") return errors.Wrap(err, "revokeAccess")
} }
@@ -81,25 +78,25 @@ func revokeTokens(
// Handle a logout request // Handle a logout request
func Logout( func Logout(
config *config.Config, conn *sql.DB,
logger *zerolog.Logger, tokenGen *jwt.TokenGenerator,
conn *db.SafeConn, logger *hlog.Logger,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel() defer cancel()
// Start the transaction tx, err := conn.BeginTx(ctx, nil)
tx, err := conn.Begin(ctx)
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("Error occured on user logout") logger.Error().Err(err).Msg("Failed to start database transaction")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusInternalServerError)
return return
} }
err = revokeTokens(config, ctx, tx, r) defer tx.Rollback()
err = revokeTokens(tokenGen, tx, r)
if err != nil { if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error occured on user logout") logger.Error().Err(err).Msg("Error occured on user logout")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return

View File

@@ -2,16 +2,16 @@ package handler
import ( import (
"net/http" "net/http"
"projectreshoot/config" "projectreshoot/internal/view/page"
"projectreshoot/tmdb" "projectreshoot/pkg/config"
"projectreshoot/view/page"
"strconv" "strconv"
"github.com/rs/zerolog" "git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/tmdb"
) )
func Movie( func Movie(
logger *zerolog.Logger, logger *hlog.Logger,
config *config.Config, config *config.Config,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(

View File

@@ -2,16 +2,16 @@ package handler
import ( import (
"net/http" "net/http"
"projectreshoot/config" "projectreshoot/internal/view/component/search"
"projectreshoot/tmdb" "projectreshoot/internal/view/page"
"projectreshoot/view/component/search" "projectreshoot/pkg/config"
"projectreshoot/view/page"
"github.com/rs/zerolog" "git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/tmdb"
) )
func SearchMovies( func SearchMovies(
logger *zerolog.Logger, logger *hlog.Logger,
config *config.Config, config *config.Config,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(

View File

@@ -2,7 +2,7 @@ package handler
import ( import (
"net/http" "net/http"
"projectreshoot/view/page" "projectreshoot/internal/view/page"
) )
func ProfilePage() http.Handler { func ProfilePage() http.Handler {

View File

@@ -2,54 +2,53 @@ package handler
import ( import (
"context" "context"
"database/sql"
"net/http" "net/http"
"time" "time"
"projectreshoot/config" "projectreshoot/internal/view/component/form"
"projectreshoot/contexts" "projectreshoot/pkg/config"
"projectreshoot/cookies" "projectreshoot/pkg/contexts"
"projectreshoot/db" "projectreshoot/pkg/cookies"
"projectreshoot/jwt"
"projectreshoot/view/component/form" "git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
) )
// Get the tokens from the request // Get the tokens from the request
func getTokens( func getTokens(
config *config.Config, tokenGen *jwt.TokenGenerator,
ctx context.Context, tx *sql.Tx,
tx *db.SafeTX,
r *http.Request, r *http.Request,
) (*jwt.AccessToken, *jwt.RefreshToken, error) { ) (*jwt.AccessToken, *jwt.RefreshToken, error) {
// get the existing tokens from the cookies // get the existing tokens from the cookies
atStr, rtStr := cookies.GetTokenStrings(r) atStr, rtStr := cookies.GetTokenStrings(r)
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) aT, err := tokenGen.ValidateAccess(tx, atStr)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken") return nil, nil, errors.Wrap(err, "tokenGen.ValidateAccess")
} }
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) rT, err := tokenGen.ValidateRefresh(tx, rtStr)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken") return nil, nil, errors.Wrap(err, "tokenGen.ValidateRefresh")
} }
return aT, rT, nil return aT, rT, nil
} }
// Revoke the given token pair // Revoke the given token pair
func revokeTokenPair( func revokeTokenPair(
ctx context.Context, tx *sql.Tx,
tx *db.SafeTX,
aT *jwt.AccessToken, aT *jwt.AccessToken,
rT *jwt.RefreshToken, rT *jwt.RefreshToken,
) error { ) error {
err := jwt.RevokeToken(ctx, tx, aT) err := aT.Revoke(tx)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.RevokeToken") return errors.Wrap(err, "aT.Revoke")
} }
err = jwt.RevokeToken(ctx, tx, rT) err = rT.Revoke(tx)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.RevokeToken") return errors.Wrap(err, "rT.Revoke")
} }
return nil return nil
} }
@@ -57,12 +56,12 @@ func revokeTokenPair(
// Issue new tokens for the user, invalidating the old ones // Issue new tokens for the user, invalidating the old ones
func refreshTokens( func refreshTokens(
config *config.Config, config *config.Config,
ctx context.Context, tokenGen *jwt.TokenGenerator,
tx *db.SafeTX, tx *sql.Tx,
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
) error { ) error {
aT, rT, err := getTokens(config, ctx, tx, r) aT, rT, err := getTokens(tokenGen, tx, r)
if err != nil { if err != nil {
return errors.Wrap(err, "getTokens") return errors.Wrap(err, "getTokens")
} }
@@ -72,11 +71,11 @@ func refreshTokens(
}[aT.TTL] }[aT.TTL]
// issue new tokens for the user // issue new tokens for the user
user := contexts.GetUser(r.Context()) user := contexts.GetUser(r.Context())
err = cookies.SetTokenCookies(w, r, config, user.User, true, rememberMe) err = cookies.SetTokenCookies(w, r, config, tokenGen, user.User, true, rememberMe)
if err != nil { if err != nil {
return errors.Wrap(err, "cookies.SetTokenCookies") return errors.Wrap(err, "cookies.SetTokenCookies")
} }
err = revokeTokenPair(ctx, tx, aT, rT) err = revokeTokenPair(tx, aT, rT)
if err != nil { if err != nil {
return errors.Wrap(err, "revokeTokenPair") return errors.Wrap(err, "revokeTokenPair")
} }
@@ -86,12 +85,13 @@ func refreshTokens(
// Validate the provided password // Validate the provided password
func validatePassword( func validatePassword(
tx *sql.Tx,
r *http.Request, r *http.Request,
) error { ) error {
r.ParseForm() r.ParseForm()
password := r.FormValue("password") password := r.FormValue("password")
user := contexts.GetUser(r.Context()) user := contexts.GetUser(r.Context())
err := user.CheckPassword(password) err := user.CheckPassword(tx, password)
if err != nil { if err != nil {
return errors.Wrap(err, "user.CheckPassword") return errors.Wrap(err, "user.CheckPassword")
} }
@@ -100,9 +100,10 @@ func validatePassword(
// Handle request to reauthenticate (i.e. make token fresh again) // Handle request to reauthenticate (i.e. make token fresh again)
func Reauthenticate( func Reauthenticate(
logger *zerolog.Logger, logger *hlog.Logger,
config *config.Config, config *config.Config,
conn *db.SafeConn, conn *sql.DB,
tokenGen *jwt.TokenGenerator,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
@@ -110,22 +111,21 @@ func Reauthenticate(
defer cancel() defer cancel()
// Start the transaction // Start the transaction
tx, err := conn.Begin(ctx) tx, err := conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("Failed to refresh user tokens") logger.Error().Err(err).Msg("Failed to start transaction")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusInternalServerError)
return return
} }
err = validatePassword(r) defer tx.Rollback()
err = validatePassword(tx, r)
if err != nil { if err != nil {
tx.Rollback()
w.WriteHeader(445) w.WriteHeader(445)
form.ConfirmPassword("Incorrect password").Render(r.Context(), w) form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
return return
} }
err = refreshTokens(config, ctx, tx, w, r) err = refreshTokens(config, tokenGen, tx, w, r)
if err != nil { if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Failed to refresh user tokens") logger.Error().Err(err).Msg("Failed to refresh user tokens")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return

View File

@@ -2,30 +2,32 @@ package handler
import ( import (
"context" "context"
"database/sql"
"net/http" "net/http"
"time" "time"
"projectreshoot/config" "projectreshoot/internal/models"
"projectreshoot/cookies" "projectreshoot/internal/view/component/form"
"projectreshoot/db" "projectreshoot/internal/view/page"
"projectreshoot/view/component/form" "projectreshoot/pkg/config"
"projectreshoot/view/page" "projectreshoot/pkg/cookies"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
) )
func validateRegistration( func validateRegistration(
ctx context.Context, tx *sql.Tx,
tx *db.SafeTX,
r *http.Request, r *http.Request,
) (*db.User, error) { ) (*models.User, error) {
formUsername := r.FormValue("username") formUsername := r.FormValue("username")
formPassword := r.FormValue("password") formPassword := r.FormValue("password")
formConfirmPassword := r.FormValue("confirm-password") formConfirmPassword := r.FormValue("confirm-password")
unique, err := db.CheckUsernameUnique(ctx, tx, formUsername) unique, err := models.CheckUsernameUnique(tx, formUsername)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "db.CheckUsernameUnique") return nil, errors.Wrap(err, "models.CheckUsernameUnique")
} }
if !unique { if !unique {
return nil, errors.New("Username is taken") return nil, errors.New("Username is taken")
@@ -36,9 +38,9 @@ func validateRegistration(
if len(formPassword) > 72 { if len(formPassword) > 72 {
return nil, errors.New("Password exceeds maximum length of 72 bytes") return nil, errors.New("Password exceeds maximum length of 72 bytes")
} }
user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword) user, err := models.CreateNewUser(tx, formUsername, formPassword)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "db.CreateNewUser") return nil, errors.Wrap(err, "models.CreateNewUser")
} }
return user, nil return user, nil
@@ -46,8 +48,9 @@ func validateRegistration(
func RegisterRequest( func RegisterRequest(
config *config.Config, config *config.Config,
logger *zerolog.Logger, tokenGen *jwt.TokenGenerator,
conn *db.SafeConn, logger *hlog.Logger,
conn *sql.DB,
) http.Handler { ) http.Handler {
return http.HandlerFunc( return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
@@ -55,14 +58,14 @@ func RegisterRequest(
defer cancel() defer cancel()
// Start the transaction // Start the transaction
tx, err := conn.Begin(ctx) tx, err := conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("Failed to set token cookies") logger.Warn().Err(err).Msg("Failed to set token cookies")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
return return
} }
r.ParseForm() r.ParseForm()
user, err := validateRegistration(ctx, tx, r) user, err := validateRegistration(tx, r)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
if err.Error() != "Username is taken" && if err.Error() != "Username is taken" &&
@@ -77,7 +80,7 @@ func RegisterRequest(
} }
rememberMe := checkRememberMe(r) rememberMe := checkRememberMe(r)
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe) err = cookies.SetTokenCookies(w, r, config, tokenGen, user, true, rememberMe)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)

View File

@@ -1,23 +1,25 @@
package server package httpserver
import ( import (
"database/sql"
"net/http" "net/http"
"projectreshoot/config" "projectreshoot/internal/handler"
"projectreshoot/db" "projectreshoot/internal/middleware"
"projectreshoot/handler" "projectreshoot/internal/view/page"
"projectreshoot/middleware" "projectreshoot/pkg/config"
"projectreshoot/view/page"
"github.com/rs/zerolog" "git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
) )
// Add all the handled routes to the mux // Add all the handled routes to the mux
func addRoutes( func addRoutes(
mux *http.ServeMux, mux *http.ServeMux,
logger *zerolog.Logger, logger *hlog.Logger,
config *config.Config, config *config.Config,
conn *db.SafeConn, tokenGen *jwt.TokenGenerator,
conn *sql.DB,
staticFS *http.FileSystem, staticFS *http.FileSystem,
) { ) {
route := mux.Handle route := mux.Handle
@@ -39,17 +41,17 @@ func addRoutes(
// Login page and handlers // Login page and handlers
route("GET /login", loggedOut(handler.LoginPage(config.TrustedHost))) route("GET /login", loggedOut(handler.LoginPage(config.TrustedHost)))
route("POST /login", loggedOut(handler.LoginRequest(config, logger, conn))) route("POST /login", loggedOut(handler.LoginRequest(config, logger, conn, tokenGen)))
// Register page and handlers // Register page and handlers
route("GET /register", loggedOut(handler.RegisterPage(config.TrustedHost))) route("GET /register", loggedOut(handler.RegisterPage(config.TrustedHost)))
route("POST /register", loggedOut(handler.RegisterRequest(config, logger, conn))) route("POST /register", loggedOut(handler.RegisterRequest(config, tokenGen, logger, conn)))
// Logout // Logout
route("POST /logout", handler.Logout(config, logger, conn)) route("POST /logout", handler.Logout(conn, tokenGen, logger))
// Reauthentication request // Reauthentication request
route("POST /reauthenticate", loggedIn(handler.Reauthenticate(logger, config, conn))) route("POST /reauthenticate", loggedIn(handler.Reauthenticate(logger, config, conn, tokenGen)))
// Profile page // Profile page
route("GET /profile", loggedIn(handler.ProfilePage())) route("GET /profile", loggedIn(handler.ProfilePage()))

View File

@@ -0,0 +1,67 @@
package httpserver
import (
"database/sql"
"io/fs"
"net"
"net/http"
"time"
"projectreshoot/internal/middleware"
"projectreshoot/pkg/config"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
)
func NewServer(
config *config.Config,
logger *hlog.Logger,
conn *sql.DB,
tokenGen *jwt.TokenGenerator,
staticFS *fs.FS,
maint *uint32,
) *http.Server {
fs := http.FS(*staticFS)
srv := createServer(config, logger, conn, tokenGen, &fs, maint)
httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, config.Port),
Handler: srv,
ReadHeaderTimeout: config.ReadHeaderTimeout * time.Second,
WriteTimeout: config.WriteTimeout * time.Second,
IdleTimeout: config.IdleTimeout * time.Second,
}
return httpServer
}
// Returns a new http.Handler with all the routes and middleware added
func createServer(
config *config.Config,
logger *hlog.Logger,
conn *sql.DB,
tokenGen *jwt.TokenGenerator,
staticFS *http.FileSystem,
maint *uint32,
) http.Handler {
mux := http.NewServeMux()
addRoutes(
mux,
logger,
config,
tokenGen,
conn,
staticFS,
)
var handler http.Handler = mux
// Add middleware here, must be added in reverse order of execution
// i.e. First in list will get executed last during the request handling
handler = middleware.Logging(logger, handler)
handler = middleware.Authentication(logger, config, conn, tokenGen, handler, maint)
// Gzip
handler = middleware.Gzip(handler, config.GZIP)
// Start the timer for the request chain so logger can have accurate info
handler = middleware.StartTimer(handler)
return handler
}

View File

@@ -2,33 +2,34 @@ package middleware
import ( import (
"context" "context"
"database/sql"
"net/http" "net/http"
"sync/atomic" "sync/atomic"
"time" "time"
"projectreshoot/config" "projectreshoot/internal/handler"
"projectreshoot/contexts" "projectreshoot/internal/models"
"projectreshoot/cookies" "projectreshoot/pkg/config"
"projectreshoot/db" "projectreshoot/pkg/contexts"
"projectreshoot/handler" "projectreshoot/pkg/cookies"
"projectreshoot/jwt"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
) )
// Attempt to use a valid refresh token to generate a new token pair // Attempt to use a valid refresh token to generate a new token pair
func refreshAuthTokens( func refreshAuthTokens(
config *config.Config, config *config.Config,
ctx context.Context, tokenGen *jwt.TokenGenerator,
tx *db.SafeTX, tx *sql.Tx,
w http.ResponseWriter, w http.ResponseWriter,
req *http.Request, req *http.Request,
ref *jwt.RefreshToken, ref *jwt.RefreshToken,
) (*db.User, error) { ) (*models.User, error) {
user, err := ref.GetUser(ctx, tx) user, err := models.GetUserFromID(tx, ref.SUB)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "ref.GetUser") return nil, errors.Wrap(err, "models.GetUser")
} }
rememberMe := map[string]bool{ rememberMe := map[string]bool{
@@ -37,14 +38,14 @@ func refreshAuthTokens(
}[ref.TTL] }[ref.TTL]
// Set fresh to true because new tokens coming from refresh request // Set fresh to true because new tokens coming from refresh request
err = cookies.SetTokenCookies(w, req, config, user, false, rememberMe) err = cookies.SetTokenCookies(w, req, config, tokenGen, user, false, rememberMe)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "cookies.SetTokenCookies") return nil, errors.Wrap(err, "cookies.SetTokenCookies")
} }
// New tokens sent, revoke the used refresh token // New tokens sent, revoke the used refresh token
err = jwt.RevokeToken(ctx, tx, ref) err = ref.Revoke(tx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "jwt.RevokeToken") return nil, errors.Wrap(err, "ref.Revoke")
} }
// Return the authorized user // Return the authorized user
return user, nil return user, nil
@@ -53,23 +54,26 @@ func refreshAuthTokens(
// Check the cookies for token strings and attempt to authenticate them // Check the cookies for token strings and attempt to authenticate them
func getAuthenticatedUser( func getAuthenticatedUser(
config *config.Config, config *config.Config,
ctx context.Context, tokenGen *jwt.TokenGenerator,
tx *db.SafeTX, tx *sql.Tx,
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
) (*contexts.AuthenticatedUser, error) { ) (*contexts.AuthenticatedUser, error) {
// Get token strings from cookies // Get token strings from cookies
atStr, rtStr := cookies.GetTokenStrings(r) atStr, rtStr := cookies.GetTokenStrings(r)
if atStr == "" && rtStr == "" {
return nil, errors.New("No token strings provided")
}
// Attempt to parse the access token // Attempt to parse the access token
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr) aT, err := tokenGen.ValidateAccess(tx, atStr)
if err != nil { if err != nil {
// Access token invalid, attempt to parse refresh token // Access token invalid, attempt to parse refresh token
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr) rT, err := tokenGen.ValidateRefresh(tx, rtStr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "jwt.ParseRefreshToken") return nil, errors.Wrap(err, "tokenGen.ValidateRefresh")
} }
// Refresh token valid, attempt to get a new token pair // Refresh token valid, attempt to get a new token pair
user, err := refreshAuthTokens(config, ctx, tx, w, r, rT) user, err := refreshAuthTokens(config, tokenGen, tx, w, r, rT)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "refreshAuthTokens") return nil, errors.Wrap(err, "refreshAuthTokens")
} }
@@ -81,9 +85,9 @@ func getAuthenticatedUser(
return &authUser, nil return &authUser, nil
} }
// Access token valid // Access token valid
user, err := aT.GetUser(ctx, tx) user, err := models.GetUserFromID(tx, aT.SUB)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "aT.GetUser") return nil, errors.Wrap(err, "models.GetUser")
} }
authUser := contexts.AuthenticatedUser{ authUser := contexts.AuthenticatedUser{
User: user, User: user,
@@ -95,9 +99,10 @@ func getAuthenticatedUser(
// Attempt to authenticate the user and add their account details // Attempt to authenticate the user and add their account details
// to the request context // to the request context
func Authentication( func Authentication(
logger *zerolog.Logger, logger *hlog.Logger,
config *config.Config, config *config.Config,
conn *db.SafeConn, conn *sql.DB,
tokenGen *jwt.TokenGenerator,
next http.Handler, next http.Handler,
maint *uint32, maint *uint32,
) http.Handler { ) http.Handler {
@@ -114,7 +119,7 @@ func Authentication(
} }
// Start the transaction // Start the transaction
tx, err := conn.Begin(ctx) tx, err := conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
// Failed to start transaction, skip auth // Failed to start transaction, skip auth
logger.Warn().Err(err). logger.Warn().Err(err).
@@ -122,7 +127,7 @@ func Authentication(
handler.ErrorPage(http.StatusServiceUnavailable, w, r) handler.ErrorPage(http.StatusServiceUnavailable, w, r)
return return
} }
user, err := getAuthenticatedUser(config, ctx, tx, w, r) user, err := getAuthenticatedUser(config, tokenGen, tx, w, r)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
// User auth failed, delete the cookies to avoid repeat requests // User auth failed, delete the cookies to avoid repeat requests

View File

@@ -2,11 +2,11 @@ package middleware
import ( import (
"net/http" "net/http"
"projectreshoot/contexts" "projectreshoot/internal/handler"
"projectreshoot/handler" "projectreshoot/pkg/contexts"
"time" "time"
"github.com/rs/zerolog" "git.haelnorr.com/h/golib/hlog"
) )
// Wraps the http.ResponseWriter, adding a statusCode field // Wraps the http.ResponseWriter, adding a statusCode field
@@ -22,7 +22,7 @@ func (w *wrappedWriter) WriteHeader(statusCode int) {
} }
// Middleware to add logs to console with details of the request // Middleware to add logs to console with details of the request
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler { func Logging(logger *hlog.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/static/css/output.css" || if r.URL.Path == "/static/css/output.css" ||
r.URL.Path == "/static/favicon.ico" { r.URL.Path == "/static/favicon.ico" {

View File

@@ -2,8 +2,8 @@ package middleware
import ( import (
"net/http" "net/http"
"projectreshoot/contexts" "projectreshoot/internal/handler"
"projectreshoot/handler" "projectreshoot/pkg/contexts"
) )
// Checks if the user is set in the context and shows 401 page if not logged in // Checks if the user is set in the context and shows 401 page if not logged in

View File

@@ -2,7 +2,7 @@ package middleware
import ( import (
"net/http" "net/http"
"projectreshoot/contexts" "projectreshoot/pkg/contexts"
"time" "time"
) )

View File

@@ -2,7 +2,7 @@ package middleware
import ( import (
"net/http" "net/http"
"projectreshoot/contexts" "projectreshoot/pkg/contexts"
"time" "time"
) )

69
internal/models/user.go Normal file
View File

@@ -0,0 +1,69 @@
package models
import (
"database/sql"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
type User struct {
ID int // Integer ID (index primary key)
Username string // Username (unique)
Created_at int64 // Epoch timestamp when the user was added to the database
Bio string // Short byline set by the user
}
// Uses bcrypt to set the users Password_hash from the given password
func (user *User) SetPassword(
tx *sql.Tx,
password string,
) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
}
newPassword := string(hashedPassword)
query := `UPDATE users SET password_hash = ? WHERE id = ?`
_, err = tx.Exec(query, newPassword, user.ID)
if err != nil {
return errors.Wrap(err, "tx.Exec")
}
return nil
}
// Uses bcrypt to check if the given password matches the users Password_hash
func (user *User) CheckPassword(tx *sql.Tx, password string) error {
query := `SELECT password_hash FROM users WHERE id = ? LIMIT 1`
row := tx.QueryRow(query, user.ID)
hashedPassword := ""
err := row.Scan(&hashedPassword)
if err != nil {
return errors.Wrap(err, "row.Scan")
}
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
if err != nil {
return errors.Wrap(err, "bcrypt.CompareHashAndPassword")
}
return nil
}
// Change the user's username
func (user *User) ChangeUsername(tx *sql.Tx, newUsername string) error {
query := `UPDATE users SET username = ? WHERE id = ?`
_, err := tx.Exec(query, newUsername, user.ID)
if err != nil {
return errors.Wrap(err, "tx.Exec")
}
return nil
}
// Change the user's bio
func (user *User) ChangeBio(tx *sql.Tx, newBio string) error {
query := `UPDATE users SET bio = ? WHERE id = ?`
_, err := tx.Exec(query, newBio, user.ID)
if err != nil {
return errors.Wrap(err, "tx.Exec")
}
return nil
}

View File

@@ -1,7 +1,6 @@
package db package models
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
@@ -10,21 +9,20 @@ import (
// Creates a new user in the database and returns a pointer // Creates a new user in the database and returns a pointer
func CreateNewUser( func CreateNewUser(
ctx context.Context, tx *sql.Tx,
tx *SafeTX,
username string, username string,
password string, password string,
) (*User, error) { ) (*User, error) {
query := `INSERT INTO users (username) VALUES (?)` query := `INSERT INTO users (username) VALUES (?)`
_, err := tx.Exec(ctx, query, username) _, err := tx.Exec(query, username)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.Exec") return nil, errors.Wrap(err, "tx.Exec")
} }
user, err := GetUserFromUsername(ctx, tx, username) user, err := GetUserFromUsername(tx, username)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "GetUserFromUsername") return nil, errors.Wrap(err, "GetUserFromUsername")
} }
err = user.SetPassword(ctx, tx, password) err = user.SetPassword(tx, password)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "user.SetPassword") return nil, errors.Wrap(err, "user.SetPassword")
} }
@@ -33,23 +31,21 @@ func CreateNewUser(
// Fetches data from the users table using "WHERE column = 'value'" // Fetches data from the users table using "WHERE column = 'value'"
func fetchUserData( func fetchUserData(
ctx context.Context, tx *sql.Tx,
tx *SafeTX,
column string, column string,
value interface{}, value any,
) (*sql.Rows, error) { ) (*sql.Rows, error) {
query := fmt.Sprintf( query := fmt.Sprintf(
`SELECT `SELECT
id, id,
username, username,
password_hash,
created_at, created_at,
bio bio
FROM users FROM users
WHERE %s = ? COLLATE NOCASE LIMIT 1`, WHERE %s = ? COLLATE NOCASE LIMIT 1`,
column, column,
) )
rows, err := tx.Query(ctx, query, value) rows, err := tx.Query(query, value)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "tx.Query") return nil, errors.Wrap(err, "tx.Query")
} }
@@ -65,7 +61,6 @@ func scanUserRow(user *User, rows *sql.Rows) error {
err := rows.Scan( err := rows.Scan(
&user.ID, &user.ID,
&user.Username, &user.Username,
&user.Password_hash,
&user.Created_at, &user.Created_at,
&user.Bio, &user.Bio,
) )
@@ -77,8 +72,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
// Queries the database for a user matching the given username. // Queries the database for a user matching the given username.
// Query is case insensitive // Query is case insensitive
func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) { func GetUserFromUsername(tx *sql.Tx, username string) (*User, error) {
rows, err := fetchUserData(ctx, tx, "username", username) rows, err := fetchUserData(tx, "username", username)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetchUserData") return nil, errors.Wrap(err, "fetchUserData")
} }
@@ -92,8 +87,8 @@ func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*Use
} }
// Queries the database for a user matching the given ID. // Queries the database for a user matching the given ID.
func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) { func GetUserFromID(tx *sql.Tx, id int) (*User, error) {
rows, err := fetchUserData(ctx, tx, "id", id) rows, err := fetchUserData(tx, "id", id)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetchUserData") return nil, errors.Wrap(err, "fetchUserData")
} }
@@ -107,9 +102,9 @@ func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
} }
// Checks if the given username is unique. Returns true if not taken // Checks if the given username is unique. Returns true if not taken
func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) { func CheckUsernameUnique(tx *sql.Tx, username string) (bool, error) {
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1` query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
rows, err := tx.Query(ctx, query, username) rows, err := tx.Query(query, username)
if err != nil { if err != nil {
return false, errors.Wrap(err, "tx.Query") return false, errors.Wrap(err, "tx.Query")
} }

View File

@@ -1,6 +1,6 @@
package account package account
import "projectreshoot/contexts" import "projectreshoot/pkg/contexts"
templ ChangeBio(err string, bio string) { templ ChangeBio(err string, bio string) {
{{ {{

View File

@@ -1,6 +1,6 @@
package account package account
import "projectreshoot/contexts" import "projectreshoot/pkg/contexts"
templ ChangeUsername(err string, username string) { templ ChangeUsername(err string, username string) {
{{ {{

View File

@@ -1,6 +1,6 @@
package nav package nav
import "projectreshoot/contexts" import "projectreshoot/pkg/contexts"
type ProfileItem struct { type ProfileItem struct {
name string // Label to display name string // Label to display

View File

@@ -1,6 +1,6 @@
package nav package nav
import "projectreshoot/contexts" import "projectreshoot/pkg/contexts"
// Returns the mobile version of the navbar thats only visible when activated // Returns the mobile version of the navbar thats only visible when activated
templ sideNav(navItems []NavItem) { templ sideNav(navItems []NavItem) {

View File

@@ -1,7 +1,7 @@
package popup package popup
import "projectreshoot/view/component/form" import "projectreshoot/internal/view/component/form"
templ ConfirmPasswordModal() { templ ConfirmPasswordModal() {
<div <div

View File

@@ -1,7 +1,7 @@
package search package search
import "projectreshoot/tmdb"
import "fmt" import "fmt"
import "git.haelnorr.com/h/golib/tmdb"
templ MovieResults(movies *tmdb.ResultMovies, image *tmdb.Image) { templ MovieResults(movies *tmdb.ResultMovies, image *tmdb.Image) {
for _, movie := range movies.Results { for _, movie := range movies.Results {
@@ -18,23 +18,23 @@ templ MovieResults(movies *tmdb.ResultMovies, image *tmdb.Image) {
onerror="this.onerror=null; setFallbackColor(this);" onerror="this.onerror=null; setFallbackColor(this);"
/> />
<script> <script>
function setFallbackColor(img) { function setFallbackColor(img) {
const baseColor = getComputedStyle(document.documentElement). const baseColor = getComputedStyle(document.documentElement).
getPropertyValue('--base').trim(); getPropertyValue('--base').trim();
img.src = `data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='96' height='144'%3E%3Crect width='100%' height='100%' fill='${baseColor}'/%3E%3C/svg%3E`; img.src = `data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='96' height='144'%3E%3Crect width='100%' height='100%' fill='${baseColor}'/%3E%3C/svg%3E`;
} }
</script> </script>
<div> <div>
<a <a
href={ templ.SafeURL(fmt.Sprintf("/movie/%v", movie.ID)) } href={ templ.SafeURL(fmt.Sprintf("/movie/%v", movie.ID)) }
class="text-xl font-semibold transition hover:text-green" class="text-xl font-semibold transition hover:text-green"
>{ movie.Title } { movie.ReleaseYear() }</a> >{ movie.Title } { movie.ReleaseYear() }</a>
<p class="text-subtext0"> <p class="text-subtext0">
Released: Released:
<span class="font-medium">{ movie.ReleaseDate }</span> <span class="font-medium">{ movie.ReleaseDate }</span>
</p> </p>
<p class="text-subtext0"> <p class="text-subtext0">
Original Title: Original Title:
<span class="font-medium">{ movie.OriginalTitle }</span> <span class="font-medium">{ movie.OriginalTitle }</span>
</p> </p>
<p class="text-subtext0">{ movie.Overview }</p> <p class="text-subtext0">{ movie.Overview }</p>

View File

@@ -1,8 +1,8 @@
package layout package layout
import "projectreshoot/view/component/nav" import "projectreshoot/internal/view/component/nav"
import "projectreshoot/view/component/footer" import "projectreshoot/internal/view/component/footer"
import "projectreshoot/view/component/popup" import "projectreshoot/internal/view/component/popup"
// Global page layout. Includes HTML document settings, header tags // Global page layout. Includes HTML document settings, header tags
// navbar and footer // navbar and footer

View File

@@ -1,6 +1,6 @@
package page package page
import "projectreshoot/view/layout" import "projectreshoot/internal/view/layout"
// Returns the about page content // Returns the about page content
templ About() { templ About() {

View File

@@ -1,7 +1,7 @@
package page package page
import "projectreshoot/view/layout" import "projectreshoot/internal/view/layout"
import "projectreshoot/view/component/account" import "projectreshoot/internal/view/component/account"
templ Account(subpage string) { templ Account(subpage string) {
@layout.Global("Account - " + subpage) { @layout.Global("Account - " + subpage) {

View File

@@ -1,6 +1,6 @@
package page package page
import "projectreshoot/view/layout" import "projectreshoot/internal/view/layout"
import "strconv" import "strconv"
// Page template for Error pages. Error code should be a HTTP status code as // Page template for Error pages. Error code should be a HTTP status code as

View File

@@ -1,6 +1,6 @@
package page package page
import "projectreshoot/view/layout" import "projectreshoot/internal/view/layout"
// Page content for the index page // Page content for the index page
templ Index() { templ Index() {

View File

@@ -1,7 +1,7 @@
package page package page
import "projectreshoot/view/layout" import "projectreshoot/internal/view/layout"
import "projectreshoot/view/component/form" import "projectreshoot/internal/view/component/form"
// Returns the login page // Returns the login page
templ Login() { templ Login() {

View File

@@ -1,15 +1,12 @@
package page package page
import "projectreshoot/tmdb" import "git.haelnorr.com/h/golib/tmdb"
import "projectreshoot/view/layout" import "projectreshoot/internal/view/layout"
templ Movie(movie *tmdb.Movie, credits *tmdb.Credits, image *tmdb.Image) { templ Movie(movie *tmdb.Movie, credits *tmdb.Credits, image *tmdb.Image) {
@layout.Global(movie.Title) { @layout.Global(movie.Title) {
<div class="md:bg-surface0 md:p-2 md:rounded-lg transition-all"> <div class="md:bg-surface0 md:p-2 md:rounded-lg transition-all">
<div <div id="billedcrew" class="hidden">
id="billedcrew"
class="hidden"
>
for _, billedcrew := range credits.BilledCrew() { for _, billedcrew := range credits.BilledCrew() {
<span class="flex flex-col text-left w-[130px] md:w-[180px]"> <span class="flex flex-col text-left w-[130px] md:w-[180px]">
<span class="font-bold">{ billedcrew.Name }</span> <span class="font-bold">{ billedcrew.Name }</span>
@@ -20,7 +17,7 @@ templ Movie(movie *tmdb.Movie, credits *tmdb.Credits, image *tmdb.Image) {
<div class="flex items-start"> <div class="flex items-start">
<div class="w-[154px] md:w-[300px] flex-col"> <div class="w-[154px] md:w-[300px] flex-col">
<img <img
class="object-cover aspect-[2/3] w-[154px] md:w-[300px] class="object-cover aspect-2/3 w-[154px] md:w-[300px]
transition-all md:rounded-md shadow-black shadow-2xl" transition-all md:rounded-md shadow-black shadow-2xl"
src={ movie.GetPoster(image, "w300") } src={ movie.GetPoster(image, "w300") }
alt="Poster" alt="Poster"
@@ -31,27 +28,27 @@ templ Movie(movie *tmdb.Movie, credits *tmdb.Credits, image *tmdb.Image) {
mt-5 flex-wrap justify-around flex-col px-5 md:hidden" mt-5 flex-wrap justify-around flex-col px-5 md:hidden"
></div> ></div>
<script> <script>
function moveBilledCrew() { function moveBilledCrew() {
const billedCrewMd = document.getElementById('billedcrew-md'); const billedCrewMd = document.getElementById('billedcrew-md');
const billedCrewSm = document.getElementById('billedcrew-sm'); const billedCrewSm = document.getElementById('billedcrew-sm');
const billedCrew = document.getElementById('billedcrew'); const billedCrew = document.getElementById('billedcrew');
if (window.innerWidth < 768) { if (window.innerWidth < 768) {
billedCrewSm.innerHTML = billedCrew.innerHTML; billedCrewSm.innerHTML = billedCrew.innerHTML;
billedCrewMd.innerHTML = ""; billedCrewMd.innerHTML = "";
} else { } else {
billedCrewMd.innerHTML = billedCrew.innerHTML; billedCrewMd.innerHTML = billedCrew.innerHTML;
billedCrewSm.innerHTML = ""; billedCrewSm.innerHTML = "";
} }
} }
window.addEventListener('load', moveBilledCrew); window.addEventListener('load', moveBilledCrew);
const resizeObs = new ResizeObserver(() => { const resizeObs = new ResizeObserver(() => {
moveBilledCrew(); moveBilledCrew();
}); });
resizeObs.observe(document.body); resizeObs.observe(document.body);
</script> </script>
</div> </div>
<div class="flex flex-col flex-1 text-center px-4"> <div class="flex flex-col flex-1 text-center px-4">
<span class="text-xl md:text-3xl font-semibold"> <span class="text-xl md:text-3xl font-semibold">

View File

@@ -1,6 +1,6 @@
package page package page
import "projectreshoot/view/layout" import "projectreshoot/internal/view/layout"
templ Movies() { templ Movies() {
@layout.Global("Search movies") { @layout.Global("Search movies") {

View File

@@ -1,7 +1,7 @@
package page package page
import "projectreshoot/view/layout" import "projectreshoot/internal/view/layout"
import "projectreshoot/contexts" import "projectreshoot/pkg/contexts"
templ Profile() { templ Profile() {
{{ user := contexts.GetUser(ctx) }} {{ user := contexts.GetUser(ctx) }}

View File

@@ -1,7 +1,7 @@
package page package page
import "projectreshoot/view/layout" import "projectreshoot/internal/view/layout"
import "projectreshoot/view/component/form" import "projectreshoot/internal/view/component/form"
// Returns the login page // Returns the login page
templ Register() { templ Register() {

View File

@@ -1,84 +0,0 @@
package jwt
import (
"time"
"projectreshoot/config"
"projectreshoot/db"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/pkg/errors"
)
// Generates an access token for the provided user
func GenerateAccessToken(
config *config.Config,
user *db.User,
fresh bool,
rememberMe bool,
) (tokenStr string, exp int64, err error) {
issuedAt := time.Now().Unix()
expiresAt := issuedAt + (config.AccessTokenExpiry * 60)
var freshExpiresAt int64
if fresh {
freshExpiresAt = issuedAt + (config.TokenFreshTime * 60)
} else {
freshExpiresAt = issuedAt
}
var ttl string
if rememberMe {
ttl = "exp"
} else {
ttl = "session"
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
jwt.MapClaims{
"iss": config.TrustedHost,
"scope": "access",
"ttl": ttl,
"jti": uuid.New(),
"iat": issuedAt,
"exp": expiresAt,
"fresh": freshExpiresAt,
"sub": user.ID,
})
signedToken, err := token.SignedString([]byte(config.SecretKey))
if err != nil {
return "", 0, errors.Wrap(err, "token.SignedString")
}
return signedToken, expiresAt, nil
}
// Generates a refresh token for the provided user
func GenerateRefreshToken(
config *config.Config,
user *db.User,
rememberMe bool,
) (tokenStr string, exp int64, err error) {
issuedAt := time.Now().Unix()
expiresAt := issuedAt + (config.RefreshTokenExpiry * 60)
var ttl string
if rememberMe {
ttl = "exp"
} else {
ttl = "session"
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
jwt.MapClaims{
"iss": config.TrustedHost,
"scope": "refresh",
"ttl": ttl,
"jti": uuid.New(),
"iat": issuedAt,
"exp": expiresAt,
"sub": user.ID,
})
signedToken, err := token.SignedString([]byte(config.SecretKey))
if err != nil {
return "", 0, errors.Wrap(err, "token.SignedString")
}
return signedToken, expiresAt, nil
}

View File

@@ -1,268 +0,0 @@
package jwt
import (
"context"
"fmt"
"time"
"projectreshoot/config"
"projectreshoot/db"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/pkg/errors"
)
// Parse an access token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func ParseAccessToken(
config *config.Config,
ctx context.Context,
tx *db.SafeTX,
tokenString string,
) (*AccessToken, error) {
if tokenString == "" {
return nil, errors.New("Access token string not provided")
}
claims, err := parseToken(config.SecretKey, tokenString)
if err != nil {
return nil, errors.Wrap(err, "parseToken")
}
expiry, err := checkTokenExpired(claims["exp"])
if err != nil {
return nil, errors.Wrap(err, "checkTokenExpired")
}
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
if err != nil {
return nil, errors.Wrap(err, "checkTokenIssuer")
}
ttl, err := getTokenTTL(claims["ttl"])
if err != nil {
return nil, errors.Wrap(err, "getTokenTTL")
}
scope, err := getTokenScope(claims["scope"])
if err != nil {
return nil, errors.Wrap(err, "getTokenScope")
}
if scope != "access" {
return nil, errors.New("Token is not an Access token")
}
issuedAt, err := getIssuedTime(claims["iat"])
if err != nil {
return nil, errors.Wrap(err, "getIssuedTime")
}
subject, err := getTokenSubject(claims["sub"])
if err != nil {
return nil, errors.Wrap(err, "getTokenSubject")
}
fresh, err := getFreshTime(claims["fresh"])
if err != nil {
return nil, errors.Wrap(err, "getFreshTime")
}
jti, err := getTokenJTI(claims["jti"])
if err != nil {
return nil, errors.Wrap(err, "getTokenJTI")
}
token := &AccessToken{
ISS: issuer,
TTL: ttl,
EXP: expiry,
IAT: issuedAt,
SUB: subject,
Fresh: fresh,
JTI: jti,
Scope: scope,
}
valid, err := CheckTokenNotRevoked(ctx, tx, token)
if err != nil {
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
}
if !valid {
return nil, errors.New("Token has been revoked")
}
return token, nil
}
// Parse a refresh token and return a struct with all the claims. Does validation on
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func ParseRefreshToken(
config *config.Config,
ctx context.Context,
tx *db.SafeTX,
tokenString string,
) (*RefreshToken, error) {
if tokenString == "" {
return nil, errors.New("Refresh token string not provided")
}
claims, err := parseToken(config.SecretKey, tokenString)
if err != nil {
return nil, errors.Wrap(err, "parseToken")
}
expiry, err := checkTokenExpired(claims["exp"])
if err != nil {
return nil, errors.Wrap(err, "checkTokenExpired")
}
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
if err != nil {
return nil, errors.Wrap(err, "checkTokenIssuer")
}
ttl, err := getTokenTTL(claims["ttl"])
if err != nil {
return nil, errors.Wrap(err, "getTokenTTL")
}
scope, err := getTokenScope(claims["scope"])
if err != nil {
return nil, errors.Wrap(err, "getTokenScope")
}
if scope != "refresh" {
return nil, errors.New("Token is not an Refresh token")
}
issuedAt, err := getIssuedTime(claims["iat"])
if err != nil {
return nil, errors.Wrap(err, "getIssuedTime")
}
subject, err := getTokenSubject(claims["sub"])
if err != nil {
return nil, errors.Wrap(err, "getTokenSubject")
}
jti, err := getTokenJTI(claims["jti"])
if err != nil {
return nil, errors.Wrap(err, "getTokenJTI")
}
token := &RefreshToken{
ISS: issuer,
TTL: ttl,
EXP: expiry,
IAT: issuedAt,
SUB: subject,
JTI: jti,
Scope: scope,
}
valid, err := CheckTokenNotRevoked(ctx, tx, token)
if err != nil {
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
}
if !valid {
return nil, errors.New("Token has been revoked")
}
return token, nil
}
// Parse a token, validating its signing sigature and returning the claims
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
return nil, errors.Wrap(err, "jwt.Parse")
}
// Token decoded, parse the claims
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("Failed to parse claims")
}
return claims, nil
}
// Check if a token is expired. Returns the expiry if not expired
func checkTokenExpired(expiry interface{}) (int64, error) {
// Coerce the expiry to a float64 to avoid scientific notation
expFloat, ok := expiry.(float64)
if !ok {
return 0, errors.New("Missing or invalid 'exp' claim")
}
// Convert to the int64 time we expect :)
expiryTime := int64(expFloat)
// Check if its expired
isExpired := time.Now().After(time.Unix(expiryTime, 0))
if isExpired {
return 0, errors.New("Token has expired")
}
return expiryTime, nil
}
// Check if a token has a valid issuer. Returns the issuer if valid
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
issuerVal, ok := issuer.(string)
if !ok {
return "", errors.New("Missing or invalid 'iss' claim")
}
if issuer != trustedHost {
return "", errors.New("Issuer does not matched trusted host")
}
return issuerVal, nil
}
// Check the scope matches the expected scope. Returns scope if true
func getTokenScope(scope interface{}) (string, error) {
scopeStr, ok := scope.(string)
if !ok {
return "", errors.New("Missing or invalid 'scope' claim")
}
return scopeStr, nil
}
// Get the TTL of the token, either "session" or "exp"
func getTokenTTL(ttl interface{}) (string, error) {
ttlStr, ok := ttl.(string)
if !ok {
return "", errors.New("Missing or invalid 'ttl' claim")
}
if ttlStr != "exp" && ttlStr != "session" {
return "", errors.New("TTL value is not recognised")
}
return ttlStr, nil
}
// Get the time the token was issued at
func getIssuedTime(issued interface{}) (int64, error) {
// Same float64 -> int64 trick as expiry
issuedFloat, ok := issued.(float64)
if !ok {
return 0, errors.New("Missing or invalid 'iat' claim")
}
issuedAt := int64(issuedFloat)
return issuedAt, nil
}
// Get the freshness expiry timestamp
func getFreshTime(fresh interface{}) (int64, error) {
freshUntil, ok := fresh.(float64)
if !ok {
return 0, errors.New("Missing or invalid 'fresh' claim")
}
return int64(freshUntil), nil
}
// Get the subject of the token
func getTokenSubject(sub interface{}) (int, error) {
subject, ok := sub.(float64)
if !ok {
return 0, errors.New("Missing or invalid 'sub' claim")
}
return int(subject), nil
}
// Get the JTI of the token
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
jtiStr, ok := jti.(string)
if !ok {
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
}
jtiUUID, err := uuid.Parse(jtiStr)
if err != nil {
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
}
return jtiUUID, nil
}

View File

@@ -1,33 +0,0 @@
package jwt
import (
"context"
"projectreshoot/db"
"github.com/pkg/errors"
)
// Revoke a token by adding it to the database
func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error {
jti := t.GetJTI()
exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err := tx.Exec(ctx, query, jti, exp)
if err != nil {
return errors.Wrap(err, "tx.Exec")
}
return nil
}
// Check if a token has been revoked. Returns true if not revoked.
func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) {
jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := tx.Query(ctx, query, jti)
if err != nil {
return false, errors.Wrap(err, "tx.Query")
}
defer rows.Close()
revoked := rows.Next()
return !revoked, nil
}

View File

@@ -1,73 +0,0 @@
package jwt
import (
"context"
"projectreshoot/db"
"github.com/google/uuid"
"github.com/pkg/errors"
)
type Token interface {
GetJTI() uuid.UUID
GetEXP() int64
GetScope() string
GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error)
}
// Access token
type AccessToken struct {
ISS string // Issuer, generally TrustedHost
IAT int64 // Time issued at
EXP int64 // Time expiring at
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Fresh int64 // Time freshness expiring at
Scope string // Should be "access"
}
// Refresh token
type RefreshToken struct {
ISS string // Issuer, generally TrustedHost
IAT int64 // Time issued at
EXP int64 // Time expiring at
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Scope string // Should be "refresh"
}
func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
user, err := db.GetUserFromID(ctx, tx, a.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return user, nil
}
func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
user, err := db.GetUserFromID(ctx, tx, r.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return user, nil
}
func (a AccessToken) GetJTI() uuid.UUID {
return a.JTI
}
func (r RefreshToken) GetJTI() uuid.UUID {
return r.JTI
}
func (a AccessToken) GetEXP() int64 {
return a.EXP
}
func (r RefreshToken) GetEXP() int64 {
return r.EXP
}
func (a AccessToken) GetScope() string {
return a.Scope
}
func (r RefreshToken) GetScope() string {
return r.Scope
}

View File

@@ -1,84 +0,0 @@
package logging
import (
"io"
"os"
"path/filepath"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/pkgerrors"
)
// Takes a log level as string and converts it to a zerolog.Level interface.
// If the string is not a valid input it will return zerolog.InfoLevel
func GetLogLevel(level string) zerolog.Level {
levels := map[string]zerolog.Level{
"trace": zerolog.TraceLevel,
"debug": zerolog.DebugLevel,
"info": zerolog.InfoLevel,
"warn": zerolog.WarnLevel,
"error": zerolog.ErrorLevel,
"fatal": zerolog.FatalLevel,
"panic": zerolog.PanicLevel,
}
logLevel, valid := levels[level]
if !valid {
return zerolog.InfoLevel
}
return logLevel
}
// Returns a pointer to a new log file with the specified path.
// Remember to call file.Close() when finished writing to the log file
func GetLogFile(path string) (*os.File, error) {
logPath := filepath.Join(path, "server.log")
file, err := os.OpenFile(
logPath,
os.O_APPEND|os.O_CREATE|os.O_WRONLY,
0663,
)
if err != nil {
return nil, errors.Wrap(err, "os.OpenFile")
}
return file, nil
}
// Get a pointer to a new zerolog.Logger with the specified level and output
// Can provide a file, writer or both. Must provide at least one of the two
func GetLogger(
logLevel zerolog.Level,
w io.Writer,
logFile *os.File,
logDir string,
) (*zerolog.Logger, error) {
if w == nil && logFile == nil {
return nil, errors.New("No Writer provided for log output.")
}
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
var consoleWriter zerolog.ConsoleWriter
if w != nil {
consoleWriter = zerolog.ConsoleWriter{Out: w}
}
var output io.Writer
if logFile != nil {
if w != nil {
output = zerolog.MultiLevelWriter(logFile, consoleWriter)
} else {
output = logFile
}
} else {
output = consoleWriter
}
logger := zerolog.New(output).
With().
Timestamp().
Logger().
Level(logLevel)
return &logger, nil
}

233
main.go
View File

@@ -1,233 +0,0 @@
package main
import (
"context"
"embed"
"flag"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"sync"
"sync/atomic"
"syscall"
"time"
"projectreshoot/config"
"projectreshoot/db"
"projectreshoot/logging"
"projectreshoot/server"
"projectreshoot/tests"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
//go:embed static/*
var embeddedStatic embed.FS
// Gets the static files
func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) {
if _, err := os.Stat("static"); err == nil {
// Use actual filesystem in development
logger.Debug().Msg("Using filesystem for static files")
return http.Dir("static"), nil
} else {
// Use embedded filesystem in production
logger.Debug().Msg("Using embedded static files")
subFS, err := fs.Sub(embeddedStatic, "static")
if err != nil {
return nil, errors.Wrap(err, "fs.Sub")
}
return http.FS(subFS), nil
}
}
var maint uint32 // atomic: 1 if in maintenance mode
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
func handleMaintSignals(
conn *db.SafeConn,
srv *http.Server,
logger *zerolog.Logger,
config *config.Config,
) {
logger.Debug().Msg("Starting signal listener")
ch := make(chan os.Signal, 1)
srv.RegisterOnShutdown(func() {
logger.Debug().Msg("Shutting down signal listener")
close(ch)
})
go func() {
for sig := range ch {
switch sig {
case syscall.SIGUSR1:
if atomic.LoadUint32(&maint) != 1 {
atomic.StoreUint32(&maint, 1)
logger.Info().Msg("Signal received: Starting maintenance")
logger.Info().Msg("Attempting to acquire database lock")
conn.Pause(config.DBLockTimeout * time.Second)
}
case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 {
logger.Info().Msg("Signal received: Maintenance over")
logger.Info().Msg("Releasing database lock")
conn.Resume()
atomic.StoreUint32(&maint, 0)
}
}
}
}()
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
}
// Initializes and runs the server
func run(ctx context.Context, w io.Writer, args map[string]string) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
config, err := config.GetConfig(args)
if err != nil {
return errors.Wrap(err, "server.GetConfig")
}
// Return the version of the database required
if args["dbver"] == "true" {
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
return nil
}
var logfile *os.File = nil
if config.LogOutput == "both" || config.LogOutput == "file" {
logfile, err = logging.GetLogFile(config.LogDir)
if err != nil {
return errors.Wrap(err, "logging.GetLogFile")
}
defer logfile.Close()
}
var consoleWriter io.Writer
if config.LogOutput == "both" || config.LogOutput == "console" {
consoleWriter = w
}
logger, err := logging.GetLogger(
config.LogLevel,
consoleWriter,
logfile,
config.LogDir,
)
if err != nil {
return errors.Wrap(err, "logging.GetLogger")
}
logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database")
var conn *db.SafeConn
if args["test"] == "true" {
logger.Debug().Msg("Server in test mode, using test database")
ver, err := strconv.ParseInt(config.DBName, 10, 0)
if err != nil {
return errors.Wrap(err, "strconv.ParseInt")
}
testconn, err := tests.SetupTestDB(ver)
if err != nil {
return errors.Wrap(err, "tests.SetupTestDB")
}
conn = db.MakeSafe(testconn, logger)
} else {
conn, err = db.ConnectToDatabase(config.DBName, logger)
if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase")
}
}
defer conn.Close()
logger.Debug().Msg("Getting static files")
staticFS, err := getStaticFiles(logger)
if err != nil {
return errors.Wrap(err, "getStaticFiles")
}
logger.Debug().Msg("Setting up HTTP server")
srv := server.NewServer(config, logger, conn, &staticFS, &maint)
httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, config.Port),
Handler: srv,
ReadHeaderTimeout: config.ReadHeaderTimeout * time.Second,
WriteTimeout: config.WriteTimeout * time.Second,
IdleTimeout: config.IdleTimeout * time.Second,
}
// Runs function for testing in dev if --test flag true
if args["tester"] == "true" {
logger.Debug().Msg("Running tester function")
test(config, logger, conn, httpServer)
return nil
}
// Setups a channel to listen for os.Signal
handleMaintSignals(conn, httpServer, logger, config)
// Runs the http server
logger.Debug().Msg("Starting up the HTTP server")
go func() {
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error().Err(err).Msg("Error listening and serving")
}
}()
// Handles graceful shutdown
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
logger.Error().Err(err).Msg("Error shutting down server")
}
}()
wg.Wait()
logger.Info().Msg("Shutting down")
return nil
}
// Start of runtime. Parse commandline arguments & flags, Initializes context
// and starts the server
func main() {
// Parse commandline args
host := flag.String("host", "", "Override host to listen on")
port := flag.String("port", "", "Override port to listen on")
test := flag.Bool("test", false, "Run server in test mode")
tester := flag.Bool("tester", false, "Run tester function instead of main program")
dbver := flag.Bool("dbver", false, "Get the version of the database required")
loglevel := flag.String("loglevel", "", "Set log level")
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
flag.Parse()
// Map the args for easy access
args := map[string]string{
"host": *host,
"port": *port,
"test": strconv.FormatBool(*test),
"tester": strconv.FormatBool(*tester),
"dbver": strconv.FormatBool(*dbver),
"loglevel": *loglevel,
"logoutput": *logoutput,
}
// Start the server
ctx := context.Background()
if err := run(ctx, os.Stdout, args); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
}

View File

@@ -1,145 +0,0 @@
package main
import (
"bytes"
"context"
"fmt"
"net/http"
"os"
"strings"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func Test_main(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
args := map[string]string{"test": "true"}
var stdout bytes.Buffer
os.Setenv("SECRET_KEY", ".")
os.Setenv("TMDB_API_TOKEN", ".")
os.Setenv("HOST", "127.0.0.1")
os.Setenv("PORT", "3232")
runSrvErr := make(chan error)
go func() {
if err := run(ctx, &stdout, args); err != nil {
runSrvErr <- err
return
}
}()
go func() {
err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz")
if err != nil {
runSrvErr <- err
return
}
runSrvErr <- nil
}()
select {
case err := <-runSrvErr:
if err != nil {
t.Fatalf("Error starting test server: %s", err)
return
}
t.Log("Test server started")
}
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
done := make(chan bool)
go func() {
expected := "Global database lock acquired"
for {
if strings.Contains(stdout.String(), expected) {
done <- true
return
}
time.Sleep(100 * time.Millisecond)
}
}()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
proc.Signal(syscall.SIGUSR1)
select {
case <-done:
t.Log("found")
case <-time.After(250 * time.Millisecond):
t.Errorf("Not found")
}
})
t.Run("SIGUSR2 releases database global lock", func(t *testing.T) {
done := make(chan bool)
go func() {
expected := "Global database lock released"
for {
if strings.Contains(stdout.String(), expected) {
done <- true
return
}
time.Sleep(100 * time.Millisecond)
}
}()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
proc.Signal(syscall.SIGUSR2)
select {
case <-done:
t.Log("found")
case <-time.After(250 * time.Millisecond):
t.Errorf("Not found")
}
})
}
func waitForReady(
ctx context.Context,
timeout time.Duration,
endpoint string,
) error {
client := http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
endpoint,
nil,
)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
fmt.Printf("Error making request: %s\n", err.Error())
time.Sleep(250 * time.Millisecond)
continue
}
if resp.StatusCode == http.StatusOK {
fmt.Println("Endpoint is ready!")
resp.Body.Close()
return nil
}
resp.Body.Close()
select {
case <-ctx.Done():
return ctx.Err()
default:
if time.Since(startTime) >= timeout {
return fmt.Errorf("timeout reached while waiting for endpoint")
}
// wait a little while between checks
time.Sleep(250 * time.Millisecond)
}
}
}

View File

@@ -1,148 +0,0 @@
package middleware
import (
"io"
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
"projectreshoot/contexts"
"projectreshoot/db"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthenticationMiddleware(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := contexts.GetUser(r.Context())
if user == nil {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(strconv.Itoa(0)))
return
} else {
w.WriteHeader(http.StatusOK)
w.Write([]byte(strconv.Itoa(user.ID)))
}
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
authHandler := Authentication(logger, cfg, sconn, testHandler, &maint)
require.NoError(t, err)
server := httptest.NewServer(authHandler)
defer server.Close()
tokens := getTokens()
tests := []struct {
name string
id int
accessToken string
refreshToken string
expectedCode int
}{
{
name: "Valid Access Token (Fresh)",
id: 1,
accessToken: tokens["accessFresh"],
refreshToken: "",
expectedCode: http.StatusOK,
},
{
name: "Valid Access Token (Unfresh)",
id: 1,
accessToken: tokens["accessUnfresh"],
refreshToken: tokens["refreshExpired"],
expectedCode: http.StatusOK,
},
{
name: "Valid Refresh Token (Triggers Refresh)",
id: 1,
accessToken: tokens["accessExpired"],
refreshToken: tokens["refreshValid"],
expectedCode: http.StatusOK,
},
{
name: "Both tokens expired",
accessToken: tokens["accessExpired"],
refreshToken: tokens["refreshExpired"],
expectedCode: http.StatusUnauthorized,
},
{
name: "Access token revoked",
accessToken: tokens["accessRevoked"],
refreshToken: "",
expectedCode: http.StatusUnauthorized,
},
{
name: "Refresh token revoked",
accessToken: "",
refreshToken: tokens["refreshRevoked"],
expectedCode: http.StatusUnauthorized,
},
{
name: "Invalid Tokens",
accessToken: tokens["invalid"],
refreshToken: tokens["invalid"],
expectedCode: http.StatusUnauthorized,
},
{
name: "No Tokens",
accessToken: "",
refreshToken: "",
expectedCode: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{}
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
// Add cookies if provided
if tt.accessToken != "" {
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
}
if tt.refreshToken != "" {
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
}
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, strconv.Itoa(tt.id), string(body))
})
}
}
// get the tokens to test with
func getTokens() map[string]string {
tokens := map[string]string{
"accessFresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzIyMTAsImZyZXNoIjo0ODk1NjcyMjEwLCJpYXQiOjE3Mzk2NzIyMTAsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImE4Njk2YWM4LTg3OWMtNDdkNC1iZWM2LTRlY2Y4MTRiZThiZiIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.6nAquDY0JBLPdaJ9q_sMpKj1ISG4Vt2U05J57aoPue8",
"accessUnfresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJmcmVzaCI6MTczOTY3NTY3MSwiaWF0IjoxNzM5Njc1NjcxLCJpc3MiOiIxMjcuMC4wLjEiLCJqdGkiOiJjOGNhZmFjNy0yODkzLTQzNzMtOTI4ZS03MGUwODJkYmM2MGIiLCJzY29wZSI6ImFjY2VzcyIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.plWQVFwHlhXUYI5utS7ny1JfXjJSFrigkq-PnTHD5VY",
"accessExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImZyZXNoIjoxNzM5NjcyMjQ4LCJpYXQiOjE3Mzk2NzIyNDgsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjgxYzA1YzBjLTJhOGItNGQ2MC04Yzc4LWY2ZTQxODYxZDFmNCIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.iI1f17kKTuFDEMEYltJRIwRYgYQ-_nF9Wsn0KR6x77Q",
"refreshValid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImlhdCI6MTczOTY3MTkyMiwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTUxMTY3ZWEtNDA3OS00ZTczLTkzZDQtNTgwZDMzODRjZDU4Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.tvtqQ8Z4WrYWHHb0MaEPdsU2FT2KLRE1zHOv3ipoFyc",
"refreshExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImlhdCI6MTczOTY3MjI0OCwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTg5YTc5MTYtZGEzYi00YmJhLWI3ZDMtOWI1N2ViNjRhMmU0Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.rH_fytC7Duxo598xacu820pQKF9ELbG8674h_bK_c4I",
"accessRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImZyZXNoIjoxNzM5NjcxOTIyLCJpYXQiOjE3Mzk2NzE5MjIsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjBhNmIzMzhlLTkzMGEtNDNmZS04ZjcwLTFhNmRhZWQyNTZmYSIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.mZLuCp9amcm2_CqYvbHPlk86nfiuy_Or8TlntUCw4Qs",
"refreshRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJpYXQiOjE3Mzk2NzU2NzEsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImI3ZmE1MWRjLTg1MzItNDJlMS04NzU2LTVkMjViZmIyMDAzYSIsInNjb3BlIjoicmVmcmVzaCIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.5Q9yDZN5FubfCWHclUUZEkJPOUHcOEpVpgcUK-ameHo",
"invalid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0ODUxNDA5ODQsImlhdCI6MTQ4NTEzNzM4NCwiaXNzIjoiYWNtZS5jb20iLCJzdWIiOiIyOWFjMGMxOC0wYjRhLTQyY2YtODJmYy0wM2Q1NzAzMThhMWQiLCJhcHBsaWNhdGlvbklkIjoiNzkxMDM3MzQtOTdhYi00ZDFhLWFmMzctZTAwNmQwNWQyOTUyIiwicm9sZXMiOltdfQ.Mp0Pcwsz5VECK11Kf2ZZNF_SMKu5CgBeLN9ZOP04kZo",
}
return tokens
}

View File

@@ -1,87 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
"projectreshoot/db"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPageLoginRequired(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
loginRequiredHandler := LoginReq(testHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler)
defer server.Close()
tokens := getTokens()
tests := []struct {
name string
accessToken string
refreshToken string
expectedCode int
}{
{
name: "Valid Login",
accessToken: tokens["accessFresh"],
refreshToken: "",
expectedCode: http.StatusOK,
},
{
name: "Expired login",
accessToken: tokens["accessExpired"],
refreshToken: tokens["refreshExpired"],
expectedCode: http.StatusUnauthorized,
},
{
name: "No login",
accessToken: "",
refreshToken: "",
expectedCode: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{}
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
// Add cookies if provided
if tt.accessToken != "" {
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
}
if tt.refreshToken != "" {
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
}
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, resp.StatusCode)
})
}
}

View File

@@ -1,94 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
"projectreshoot/db"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReauthRequired(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
conn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
reauthRequiredHandler := FreshReq(testHandler)
loginRequiredHandler := LoginReq(reauthRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler)
defer server.Close()
tokens := getTokens()
tests := []struct {
name string
accessToken string
refreshToken string
expectedCode int
}{
{
name: "Fresh Login",
accessToken: tokens["accessFresh"],
refreshToken: "",
expectedCode: http.StatusOK,
},
{
name: "Unfresh Login",
accessToken: tokens["accessUnfresh"],
refreshToken: "",
expectedCode: 444,
},
{
name: "Expired login",
accessToken: tokens["accessExpired"],
refreshToken: tokens["refreshExpired"],
expectedCode: http.StatusUnauthorized,
},
{
name: "No login",
accessToken: "",
refreshToken: "",
expectedCode: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{}
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
// Add cookies if provided
if tt.accessToken != "" {
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
}
if tt.refreshToken != "" {
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
}
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, resp.StatusCode)
})
}
}

View File

@@ -5,12 +5,10 @@ import (
"os" "os"
"time" "time"
"projectreshoot/logging" "git.haelnorr.com/h/golib/hlog"
"projectreshoot/tmdb" "git.haelnorr.com/h/golib/tmdb"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog"
) )
type Config struct { type Config struct {
@@ -28,7 +26,7 @@ type Config struct {
AccessTokenExpiry int64 // Access token expiry in minutes AccessTokenExpiry int64 // Access token expiry in minutes
RefreshTokenExpiry int64 // Refresh token expiry in minutes RefreshTokenExpiry int64 // Refresh token expiry in minutes
TokenFreshTime int64 // Time for tokens to stay fresh in minutes TokenFreshTime int64 // Time for tokens to stay fresh in minutes
LogLevel zerolog.Level // Log level for global logging. Defaults to info LogLevel hlog.Level // Log level for global logging. Defaults to info
LogOutput string // "file", "console", or "both". Defaults to console LogOutput string // "file", "console", or "both". Defaults to console
LogDir string // Path to create log files LogDir string // Path to create log files
TMDBToken string // Read access token for TMDB API TMDBToken string // Read access token for TMDB API
@@ -41,7 +39,7 @@ func GetConfig(args map[string]string) (*Config, error) {
var ( var (
host string host string
port string port string
logLevel zerolog.Level logLevel hlog.Level
logOutput string logOutput string
valid bool valid bool
) )
@@ -57,9 +55,9 @@ func GetConfig(args map[string]string) (*Config, error) {
port = GetEnvDefault("PORT", "3010") port = GetEnvDefault("PORT", "3010")
} }
if args["loglevel"] != "" { if args["loglevel"] != "" {
logLevel = logging.GetLogLevel(args["loglevel"]) logLevel = hlog.LogLevel(args["loglevel"])
} else { } else {
logLevel = logging.GetLogLevel(GetEnvDefault("LOG_LEVEL", "info")) logLevel = hlog.LogLevel(GetEnvDefault("LOG_LEVEL", "info"))
} }
if args["logoutput"] != "" { if args["logoutput"] != "" {
opts := map[string]string{ opts := map[string]string{

View File

@@ -2,11 +2,11 @@ package contexts
import ( import (
"context" "context"
"projectreshoot/db" "projectreshoot/internal/models"
) )
type AuthenticatedUser struct { type AuthenticatedUser struct {
*db.User *models.User
Fresh int64 Fresh int64
} }

View File

@@ -4,10 +4,10 @@ import (
"net/http" "net/http"
"time" "time"
"projectreshoot/config" "projectreshoot/internal/models"
"projectreshoot/db" "projectreshoot/pkg/config"
"projectreshoot/jwt"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -58,15 +58,16 @@ func SetTokenCookies(
w http.ResponseWriter, w http.ResponseWriter,
r *http.Request, r *http.Request,
config *config.Config, config *config.Config,
user *db.User, tokenGen *jwt.TokenGenerator,
user *models.User,
fresh bool, fresh bool,
rememberMe bool, rememberMe bool,
) error { ) error {
at, atexp, err := jwt.GenerateAccessToken(config, user, fresh, rememberMe) at, atexp, err := tokenGen.NewAccess(user.ID, fresh, rememberMe)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.GenerateAccessToken") return errors.Wrap(err, "jwt.GenerateAccessToken")
} }
rt, rtexp, err := jwt.GenerateRefreshToken(config, user, rememberMe) rt, rtexp, err := tokenGen.NewRefresh(user.ID, rememberMe)
if err != nil { if err != nil {
return errors.Wrap(err, "jwt.GenerateRefreshToken") return errors.Wrap(err, "jwt.GenerateRefreshToken")
} }

20
pkg/embedfs/embedfs.go Normal file
View File

@@ -0,0 +1,20 @@
package embedfs
import (
"embed"
"io/fs"
"github.com/pkg/errors"
)
//go:embed files/*
var embeddedFiles embed.FS
// Gets the embedded files
func GetEmbeddedFS() (fs.FS, error) {
subFS, err := fs.Sub(embeddedFiles, "files")
if err != nil {
return nil, errors.Wrap(err, "fs.Sub")
}
return subFS, nil
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

File diff suppressed because it is too large Load Diff

View File

Before

Width:  |  Height:  |  Size: 834 B

After

Width:  |  Height:  |  Size: 834 B

View File

@@ -1,41 +0,0 @@
package server
import (
"net/http"
"projectreshoot/config"
"projectreshoot/db"
"projectreshoot/middleware"
"github.com/rs/zerolog"
)
// Returns a new http.Handler with all the routes and middleware added
func NewServer(
config *config.Config,
logger *zerolog.Logger,
conn *db.SafeConn,
staticFS *http.FileSystem,
maint *uint32,
) http.Handler {
mux := http.NewServeMux()
addRoutes(
mux,
logger,
config,
conn,
staticFS,
)
var handler http.Handler = mux
// Add middleware here, must be added in reverse order of execution
// i.e. First in list will get executed last during the request handling
handler = middleware.Logging(logger, handler)
handler = middleware.Authentication(logger, config, conn, handler, maint)
// Gzip
handler = middleware.Gzip(handler, config.GZIP)
// Start the timer for the request chain so logger can have accurate info
handler = middleware.StartTimer(handler)
return handler
}

View File

@@ -1,18 +0,0 @@
package tests
import (
"os"
"projectreshoot/config"
"github.com/pkg/errors"
)
func TestConfig() (*config.Config, error) {
os.Setenv("SECRET_KEY", ".")
os.Setenv("TMDB_API_TOKEN", ".")
cfg, err := config.GetConfig(map[string]string{})
if err != nil {
return nil, errors.Wrap(err, "config.GetConfig")
}
return cfg, nil
}

View File

@@ -1,90 +0,0 @@
package tests
import (
"context"
"database/sql"
"io/fs"
"os"
"path/filepath"
"github.com/pkg/errors"
"github.com/pressly/goose/v3"
_ "modernc.org/sqlite"
)
func findMigrations() (*fs.FS, error) {
dir, err := os.Getwd()
if err != nil {
return nil, err
}
for {
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations"))
return &migrationsdir, nil
}
parent := filepath.Dir(dir)
if parent == dir { // Reached root
return nil, errors.New("Unable to locate migrations directory")
}
dir = parent
}
}
func findTestData() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
return filepath.Join(dir, "tests", "testdata.sql"), nil
}
parent := filepath.Dir(dir)
if parent == dir { // Reached root
return "", errors.New("Unable to locate test data")
}
dir = parent
}
}
func SetupTestDB(version int64) (*sql.DB, error) {
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
migrations, err := findMigrations()
if err != nil {
return nil, errors.Wrap(err, "findMigrations")
}
provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations)
if err != nil {
return nil, errors.Wrap(err, "goose.NewProvider")
}
ctx := context.Background()
if _, err := provider.UpTo(ctx, version); err != nil {
return nil, errors.Wrap(err, "provider.UpTo")
}
// Load the test data
dataPath, err := findTestData()
if err != nil {
return nil, errors.Wrap(err, "findSchema")
}
sqlBytes, err := os.ReadFile(dataPath)
if err != nil {
return nil, errors.Wrap(err, "os.ReadFile")
}
dataSQL := string(sqlBytes)
_, err = conn.Exec(dataSQL)
if err != nil {
return nil, errors.Wrap(err, "tx.Exec")
}
return conn, nil
}

View File

@@ -1,33 +0,0 @@
package tests
import (
"testing"
"github.com/rs/zerolog"
)
type TLogWriter struct {
t *testing.T
}
// Write implements the io.Writer interface for TLogWriter.
func (w *TLogWriter) Write(p []byte) (n int, err error) {
w.t.Logf("%s", p)
return len(p), nil
}
// Return a fake logger to satisfy functions that expect one
func NilLogger() *zerolog.Logger {
logger := zerolog.New(nil)
return &logger
}
// Return a logger that makes use of the T.Log method to enable debugging tests
func DebugLogger(t *testing.T) *zerolog.Logger {
logger := zerolog.New(GetTLogWriter(t))
return &logger
}
func GetTLogWriter(t *testing.T) *TLogWriter {
return &TLogWriter{t: t}
}

View File

@@ -1,3 +0,0 @@
INSERT INTO users VALUES(1,'testuser','hashedpassword',1738995274, 'bio');
INSERT INTO jwtblacklist VALUES('0a6b338e-930a-43fe-8f70-1a6daed256fa', 33299675344);
INSERT INTO jwtblacklist VALUES('b7fa51dc-8532-42e1-8756-5d25bfb2003a', 33299675344);

View File

@@ -1,32 +0,0 @@
package tmdb
import (
"encoding/json"
"github.com/pkg/errors"
)
type Config struct {
Image Image `json:"images"`
}
type Image struct {
BaseURL string `json:"base_url"`
SecureBaseURL string `json:"secure_base_url"`
BackdropSizes []string `json:"backdrop_sizes"`
LogoSizes []string `json:"logo_sizes"`
PosterSizes []string `json:"poster_sizes"`
ProfileSizes []string `json:"profile_sizes"`
StillSizes []string `json:"still_sizes"`
}
func GetConfig(token string) (*Config, error) {
url := "https://api.themoviedb.org/3/configuration"
data, err := tmdbGet(url, token)
if err != nil {
return nil, errors.Wrap(err, "tmdbGet")
}
config := Config{}
json.Unmarshal(data, &config)
return &config, nil
}

View File

@@ -1,54 +0,0 @@
package tmdb
import (
"encoding/json"
"fmt"
"github.com/pkg/errors"
)
type Credits struct {
ID int32 `json:"id"`
Cast []Cast `json:"cast"`
Crew []Crew `json:"crew"`
}
type Cast struct {
Adult bool `json:"adult"`
Gender int `json:"gender"`
ID int32 `json:"id"`
KnownFor string `json:"known_for_department"`
Name string `json:"name"`
OriginalName string `json:"original_name"`
Popularity int `json:"popularity"`
Profile string `json:"profile_path"`
CastID int32 `json:"cast_id"`
Character string `json:"character"`
CreditID string `json:"credit_id"`
Order int `json:"order"`
}
type Crew struct {
Adult bool `json:"adult"`
Gender int `json:"gender"`
ID int32 `json:"id"`
KnownFor string `json:"known_for_department"`
Name string `json:"name"`
OriginalName string `json:"original_name"`
Popularity int `json:"popularity"`
Profile string `json:"profile_path"`
CreditID string `json:"credit_id"`
Department string `json:"department"`
Job string `json:"job"`
}
func GetCredits(movieid int32, token string) (*Credits, error) {
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v/credits?language=en-US", movieid)
data, err := tmdbGet(url, token)
if err != nil {
return nil, errors.Wrap(err, "tmdbGet")
}
credits := Credits{}
json.Unmarshal(data, &credits)
return &credits, nil
}

Some files were not shown because too many files have changed in this diff Show More