Merge pull request #21 from Haelnorr/staging
Updates to database integration
This commit is contained in:
11
.githooks/pre-push
Normal file
11
.githooks/pre-push
Normal file
@@ -0,0 +1,11 @@
|
||||
#!/bin/sh
|
||||
protected_branches=("master" "staging")
|
||||
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
||||
|
||||
for branch in "${protected_branches[@]}"; do
|
||||
if [ "$current_branch" = "$branch" ]; then
|
||||
echo "Direct pushes to '$branch' are not allowed. Use a pull request instead."
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
exit 0
|
||||
14
.github/workflows/deploy_production.yaml
vendored
14
.github/workflows/deploy_production.yaml
vendored
@@ -33,11 +33,15 @@ jobs:
|
||||
- name: Build the binary
|
||||
run: make build SUFFIX=-production-$GITHUB_SHA
|
||||
|
||||
- name: Build the migration binary
|
||||
run: make migrate SUFFIX=-production-$GITHUB_SHA
|
||||
|
||||
- name: Deploy to Server
|
||||
env:
|
||||
USER: deploy
|
||||
HOST: projectreshoot.com
|
||||
DIR: /home/deploy/releases/production
|
||||
MIG_DIR: /home/deploy/migration-bin
|
||||
DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }}
|
||||
run: |
|
||||
mkdir -p ~/.ssh
|
||||
@@ -49,7 +53,13 @@ jobs:
|
||||
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
||||
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
||||
|
||||
scp -i ~/.ssh/id_ed25519 projectreshoot-production-${GITHUB_SHA} $USER@$HOST:$DIR
|
||||
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_production.sh $GITHUB_SHA
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
||||
scp -i ~/.ssh/id_ed25519 prmigrate-production-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
||||
|
||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrationcleanup.sh $USER@$HOST:$MIG_DIR
|
||||
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy.sh $GITHUB_SHA production
|
||||
|
||||
14
.github/workflows/deploy_staging.yaml
vendored
14
.github/workflows/deploy_staging.yaml
vendored
@@ -33,11 +33,15 @@ jobs:
|
||||
- name: Build the binary
|
||||
run: make build SUFFIX=-staging-$GITHUB_SHA
|
||||
|
||||
- name: Build the migration binary
|
||||
run: make migrate SUFFIX=-staging-$GITHUB_SHA
|
||||
|
||||
- name: Deploy to Server
|
||||
env:
|
||||
USER: deploy
|
||||
HOST: projectreshoot.com
|
||||
DIR: /home/deploy/releases/staging
|
||||
MIG_DIR: /home/deploy/migration-bin
|
||||
DEPLOY_SSH_PRIVATE_KEY: ${{ secrets.DEPLOY_SSH_PRIVATE_KEY }}
|
||||
run: |
|
||||
mkdir -p ~/.ssh
|
||||
@@ -49,7 +53,13 @@ jobs:
|
||||
echo " UserKnownHostsFile /dev/null" >> ~/.ssh/config
|
||||
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $DIR
|
||||
|
||||
scp -i ~/.ssh/id_ed25519 projectreshoot-staging-${GITHUB_SHA} $USER@$HOST:$DIR
|
||||
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy_staging.sh $GITHUB_SHA
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST mkdir -p $MIG_DIR
|
||||
scp -i ~/.ssh/id_ed25519 prmigrate-staging-${GITHUB_SHA} $USER@$HOST:$MIG_DIR
|
||||
|
||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/backup.sh $USER@$HOST:$MIG_DIR
|
||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrate.sh $USER@$HOST:$MIG_DIR
|
||||
scp -i ~/.ssh/id_ed25519 ./deploy/db/migrationcleanup.sh $USER@$HOST:$MIG_DIR
|
||||
|
||||
ssh -i ~/.ssh/id_ed25519 $USER@$HOST 'bash -s' < ./deploy/deploy.sh $GITHUB_SHA staging
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,7 +2,9 @@
|
||||
query.sql
|
||||
*.db
|
||||
.logs/
|
||||
server.log
|
||||
tmp/
|
||||
prmigrate
|
||||
projectreshoot
|
||||
static/css/output.css
|
||||
view/**/*_templ.go
|
||||
|
||||
11
Makefile
11
Makefile
@@ -1,5 +1,6 @@
|
||||
# Makefile
|
||||
.PHONY: build
|
||||
.PHONY: migrate
|
||||
|
||||
BINARY_NAME=projectreshoot
|
||||
|
||||
@@ -17,14 +18,20 @@ dev:
|
||||
|
||||
tester:
|
||||
go mod tidy && \
|
||||
go run . --port 3232 --test --loglevel trace
|
||||
go run . --port 3232 --tester --loglevel trace
|
||||
|
||||
test:
|
||||
rm -f **/.projectreshoot-test-database.db && \
|
||||
go mod tidy && \
|
||||
templ generate && \
|
||||
go generate && \
|
||||
go test .
|
||||
go test ./db
|
||||
go test ./middleware
|
||||
|
||||
clean:
|
||||
go clean
|
||||
|
||||
migrate:
|
||||
go mod tidy && \
|
||||
go generate && \
|
||||
go build -ldflags="-w -s" -o prmigrate${SUFFIX} ./migrate
|
||||
|
||||
@@ -21,7 +21,8 @@ type Config struct {
|
||||
ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
|
||||
WriteTimeout time.Duration // Timeout for writing requests in seconds
|
||||
IdleTimeout time.Duration // Timeout for idle connections in seconds
|
||||
DBName string // Filename of the db (doesnt include file extension)
|
||||
DBName string // Filename of the db - hardcoded and doubles as DB version
|
||||
DBLockTimeout time.Duration // Timeout for acquiring database lock
|
||||
SecretKey string // Secret key for signing tokens
|
||||
AccessTokenExpiry int64 // Access token expiry in minutes
|
||||
RefreshTokenExpiry int64 // Refresh token expiry in minutes
|
||||
@@ -33,10 +34,7 @@ type Config struct {
|
||||
|
||||
// Load the application configuration and get a pointer to the Config object
|
||||
func GetConfig(args map[string]string) (*Config, error) {
|
||||
err := godotenv.Load(".env")
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
godotenv.Load(".env")
|
||||
var (
|
||||
host string
|
||||
port string
|
||||
@@ -89,7 +87,8 @@ func GetConfig(args map[string]string) (*Config, error) {
|
||||
ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
|
||||
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
|
||||
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
|
||||
DBName: GetEnvDefault("DB_NAME", "projectreshoot"),
|
||||
DBName: "00001",
|
||||
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
|
||||
SecretKey: os.Getenv("SECRET_KEY"),
|
||||
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
|
||||
RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day
|
||||
@@ -99,7 +98,7 @@ func GetConfig(args map[string]string) (*Config, error) {
|
||||
LogDir: GetEnvDefault("LOG_DIR", ""),
|
||||
}
|
||||
|
||||
if config.SecretKey == "" {
|
||||
if config.SecretKey == "" && args["dbver"] != "true" {
|
||||
return nil, errors.New("Envar not set: SECRET_KEY")
|
||||
}
|
||||
|
||||
|
||||
@@ -3,19 +3,56 @@ package db
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// Returns a database connection handle for the Turso DB
|
||||
func ConnectToDatabase(dbName string) (*sql.DB, error) {
|
||||
// Returns a database connection handle for the DB
|
||||
func ConnectToDatabase(
|
||||
dbName string,
|
||||
logger *zerolog.Logger,
|
||||
) (*SafeConn, error) {
|
||||
file := fmt.Sprintf("file:%s.db", dbName)
|
||||
db, err := sql.Open("sqlite3", file)
|
||||
|
||||
db, err := sql.Open("sqlite", file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
return db, nil
|
||||
version, err := strconv.Atoi(dbName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "strconv.Atoi")
|
||||
}
|
||||
err = checkDBVersion(db, version)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkDBVersion")
|
||||
}
|
||||
conn := MakeSafe(db, logger)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check the database version
|
||||
func checkDBVersion(db *sql.DB, expectVer int) error {
|
||||
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
|
||||
ORDER BY version_id DESC LIMIT 1`
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "checkDBVersion")
|
||||
}
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
var version int
|
||||
err = rows.Scan(&version)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rows.Scan")
|
||||
}
|
||||
if version != expectVer {
|
||||
return errors.New("Version mismatch")
|
||||
}
|
||||
} else {
|
||||
return errors.New("No version found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
129
db/safeconn.go
Normal file
129
db/safeconn.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type SafeConn struct {
|
||||
db *sql.DB
|
||||
readLockCount uint32
|
||||
globalLockStatus uint32
|
||||
globalLockRequested uint32
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
// Make the provided db handle safe and attach a logger to it
|
||||
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
|
||||
return &SafeConn{db: db, logger: logger}
|
||||
}
|
||||
|
||||
// Attempts to acquire a global lock on the database connection
|
||||
func (conn *SafeConn) acquireGlobalLock() bool {
|
||||
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
||||
return false
|
||||
}
|
||||
conn.globalLockStatus = 1
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
// Releases a global lock on the database connection
|
||||
func (conn *SafeConn) releaseGlobalLock() {
|
||||
conn.globalLockStatus = 0
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock released")
|
||||
}
|
||||
|
||||
// Acquire a read lock on the connection. Multiple read locks can be acquired
|
||||
// at the same time
|
||||
func (conn *SafeConn) acquireReadLock() bool {
|
||||
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
||||
return false
|
||||
}
|
||||
conn.readLockCount += 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
// Release a read lock. Decrements read lock count by 1
|
||||
func (conn *SafeConn) releaseReadLock() {
|
||||
conn.readLockCount -= 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock released")
|
||||
}
|
||||
|
||||
// Starts a new transaction based on the current context. Will cancel if
|
||||
// the context is closed/cancelled/done
|
||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
|
||||
lockAcquired := make(chan struct{})
|
||||
lockCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-lockCtx.Done():
|
||||
return
|
||||
default:
|
||||
if conn.acquireReadLock() {
|
||||
close(lockAcquired)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lockAcquired:
|
||||
tx, err := conn.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
conn.releaseReadLock()
|
||||
return nil, err
|
||||
}
|
||||
return &SafeTX{tx: tx, sc: conn}, nil
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
return nil, errors.New("Transaction time out due to database lock")
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire a global lock, preventing all transactions
|
||||
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
||||
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
||||
conn.globalLockRequested = 1
|
||||
defer func() { conn.globalLockRequested = 0 }()
|
||||
timeout := time.After(timeoutAfter)
|
||||
attempt := 0
|
||||
for {
|
||||
if conn.acquireGlobalLock() {
|
||||
conn.logger.Info().Msg("Global database lock acquired")
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-timeout:
|
||||
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Release the global lock
|
||||
func (conn *SafeConn) Resume() {
|
||||
conn.releaseGlobalLock()
|
||||
conn.logger.Info().Msg("Global database lock released")
|
||||
}
|
||||
|
||||
// Close the database connection
|
||||
func (conn *SafeConn) Close() error {
|
||||
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
||||
conn.acquireGlobalLock()
|
||||
defer conn.releaseGlobalLock()
|
||||
conn.logger.Debug().Msg("Closing database connection")
|
||||
return conn.db.Close()
|
||||
}
|
||||
143
db/safeconntx_test.go
Normal file
143
db/safeconntx_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"projectreshoot/tests"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSafeConn(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
var requested sync.WaitGroup
|
||||
var engaged sync.WaitGroup
|
||||
requested.Add(1)
|
||||
engaged.Add(1)
|
||||
go func() {
|
||||
requested.Done()
|
||||
sconn.Pause(5 * time.Second)
|
||||
engaged.Done()
|
||||
}()
|
||||
requested.Wait()
|
||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
||||
tx.Commit()
|
||||
engaged.Wait()
|
||||
assert.Equal(t, uint32(1), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
||||
sconn.Resume()
|
||||
})
|
||||
t.Run("Lock abandons after timeout", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
sconn.Pause(250 * time.Millisecond)
|
||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
||||
tx.Commit()
|
||||
})
|
||||
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
var requested sync.WaitGroup
|
||||
var engaged sync.WaitGroup
|
||||
requested.Add(1)
|
||||
engaged.Add(1)
|
||||
go func() {
|
||||
requested.Done()
|
||||
sconn.Pause(5 * time.Second)
|
||||
engaged.Done()
|
||||
}()
|
||||
requested.Wait()
|
||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err = sconn.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
tx.Commit()
|
||||
engaged.Wait()
|
||||
_, err = sconn.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
sconn.Resume()
|
||||
tx, err = sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx.Commit()
|
||||
})
|
||||
}
|
||||
func TestSafeTX(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
t.Run("Commit releases lock", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
||||
tx.Commit()
|
||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
||||
})
|
||||
t.Run("Rollback releases lock", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
||||
tx.Rollback()
|
||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
||||
})
|
||||
t.Run("Multiple TX can gain read lock", func(t *testing.T) {
|
||||
tx1, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx2, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx3, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx1.Commit()
|
||||
tx2.Commit()
|
||||
tx3.Commit()
|
||||
})
|
||||
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
sconn.acquireGlobalLock()
|
||||
defer sconn.releaseGlobalLock()
|
||||
_, err := sconn.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
t.Run("Lock acquires if lock released", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
sconn.acquireGlobalLock()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
tx, err := sconn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
tx.Commit()
|
||||
wg.Done()
|
||||
}()
|
||||
sconn.releaseGlobalLock()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
61
db/safetx.go
Normal file
61
db/safetx.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Extends sql.Tx for use with SafeConn
|
||||
type SafeTX struct {
|
||||
tx *sql.Tx
|
||||
sc *SafeConn
|
||||
}
|
||||
|
||||
// Query the database inside the transaction
|
||||
func (stx *SafeTX) Query(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (*sql.Rows, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot query without a transaction")
|
||||
}
|
||||
return stx.tx.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Exec a statement on the database inside the transaction
|
||||
func (stx *SafeTX) Exec(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (sql.Result, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot exec without a transaction")
|
||||
}
|
||||
return stx.tx.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Commit the current transaction and release the read lock
|
||||
func (stx *SafeTX) Commit() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot commit without a transaction")
|
||||
}
|
||||
err := stx.tx.Commit()
|
||||
stx.tx = nil
|
||||
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Abort the current transaction, releasing the read lock
|
||||
func (stx *SafeTX) Rollback() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot rollback without a transaction")
|
||||
}
|
||||
err := stx.tx.Rollback()
|
||||
stx.tx = nil
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
20
db/user.go
20
db/user.go
@@ -1,7 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -16,16 +16,16 @@ type User struct {
|
||||
}
|
||||
|
||||
// Uses bcrypt to set the users Password_hash from the given password
|
||||
func (user *User) SetPassword(conn *sql.DB, password string) error {
|
||||
func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||
}
|
||||
user.Password_hash = string(hashedPassword)
|
||||
query := `UPDATE users SET password_hash = ? WHERE id = ?`
|
||||
_, err = conn.Exec(query, user.Password_hash, user.ID)
|
||||
_, err = tx.Exec(ctx, query, user.Password_hash, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -40,21 +40,21 @@ func (user *User) CheckPassword(password string) error {
|
||||
}
|
||||
|
||||
// Change the user's username
|
||||
func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error {
|
||||
func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error {
|
||||
query := `UPDATE users SET username = ? WHERE id = ?`
|
||||
_, err := conn.Exec(query, newUsername, user.ID)
|
||||
_, err := tx.Exec(ctx, query, newUsername, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the user's bio
|
||||
func (user *User) ChangeBio(conn *sql.DB, newBio string) error {
|
||||
func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error {
|
||||
query := `UPDATE users SET bio = ? WHERE id = ?`
|
||||
_, err := conn.Exec(query, newBio, user.ID)
|
||||
_, err := tx.Exec(ctx, query, newBio, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
@@ -8,17 +9,22 @@ import (
|
||||
)
|
||||
|
||||
// Creates a new user in the database and returns a pointer
|
||||
func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) {
|
||||
func CreateNewUser(
|
||||
ctx context.Context,
|
||||
tx *SafeTX,
|
||||
username string,
|
||||
password string,
|
||||
) (*User, error) {
|
||||
query := `INSERT INTO users (username) VALUES (?)`
|
||||
_, err := conn.Exec(query, username)
|
||||
_, err := tx.Exec(ctx, query, username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Exec")
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
user, err := GetUserFromUsername(conn, username)
|
||||
user, err := GetUserFromUsername(ctx, tx, username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "GetUserFromUsername")
|
||||
}
|
||||
err = user.SetPassword(conn, password)
|
||||
err = user.SetPassword(ctx, tx, password)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "user.SetPassword")
|
||||
}
|
||||
@@ -26,7 +32,12 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error
|
||||
}
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
|
||||
func fetchUserData(
|
||||
ctx context.Context,
|
||||
tx *SafeTX,
|
||||
column string,
|
||||
value interface{},
|
||||
) (*sql.Rows, error) {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT
|
||||
id,
|
||||
@@ -38,36 +49,36 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e
|
||||
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
|
||||
column,
|
||||
)
|
||||
rows, err := conn.Query(query, value)
|
||||
rows, err := tx.Query(ctx, query, value)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Query")
|
||||
return nil, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Scan the next row into the provided user pointer. Calls rows.Next() and
|
||||
// assumes only row in the result. Providing a rows object with more than 1
|
||||
// row may result in undefined behaviour.
|
||||
// Calls rows.Next() and scans the row into the provided user pointer.
|
||||
// Will error if no row available
|
||||
func scanUserRow(user *User, rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Password_hash,
|
||||
&user.Created_at,
|
||||
&user.Bio,
|
||||
)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rows.Scan")
|
||||
}
|
||||
if !rows.Next() {
|
||||
return errors.New("User not found")
|
||||
}
|
||||
err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Password_hash,
|
||||
&user.Created_at,
|
||||
&user.Bio,
|
||||
)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rows.Scan")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Queries the database for a user matching the given username.
|
||||
// Query is case insensitive
|
||||
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
||||
rows, err := fetchUserData(conn, "username", username)
|
||||
func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) {
|
||||
rows, err := fetchUserData(ctx, tx, "username", username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetchUserData")
|
||||
}
|
||||
@@ -81,8 +92,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
|
||||
}
|
||||
|
||||
// Queries the database for a user matching the given ID.
|
||||
func GetUserFromID(conn *sql.DB, id int) (*User, error) {
|
||||
rows, err := fetchUserData(conn, "id", id)
|
||||
func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
|
||||
rows, err := fetchUserData(ctx, tx, "id", id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetchUserData")
|
||||
}
|
||||
@@ -96,11 +107,11 @@ func GetUserFromID(conn *sql.DB, id int) (*User, error) {
|
||||
}
|
||||
|
||||
// Checks if the given username is unique. Returns true if not taken
|
||||
func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) {
|
||||
func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) {
|
||||
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
|
||||
rows, err := conn.Query(query, username)
|
||||
rows, err := tx.Query(ctx, query, username)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "conn.Query")
|
||||
return false, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
taken := rows.Next()
|
||||
|
||||
@@ -1,12 +1,58 @@
|
||||
projectreshoot.com {
|
||||
reverse_proxy localhost:3000 localhost:3001 localhost:3002 {
|
||||
health_uri /healthz
|
||||
fail_duration 30s
|
||||
}
|
||||
rate_limit {
|
||||
zone auth {
|
||||
match {
|
||||
method POST
|
||||
path /login /register
|
||||
}
|
||||
key {remote_host}
|
||||
events 4
|
||||
window 1m
|
||||
}
|
||||
zone client {
|
||||
key {remote_host}
|
||||
events 100
|
||||
window 1m
|
||||
}
|
||||
}
|
||||
reverse_proxy localhost:3000 localhost:3001 localhost:3002 {
|
||||
transport http {
|
||||
max_conns_per_host 10
|
||||
}
|
||||
health_uri /healthz
|
||||
fail_duration 30s
|
||||
}
|
||||
log {
|
||||
output file /var/log/caddy/access.log
|
||||
}
|
||||
}
|
||||
|
||||
staging.projectreshoot.com {
|
||||
reverse_proxy localhost:3005 localhost:3006 localhost:3007 {
|
||||
health_uri /healthz
|
||||
fail_duration 30s
|
||||
}
|
||||
rate_limit {
|
||||
zone auth {
|
||||
match {
|
||||
method POST
|
||||
path /login /register
|
||||
}
|
||||
key {remote_host}
|
||||
events 4
|
||||
window 1m
|
||||
}
|
||||
zone client {
|
||||
key {remote_host}
|
||||
events 100
|
||||
window 1m
|
||||
}
|
||||
}
|
||||
reverse_proxy localhost:3005 localhost:3006 localhost:3007 {
|
||||
transport http {
|
||||
max_conns_per_host 10
|
||||
}
|
||||
health_uri /healthz
|
||||
fail_duration 30s
|
||||
}
|
||||
log {
|
||||
output file /var/log/caddy/access-staging.log
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
112
deploy/db/backup.sh
Executable file
112
deploy/db/backup.sh
Executable file
@@ -0,0 +1,112 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Exit on error
|
||||
set -e
|
||||
|
||||
if [[ -z "$1" ]]; then
|
||||
echo "Usage: $0 <environment>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ENVR="$1"
|
||||
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||
echo "Error: environment must be 'production' or 'staging'."
|
||||
exit 1
|
||||
fi
|
||||
ACTIVE_DIR="/home/deploy/$ENVR"
|
||||
DATA_DIR="/home/deploy/data/$ENVR"
|
||||
BACKUP_DIR="/home/deploy/data/backups/$ENVR"
|
||||
if [[ "$ENVR" == "production" ]]; then
|
||||
SERVICE_NAME="projectreshoot"
|
||||
declare -a PORTS=("3000" "3001" "3002")
|
||||
else
|
||||
SERVICE_NAME="$ENVR.projectreshoot"
|
||||
declare -a PORTS=("3005" "3006" "3007")
|
||||
fi
|
||||
|
||||
# Send SIGUSR2 to release maintenance mode
|
||||
release_maintenance() {
|
||||
echo "Releasing maintenance mode..."
|
||||
for PORT in "${PORTS[@]}"; do
|
||||
sudo systemctl kill -s SIGUSR2 "$SERVICE_NAME@$PORT.service"
|
||||
done
|
||||
}
|
||||
|
||||
shopt -s nullglob
|
||||
DB_FILES=("$ACTIVE_DIR"/*.db)
|
||||
DB_COUNT=${#DB_FILES[@]}
|
||||
|
||||
if [[ $DB_COUNT -gt 1 ]]; then
|
||||
echo "Error: More than one .db file found in $ACTIVE_DIR. Manual intervention required."
|
||||
exit 1
|
||||
elif [[ $DB_COUNT -eq 0 ]]; then
|
||||
echo "Error: No .db file found in $ACTIVE_DIR."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract the filename without extension
|
||||
DB_FILE="${DB_FILES[0]}"
|
||||
DB_VER=$(basename "$DB_FILE" .db)
|
||||
|
||||
# Send SIGUSR1 to trigger maintenance mode only for active services
|
||||
declare -a ACTIVE_PORTS=()
|
||||
for PORT in "${PORTS[@]}"; do
|
||||
if systemctl is-active --quiet "$SERVICE_NAME@$PORT.service"; then
|
||||
sudo systemctl kill -s SIGUSR1 "$SERVICE_NAME@$PORT.service"
|
||||
ACTIVE_PORTS+=("$PORT")
|
||||
fi
|
||||
done
|
||||
trap release_maintenance EXIT
|
||||
|
||||
# Function to check logs for success or failure
|
||||
check_logs() {
|
||||
local port="$1"
|
||||
local service="$SERVICE_NAME@$port.service"
|
||||
|
||||
echo "Waiting for $service to enter maintenance mode..."
|
||||
|
||||
# Check the last few lines first in case the message already appeared
|
||||
if sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Global database lock acquired"; then
|
||||
echo "$service successfully entered maintenance mode."
|
||||
return 0
|
||||
elif sudo journalctl -u "$service" -n 20 --no-pager | grep -q "Timeout: Global database lock abandoned"; then
|
||||
echo "Error: $service failed to enter maintenance mode."
|
||||
return 1
|
||||
fi
|
||||
|
||||
# If not found, continuously watch logs until we get a success or failure message
|
||||
sudo journalctl -u "$service" -f --no-pager | while read -r line; do
|
||||
if echo "$line" | grep -q "Global database lock acquired"; then
|
||||
echo "$service successfully entered maintenance mode."
|
||||
pkill -P $$ journalctl # Kill journalctl process once we have success
|
||||
return 0
|
||||
elif echo "$line" | grep -q "Timeout: Global database lock abandoned"; then
|
||||
echo "Error: $service failed to enter maintenance mode."
|
||||
pkill -P $$ journalctl # Kill journalctl process on failure
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
# Check logs for each service
|
||||
for PORT in "${ACTIVE_PORTS[@]}"; do
|
||||
check_logs "$PORT"
|
||||
done
|
||||
|
||||
# Get current datetime in YYYY-MM-DD-HHMM format
|
||||
TIMESTAMP=$(date +"%Y-%m-%d-%H%M")
|
||||
|
||||
# Define source and destination paths
|
||||
SOURCE_DB="$DATA_DIR/$DB_VER.db"
|
||||
BACKUP_DB="$BACKUP_DIR/${DB_VER}-${TIMESTAMP}.db"
|
||||
|
||||
# Copy the database file
|
||||
if [[ -f "$SOURCE_DB" ]]; then
|
||||
cp "$SOURCE_DB" "$BACKUP_DB"
|
||||
echo "Backup created: $BACKUP_DB"
|
||||
else
|
||||
echo "Error: Source database file $SOURCE_DB not found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
77
deploy/db/migrate.sh
Executable file
77
deploy/db/migrate.sh
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [[ -z "$1" ]]; then
|
||||
echo "Usage: $0 <environment> <version> <commit-hash>"
|
||||
exit 1
|
||||
fi
|
||||
ENVR="$1"
|
||||
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||
echo "Error: environment must be 'production' or 'staging'."
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$2" ]]; then
|
||||
echo "Usage: $0 <environment> <version> <commit-hash>"
|
||||
exit 1
|
||||
fi
|
||||
TGT_VER="$2"
|
||||
re='^[0-9]+$'
|
||||
if ! [[ $TGT_VER =~ $re ]] ; then
|
||||
echo "Error: version not a number" >&2
|
||||
exit 1
|
||||
fi
|
||||
if [ -z "$3" ]; then
|
||||
echo "Usage: $0 <environment> <version> <commit-hash>"
|
||||
exit 1
|
||||
fi
|
||||
COMMIT_HASH=$3
|
||||
MIGRATION_BIN="/home/deploy/migration-bin"
|
||||
BACKUP_OUTPUT=$(/bin/bash ${MIGRATION_BIN}/backup.sh "$ENVR" 2>&1)
|
||||
echo "$BACKUP_OUTPUT"
|
||||
if [[ $? -ne 0 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
BACKUP_FILE=$(echo "$BACKUP_OUTPUT" | grep -oP '(?<=Backup created: ).*')
|
||||
if [[ -z "$BACKUP_FILE" ]]; then
|
||||
echo "Error: backup failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
FILE_NAME=${BACKUP_FILE##*/}
|
||||
CUR_VER=${FILE_NAME%%-*}
|
||||
if [[ $((+$TGT_VER)) == $((+$CUR_VER)) ]]; then
|
||||
echo "Version same, skipping migration"
|
||||
exit 0
|
||||
fi
|
||||
if [[ $((+$TGT_VER)) > $((+$CUR_VER)) ]]; then
|
||||
CMD="up-to"
|
||||
fi
|
||||
if [[ $((+$TGT_VER)) < $((+$CUR_VER)) ]]; then
|
||||
CMD="down-to"
|
||||
fi
|
||||
TIMESTAMP=$(date +"%Y-%m-%d-%H%M")
|
||||
|
||||
ACTIVE_DIR="/home/deploy/$ENVR"
|
||||
DATA_DIR="/home/deploy/data/$ENVR"
|
||||
BACKUP_DIR="/home/deploy/data/backups/$ENVR"
|
||||
UPDATED_BACKUP="$BACKUP_DIR/${TGT_VER}-${TIMESTAMP}.db"
|
||||
UPDATED_COPY="$DATA_DIR/${TGT_VER}.db"
|
||||
UPDATED_LINK="$ACTIVE_DIR/${TGT_VER}.db"
|
||||
|
||||
cp $BACKUP_FILE $UPDATED_BACKUP
|
||||
failed_cleanup() {
|
||||
rm $UPDATED_BACKUP
|
||||
}
|
||||
trap 'if [ $? -ne 0 ]; then failed_cleanup; fi' EXIT
|
||||
|
||||
echo "Migration in progress from $CUR_VER to $TGT_VER"
|
||||
${MIGRATION_BIN}/prmigrate-${ENVR}-${COMMIT_HASH} $UPDATED_BACKUP $CMD $TGT_VER
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Migration failed"
|
||||
exit 1
|
||||
fi
|
||||
echo "Migration completed"
|
||||
|
||||
cp $UPDATED_BACKUP $UPDATED_COPY
|
||||
ln -s $UPDATED_COPY $UPDATED_LINK
|
||||
echo "Upgraded database linked and ready for deploy"
|
||||
exit 0
|
||||
27
deploy/db/migrationcleanup.sh
Executable file
27
deploy/db/migrationcleanup.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Exit on error
|
||||
set -e
|
||||
|
||||
if [[ -z "$1" ]]; then
|
||||
echo "Usage: $0 <environment> <version>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ENVR="$1"
|
||||
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||
echo "Error: environment must be 'production' or 'staging'."
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$2" ]]; then
|
||||
echo "Usage: $0 <environment> <version>"
|
||||
exit 1
|
||||
fi
|
||||
TGT_VER="$2"
|
||||
re='^[0-9]+$'
|
||||
if ! [[ $TGT_VER =~ $re ]] ; then
|
||||
echo "Error: version not a number" >&2
|
||||
exit 1
|
||||
fi
|
||||
ACTIVE_DIR="/home/deploy/$ENVR"
|
||||
find "$ACTIVE_DIR" -type l -name "*.db" ! -name "${TGT_VER}.db" -exec rm -v {} +
|
||||
@@ -8,19 +8,37 @@ if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <commit-hash>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
COMMIT_HASH=$1
|
||||
RELEASES_DIR="/home/deploy/releases/production"
|
||||
DEPLOY_BIN="/home/deploy/production/projectreshoot"
|
||||
SERVICE_NAME="projectreshoot"
|
||||
BINARY_NAME="projectreshoot-production-${COMMIT_HASH}"
|
||||
ENVR="$2"
|
||||
if [[ "$ENVR" != "production" && "$ENVR" != "staging" ]]; then
|
||||
echo "Error: environment must be 'production' or 'staging'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RELEASES_DIR="/home/deploy/releases/$ENVR"
|
||||
DEPLOY_BIN="/home/deploy/$ENVR/projectreshoot"
|
||||
MIGRATION_BIN="/home/deploy/migration-bin"
|
||||
BINARY_NAME="projectreshoot-$ENVR-${COMMIT_HASH}"
|
||||
declare -a PORTS=("3000" "3001" "3002")
|
||||
if [[ "$ENVR" == "production" ]]; then
|
||||
SERVICE_NAME="projectreshoot"
|
||||
declare -a PORTS=("3000" "3001" "3002")
|
||||
else
|
||||
SERVICE_NAME="$ENVR.projectreshoot"
|
||||
declare -a PORTS=("3005" "3006" "3007")
|
||||
fi
|
||||
|
||||
# Check if the binary exists
|
||||
if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then
|
||||
echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}"
|
||||
exit 1
|
||||
fi
|
||||
DB_VER=$(${RELEASES_DIR}/${BINARY_NAME} --dbver | grep -oP '(?<=Database version: ).*')
|
||||
${MIGRATION_BIN}/migrate.sh $ENVR $DB_VER $COMMIT_HASH
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "Migration failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Keep a reference to the previous binary from the symlink
|
||||
if [ -L "${DEPLOY_BIN}" ]; then
|
||||
@@ -92,3 +110,4 @@ for port in "${PORTS[@]}"; do
|
||||
done
|
||||
|
||||
echo "Deployment completed successfully."
|
||||
${MIGRATION_BIN}/migrationcleanup.sh $ENVR $DB_VER
|
||||
@@ -1,94 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Exit on error
|
||||
set -e
|
||||
|
||||
# Check if commit hash is passed as an argument
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <commit-hash>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
COMMIT_HASH=$1
|
||||
RELEASES_DIR="/home/deploy/releases/staging"
|
||||
DEPLOY_BIN="/home/deploy/staging/projectreshoot"
|
||||
SERVICE_NAME="staging.projectreshoot"
|
||||
BINARY_NAME="projectreshoot-staging-${COMMIT_HASH}"
|
||||
declare -a PORTS=("3005" "3006" "3007")
|
||||
|
||||
# Check if the binary exists
|
||||
if [ ! -f "${RELEASES_DIR}/${BINARY_NAME}" ]; then
|
||||
echo "Binary ${BINARY_NAME} not found in ${RELEASES_DIR}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Keep a reference to the previous binary from the symlink
|
||||
if [ -L "${DEPLOY_BIN}" ]; then
|
||||
PREVIOUS=$(readlink -f $DEPLOY_BIN)
|
||||
echo "Current binary is ${PREVIOUS}, saved for rollback."
|
||||
else
|
||||
echo "No symbolic link found, no previous binary to backup."
|
||||
PREVIOUS=""
|
||||
fi
|
||||
|
||||
rollback_deployment() {
|
||||
if [ -n "$PREVIOUS" ]; then
|
||||
echo "Rolling back to previous binary: ${PREVIOUS}"
|
||||
ln -sfn "${PREVIOUS}" "${DEPLOY_BIN}"
|
||||
else
|
||||
echo "No previous binary to roll back to."
|
||||
fi
|
||||
|
||||
# wait to restart the services
|
||||
sleep 10
|
||||
|
||||
# Restart all services with the previous binary
|
||||
for port in "${PORTS[@]}"; do
|
||||
SERVICE="${SERVICE_NAME}@${port}.service"
|
||||
echo "Restarting $SERVICE..."
|
||||
sudo systemctl restart $SERVICE
|
||||
done
|
||||
|
||||
echo "Rollback completed."
|
||||
}
|
||||
|
||||
# Copy the binary to the deployment directory
|
||||
echo "Promoting ${BINARY_NAME} to ${DEPLOY_BIN}..."
|
||||
ln -sf "${RELEASES_DIR}/${BINARY_NAME}" "${DEPLOY_BIN}"
|
||||
|
||||
WAIT_TIME=5
|
||||
restart_service() {
|
||||
local port=$1
|
||||
local SERVICE="${SERVICE_NAME}@${port}.service"
|
||||
echo "Restarting ${SERVICE}..."
|
||||
|
||||
# Restart the service
|
||||
if ! sudo systemctl restart "$SERVICE"; then
|
||||
echo "Error: Failed to restart ${SERVICE}. Rolling back deployment."
|
||||
|
||||
# Call the rollback function
|
||||
rollback_deployment
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Wait a few seconds to allow the service to fully start
|
||||
echo "Waiting for ${SERVICE} to fully start..."
|
||||
sleep $WAIT_TIME
|
||||
|
||||
# Check the status of the service
|
||||
if ! systemctl is-active --quiet "${SERVICE}"; then
|
||||
echo "Error: ${SERVICE} failed to start correctly. Rolling back deployment."
|
||||
|
||||
# Call the rollback function
|
||||
rollback_deployment
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "${SERVICE}.service restarted successfully."
|
||||
}
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
restart_service $port
|
||||
done
|
||||
|
||||
echo "Deployment completed successfully."
|
||||
17
go.mod
17
go.mod
@@ -1,24 +1,37 @@
|
||||
module projectreshoot
|
||||
|
||||
go 1.23.5
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/a-h/templ v0.3.833
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pressly/goose/v3 v3.24.1
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/crypto v0.33.0
|
||||
modernc.org/sqlite v1.35.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.61.13 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.8.2 // indirect
|
||||
)
|
||||
|
||||
58
go.sum
58
go.sum
@@ -3,15 +3,24 @@ github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YY
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
@@ -19,25 +28,68 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
|
||||
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pressly/goose/v3 v3.24.1 h1:bZmxRco2uy5uu5Ng1MMVEfYsFlrMJI+e/VMXHQ3C4LY=
|
||||
github.com/pressly/goose/v3 v3.24.1/go.mod h1:rEWreU9uVtt0DHCyLzF9gRcWiiTF/V+528DV+4DORug=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
||||
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
|
||||
modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo=
|
||||
modernc.org/ccgo/v4 v4.23.16/go.mod h1:nNma8goMTY7aQZQNTyN9AIoJfxav4nvTnvKThAeMDdo=
|
||||
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
|
||||
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
|
||||
modernc.org/gc/v2 v2.6.3 h1:aJVhcqAte49LF+mGveZ5KPlsp4tdGdAOT4sipJXADjw=
|
||||
modernc.org/gc/v2 v2.6.3/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/libc v1.61.13 h1:3LRd6ZO1ezsFiX1y+bHd1ipyEHIJKvuprv0sLTBwLW8=
|
||||
modernc.org/libc v1.61.13/go.mod h1:8F/uJWL/3nNil0Lgt1Dpz+GgkApWh04N3el3hxJcA6E=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.8.2 h1:cL9L4bcoAObu4NkxOlKWBWtNHIsnnACGF/TbqQ6sbcI=
|
||||
modernc.org/memory v1.8.2/go.mod h1:ZbjSvMO5NQ1A2i3bWeDiVMxIorXwdClKE/0SZ+BMotU=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.35.0 h1:yQps4fegMnZFdphtzlfQTCNBWtS0CZv48pRpW3RFHRw=
|
||||
modernc.org/sqlite v1.35.0/go.mod h1:9cr2sicr7jIaWTBKQmAxQLfBv9LL0su4ZTEV+utt3ic=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package handlers
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/cookies"
|
||||
@@ -15,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
// Renders the account page on the 'General' subpage
|
||||
func HandleAccountPage() http.Handler {
|
||||
func AccountPage() http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("subpage")
|
||||
@@ -29,7 +30,7 @@ func HandleAccountPage() http.Handler {
|
||||
}
|
||||
|
||||
// Handles a request to change the subpage for the Accou/accountnt page
|
||||
func HandleAccountSubpage() http.Handler {
|
||||
func AccountSubpage() http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
@@ -41,65 +42,95 @@ func HandleAccountSubpage() http.Handler {
|
||||
}
|
||||
|
||||
// Handles a request to change the users username
|
||||
func HandleChangeUsername(
|
||||
func ChangeUsername(
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Error updating username")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r.ParseForm()
|
||||
newUsername := r.FormValue("username")
|
||||
|
||||
unique, err := db.CheckUsernameUnique(conn, newUsername)
|
||||
unique, err := db.CheckUsernameUnique(ctx, tx, newUsername)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error updating username")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !unique {
|
||||
tx.Rollback()
|
||||
account.ChangeUsername("Username is taken", newUsername).
|
||||
Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
user := contexts.GetUser(r.Context())
|
||||
err = user.ChangeUsername(conn, newUsername)
|
||||
err = user.ChangeUsername(ctx, tx, newUsername)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error updating username")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.Header().Set("HX-Refresh", "true")
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Handles a request to change the users bio
|
||||
func HandleChangeBio(
|
||||
func ChangeBio(
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Error updating bio")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r.ParseForm()
|
||||
newBio := r.FormValue("bio")
|
||||
leng := len([]rune(newBio))
|
||||
if leng > 128 {
|
||||
tx.Rollback()
|
||||
account.ChangeBio("Bio limited to 128 characters", newBio).
|
||||
Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
user := contexts.GetUser(r.Context())
|
||||
err := user.ChangeBio(conn, newBio)
|
||||
err = user.ChangeBio(ctx, tx, newBio)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error updating bio")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.Header().Set("HX-Refresh", "true")
|
||||
},
|
||||
)
|
||||
}
|
||||
func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) {
|
||||
func validateChangePassword(
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) (string, error) {
|
||||
r.ParseForm()
|
||||
formPassword := r.FormValue("password")
|
||||
formConfirmPassword := r.FormValue("confirm-password")
|
||||
@@ -113,24 +144,37 @@ func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) {
|
||||
}
|
||||
|
||||
// Handles a request to change the users password
|
||||
func HandleChangePassword(
|
||||
func ChangePassword(
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
newPass, err := validateChangePassword(conn, r)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Error updating password")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
newPass, err := validateChangePassword(ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
account.ChangePassword(err.Error()).Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
user := contexts.GetUser(r.Context())
|
||||
err = user.SetPassword(conn, newPass)
|
||||
err = user.SetPassword(ctx, tx, newPass)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error updating password")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.Header().Set("HX-Refresh", "true")
|
||||
},
|
||||
)
|
||||
24
handler/errorpage.go
Normal file
24
handler/errorpage.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"projectreshoot/view/page"
|
||||
)
|
||||
|
||||
func ErrorPage(
|
||||
errorCode int,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) {
|
||||
message := map[int]string{
|
||||
401: "You need to login to view this page.",
|
||||
403: "You do not have permission to view this page.",
|
||||
404: "The page or resource you have requested does not exist.",
|
||||
500: `An error occured on the server. Please try again, and if this
|
||||
continues to happen contact an administrator.`,
|
||||
503: "The server is currently down for maintenance and should be back soon. =)",
|
||||
}
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
page.Error(errorCode, http.StatusText(errorCode), message[errorCode]).
|
||||
Render(r.Context(), w)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package handlers
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -8,15 +8,11 @@ import (
|
||||
|
||||
// Handles responses to the / path. Also serves a 404 Page for paths that
|
||||
// don't have explicit handlers
|
||||
func HandleRoot() http.Handler {
|
||||
func Root() http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
page.Error(
|
||||
"404",
|
||||
"Page not found",
|
||||
"The page or resource you have requested does not exist",
|
||||
).Render(r.Context(), w)
|
||||
ErrorPage(http.StatusNotFound, w, r)
|
||||
return
|
||||
}
|
||||
page.Index().Render(r.Context(), w)
|
||||
@@ -1,8 +1,9 @@
|
||||
package handlers
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/cookies"
|
||||
@@ -16,10 +17,14 @@ import (
|
||||
|
||||
// Validates the username matches a user in the database and the password
|
||||
// is correct. Returns the corresponding user
|
||||
func validateLogin(conn *sql.DB, r *http.Request) (*db.User, error) {
|
||||
func validateLogin(
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) (*db.User, error) {
|
||||
formUsername := r.FormValue("username")
|
||||
formPassword := r.FormValue("password")
|
||||
user, err := db.GetUserFromUsername(conn, formUsername)
|
||||
user, err := db.GetUserFromUsername(ctx, tx, formUsername)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromUsername")
|
||||
}
|
||||
@@ -44,16 +49,27 @@ func checkRememberMe(r *http.Request) bool {
|
||||
// Handles an attempted login request. On success will return a HTMX redirect
|
||||
// and on fail will return the login form again, passing the error to the
|
||||
// template for user feedback
|
||||
func HandleLoginRequest(
|
||||
func LoginRequest(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
user, err := validateLogin(conn, r)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to set token cookies")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r.ParseForm()
|
||||
user, err := validateLogin(ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
if err.Error() != "Username or password incorrect" {
|
||||
logger.Warn().Caller().Err(err).Msg("Login request failed")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
@@ -66,10 +82,13 @@ func HandleLoginRequest(
|
||||
rememberMe := checkRememberMe(r)
|
||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
||||
return
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
pageFrom := cookies.CheckPageFrom(w, r)
|
||||
w.Header().Set("HX-Redirect", pageFrom)
|
||||
},
|
||||
@@ -78,7 +97,7 @@ func HandleLoginRequest(
|
||||
|
||||
// Handles a request to view the login page. Will attempt to set "pagefrom"
|
||||
// cookie so a successful login can redirect the user to the page they came
|
||||
func HandleLoginPage(trustedHost string) http.Handler {
|
||||
func LoginPage(trustedHost string) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
cookies.SetPageFrom(w, r, trustedHost)
|
||||
113
handler/logout.go
Normal file
113
handler/logout.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/cookies"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/jwt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func revokeAccess(
|
||||
config *config.Config,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
atStr string,
|
||||
) error {
|
||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "Token is expired") ||
|
||||
strings.Contains(err.Error(), "Token has been revoked") {
|
||||
return nil // Token is expired, dont need to revoke it
|
||||
}
|
||||
return errors.Wrap(err, "jwt.ParseAccessToken")
|
||||
}
|
||||
err = jwt.RevokeToken(ctx, tx, aT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func revokeRefresh(
|
||||
config *config.Config,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
rtStr string,
|
||||
) error {
|
||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "Token is expired") ||
|
||||
strings.Contains(err.Error(), "Token has been revoked") {
|
||||
return nil // Token is expired, dont need to revoke it
|
||||
}
|
||||
return errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||
}
|
||||
err = jwt.RevokeToken(ctx, tx, rT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retrieve and revoke the user's tokens
|
||||
func revokeTokens(
|
||||
config *config.Config,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) error {
|
||||
// get the tokens from the cookies
|
||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||
// revoke the refresh token first as the access token expires quicker
|
||||
// only matters if there is an error revoking the tokens
|
||||
err := revokeRefresh(config, ctx, tx, rtStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "revokeRefresh")
|
||||
}
|
||||
err = revokeAccess(config, ctx, tx, atStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "revokeAccess")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle a logout request
|
||||
func Logout(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Error occured on user logout")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
err = revokeTokens(config, ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Error occured on user logout")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
cookies.DeleteCookie(w, "access", "/")
|
||||
cookies.DeleteCookie(w, "refresh", "/")
|
||||
w.Header().Set("HX-Redirect", "/login")
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package handlers
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -1,11 +1,11 @@
|
||||
package handlers
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"projectreshoot/view/page"
|
||||
)
|
||||
|
||||
func HandleProfilePage() http.Handler {
|
||||
func ProfilePage() http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
page.Profile().Render(r.Context(), w)
|
||||
@@ -1,12 +1,14 @@
|
||||
package handlers
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/cookies"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/jwt"
|
||||
"projectreshoot/view/component/form"
|
||||
|
||||
@@ -17,16 +19,17 @@ import (
|
||||
// Get the tokens from the request
|
||||
func getTokens(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
|
||||
// get the existing tokens from the cookies
|
||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken")
|
||||
}
|
||||
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||
}
|
||||
@@ -35,15 +38,16 @@ func getTokens(
|
||||
|
||||
// Revoke the given token pair
|
||||
func revokeTokenPair(
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
aT *jwt.AccessToken,
|
||||
rT *jwt.RefreshToken,
|
||||
) error {
|
||||
err := jwt.RevokeToken(conn, aT)
|
||||
err := jwt.RevokeToken(ctx, tx, aT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
err = jwt.RevokeToken(conn, rT)
|
||||
err = jwt.RevokeToken(ctx, tx, rT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
@@ -53,11 +57,12 @@ func revokeTokenPair(
|
||||
// Issue new tokens for the user, invalidating the old ones
|
||||
func refreshTokens(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) error {
|
||||
aT, rT, err := getTokens(config, conn, r)
|
||||
aT, rT, err := getTokens(config, ctx, tx, r)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "getTokens")
|
||||
}
|
||||
@@ -71,7 +76,7 @@ func refreshTokens(
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cookies.SetTokenCookies")
|
||||
}
|
||||
err = revokeTokenPair(conn, aT, rT)
|
||||
err = revokeTokenPair(ctx, tx, aT, rT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "revokeTokenPair")
|
||||
}
|
||||
@@ -94,25 +99,38 @@ func validatePassword(
|
||||
}
|
||||
|
||||
// Handle request to reauthenticate (i.e. make token fresh again)
|
||||
func HandleReauthenticate(
|
||||
func Reauthenticate(
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
err := validatePassword(r)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to refresh user tokens")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
err = validatePassword(r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
w.WriteHeader(445)
|
||||
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
|
||||
return
|
||||
}
|
||||
err = refreshTokens(config, conn, w, r)
|
||||
err = refreshTokens(config, ctx, tx, w, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logger.Error().Err(err).Msg("Failed to refresh user tokens")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
)
|
||||
@@ -1,8 +1,9 @@
|
||||
package handlers
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/cookies"
|
||||
@@ -14,11 +15,15 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
|
||||
func validateRegistration(
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
r *http.Request,
|
||||
) (*db.User, error) {
|
||||
formUsername := r.FormValue("username")
|
||||
formPassword := r.FormValue("password")
|
||||
formConfirmPassword := r.FormValue("confirm-password")
|
||||
unique, err := db.CheckUsernameUnique(conn, formUsername)
|
||||
unique, err := db.CheckUsernameUnique(ctx, tx, formUsername)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.CheckUsernameUnique")
|
||||
}
|
||||
@@ -31,7 +36,7 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
|
||||
if len(formPassword) > 72 {
|
||||
return nil, errors.New("Password exceeds maximum length of 72 bytes")
|
||||
}
|
||||
user, err := db.CreateNewUser(conn, formUsername, formPassword)
|
||||
user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.CreateNewUser")
|
||||
}
|
||||
@@ -39,16 +44,27 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func HandleRegisterRequest(
|
||||
func RegisterRequest(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
user, err := validateRegistration(conn, r)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to set token cookies")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r.ParseForm()
|
||||
user, err := validateRegistration(ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
if err.Error() != "Username is taken" &&
|
||||
err.Error() != "Passwords do not match" &&
|
||||
err.Error() != "Password exceeds maximum length of 72 bytes" {
|
||||
@@ -63,10 +79,12 @@ func HandleRegisterRequest(
|
||||
rememberMe := checkRememberMe(r)
|
||||
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
||||
return
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
pageFrom := cookies.CheckPageFrom(w, r)
|
||||
w.Header().Set("HX-Redirect", pageFrom)
|
||||
},
|
||||
@@ -75,7 +93,7 @@ func HandleRegisterRequest(
|
||||
|
||||
// Handles a request to view the login page. Will attempt to set "pagefrom"
|
||||
// cookie so a successful login can redirect the user to the page they came
|
||||
func HandleRegisterPage(trustedHost string) http.Handler {
|
||||
func RegisterPage(trustedHost string) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
cookies.SetPageFrom(w, r, trustedHost)
|
||||
@@ -1,4 +1,4 @@
|
||||
package handlers
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -42,7 +42,7 @@ func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) {
|
||||
|
||||
// Handles requests for static files, without allowing access to the
|
||||
// directory viewer and returning 404 if an exact file is not found
|
||||
func HandleStatic(staticFS *http.FileSystem) http.Handler {
|
||||
func StaticFS(staticFS *http.FileSystem) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
nfs := justFilesFilesystem{*staticFS}
|
||||
37
handler/withtransaction.go
Normal file
37
handler/withtransaction.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func removeme(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
logger *zerolog.Logger,
|
||||
conn *db.SafeConn,
|
||||
handler func(
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
),
|
||||
onfail func(err error),
|
||||
) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
onfail(err)
|
||||
return
|
||||
}
|
||||
|
||||
handler(ctx, tx, w, r)
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/cookies"
|
||||
"projectreshoot/jwt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Retrieve and revoke the user's tokens
|
||||
func revokeTokens(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
r *http.Request,
|
||||
) error {
|
||||
// get the tokens from the cookies
|
||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.ParseAccessToken")
|
||||
}
|
||||
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||
}
|
||||
// revoke the refresh token first as the access token expires quicker
|
||||
// only matters if there is an error revoking the tokens
|
||||
err = jwt.RevokeToken(conn, rT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
err = jwt.RevokeToken(conn, aT)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle a logout request
|
||||
func HandleLogout(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
err := revokeTokens(config, conn, r)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Error occured on user logout")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
cookies.DeleteCookie(w, "access", "/")
|
||||
cookies.DeleteCookie(w, "refresh", "/")
|
||||
w.Header().Set("HX-Redirect", "/login")
|
||||
},
|
||||
)
|
||||
}
|
||||
13
jwt/parse.go
13
jwt/parse.go
@@ -1,11 +1,12 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
@@ -17,7 +18,8 @@ import (
|
||||
// has the correct scope.
|
||||
func ParseAccessToken(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
tokenString string,
|
||||
) (*AccessToken, error) {
|
||||
if tokenString == "" {
|
||||
@@ -74,7 +76,7 @@ func ParseAccessToken(
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
valid, err := CheckTokenNotRevoked(conn, token)
|
||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||
}
|
||||
@@ -89,7 +91,8 @@ func ParseAccessToken(
|
||||
// has the correct scope.
|
||||
func ParseRefreshToken(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
tokenString string,
|
||||
) (*RefreshToken, error) {
|
||||
if tokenString == "" {
|
||||
@@ -141,7 +144,7 @@ func ParseRefreshToken(
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
valid, err := CheckTokenNotRevoked(conn, token)
|
||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||
}
|
||||
|
||||
@@ -1,32 +1,33 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Revoke a token by adding it to the database
|
||||
func RevokeToken(conn *sql.DB, t Token) error {
|
||||
func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error {
|
||||
jti := t.GetJTI()
|
||||
exp := t.GetEXP()
|
||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||
_, err := conn.Exec(query, jti, exp)
|
||||
_, err := tx.Exec(ctx, query, jti, exp)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "conn.Exec")
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if a token has been revoked. Returns true if not revoked.
|
||||
func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
|
||||
func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) {
|
||||
jti := t.GetJTI()
|
||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||
rows, err := conn.Query(query, jti)
|
||||
defer rows.Close()
|
||||
rows, err := tx.Query(ctx, query, jti)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "conn.Exec")
|
||||
return false, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
revoked := rows.Next()
|
||||
return !revoked, nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -12,7 +12,7 @@ type Token interface {
|
||||
GetJTI() uuid.UUID
|
||||
GetEXP() int64
|
||||
GetScope() string
|
||||
GetUser(conn *sql.DB) (*db.User, error)
|
||||
GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error)
|
||||
}
|
||||
|
||||
// Access token
|
||||
@@ -38,15 +38,15 @@ type RefreshToken struct {
|
||||
Scope string // Should be "refresh"
|
||||
}
|
||||
|
||||
func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) {
|
||||
user, err := db.GetUserFromID(conn, a.SUB)
|
||||
func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
||||
user, err := db.GetUserFromID(ctx, tx, a.SUB)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) {
|
||||
user, err := db.GetUserFromID(conn, r.SUB)
|
||||
func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
|
||||
user, err := db.GetUserFromID(ctx, tx, r.SUB)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||
}
|
||||
|
||||
105
main.go
105
main.go
@@ -13,28 +13,32 @@ import (
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/logging"
|
||||
"projectreshoot/server"
|
||||
"projectreshoot/tests"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
var embeddedStatic embed.FS
|
||||
|
||||
// Gets the static files
|
||||
func getStaticFiles() (http.FileSystem, error) {
|
||||
func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) {
|
||||
if _, err := os.Stat("static"); err == nil {
|
||||
// Use actual filesystem in development
|
||||
fmt.Println("Using filesystem for static files")
|
||||
logger.Debug().Msg("Using filesystem for static files")
|
||||
return http.Dir("static"), nil
|
||||
} else {
|
||||
// Use embedded filesystem in production
|
||||
fmt.Println("Using embedded static files")
|
||||
logger.Debug().Msg("Using embedded static files")
|
||||
subFS, err := fs.Sub(embeddedStatic, "static")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fs.Sub")
|
||||
@@ -43,6 +47,44 @@ func getStaticFiles() (http.FileSystem, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var maint uint32 // atomic: 1 if in maintenance mode
|
||||
|
||||
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
|
||||
func handleMaintSignals(
|
||||
conn *db.SafeConn,
|
||||
srv *http.Server,
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
) {
|
||||
logger.Debug().Msg("Starting signal listener")
|
||||
ch := make(chan os.Signal, 1)
|
||||
srv.RegisterOnShutdown(func() {
|
||||
logger.Debug().Msg("Shutting down signal listener")
|
||||
close(ch)
|
||||
})
|
||||
go func() {
|
||||
for sig := range ch {
|
||||
switch sig {
|
||||
case syscall.SIGUSR1:
|
||||
if atomic.LoadUint32(&maint) != 1 {
|
||||
atomic.StoreUint32(&maint, 1)
|
||||
logger.Info().Msg("Signal received: Starting maintenance")
|
||||
logger.Info().Msg("Attempting to acquire database lock")
|
||||
conn.Pause(config.DBLockTimeout * time.Second)
|
||||
}
|
||||
case syscall.SIGUSR2:
|
||||
if atomic.LoadUint32(&maint) != 0 {
|
||||
logger.Info().Msg("Signal received: Maintenance over")
|
||||
logger.Info().Msg("Releasing database lock")
|
||||
conn.Resume()
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
|
||||
}
|
||||
|
||||
// Initializes and runs the server
|
||||
func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||
@@ -53,6 +95,12 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
return errors.Wrap(err, "server.GetConfig")
|
||||
}
|
||||
|
||||
// Return the version of the database required
|
||||
if args["dbver"] == "true" {
|
||||
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
|
||||
return nil
|
||||
}
|
||||
|
||||
var logfile *os.File = nil
|
||||
if config.LogOutput == "both" || config.LogOutput == "file" {
|
||||
logfile, err = logging.GetLogFile(config.LogDir)
|
||||
@@ -77,18 +125,36 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
return errors.Wrap(err, "logging.GetLogger")
|
||||
}
|
||||
|
||||
conn, err := db.ConnectToDatabase(config.DBName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
||||
logger.Debug().Msg("Config loaded and logger started")
|
||||
logger.Debug().Msg("Connecting to database")
|
||||
var conn *db.SafeConn
|
||||
if args["test"] == "true" {
|
||||
logger.Debug().Msg("Server in test mode, using test database")
|
||||
ver, err := strconv.ParseInt(config.DBName, 10, 0)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "strconv.ParseInt")
|
||||
}
|
||||
testconn, err := tests.SetupTestDB(ver)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tests.SetupTestDB")
|
||||
}
|
||||
conn = db.MakeSafe(testconn, logger)
|
||||
} else {
|
||||
conn, err = db.ConnectToDatabase(config.DBName, logger)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.ConnectToDatabase")
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
staticFS, err := getStaticFiles()
|
||||
logger.Debug().Msg("Getting static files")
|
||||
staticFS, err := getStaticFiles(logger)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "getStaticFiles")
|
||||
}
|
||||
|
||||
srv := server.NewServer(config, logger, conn, &staticFS)
|
||||
logger.Debug().Msg("Setting up HTTP server")
|
||||
srv := server.NewServer(config, logger, conn, &staticFS, &maint)
|
||||
httpServer := &http.Server{
|
||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
||||
Handler: srv,
|
||||
@@ -98,18 +164,25 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
}
|
||||
|
||||
// Runs function for testing in dev if --test flag true
|
||||
if args["test"] == "true" {
|
||||
if args["tester"] == "true" {
|
||||
logger.Debug().Msg("Running tester function")
|
||||
test(config, logger, conn, httpServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Setups a channel to listen for os.Signal
|
||||
handleMaintSignals(conn, httpServer, logger, config)
|
||||
|
||||
// Runs the http server
|
||||
logger.Debug().Msg("Starting up the HTTP server")
|
||||
go func() {
|
||||
fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr)
|
||||
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
|
||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
fmt.Fprintf(os.Stderr, "Error listening and serving: %s\n", err)
|
||||
logger.Error().Err(err).Msg("Error listening and serving")
|
||||
}
|
||||
}()
|
||||
|
||||
// Handles graceful shutdown
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -119,11 +192,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
|
||||
defer cancel()
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error shutting down http server: %s\n", err)
|
||||
logger.Error().Err(err).Msg("Error shutting down server")
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
fmt.Fprintln(w, "Shutting down")
|
||||
logger.Info().Msg("Shutting down")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -133,7 +206,9 @@ func main() {
|
||||
// Parse commandline args
|
||||
host := flag.String("host", "", "Override host to listen on")
|
||||
port := flag.String("port", "", "Override port to listen on")
|
||||
test := flag.Bool("test", false, "Run test function instead of main program")
|
||||
test := flag.Bool("test", false, "Run server in test mode")
|
||||
tester := flag.Bool("tester", false, "Run tester function instead of main program")
|
||||
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
||||
loglevel := flag.String("loglevel", "", "Set log level")
|
||||
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
||||
flag.Parse()
|
||||
@@ -143,6 +218,8 @@ func main() {
|
||||
"host": *host,
|
||||
"port": *port,
|
||||
"test": strconv.FormatBool(*test),
|
||||
"tester": strconv.FormatBool(*tester),
|
||||
"dbver": strconv.FormatBool(*dbver),
|
||||
"loglevel": *loglevel,
|
||||
"logoutput": *logoutput,
|
||||
}
|
||||
|
||||
89
main_test.go
89
main_test.go
@@ -1,26 +1,102 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_main(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
t.Cleanup(cancel)
|
||||
args := map[string]string{}
|
||||
go run(ctx, os.Stdout, args)
|
||||
args := map[string]string{"test": "true"}
|
||||
var stdout bytes.Buffer
|
||||
os.Setenv("SECRET_KEY", ".")
|
||||
os.Setenv("HOST", "127.0.0.1")
|
||||
os.Setenv("PORT", "3232")
|
||||
runSrvErr := make(chan error)
|
||||
go func() {
|
||||
if err := run(ctx, &stdout, args); err != nil {
|
||||
runSrvErr <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for the server to become available
|
||||
waitForReady(ctx, 10*time.Second, "http://localhost:3333/healthz")
|
||||
go func() {
|
||||
err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz")
|
||||
if err != nil {
|
||||
runSrvErr <- err
|
||||
return
|
||||
}
|
||||
runSrvErr <- nil
|
||||
}()
|
||||
select {
|
||||
case err := <-runSrvErr:
|
||||
if err != nil {
|
||||
t.Fatalf("Error starting test server: %s", err)
|
||||
return
|
||||
}
|
||||
t.Log("Test server started")
|
||||
}
|
||||
|
||||
// do tests
|
||||
fmt.Println("Tests starting")
|
||||
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
expected := "Global database lock acquired"
|
||||
for {
|
||||
if strings.Contains(stdout.String(), expected) {
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
proc, err := os.FindProcess(os.Getpid())
|
||||
require.NoError(t, err)
|
||||
proc.Signal(syscall.SIGUSR1)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("found")
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
t.Errorf("Not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SIGUSR2 releases database global lock", func(t *testing.T) {
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
expected := "Global database lock released"
|
||||
for {
|
||||
if strings.Contains(stdout.String(), expected) {
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
proc, err := os.FindProcess(os.Getpid())
|
||||
require.NoError(t, err)
|
||||
proc.Signal(syscall.SIGUSR2)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("found")
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
t.Errorf("Not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func waitForReady(
|
||||
@@ -44,6 +120,7 @@ func waitForReady(
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("Error making request: %s\n", err.Error())
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/cookies"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/handler"
|
||||
"projectreshoot/jwt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -18,14 +20,15 @@ import (
|
||||
// Attempt to use a valid refresh token to generate a new token pair
|
||||
func refreshAuthTokens(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
w http.ResponseWriter,
|
||||
req *http.Request,
|
||||
ref *jwt.RefreshToken,
|
||||
) (*db.User, error) {
|
||||
user, err := ref.GetUser(conn)
|
||||
user, err := ref.GetUser(ctx, tx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "rT.GetUser")
|
||||
return nil, errors.Wrap(err, "ref.GetUser")
|
||||
}
|
||||
|
||||
rememberMe := map[string]bool{
|
||||
@@ -39,7 +42,7 @@ func refreshAuthTokens(
|
||||
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
|
||||
}
|
||||
// New tokens sent, revoke the used refresh token
|
||||
err = jwt.RevokeToken(conn, ref)
|
||||
err = jwt.RevokeToken(ctx, tx, ref)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "jwt.RevokeToken")
|
||||
}
|
||||
@@ -50,22 +53,23 @@ func refreshAuthTokens(
|
||||
// Check the cookies for token strings and attempt to authenticate them
|
||||
func getAuthenticatedUser(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
ctx context.Context,
|
||||
tx *db.SafeTX,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) (*contexts.AuthenticatedUser, error) {
|
||||
// Get token strings from cookies
|
||||
atStr, rtStr := cookies.GetTokenStrings(r)
|
||||
// Attempt to parse the access token
|
||||
aT, err := jwt.ParseAccessToken(config, conn, atStr)
|
||||
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
|
||||
if err != nil {
|
||||
// Access token invalid, attempt to parse refresh token
|
||||
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
|
||||
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
|
||||
}
|
||||
// Refresh token valid, attempt to get a new token pair
|
||||
user, err := refreshAuthTokens(config, conn, w, r, rT)
|
||||
user, err := refreshAuthTokens(config, ctx, tx, w, r, rT)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "refreshAuthTokens")
|
||||
}
|
||||
@@ -77,9 +81,9 @@ func getAuthenticatedUser(
|
||||
return &authUser, nil
|
||||
}
|
||||
// Access token valid
|
||||
user, err := aT.GetUser(conn)
|
||||
user, err := aT.GetUser(ctx, tx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "rT.GetUser")
|
||||
return nil, errors.Wrap(err, "aT.GetUser")
|
||||
}
|
||||
authUser := contexts.AuthenticatedUser{
|
||||
User: user,
|
||||
@@ -93,12 +97,34 @@ func getAuthenticatedUser(
|
||||
func Authentication(
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
next http.Handler,
|
||||
maint *uint32,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := getAuthenticatedUser(config, conn, w, r)
|
||||
if r.URL.Path == "/static/css/output.css" ||
|
||||
r.URL.Path == "/static/favicon.ico" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
if atomic.LoadUint32(maint) == 1 {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
// Failed to start transaction, skip auth
|
||||
logger.Warn().Err(err).
|
||||
Msg("Skipping Auth - unable to start a transaction")
|
||||
handler.ErrorPage(http.StatusServiceUnavailable, w, r)
|
||||
return
|
||||
}
|
||||
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
// User auth failed, delete the cookies to avoid repeat requests
|
||||
cookies.DeleteCookie(w, "access", "/")
|
||||
cookies.DeleteCookie(w, "refresh", "/")
|
||||
@@ -106,9 +132,12 @@ func Authentication(
|
||||
Str("remote_addr", r.RemoteAddr).
|
||||
Err(err).
|
||||
Msg("Failed to authenticate user")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
ctx := contexts.SetUser(r.Context(), user)
|
||||
newReq := r.WithContext(ctx)
|
||||
tx.Commit()
|
||||
uctx := contexts.SetUser(r.Context(), user)
|
||||
newReq := r.WithContext(uctx)
|
||||
next.ServeHTTP(w, newReq)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -15,14 +17,15 @@ import (
|
||||
)
|
||||
|
||||
func TestAuthenticationMiddleware(t *testing.T) {
|
||||
// Basic setup
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -36,9 +39,10 @@ func TestAuthenticationMiddleware(t *testing.T) {
|
||||
w.Write([]byte(strconv.Itoa(user.ID)))
|
||||
}
|
||||
})
|
||||
|
||||
var maint uint32
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
// Add the middleware and create the server
|
||||
authHandler := Authentication(logger, cfg, conn, testHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, testHandler, &maint)
|
||||
require.NoError(t, err)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/handler"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
@@ -23,9 +24,14 @@ func (w *wrappedWriter) WriteHeader(statusCode int) {
|
||||
// Middleware to add logs to console with details of the request
|
||||
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/static/css/output.css" ||
|
||||
r.URL.Path == "/static/favicon.ico" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
start, err := contexts.GetStartTime(r.Context())
|
||||
if err != nil {
|
||||
// Handle failure here. internal server error maybe
|
||||
handler.ErrorPage(http.StatusInternalServerError, w, r)
|
||||
return
|
||||
}
|
||||
wrapped := &wrappedWriter{
|
||||
@@ -38,7 +44,7 @@ func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
|
||||
Str("method", r.Method).
|
||||
Str("resource", r.URL.Path).
|
||||
Dur("time_elapsed", time.Since(start)).
|
||||
Str("remote_addr", r.RemoteAddr).
|
||||
Str("remote_addr", r.Header.Get("X-Forwarded-For")).
|
||||
Msg("Served")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,20 +3,15 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/view/page"
|
||||
"projectreshoot/handler"
|
||||
)
|
||||
|
||||
// Checks if the user is set in the context and shows 401 page if not logged in
|
||||
func RequiresLogin(next http.Handler) http.Handler {
|
||||
func LoginReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := contexts.GetUser(r.Context())
|
||||
if user == nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
page.Error(
|
||||
"401",
|
||||
"Unauthorized",
|
||||
"Please login to view this page",
|
||||
).Render(r.Context(), w)
|
||||
handler.ErrorPage(http.StatusUnauthorized, w, r)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -25,7 +20,7 @@ func RequiresLogin(next http.Handler) http.Handler {
|
||||
|
||||
// Checks if the user is set in the context and redirects them to profile if
|
||||
// they are logged in
|
||||
func RequiresLogout(next http.Handler) http.Handler {
|
||||
func LogoutReq(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := contexts.GetUser(r.Context())
|
||||
if user != nil {
|
||||
|
||||
@@ -3,8 +3,11 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,23 +15,26 @@ import (
|
||||
)
|
||||
|
||||
func TestPageLoginRequired(t *testing.T) {
|
||||
// Basic setup
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
var maint uint32
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
// Add the middleware and create the server
|
||||
loginRequiredHandler := RequiresLogin(testHandler)
|
||||
authHandler := Authentication(logger, cfg, conn, loginRequiredHandler)
|
||||
loginRequiredHandler := LoginReq(testHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func RequiresFresh(
|
||||
func FreshReq(
|
||||
next http.Handler,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -3,33 +3,39 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestActionReauthRequired(t *testing.T) {
|
||||
// Basic setup
|
||||
func TestReauthRequired(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
conn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := db.MakeSafe(conn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
var maint uint32
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
// Add the middleware and create the server
|
||||
reauthRequiredHandler := RequiresFresh(testHandler)
|
||||
loginRequiredHandler := RequiresLogin(reauthRequiredHandler)
|
||||
authHandler := Authentication(logger, cfg, conn, loginRequiredHandler)
|
||||
reauthRequiredHandler := FreshReq(testHandler)
|
||||
loginRequiredHandler := LoginReq(reauthRequiredHandler)
|
||||
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
|
||||
69
migrate/migrate.go
Normal file
69
migrate/migrate.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
//go:embed migrations
|
||||
var migrationsFS embed.FS
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 4 {
|
||||
fmt.Println("Usage: prmigrate <file_path> up-to|down-to <version>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
filePath := os.Args[1]
|
||||
direction := os.Args[2]
|
||||
versionStr := os.Args[3]
|
||||
|
||||
version, err := strconv.Atoi(versionStr)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid version number: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
log.Fatalf("Database file does not exist: %v", filePath)
|
||||
}
|
||||
db, err := sql.Open("sqlite", filePath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
migrations, err := fs.Sub(migrationsFS, "migrations")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get migrations from embedded filesystem")
|
||||
}
|
||||
provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create migration provider: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch direction {
|
||||
case "up-to":
|
||||
_, err = provider.UpTo(ctx, int64(version))
|
||||
case "down-to":
|
||||
_, err = provider.DownTo(ctx, int64(version))
|
||||
default:
|
||||
log.Fatalf("Invalid direction: use 'up-to' or 'down-to'")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Migration failed: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Migration successful!")
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
PRAGMA foreign_keys=ON;
|
||||
BEGIN TRANSACTION;
|
||||
CREATE TABLE IF NOT EXISTS jwtblacklist (
|
||||
jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'),
|
||||
exp INTEGER NOT NULL
|
||||
@@ -11,9 +12,16 @@ CREATE TABLE IF NOT EXISTS "users" (
|
||||
created_at INTEGER DEFAULT (unixepoch()),
|
||||
bio TEXT DEFAULT ""
|
||||
) STRICT;
|
||||
CREATE TRIGGER cleanup_expired_tokens
|
||||
CREATE TRIGGER IF NOT EXISTS cleanup_expired_tokens
|
||||
AFTER INSERT ON jwtblacklist
|
||||
BEGIN
|
||||
DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now');
|
||||
END;
|
||||
COMMIT;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
DROP TRIGGER IF EXISTS cleanup_expired_tokens;
|
||||
DROP TABLE IF EXISTS jwtblacklist;
|
||||
DROP TABLE IF EXISTS users;
|
||||
-- +goose StatementEnd
|
||||
@@ -1,11 +1,11 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/handlers"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/handler"
|
||||
"projectreshoot/middleware"
|
||||
"projectreshoot/view/page"
|
||||
|
||||
@@ -17,85 +17,47 @@ func addRoutes(
|
||||
mux *http.ServeMux,
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
staticFS *http.FileSystem,
|
||||
) {
|
||||
route := mux.Handle
|
||||
loggedIn := middleware.LoginReq
|
||||
loggedOut := middleware.LogoutReq
|
||||
fresh := middleware.FreshReq
|
||||
|
||||
// Health check
|
||||
mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {})
|
||||
|
||||
// Static files
|
||||
mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic(staticFS)))
|
||||
route("GET /static/", http.StripPrefix("/static/", handler.StaticFS(staticFS)))
|
||||
|
||||
// Index page and unhandled catchall (404)
|
||||
mux.Handle("GET /", handlers.HandleRoot())
|
||||
route("GET /", handler.Root())
|
||||
|
||||
// Static content, unprotected pages
|
||||
mux.Handle("GET /about", handlers.HandlePage(page.About()))
|
||||
route("GET /about", handler.HandlePage(page.About()))
|
||||
|
||||
// Login page and handlers
|
||||
mux.Handle("GET /login",
|
||||
middleware.RequiresLogout(
|
||||
handlers.HandleLoginPage(config.TrustedHost),
|
||||
))
|
||||
mux.Handle("POST /login",
|
||||
middleware.RequiresLogout(
|
||||
handlers.HandleLoginRequest(
|
||||
config,
|
||||
logger,
|
||||
conn,
|
||||
)))
|
||||
route("GET /login", loggedOut(handler.LoginPage(config.TrustedHost)))
|
||||
route("POST /login", loggedOut(handler.LoginRequest(config, logger, conn)))
|
||||
|
||||
// Register page and handlers
|
||||
mux.Handle("GET /register",
|
||||
middleware.RequiresLogout(
|
||||
handlers.HandleRegisterPage(config.TrustedHost),
|
||||
))
|
||||
mux.Handle("POST /register",
|
||||
middleware.RequiresLogout(
|
||||
handlers.HandleRegisterRequest(
|
||||
config,
|
||||
logger,
|
||||
conn,
|
||||
)))
|
||||
route("GET /register", loggedOut(handler.RegisterPage(config.TrustedHost)))
|
||||
route("POST /register", loggedOut(handler.RegisterRequest(config, logger, conn)))
|
||||
|
||||
// Logout
|
||||
mux.Handle("POST /logout", handlers.HandleLogout(config, logger, conn))
|
||||
route("POST /logout", handler.Logout(config, logger, conn))
|
||||
|
||||
// Reauthentication request
|
||||
mux.Handle("POST /reauthenticate",
|
||||
middleware.RequiresLogin(
|
||||
handlers.HandleReauthenticate(logger, config, conn),
|
||||
))
|
||||
route("POST /reauthenticate", loggedIn(handler.Reauthenticate(logger, config, conn)))
|
||||
|
||||
// Profile page
|
||||
mux.Handle("GET /profile",
|
||||
middleware.RequiresLogin(
|
||||
handlers.HandleProfilePage(),
|
||||
))
|
||||
route("GET /profile", loggedIn(handler.ProfilePage()))
|
||||
|
||||
// Account page
|
||||
mux.Handle("GET /account",
|
||||
middleware.RequiresLogin(
|
||||
handlers.HandleAccountPage(),
|
||||
))
|
||||
mux.Handle("POST /account-select-page",
|
||||
middleware.RequiresLogin(
|
||||
handlers.HandleAccountSubpage(),
|
||||
))
|
||||
mux.Handle("POST /change-username",
|
||||
middleware.RequiresLogin(
|
||||
middleware.RequiresFresh(
|
||||
handlers.HandleChangeUsername(logger, conn),
|
||||
),
|
||||
))
|
||||
mux.Handle("POST /change-bio",
|
||||
middleware.RequiresLogin(
|
||||
handlers.HandleChangeBio(logger, conn),
|
||||
))
|
||||
mux.Handle("POST /change-password",
|
||||
middleware.RequiresLogin(
|
||||
middleware.RequiresFresh(
|
||||
handlers.HandleChangePassword(logger, conn),
|
||||
),
|
||||
))
|
||||
route("GET /account", loggedIn(handler.AccountPage()))
|
||||
route("POST /account-select-page", loggedIn(handler.AccountSubpage()))
|
||||
route("POST /change-username", loggedIn(fresh(handler.ChangeUsername(logger, conn))))
|
||||
route("POST /change-bio", loggedIn(handler.ChangeBio(logger, conn)))
|
||||
route("POST /change-password", loggedIn(fresh(handler.ChangePassword(logger, conn))))
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/middleware"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
@@ -14,8 +14,9 @@ import (
|
||||
func NewServer(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
staticFS *http.FileSystem,
|
||||
maint *uint32,
|
||||
) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
addRoutes(
|
||||
@@ -29,7 +30,7 @@ func NewServer(
|
||||
// Add middleware here, must be added in reverse order of execution
|
||||
// i.e. First in list will get executed last during the request handling
|
||||
handler = middleware.Logging(logger, handler)
|
||||
handler = middleware.Authentication(logger, config, conn, handler)
|
||||
handler = middleware.Authentication(logger, config, conn, handler, maint)
|
||||
|
||||
// Gzip
|
||||
handler = middleware.Gzip(handler, config.GZIP)
|
||||
|
||||
14
setup-hooks.sh
Normal file
14
setup-hooks.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
#!/bin/sh
|
||||
HOOKS_DIR=".githooks"
|
||||
GIT_HOOKS_DIR=".git/hooks"
|
||||
|
||||
mkdir -p "$GIT_HOOKS_DIR"
|
||||
|
||||
for hook in "$HOOKS_DIR"/*; do
|
||||
hook_name=$(basename "$hook")
|
||||
cp "$hook" "$GIT_HOOKS_DIR/$hook_name"
|
||||
chmod +x "$GIT_HOOKS_DIR/$hook_name"
|
||||
done
|
||||
|
||||
echo "Git hooks installed!"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/config"
|
||||
"projectreshoot/db"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
func test(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *sql.DB,
|
||||
conn *db.SafeConn,
|
||||
srv *http.Server,
|
||||
) {
|
||||
}
|
||||
|
||||
@@ -1,64 +1,83 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/pressly/goose/v3"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func findSQLFile(filename string) (string, error) {
|
||||
func findMigrations() (*fs.FS, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||
migrationsdir := os.DirFS(filepath.Join(dir, "migrate", "migrations"))
|
||||
return &migrationsdir, nil
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir { // Reached root
|
||||
return nil, errors.New("Unable to locate migrations directory")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
func findTestData() (string, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, filename)); err == nil {
|
||||
return filepath.Join(dir, filename), nil
|
||||
if _, err := os.Stat(filepath.Join(dir, "main.go")); err == nil {
|
||||
return filepath.Join(dir, "tests", "testdata.sql"), nil
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir { // Reached root
|
||||
return "", errors.New(fmt.Sprintf("Unable to locate %s", filename))
|
||||
return "", errors.New("Unable to locate test data")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
// SetupTestDB initializes a test SQLite database with mock data
|
||||
// Make sure to call DeleteTestDB when finished to cleanup
|
||||
func SetupTestDB() (*sql.DB, error) {
|
||||
conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
|
||||
func SetupTestDB(version int64) (*sql.DB, error) {
|
||||
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
// Setup the test database
|
||||
schemaPath, err := findSQLFile("schema.sql")
|
||||
|
||||
migrations, err := findMigrations()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "findSchema")
|
||||
return nil, errors.Wrap(err, "findMigrations")
|
||||
}
|
||||
provider, err := goose.NewProvider(goose.DialectSQLite3, conn, *migrations)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "goose.NewProvider")
|
||||
}
|
||||
ctx := context.Background()
|
||||
if _, err := provider.UpTo(ctx, version); err != nil {
|
||||
return nil, errors.Wrap(err, "provider.UpTo")
|
||||
}
|
||||
|
||||
sqlBytes, err := os.ReadFile(schemaPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "os.ReadFile")
|
||||
}
|
||||
schemaSQL := string(sqlBytes)
|
||||
|
||||
_, err = conn.Exec(schemaSQL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Exec")
|
||||
}
|
||||
// NOTE: ==================================================
|
||||
// Load the test data
|
||||
dataPath, err := findSQLFile("testdata.sql")
|
||||
dataPath, err := findTestData()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "findSchema")
|
||||
}
|
||||
sqlBytes, err = os.ReadFile(dataPath)
|
||||
sqlBytes, err := os.ReadFile(dataPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "os.ReadFile")
|
||||
}
|
||||
@@ -66,20 +85,7 @@ func SetupTestDB() (*sql.DB, error) {
|
||||
|
||||
_, err = conn.Exec(dataSQL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Exec")
|
||||
return nil, errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Deletes the test database from disk
|
||||
func DeleteTestDB() error {
|
||||
fileName := ".projectreshoot-test-database.db"
|
||||
|
||||
// Attempt to remove the file
|
||||
err := os.Remove(fileName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "os.Remove")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ func NilLogger() *zerolog.Logger {
|
||||
|
||||
// Return a logger that makes use of the T.Log method to enable debugging tests
|
||||
func DebugLogger(t *testing.T) *zerolog.Logger {
|
||||
logger := zerolog.New(&TLogWriter{t: t})
|
||||
logger := zerolog.New(GetTLogWriter(t))
|
||||
return &logger
|
||||
}
|
||||
|
||||
func GetTLogWriter(t *testing.T) *TLogWriter {
|
||||
return &TLogWriter{t: t}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ templ LoginForm(loginError string) {
|
||||
<!-- Form Group -->
|
||||
<div>
|
||||
<label
|
||||
for="email"
|
||||
for="username"
|
||||
class="block text-sm mb-2"
|
||||
>Username</label>
|
||||
<div class="relative">
|
||||
|
||||
@@ -38,7 +38,7 @@ templ RegisterForm(registerError string) {
|
||||
>
|
||||
<div>
|
||||
<label
|
||||
for="email"
|
||||
for="username"
|
||||
class="block text-sm mb-2"
|
||||
>Username</label>
|
||||
<div class="relative">
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package popup
|
||||
|
||||
templ ErrorPopup() {
|
||||
templ Error500Popup() {
|
||||
<div
|
||||
x-cloak
|
||||
x-show="showError"
|
||||
x-show="showError500"
|
||||
class="absolute w-82 left-0 right-0 mt-20 mr-5 ml-auto"
|
||||
x-transition:enter="transform translate-x-[100%] opacity-0 duration-200"
|
||||
x-transition:enter-start="opacity-0 translate-x-[100%]"
|
||||
@@ -44,7 +44,7 @@ templ ErrorPopup() {
|
||||
stroke-width="1.5"
|
||||
stroke="currentColor"
|
||||
class="size-6 text-subtext0 hover:cursor-pointer"
|
||||
@click="showError=false"
|
||||
@click="showError500=false"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
63
view/component/popup/error503Popup.templ
Normal file
63
view/component/popup/error503Popup.templ
Normal file
@@ -0,0 +1,63 @@
|
||||
package popup
|
||||
|
||||
templ Error503Popup() {
|
||||
<div
|
||||
x-cloak
|
||||
x-show="showError503"
|
||||
class="absolute w-82 left-0 right-0 mt-20 mr-5 ml-auto"
|
||||
x-transition:enter="transform translate-x-[100%] opacity-0 duration-200"
|
||||
x-transition:enter-start="opacity-0 translate-x-[100%]"
|
||||
x-transition:enter-end="opacity-100 translate-x-0"
|
||||
x-transition:leave="opacity-0 duration-200"
|
||||
x-transition:leave-start="opacity-100 translate-x-0"
|
||||
x-transition:leave-end="opacity-0 translate-x-[100%]"
|
||||
>
|
||||
<div
|
||||
role="alert"
|
||||
class="rounded-sm bg-dark-red p-4"
|
||||
>
|
||||
<div class="flex justify-between">
|
||||
<div class="flex items-center gap-2 text-red w-fit">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
class="size-5"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M9.401 3.003c1.155-2 4.043-2 5.197 0l7.355
|
||||
12.748c1.154 2-.29 4.5-2.599 4.5H4.645c-2.309
|
||||
0-3.752-2.5-2.598-4.5L9.4 3.003zM12 8.25a.75.75
|
||||
0 01.75.75v3.75a.75.75 0 01-1.5 0V9a.75.75 0
|
||||
01.75-.75zm0 8.25a.75.75 0 100-1.5.75.75 0 000 1.5z"
|
||||
clip-rule="evenodd"
|
||||
></path>
|
||||
</svg>
|
||||
<strong class="block font-medium">Service Unavailable</strong>
|
||||
</div>
|
||||
<div class="flex">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width="1.5"
|
||||
stroke="currentColor"
|
||||
class="size-6 text-subtext0 hover:cursor-pointer"
|
||||
@click="showError503=false"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
></path>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
<p class="mt-2 text-sm text-red">
|
||||
The service is currently available. It could be down for maintenance.
|
||||
Please try again later.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
@@ -41,11 +41,12 @@ templ Global() {
|
||||
<script src="https://unpkg.com/alpinejs" defer></script>
|
||||
<script>
|
||||
// uncomment this line to enable logging of htmx events
|
||||
// htmx.logAll();
|
||||
htmx.logAll();
|
||||
</script>
|
||||
<script>
|
||||
const bodyData = {
|
||||
showError: false,
|
||||
showError500: false,
|
||||
showError503: false,
|
||||
showConfirmPasswordModal: false,
|
||||
handleHtmxBeforeOnLoad(event) {
|
||||
const requestPath = event.detail.pathInfo.requestPath;
|
||||
@@ -65,8 +66,13 @@ templ Global() {
|
||||
|
||||
// internal server error
|
||||
if (errorCode.includes('Code 500')) {
|
||||
this.showError = true;
|
||||
setTimeout(() => this.showError = false, 6000);
|
||||
this.showError500 = true;
|
||||
setTimeout(() => this.showError500 = false, 6000);
|
||||
}
|
||||
// service not available error
|
||||
if (errorCode.includes('Code 503')) {
|
||||
this.showError503 = true;
|
||||
setTimeout(() => this.showError503 = false, 6000);
|
||||
}
|
||||
|
||||
// user is authorized but needs to refresh their login
|
||||
@@ -83,7 +89,8 @@ templ Global() {
|
||||
x-on:htmx:error="handleHtmxError($event)"
|
||||
x-on:htmx:before-on-load="handleHtmxBeforeOnLoad($event)"
|
||||
>
|
||||
@popup.ErrorPopup()
|
||||
@popup.Error500Popup()
|
||||
@popup.Error503Popup()
|
||||
@popup.ConfirmPasswordModal()
|
||||
<div
|
||||
id="main-content"
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package page
|
||||
|
||||
import "projectreshoot/view/layout"
|
||||
import "strconv"
|
||||
|
||||
// Page template for Error pages. Error code should be a HTTP status code as
|
||||
// a string, and err should be the corresponding response title.
|
||||
// Message is a custom error message displayed below the code and error.
|
||||
templ Error(code string, err string, message string) {
|
||||
templ Error(code int, err string, message string) {
|
||||
@layout.Global() {
|
||||
<div
|
||||
class="grid mt-24 left-0 right-0 top-0 bottom-0
|
||||
@@ -14,7 +15,7 @@ templ Error(code string, err string, message string) {
|
||||
<div class="text-center">
|
||||
<h1
|
||||
class="text-9xl text-text"
|
||||
>{ code }</h1>
|
||||
>{ strconv.Itoa(code) }</h1>
|
||||
<p
|
||||
class="text-2xl font-bold tracking-tight text-subtext1
|
||||
sm:text-4xl"
|
||||
|
||||
@@ -8,7 +8,6 @@ templ Index() {
|
||||
<div class="text-center mt-24">
|
||||
<div class="text-4xl lg:text-6xl">Project Reshoot</div>
|
||||
<div>A better way to discover and rate films</div>
|
||||
<div>If you're seeing this text, you're my favourite :)</div>
|
||||
</div>
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user