Merge pull request #10 from Haelnorr/databaselocking

Databaselocking
This commit is contained in:
2025-02-18 23:21:24 +11:00
committed by GitHub
39 changed files with 1020 additions and 211 deletions

11
.githooks/pre-push Normal file
View File

@@ -0,0 +1,11 @@
#!/bin/sh
protected_branches=("master" "staging")
current_branch=$(git rev-parse --abbrev-ref HEAD)
for branch in "${protected_branches[@]}"; do
if [ "$current_branch" = "$branch" ]; then
echo "Direct pushes to '$branch' are not allowed. Use a pull request instead."
exit 1
fi
done
exit 0

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
query.sql
*.db
.logs/
server.log
tmp/
projectreshoot
static/css/output.css

View File

@@ -20,10 +20,11 @@ tester:
go run . --port 3232 --test --loglevel trace
test:
rm -f **/.projectreshoot-test-database.db && \
go mod tidy && \
templ generate && \
go generate && \
go test .
go test ./db
go test ./middleware
clean:

View File

@@ -22,6 +22,7 @@ type Config struct {
WriteTimeout time.Duration // Timeout for writing requests in seconds
IdleTimeout time.Duration // Timeout for idle connections in seconds
DBName string // Filename of the db (doesnt include file extension)
DBLockTimeout time.Duration // Timeout for acquiring database lock
SecretKey string // Secret key for signing tokens
AccessTokenExpiry int64 // Access token expiry in minutes
RefreshTokenExpiry int64 // Refresh token expiry in minutes
@@ -33,10 +34,7 @@ type Config struct {
// Load the application configuration and get a pointer to the Config object
func GetConfig(args map[string]string) (*Config, error) {
err := godotenv.Load(".env")
if err != nil {
fmt.Println(err)
}
godotenv.Load(".env")
var (
host string
port string
@@ -90,6 +88,7 @@ func GetConfig(args map[string]string) (*Config, error) {
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
DBName: GetEnvDefault("DB_NAME", "projectreshoot"),
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
SecretKey: os.Getenv("SECRET_KEY"),
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day

View File

@@ -1,21 +1,190 @@
package db
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/rs/zerolog"
_ "github.com/mattn/go-sqlite3"
_ "modernc.org/sqlite"
)
// Returns a database connection handle for the Turso DB
func ConnectToDatabase(dbName string) (*sql.DB, error) {
file := fmt.Sprintf("file:%s.db", dbName)
db, err := sql.Open("sqlite3", file)
type SafeConn struct {
db *sql.DB
readLockCount uint32
globalLockStatus uint32
globalLockRequested uint32
logger *zerolog.Logger
}
func MakeSafe(db *sql.DB, logger *zerolog.Logger) *SafeConn {
return &SafeConn{db: db, logger: logger}
}
// Extends sql.Tx for use with SafeConn
type SafeTX struct {
tx *sql.Tx
sc *SafeConn
}
func (conn *SafeConn) acquireGlobalLock() bool {
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
return false
}
conn.globalLockStatus = 1
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock acquired")
return true
}
func (conn *SafeConn) releaseGlobalLock() {
conn.globalLockStatus = 0
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
Msg("Global lock released")
}
func (conn *SafeConn) acquireReadLock() bool {
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
return false
}
conn.readLockCount += 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock acquired")
return true
}
func (conn *SafeConn) releaseReadLock() {
conn.readLockCount -= 1
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
Msg("Read lock released")
}
// Starts a new transaction based on the current context. Will cancel if
// the context is closed/cancelled/done
func (conn *SafeConn) Begin(ctx context.Context) (*SafeTX, error) {
lockAcquired := make(chan struct{})
lockCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-lockCtx.Done():
return
default:
if conn.acquireReadLock() {
close(lockAcquired)
}
}
}()
select {
case <-lockAcquired:
tx, err := conn.db.BeginTx(ctx, nil)
if err != nil {
conn.releaseReadLock()
return nil, err
}
return &SafeTX{tx: tx, sc: conn}, nil
case <-ctx.Done():
cancel()
return nil, errors.New("Transaction time out due to database lock")
}
}
// Query the database inside the transaction
func (stx *SafeTX) Query(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Rows, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
return stx.tx.QueryContext(ctx, query, args...)
}
// Exec a statement on the database inside the transaction
func (stx *SafeTX) Exec(
ctx context.Context,
query string,
args ...interface{},
) (sql.Result, error) {
if stx.tx == nil {
return nil, errors.New("Cannot exec without a transaction")
}
return stx.tx.ExecContext(ctx, query, args...)
}
// Commit the current transaction and release the read lock
func (stx *SafeTX) Commit() error {
if stx.tx == nil {
return errors.New("Cannot commit without a transaction")
}
err := stx.tx.Commit()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Abort the current transaction, releasing the read lock
func (stx *SafeTX) Rollback() error {
if stx.tx == nil {
return errors.New("Cannot rollback without a transaction")
}
err := stx.tx.Rollback()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Acquire a global lock, preventing all transactions
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
conn.logger.Info().Msg("Attempting to acquire global database lock")
conn.globalLockRequested = 1
defer func() { conn.globalLockRequested = 0 }()
timeout := time.After(timeoutAfter)
attempt := 0
for {
if conn.acquireGlobalLock() {
conn.logger.Info().Msg("Global database lock acquired")
return
}
select {
case <-timeout:
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
return
case <-time.After(100 * time.Millisecond):
attempt++
}
}
}
// Release the global lock
func (conn *SafeConn) Resume() {
conn.releaseGlobalLock()
conn.logger.Info().Msg("Global database lock released")
}
// Close the database connection
func (conn *SafeConn) Close() error {
conn.logger.Debug().Msg("Acquiring global lock for connection close")
conn.acquireGlobalLock()
defer conn.releaseGlobalLock()
conn.logger.Debug().Msg("Closing database connection")
return conn.db.Close()
}
// Returns a database connection handle for the DB
func ConnectToDatabase(dbName string, logger *zerolog.Logger) (*SafeConn, error) {
file := fmt.Sprintf("file:%s.db", dbName)
db, err := sql.Open("sqlite", file)
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
return db, nil
conn := MakeSafe(db, logger)
return conn, nil
}

134
db/connection_test.go Normal file
View File

@@ -0,0 +1,134 @@
package db
import (
"context"
"projectreshoot/tests"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSafeConn(t *testing.T) {
logger := tests.NilLogger()
conn, err := tests.SetupTestDB()
require.NoError(t, err)
sconn := MakeSafe(conn, logger)
defer sconn.Close()
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
var requested sync.WaitGroup
var engaged sync.WaitGroup
requested.Add(1)
engaged.Add(1)
go func() {
requested.Done()
sconn.Pause(5 * time.Second)
engaged.Done()
}()
requested.Wait()
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(1), sconn.globalLockRequested)
tx.Commit()
engaged.Wait()
assert.Equal(t, uint32(1), sconn.globalLockStatus)
assert.Equal(t, uint32(0), sconn.globalLockRequested)
sconn.Resume()
})
t.Run("Lock abandons after timeout", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
sconn.Pause(250 * time.Millisecond)
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(0), sconn.globalLockRequested)
tx.Commit()
})
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
var requested sync.WaitGroup
var engaged sync.WaitGroup
requested.Add(1)
engaged.Add(1)
go func() {
requested.Done()
sconn.Pause(5 * time.Second)
engaged.Done()
}()
requested.Wait()
assert.Equal(t, uint32(0), sconn.globalLockStatus)
assert.Equal(t, uint32(1), sconn.globalLockRequested)
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
_, err = sconn.Begin(ctx)
require.Error(t, err)
tx.Commit()
engaged.Wait()
_, err = sconn.Begin(ctx)
require.Error(t, err)
sconn.Resume()
tx, err = sconn.Begin(t.Context())
require.NoError(t, err)
tx.Commit()
})
}
func TestSafeTX(t *testing.T) {
logger := tests.NilLogger()
conn, err := tests.SetupTestDB()
require.NoError(t, err)
sconn := MakeSafe(conn, logger)
defer sconn.Close()
t.Run("Commit releases lock", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
assert.Equal(t, uint32(1), sconn.readLockCount)
tx.Commit()
assert.Equal(t, uint32(0), sconn.readLockCount)
})
t.Run("Rollback releases lock", func(t *testing.T) {
tx, err := sconn.Begin(t.Context())
require.NoError(t, err)
assert.Equal(t, uint32(1), sconn.readLockCount)
tx.Rollback()
assert.Equal(t, uint32(0), sconn.readLockCount)
})
t.Run("Multiple TX can gain read lock", func(t *testing.T) {
tx1, err := sconn.Begin(t.Context())
require.NoError(t, err)
tx2, err := sconn.Begin(t.Context())
require.NoError(t, err)
tx3, err := sconn.Begin(t.Context())
require.NoError(t, err)
tx1.Commit()
tx2.Commit()
tx3.Commit()
})
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
sconn.acquireGlobalLock()
defer sconn.releaseGlobalLock()
_, err := sconn.Begin(ctx)
require.Error(t, err)
})
t.Run("Lock acquires if lock released", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
defer cancel()
sconn.acquireGlobalLock()
var wg sync.WaitGroup
wg.Add(1)
go func() {
tx, err := sconn.Begin(ctx)
require.NoError(t, err)
tx.Commit()
wg.Done()
}()
sconn.releaseGlobalLock()
wg.Wait()
})
}

View File

@@ -1,7 +1,7 @@
package db
import (
"database/sql"
"context"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
@@ -16,16 +16,16 @@ type User struct {
}
// Uses bcrypt to set the users Password_hash from the given password
func (user *User) SetPassword(conn *sql.DB, password string) error {
func (user *User) SetPassword(ctx context.Context, tx *SafeTX, password string) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
}
user.Password_hash = string(hashedPassword)
query := `UPDATE users SET password_hash = ? WHERE id = ?`
_, err = conn.Exec(query, user.Password_hash, user.ID)
_, err = tx.Exec(ctx, query, user.Password_hash, user.ID)
if err != nil {
return errors.Wrap(err, "conn.Exec")
return errors.Wrap(err, "tx.Exec")
}
return nil
}
@@ -40,21 +40,21 @@ func (user *User) CheckPassword(password string) error {
}
// Change the user's username
func (user *User) ChangeUsername(conn *sql.DB, newUsername string) error {
func (user *User) ChangeUsername(ctx context.Context, tx *SafeTX, newUsername string) error {
query := `UPDATE users SET username = ? WHERE id = ?`
_, err := conn.Exec(query, newUsername, user.ID)
_, err := tx.Exec(ctx, query, newUsername, user.ID)
if err != nil {
return errors.Wrap(err, "conn.Exec")
return errors.Wrap(err, "tx.Exec")
}
return nil
}
// Change the user's bio
func (user *User) ChangeBio(conn *sql.DB, newBio string) error {
func (user *User) ChangeBio(ctx context.Context, tx *SafeTX, newBio string) error {
query := `UPDATE users SET bio = ? WHERE id = ?`
_, err := conn.Exec(query, newBio, user.ID)
_, err := tx.Exec(ctx, query, newBio, user.ID)
if err != nil {
return errors.Wrap(err, "conn.Exec")
return errors.Wrap(err, "tx.Exec")
}
return nil
}

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"fmt"
@@ -8,17 +9,22 @@ import (
)
// Creates a new user in the database and returns a pointer
func CreateNewUser(conn *sql.DB, username string, password string) (*User, error) {
func CreateNewUser(
ctx context.Context,
tx *SafeTX,
username string,
password string,
) (*User, error) {
query := `INSERT INTO users (username) VALUES (?)`
_, err := conn.Exec(query, username)
_, err := tx.Exec(ctx, query, username)
if err != nil {
return nil, errors.Wrap(err, "conn.Exec")
return nil, errors.Wrap(err, "tx.Exec")
}
user, err := GetUserFromUsername(conn, username)
user, err := GetUserFromUsername(ctx, tx, username)
if err != nil {
return nil, errors.Wrap(err, "GetUserFromUsername")
}
err = user.SetPassword(conn, password)
err = user.SetPassword(ctx, tx, password)
if err != nil {
return nil, errors.Wrap(err, "user.SetPassword")
}
@@ -26,7 +32,12 @@ func CreateNewUser(conn *sql.DB, username string, password string) (*User, error
}
// Fetches data from the users table using "WHERE column = 'value'"
func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, error) {
func fetchUserData(
ctx context.Context,
tx *SafeTX,
column string,
value interface{},
) (*sql.Rows, error) {
query := fmt.Sprintf(
`SELECT
id,
@@ -38,9 +49,9 @@ func fetchUserData(conn *sql.DB, column string, value interface{}) (*sql.Rows, e
WHERE %s = ? COLLATE NOCASE LIMIT 1`,
column,
)
rows, err := conn.Query(query, value)
rows, err := tx.Query(ctx, query, value)
if err != nil {
return nil, errors.Wrap(err, "conn.Query")
return nil, errors.Wrap(err, "tx.Query")
}
return rows, nil
}
@@ -66,8 +77,8 @@ func scanUserRow(user *User, rows *sql.Rows) error {
// Queries the database for a user matching the given username.
// Query is case insensitive
func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
rows, err := fetchUserData(conn, "username", username)
func GetUserFromUsername(ctx context.Context, tx *SafeTX, username string) (*User, error) {
rows, err := fetchUserData(ctx, tx, "username", username)
if err != nil {
return nil, errors.Wrap(err, "fetchUserData")
}
@@ -81,8 +92,8 @@ func GetUserFromUsername(conn *sql.DB, username string) (*User, error) {
}
// Queries the database for a user matching the given ID.
func GetUserFromID(conn *sql.DB, id int) (*User, error) {
rows, err := fetchUserData(conn, "id", id)
func GetUserFromID(ctx context.Context, tx *SafeTX, id int) (*User, error) {
rows, err := fetchUserData(ctx, tx, "id", id)
if err != nil {
return nil, errors.Wrap(err, "fetchUserData")
}
@@ -96,11 +107,11 @@ func GetUserFromID(conn *sql.DB, id int) (*User, error) {
}
// Checks if the given username is unique. Returns true if not taken
func CheckUsernameUnique(conn *sql.DB, username string) (bool, error) {
func CheckUsernameUnique(ctx context.Context, tx *SafeTX, username string) (bool, error) {
query := `SELECT 1 FROM users WHERE username = ? COLLATE NOCASE LIMIT 1`
rows, err := conn.Query(query, username)
rows, err := tx.Query(ctx, query, username)
if err != nil {
return false, errors.Wrap(err, "conn.Query")
return false, errors.Wrap(err, "tx.Query")
}
defer rows.Close()
taken := rows.Next()

11
go.mod
View File

@@ -7,18 +7,27 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/mattn/go-sqlite3 v1.14.24
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.10.0
golang.org/x/crypto v0.33.0
modernc.org/sqlite v1.35.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.61.13 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.8.2 // indirect
)

50
go.sum
View File

@@ -3,15 +3,24 @@ github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YY
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@@ -19,12 +28,14 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
@@ -32,12 +43,45 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo=
golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0=
modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo=
modernc.org/ccgo/v4 v4.23.16/go.mod h1:nNma8goMTY7aQZQNTyN9AIoJfxav4nvTnvKThAeMDdo=
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
modernc.org/gc/v2 v2.6.3 h1:aJVhcqAte49LF+mGveZ5KPlsp4tdGdAOT4sipJXADjw=
modernc.org/gc/v2 v2.6.3/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/libc v1.61.13 h1:3LRd6ZO1ezsFiX1y+bHd1ipyEHIJKvuprv0sLTBwLW8=
modernc.org/libc v1.61.13/go.mod h1:8F/uJWL/3nNil0Lgt1Dpz+GgkApWh04N3el3hxJcA6E=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.8.2 h1:cL9L4bcoAObu4NkxOlKWBWtNHIsnnACGF/TbqQ6sbcI=
modernc.org/memory v1.8.2/go.mod h1:ZbjSvMO5NQ1A2i3bWeDiVMxIorXwdClKE/0SZ+BMotU=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.35.0 h1:yQps4fegMnZFdphtzlfQTCNBWtS0CZv48pRpW3RFHRw=
modernc.org/sqlite v1.35.0/go.mod h1:9cr2sicr7jIaWTBKQmAxQLfBv9LL0su4ZTEV+utt3ic=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

View File

@@ -1,8 +1,9 @@
package handlers
import (
"database/sql"
"context"
"net/http"
"time"
"projectreshoot/contexts"
"projectreshoot/cookies"
@@ -43,31 +44,44 @@ func HandleAccountSubpage() http.Handler {
// Handles a request to change the users username
func HandleChangeUsername(
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
logger.Warn().Err(err).Msg("Error updating username")
w.WriteHeader(http.StatusServiceUnavailable)
return
}
r.ParseForm()
newUsername := r.FormValue("username")
unique, err := db.CheckUsernameUnique(conn, newUsername)
unique, err := db.CheckUsernameUnique(ctx, tx, newUsername)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error updating username")
w.WriteHeader(http.StatusInternalServerError)
return
}
if !unique {
tx.Rollback()
account.ChangeUsername("Username is taken", newUsername).
Render(r.Context(), w)
return
}
user := contexts.GetUser(r.Context())
err = user.ChangeUsername(conn, newUsername)
err = user.ChangeUsername(ctx, tx, newUsername)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error updating username")
w.WriteHeader(http.StatusInternalServerError)
return
}
tx.Commit()
w.Header().Set("HX-Refresh", "true")
},
)
@@ -76,30 +90,47 @@ func HandleChangeUsername(
// Handles a request to change the users bio
func HandleChangeBio(
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
logger.Warn().Err(err).Msg("Error updating bio")
w.WriteHeader(http.StatusServiceUnavailable)
return
}
r.ParseForm()
newBio := r.FormValue("bio")
leng := len([]rune(newBio))
if leng > 128 {
tx.Rollback()
account.ChangeBio("Bio limited to 128 characters", newBio).
Render(r.Context(), w)
return
}
user := contexts.GetUser(r.Context())
err := user.ChangeBio(conn, newBio)
err = user.ChangeBio(ctx, tx, newBio)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error updating bio")
w.WriteHeader(http.StatusInternalServerError)
return
}
tx.Commit()
w.Header().Set("HX-Refresh", "true")
},
)
}
func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) {
func validateChangePassword(
ctx context.Context,
tx *db.SafeTX,
r *http.Request,
) (string, error) {
r.ParseForm()
formPassword := r.FormValue("password")
formConfirmPassword := r.FormValue("confirm-password")
@@ -115,22 +146,35 @@ func validateChangePassword(conn *sql.DB, r *http.Request) (string, error) {
// Handles a request to change the users password
func HandleChangePassword(
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
newPass, err := validateChangePassword(conn, r)
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
logger.Warn().Err(err).Msg("Error updating password")
w.WriteHeader(http.StatusServiceUnavailable)
return
}
newPass, err := validateChangePassword(ctx, tx, r)
if err != nil {
tx.Rollback()
account.ChangePassword(err.Error()).Render(r.Context(), w)
return
}
user := contexts.GetUser(r.Context())
err = user.SetPassword(conn, newPass)
err = user.SetPassword(ctx, tx, newPass)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error updating password")
w.WriteHeader(http.StatusInternalServerError)
return
}
tx.Commit()
w.Header().Set("HX-Refresh", "true")
},
)

24
handlers/errorpage.go Normal file
View File

@@ -0,0 +1,24 @@
package handlers
import (
"net/http"
"projectreshoot/view/page"
)
func ErrorPage(
errorCode int,
w http.ResponseWriter,
r *http.Request,
) {
message := map[int]string{
401: "You need to login to view this page.",
403: "You do not have permission to view this page.",
404: "The page or resource you have requested does not exist.",
500: `An error occured on the server. Please try again, and if this
continues to happen contact an administrator.`,
503: "The server is currently down for maintenance and should be back soon. =)",
}
w.WriteHeader(http.StatusUnauthorized)
page.Error(errorCode, http.StatusText(errorCode), message[errorCode]).
Render(r.Context(), w)
}

View File

@@ -12,11 +12,7 @@ func HandleRoot() http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
page.Error(
"404",
"Page not found",
"The page or resource you have requested does not exist",
).Render(r.Context(), w)
ErrorPage(http.StatusNotFound, w, r)
return
}
page.Index().Render(r.Context(), w)

View File

@@ -1,8 +1,9 @@
package handlers
import (
"database/sql"
"context"
"net/http"
"time"
"projectreshoot/config"
"projectreshoot/cookies"
@@ -16,10 +17,14 @@ import (
// Validates the username matches a user in the database and the password
// is correct. Returns the corresponding user
func validateLogin(conn *sql.DB, r *http.Request) (*db.User, error) {
func validateLogin(
ctx context.Context,
tx *db.SafeTX,
r *http.Request,
) (*db.User, error) {
formUsername := r.FormValue("username")
formPassword := r.FormValue("password")
user, err := db.GetUserFromUsername(conn, formUsername)
user, err := db.GetUserFromUsername(ctx, tx, formUsername)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromUsername")
}
@@ -47,13 +52,24 @@ func checkRememberMe(r *http.Request) bool {
func HandleLoginRequest(
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
user, err := validateLogin(conn, r)
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
logger.Warn().Err(err).Msg("Failed to set token cookies")
w.WriteHeader(http.StatusServiceUnavailable)
return
}
r.ParseForm()
user, err := validateLogin(ctx, tx, r)
if err != nil {
tx.Rollback()
if err.Error() != "Username or password incorrect" {
logger.Warn().Caller().Err(err).Msg("Login request failed")
w.WriteHeader(http.StatusInternalServerError)
@@ -66,10 +82,13 @@ func HandleLoginRequest(
rememberMe := checkRememberMe(r)
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
if err != nil {
tx.Rollback()
w.WriteHeader(http.StatusInternalServerError)
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
return
}
tx.Commit()
pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom)
},

View File

@@ -1,41 +1,80 @@
package handlers
import (
"database/sql"
"context"
"net/http"
"strings"
"time"
"projectreshoot/config"
"projectreshoot/cookies"
"projectreshoot/db"
"projectreshoot/jwt"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
func revokeAccess(
config *config.Config,
ctx context.Context,
tx *db.SafeTX,
atStr string,
) error {
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
if err != nil {
if strings.Contains(err.Error(), "Token is expired") ||
strings.Contains(err.Error(), "Token has been revoked") {
return nil // Token is expired, dont need to revoke it
}
return errors.Wrap(err, "jwt.ParseAccessToken")
}
err = jwt.RevokeToken(ctx, tx, aT)
if err != nil {
return errors.Wrap(err, "jwt.RevokeToken")
}
return nil
}
func revokeRefresh(
config *config.Config,
ctx context.Context,
tx *db.SafeTX,
rtStr string,
) error {
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
if err != nil {
if strings.Contains(err.Error(), "Token is expired") ||
strings.Contains(err.Error(), "Token has been revoked") {
return nil // Token is expired, dont need to revoke it
}
return errors.Wrap(err, "jwt.ParseRefreshToken")
}
err = jwt.RevokeToken(ctx, tx, rT)
if err != nil {
return errors.Wrap(err, "jwt.RevokeToken")
}
return nil
}
// Retrieve and revoke the user's tokens
func revokeTokens(
config *config.Config,
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
r *http.Request,
) error {
// get the tokens from the cookies
atStr, rtStr := cookies.GetTokenStrings(r)
aT, err := jwt.ParseAccessToken(config, conn, atStr)
if err != nil {
return errors.Wrap(err, "jwt.ParseAccessToken")
}
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
if err != nil {
return errors.Wrap(err, "jwt.ParseRefreshToken")
}
// revoke the refresh token first as the access token expires quicker
// only matters if there is an error revoking the tokens
err = jwt.RevokeToken(conn, rT)
err := revokeRefresh(config, ctx, tx, rtStr)
if err != nil {
return errors.Wrap(err, "jwt.RevokeToken")
return errors.Wrap(err, "revokeRefresh")
}
err = jwt.RevokeToken(conn, aT)
err = revokeAccess(config, ctx, tx, atStr)
if err != nil {
return errors.Wrap(err, "jwt.RevokeToken")
return errors.Wrap(err, "revokeAccess")
}
return nil
}
@@ -44,16 +83,28 @@ func revokeTokens(
func HandleLogout(
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
err := revokeTokens(config, conn, r)
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
logger.Warn().Err(err).Msg("Error occured on user logout")
w.WriteHeader(http.StatusServiceUnavailable)
return
}
err = revokeTokens(config, ctx, tx, r)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Error occured on user logout")
w.WriteHeader(http.StatusInternalServerError)
return
}
tx.Commit()
cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/")
w.Header().Set("HX-Redirect", "/login")

View File

@@ -1,12 +1,14 @@
package handlers
import (
"database/sql"
"context"
"net/http"
"time"
"projectreshoot/config"
"projectreshoot/contexts"
"projectreshoot/cookies"
"projectreshoot/db"
"projectreshoot/jwt"
"projectreshoot/view/component/form"
@@ -17,16 +19,17 @@ import (
// Get the tokens from the request
func getTokens(
config *config.Config,
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
r *http.Request,
) (*jwt.AccessToken, *jwt.RefreshToken, error) {
// get the existing tokens from the cookies
atStr, rtStr := cookies.GetTokenStrings(r)
aT, err := jwt.ParseAccessToken(config, conn, atStr)
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
if err != nil {
return nil, nil, errors.Wrap(err, "jwt.ParseAccessToken")
}
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
if err != nil {
return nil, nil, errors.Wrap(err, "jwt.ParseRefreshToken")
}
@@ -35,15 +38,16 @@ func getTokens(
// Revoke the given token pair
func revokeTokenPair(
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
aT *jwt.AccessToken,
rT *jwt.RefreshToken,
) error {
err := jwt.RevokeToken(conn, aT)
err := jwt.RevokeToken(ctx, tx, aT)
if err != nil {
return errors.Wrap(err, "jwt.RevokeToken")
}
err = jwt.RevokeToken(conn, rT)
err = jwt.RevokeToken(ctx, tx, rT)
if err != nil {
return errors.Wrap(err, "jwt.RevokeToken")
}
@@ -53,11 +57,12 @@ func revokeTokenPair(
// Issue new tokens for the user, invalidating the old ones
func refreshTokens(
config *config.Config,
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
w http.ResponseWriter,
r *http.Request,
) error {
aT, rT, err := getTokens(config, conn, r)
aT, rT, err := getTokens(config, ctx, tx, r)
if err != nil {
return errors.Wrap(err, "getTokens")
}
@@ -71,7 +76,7 @@ func refreshTokens(
if err != nil {
return errors.Wrap(err, "cookies.SetTokenCookies")
}
err = revokeTokenPair(conn, aT, rT)
err = revokeTokenPair(ctx, tx, aT, rT)
if err != nil {
return errors.Wrap(err, "revokeTokenPair")
}
@@ -97,22 +102,35 @@ func validatePassword(
func HandleReauthenticate(
logger *zerolog.Logger,
config *config.Config,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
err := validatePassword(r)
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
logger.Warn().Err(err).Msg("Failed to refresh user tokens")
w.WriteHeader(http.StatusServiceUnavailable)
return
}
err = validatePassword(r)
if err != nil {
tx.Rollback()
w.WriteHeader(445)
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
return
}
err = refreshTokens(config, conn, w, r)
err = refreshTokens(config, ctx, tx, w, r)
if err != nil {
tx.Rollback()
logger.Error().Err(err).Msg("Failed to refresh user tokens")
w.WriteHeader(http.StatusInternalServerError)
return
}
tx.Commit()
w.WriteHeader(http.StatusOK)
},
)

View File

@@ -1,8 +1,9 @@
package handlers
import (
"database/sql"
"context"
"net/http"
"time"
"projectreshoot/config"
"projectreshoot/cookies"
@@ -14,11 +15,15 @@ import (
"github.com/rs/zerolog"
)
func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
func validateRegistration(
ctx context.Context,
tx *db.SafeTX,
r *http.Request,
) (*db.User, error) {
formUsername := r.FormValue("username")
formPassword := r.FormValue("password")
formConfirmPassword := r.FormValue("confirm-password")
unique, err := db.CheckUsernameUnique(conn, formUsername)
unique, err := db.CheckUsernameUnique(ctx, tx, formUsername)
if err != nil {
return nil, errors.Wrap(err, "db.CheckUsernameUnique")
}
@@ -31,7 +36,7 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
if len(formPassword) > 72 {
return nil, errors.New("Password exceeds maximum length of 72 bytes")
}
user, err := db.CreateNewUser(conn, formUsername, formPassword)
user, err := db.CreateNewUser(ctx, tx, formUsername, formPassword)
if err != nil {
return nil, errors.Wrap(err, "db.CreateNewUser")
}
@@ -42,13 +47,24 @@ func validateRegistration(conn *sql.DB, r *http.Request) (*db.User, error) {
func HandleRegisterRequest(
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
user, err := validateRegistration(conn, r)
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
logger.Warn().Err(err).Msg("Failed to set token cookies")
w.WriteHeader(http.StatusServiceUnavailable)
return
}
r.ParseForm()
user, err := validateRegistration(ctx, tx, r)
if err != nil {
tx.Rollback()
if err.Error() != "Username is taken" &&
err.Error() != "Passwords do not match" &&
err.Error() != "Password exceeds maximum length of 72 bytes" {
@@ -63,10 +79,12 @@ func HandleRegisterRequest(
rememberMe := checkRememberMe(r)
err = cookies.SetTokenCookies(w, r, config, user, true, rememberMe)
if err != nil {
tx.Rollback()
w.WriteHeader(http.StatusInternalServerError)
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
return
}
tx.Commit()
pageFrom := cookies.CheckPageFrom(w, r)
w.Header().Set("HX-Redirect", pageFrom)
},

View File

@@ -0,0 +1,37 @@
package handlers
import (
"context"
"net/http"
"time"
"projectreshoot/db"
"github.com/rs/zerolog"
)
func removeme(
w http.ResponseWriter,
r *http.Request,
logger *zerolog.Logger,
conn *db.SafeConn,
handler func(
ctx context.Context,
tx *db.SafeTX,
w http.ResponseWriter,
r *http.Request,
),
onfail func(err error),
) {
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
onfail(err)
return
}
handler(ctx, tx, w, r)
}

View File

@@ -1,11 +1,12 @@
package jwt
import (
"database/sql"
"context"
"fmt"
"time"
"projectreshoot/config"
"projectreshoot/db"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
@@ -17,7 +18,8 @@ import (
// has the correct scope.
func ParseAccessToken(
config *config.Config,
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
tokenString string,
) (*AccessToken, error) {
if tokenString == "" {
@@ -74,7 +76,7 @@ func ParseAccessToken(
Scope: scope,
}
valid, err := CheckTokenNotRevoked(conn, token)
valid, err := CheckTokenNotRevoked(ctx, tx, token)
if err != nil {
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
}
@@ -89,7 +91,8 @@ func ParseAccessToken(
// has the correct scope.
func ParseRefreshToken(
config *config.Config,
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
tokenString string,
) (*RefreshToken, error) {
if tokenString == "" {
@@ -141,7 +144,7 @@ func ParseRefreshToken(
Scope: scope,
}
valid, err := CheckTokenNotRevoked(conn, token)
valid, err := CheckTokenNotRevoked(ctx, tx, token)
if err != nil {
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
}

View File

@@ -1,32 +1,33 @@
package jwt
import (
"database/sql"
"context"
"projectreshoot/db"
"github.com/pkg/errors"
)
// Revoke a token by adding it to the database
func RevokeToken(conn *sql.DB, t Token) error {
func RevokeToken(ctx context.Context, tx *db.SafeTX, t Token) error {
jti := t.GetJTI()
exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err := conn.Exec(query, jti, exp)
_, err := tx.Exec(ctx, query, jti, exp)
if err != nil {
return errors.Wrap(err, "conn.Exec")
return errors.Wrap(err, "tx.Exec")
}
return nil
}
// Check if a token has been revoked. Returns true if not revoked.
func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
func CheckTokenNotRevoked(ctx context.Context, tx *db.SafeTX, t Token) (bool, error) {
jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := conn.Query(query, jti)
defer rows.Close()
rows, err := tx.Query(ctx, query, jti)
if err != nil {
return false, errors.Wrap(err, "conn.Exec")
return false, errors.Wrap(err, "tx.Query")
}
defer rows.Close()
revoked := rows.Next()
return !revoked, nil
}

View File

@@ -1,7 +1,7 @@
package jwt
import (
"database/sql"
"context"
"projectreshoot/db"
"github.com/google/uuid"
@@ -12,7 +12,7 @@ type Token interface {
GetJTI() uuid.UUID
GetEXP() int64
GetScope() string
GetUser(conn *sql.DB) (*db.User, error)
GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error)
}
// Access token
@@ -38,15 +38,15 @@ type RefreshToken struct {
Scope string // Should be "refresh"
}
func (a AccessToken) GetUser(conn *sql.DB) (*db.User, error) {
user, err := db.GetUserFromID(conn, a.SUB)
func (a AccessToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
user, err := db.GetUserFromID(ctx, tx, a.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}
return user, nil
}
func (r RefreshToken) GetUser(conn *sql.DB) (*db.User, error) {
user, err := db.GetUserFromID(conn, r.SUB)
func (r RefreshToken) GetUser(ctx context.Context, tx *db.SafeTX) (*db.User, error) {
user, err := db.GetUserFromID(ctx, tx, r.SUB)
if err != nil {
return nil, errors.Wrap(err, "db.GetUserFromID")
}

72
main.go
View File

@@ -13,6 +13,8 @@ import (
"os/signal"
"strconv"
"sync"
"sync/atomic"
"syscall"
"time"
"projectreshoot/config"
@@ -21,20 +23,21 @@ import (
"projectreshoot/server"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
//go:embed static/*
var embeddedStatic embed.FS
// Gets the static files
func getStaticFiles() (http.FileSystem, error) {
func getStaticFiles(logger *zerolog.Logger) (http.FileSystem, error) {
if _, err := os.Stat("static"); err == nil {
// Use actual filesystem in development
fmt.Println("Using filesystem for static files")
logger.Debug().Msg("Using filesystem for static files")
return http.Dir("static"), nil
} else {
// Use embedded filesystem in production
fmt.Println("Using embedded static files")
logger.Debug().Msg("Using embedded static files")
subFS, err := fs.Sub(embeddedStatic, "static")
if err != nil {
return nil, errors.Wrap(err, "fs.Sub")
@@ -43,6 +46,44 @@ func getStaticFiles() (http.FileSystem, error) {
}
}
var maint uint32 // atomic: 1 if in maintenance mode
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
func handleMaintSignals(
conn *db.SafeConn,
srv *http.Server,
logger *zerolog.Logger,
config *config.Config,
) {
logger.Debug().Msg("Starting signal listener")
ch := make(chan os.Signal, 1)
srv.RegisterOnShutdown(func() {
logger.Debug().Msg("Shutting down signal listener")
close(ch)
})
go func() {
for sig := range ch {
switch sig {
case syscall.SIGUSR1:
if atomic.LoadUint32(&maint) != 1 {
atomic.StoreUint32(&maint, 1)
logger.Info().Msg("Signal received: Starting maintenance")
logger.Info().Msg("Attempting to acquire database lock")
conn.Pause(config.DBLockTimeout * time.Second)
}
case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 {
logger.Info().Msg("Signal received: Maintenance over")
logger.Info().Msg("Releasing database lock")
conn.Resume()
atomic.StoreUint32(&maint, 0)
}
}
}
}()
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
}
// Initializes and runs the server
func run(ctx context.Context, w io.Writer, args map[string]string) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
@@ -77,18 +118,22 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return errors.Wrap(err, "logging.GetLogger")
}
conn, err := db.ConnectToDatabase(config.DBName)
logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database")
conn, err := db.ConnectToDatabase(config.DBName, logger)
if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase")
}
defer conn.Close()
staticFS, err := getStaticFiles()
logger.Debug().Msg("Getting static files")
staticFS, err := getStaticFiles(logger)
if err != nil {
return errors.Wrap(err, "getStaticFiles")
}
srv := server.NewServer(config, logger, conn, &staticFS)
logger.Debug().Msg("Setting up HTTP server")
srv := server.NewServer(config, logger, conn, &staticFS, &maint)
httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, config.Port),
Handler: srv,
@@ -99,17 +144,24 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
// Runs function for testing in dev if --test flag true
if args["test"] == "true" {
logger.Debug().Msg("Running tester function")
test(config, logger, conn, httpServer)
return nil
}
// Setups a channel to listen for os.Signal
handleMaintSignals(conn, httpServer, logger, config)
// Runs the http server
logger.Debug().Msg("Starting up the HTTP server")
go func() {
fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr)
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Fprintf(os.Stderr, "Error listening and serving: %s\n", err)
logger.Error().Err(err).Msg("Error listening and serving")
}
}()
// Handles graceful shutdown
var wg sync.WaitGroup
wg.Add(1)
go func() {
@@ -119,11 +171,11 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
fmt.Fprintf(os.Stderr, "Error shutting down http server: %s\n", err)
logger.Error().Err(err).Msg("Error shutting down server")
}
}()
wg.Wait()
fmt.Fprintln(w, "Shutting down")
logger.Info().Msg("Shutting down")
return nil
}

View File

@@ -1,12 +1,17 @@
package main
import (
"bytes"
"context"
"fmt"
"net/http"
"os"
"strings"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func Test_main(t *testing.T) {
@@ -14,13 +19,60 @@ func Test_main(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
args := map[string]string{}
go run(ctx, os.Stdout, args)
var stdout bytes.Buffer
go run(ctx, &stdout, args)
// wait for the server to become available
waitForReady(ctx, 10*time.Second, "http://localhost:3333/healthz")
// do tests
fmt.Println("Tests starting")
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
done := make(chan bool)
go func() {
expected := "Global database lock acquired"
for {
if strings.Contains(stdout.String(), expected) {
done <- true
return
}
time.Sleep(100 * time.Millisecond)
}
}()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
proc.Signal(syscall.SIGUSR1)
select {
case <-done:
t.Log("found")
case <-time.After(250 * time.Millisecond):
t.Errorf("Not found")
}
})
t.Run("SIGUSR2 releases database global lock", func(t *testing.T) {
done := make(chan bool)
go func() {
expected := "Global database lock released"
for {
if strings.Contains(stdout.String(), expected) {
done <- true
return
}
time.Sleep(100 * time.Millisecond)
}
}()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
proc.Signal(syscall.SIGUSR2)
select {
case <-done:
t.Log("found")
case <-time.After(250 * time.Millisecond):
t.Errorf("Not found")
}
})
}
func waitForReady(

View File

@@ -1,14 +1,16 @@
package middleware
import (
"database/sql"
"context"
"net/http"
"sync/atomic"
"time"
"projectreshoot/config"
"projectreshoot/contexts"
"projectreshoot/cookies"
"projectreshoot/db"
"projectreshoot/handlers"
"projectreshoot/jwt"
"github.com/pkg/errors"
@@ -18,14 +20,15 @@ import (
// Attempt to use a valid refresh token to generate a new token pair
func refreshAuthTokens(
config *config.Config,
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
w http.ResponseWriter,
req *http.Request,
ref *jwt.RefreshToken,
) (*db.User, error) {
user, err := ref.GetUser(conn)
user, err := ref.GetUser(ctx, tx)
if err != nil {
return nil, errors.Wrap(err, "rT.GetUser")
return nil, errors.Wrap(err, "ref.GetUser")
}
rememberMe := map[string]bool{
@@ -39,7 +42,7 @@ func refreshAuthTokens(
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
}
// New tokens sent, revoke the used refresh token
err = jwt.RevokeToken(conn, ref)
err = jwt.RevokeToken(ctx, tx, ref)
if err != nil {
return nil, errors.Wrap(err, "jwt.RevokeToken")
}
@@ -50,22 +53,23 @@ func refreshAuthTokens(
// Check the cookies for token strings and attempt to authenticate them
func getAuthenticatedUser(
config *config.Config,
conn *sql.DB,
ctx context.Context,
tx *db.SafeTX,
w http.ResponseWriter,
r *http.Request,
) (*contexts.AuthenticatedUser, error) {
// Get token strings from cookies
atStr, rtStr := cookies.GetTokenStrings(r)
// Attempt to parse the access token
aT, err := jwt.ParseAccessToken(config, conn, atStr)
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
if err != nil {
// Access token invalid, attempt to parse refresh token
rT, err := jwt.ParseRefreshToken(config, conn, rtStr)
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
if err != nil {
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
}
// Refresh token valid, attempt to get a new token pair
user, err := refreshAuthTokens(config, conn, w, r, rT)
user, err := refreshAuthTokens(config, ctx, tx, w, r, rT)
if err != nil {
return nil, errors.Wrap(err, "refreshAuthTokens")
}
@@ -77,9 +81,9 @@ func getAuthenticatedUser(
return &authUser, nil
}
// Access token valid
user, err := aT.GetUser(conn)
user, err := aT.GetUser(ctx, tx)
if err != nil {
return nil, errors.Wrap(err, "rT.GetUser")
return nil, errors.Wrap(err, "aT.GetUser")
}
authUser := contexts.AuthenticatedUser{
User: user,
@@ -93,12 +97,34 @@ func getAuthenticatedUser(
func Authentication(
logger *zerolog.Logger,
config *config.Config,
conn *sql.DB,
conn *db.SafeConn,
next http.Handler,
maint *uint32,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, err := getAuthenticatedUser(config, conn, w, r)
if r.URL.Path == "/static/css/output.css" ||
r.URL.Path == "/static/favicon.ico" {
next.ServeHTTP(w, r)
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
if atomic.LoadUint32(maint) == 1 {
cancel()
}
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
// Failed to start transaction, skip auth
logger.Warn().Err(err).
Msg("Skipping Auth - unable to start a transaction")
handlers.ErrorPage(http.StatusServiceUnavailable, w, r)
return
}
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
if err != nil {
tx.Rollback()
// User auth failed, delete the cookies to avoid repeat requests
cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/")
@@ -106,9 +132,12 @@ func Authentication(
Str("remote_addr", r.RemoteAddr).
Err(err).
Msg("Failed to authenticate user")
next.ServeHTTP(w, r)
return
}
ctx := contexts.SetUser(r.Context(), user)
newReq := r.WithContext(ctx)
tx.Commit()
uctx := contexts.SetUser(r.Context(), user)
newReq := r.WithContext(uctx)
next.ServeHTTP(w, newReq)
})
}

View File

@@ -5,9 +5,11 @@ import (
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
"projectreshoot/contexts"
"projectreshoot/db"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
@@ -15,14 +17,15 @@ import (
)
func TestAuthenticationMiddleware(t *testing.T) {
// Basic setup
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
// Basic setup
conn, err := tests.SetupTestDB()
require.NoError(t, err)
require.NotNil(t, conn)
defer tests.DeleteTestDB()
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -36,9 +39,10 @@ func TestAuthenticationMiddleware(t *testing.T) {
w.Write([]byte(strconv.Itoa(user.ID)))
}
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
authHandler := Authentication(logger, cfg, conn, testHandler)
authHandler := Authentication(logger, cfg, sconn, testHandler, &maint)
require.NoError(t, err)
server := httptest.NewServer(authHandler)
defer server.Close()

View File

@@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"projectreshoot/contexts"
"projectreshoot/handlers"
"time"
"github.com/rs/zerolog"
@@ -23,9 +24,14 @@ func (w *wrappedWriter) WriteHeader(statusCode int) {
// Middleware to add logs to console with details of the request
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/static/css/output.css" ||
r.URL.Path == "/static/favicon.ico" {
next.ServeHTTP(w, r)
return
}
start, err := contexts.GetStartTime(r.Context())
if err != nil {
// Handle failure here. internal server error maybe
handlers.ErrorPage(http.StatusInternalServerError, w, r)
return
}
wrapped := &wrappedWriter{

View File

@@ -3,7 +3,7 @@ package middleware
import (
"net/http"
"projectreshoot/contexts"
"projectreshoot/view/page"
"projectreshoot/handlers"
)
// Checks if the user is set in the context and shows 401 page if not logged in
@@ -11,12 +11,7 @@ func RequiresLogin(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := contexts.GetUser(r.Context())
if user == nil {
w.WriteHeader(http.StatusUnauthorized)
page.Error(
"401",
"Unauthorized",
"Please login to view this page",
).Render(r.Context(), w)
handlers.ErrorPage(http.StatusUnauthorized, w, r)
return
}
next.ServeHTTP(w, r)

View File

@@ -3,8 +3,10 @@ package middleware
import (
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"projectreshoot/db"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
@@ -12,23 +14,26 @@ import (
)
func TestPageLoginRequired(t *testing.T) {
// Basic setup
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
// Basic setup
conn, err := tests.SetupTestDB()
require.NoError(t, err)
require.NotNil(t, conn)
defer tests.DeleteTestDB()
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
loginRequiredHandler := RequiresLogin(testHandler)
authHandler := Authentication(logger, cfg, conn, loginRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler)
defer server.Close()

View File

@@ -3,33 +3,38 @@ package middleware
import (
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"projectreshoot/db"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestActionReauthRequired(t *testing.T) {
// Basic setup
cfg, err := tests.TestConfig()
require.NoError(t, err)
func TestReauthRequired(t *testing.T) {
logger := tests.NilLogger()
// Basic setup
conn, err := tests.SetupTestDB()
require.NoError(t, err)
require.NotNil(t, conn)
defer tests.DeleteTestDB()
sconn := db.MakeSafe(conn, logger)
defer sconn.Close()
cfg, err := tests.TestConfig()
require.NoError(t, err)
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
reauthRequiredHandler := RequiresFresh(testHandler)
loginRequiredHandler := RequiresLogin(reauthRequiredHandler)
authHandler := Authentication(logger, cfg, conn, loginRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler)
defer server.Close()

View File

@@ -1,10 +1,10 @@
package server
import (
"database/sql"
"net/http"
"projectreshoot/config"
"projectreshoot/db"
"projectreshoot/handlers"
"projectreshoot/middleware"
"projectreshoot/view/page"
@@ -17,7 +17,7 @@ func addRoutes(
mux *http.ServeMux,
logger *zerolog.Logger,
config *config.Config,
conn *sql.DB,
conn *db.SafeConn,
staticFS *http.FileSystem,
) {
// Health check

View File

@@ -1,10 +1,10 @@
package server
import (
"database/sql"
"net/http"
"projectreshoot/config"
"projectreshoot/db"
"projectreshoot/middleware"
"github.com/rs/zerolog"
@@ -14,8 +14,9 @@ import (
func NewServer(
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
staticFS *http.FileSystem,
maint *uint32,
) http.Handler {
mux := http.NewServeMux()
addRoutes(
@@ -29,7 +30,7 @@ func NewServer(
// Add middleware here, must be added in reverse order of execution
// i.e. First in list will get executed last during the request handling
handler = middleware.Logging(logger, handler)
handler = middleware.Authentication(logger, config, conn, handler)
handler = middleware.Authentication(logger, config, conn, handler, maint)
// Gzip
handler = middleware.Gzip(handler, config.GZIP)

14
setup-hooks.sh Normal file
View File

@@ -0,0 +1,14 @@
#!/bin/sh
HOOKS_DIR=".githooks"
GIT_HOOKS_DIR=".git/hooks"
mkdir -p "$GIT_HOOKS_DIR"
for hook in "$HOOKS_DIR"/*; do
hook_name=$(basename "$hook")
cp "$hook" "$GIT_HOOKS_DIR/$hook_name"
chmod +x "$GIT_HOOKS_DIR/$hook_name"
done
echo "Git hooks installed!"

View File

@@ -1,10 +1,10 @@
package main
import (
"database/sql"
"net/http"
"projectreshoot/config"
"projectreshoot/db"
"github.com/rs/zerolog"
)
@@ -18,7 +18,7 @@ import (
func test(
config *config.Config,
logger *zerolog.Logger,
conn *sql.DB,
conn *db.SafeConn,
srv *http.Server,
) {
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/pkg/errors"
_ "github.com/mattn/go-sqlite3"
_ "modernc.org/sqlite"
)
func findSQLFile(filename string) (string, error) {
@@ -31,9 +31,8 @@ func findSQLFile(filename string) (string, error) {
}
// SetupTestDB initializes a test SQLite database with mock data
// Make sure to call DeleteTestDB when finished to cleanup
func SetupTestDB() (*sql.DB, error) {
conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
conn, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
@@ -51,7 +50,7 @@ func SetupTestDB() (*sql.DB, error) {
_, err = conn.Exec(schemaSQL)
if err != nil {
return nil, errors.Wrap(err, "conn.Exec")
return nil, errors.Wrap(err, "tx.Exec")
}
// Load the test data
dataPath, err := findSQLFile("testdata.sql")
@@ -66,20 +65,7 @@ func SetupTestDB() (*sql.DB, error) {
_, err = conn.Exec(dataSQL)
if err != nil {
return nil, errors.Wrap(err, "conn.Exec")
return nil, errors.Wrap(err, "tx.Exec")
}
return conn, nil
}
// Deletes the test database from disk
func DeleteTestDB() error {
fileName := ".projectreshoot-test-database.db"
// Attempt to remove the file
err := os.Remove(fileName)
if err != nil {
return errors.Wrap(err, "os.Remove")
}
return nil
}

View File

@@ -1,9 +1,9 @@
package popup
templ ErrorPopup() {
templ Error500Popup() {
<div
x-cloak
x-show="showError"
x-show="showError500"
class="absolute w-82 left-0 right-0 mt-20 mr-5 ml-auto"
x-transition:enter="transform translate-x-[100%] opacity-0 duration-200"
x-transition:enter-start="opacity-0 translate-x-[100%]"
@@ -44,7 +44,7 @@ templ ErrorPopup() {
stroke-width="1.5"
stroke="currentColor"
class="size-6 text-subtext0 hover:cursor-pointer"
@click="showError=false"
@click="showError500=false"
>
<path
stroke-linecap="round"

View File

@@ -0,0 +1,63 @@
package popup
templ Error503Popup() {
<div
x-cloak
x-show="showError503"
class="absolute w-82 left-0 right-0 mt-20 mr-5 ml-auto"
x-transition:enter="transform translate-x-[100%] opacity-0 duration-200"
x-transition:enter-start="opacity-0 translate-x-[100%]"
x-transition:enter-end="opacity-100 translate-x-0"
x-transition:leave="opacity-0 duration-200"
x-transition:leave-start="opacity-100 translate-x-0"
x-transition:leave-end="opacity-0 translate-x-[100%]"
>
<div
role="alert"
class="rounded-sm bg-dark-red p-4"
>
<div class="flex justify-between">
<div class="flex items-center gap-2 text-red w-fit">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
class="size-5"
>
<path
fill-rule="evenodd"
d="M9.401 3.003c1.155-2 4.043-2 5.197 0l7.355
12.748c1.154 2-.29 4.5-2.599 4.5H4.645c-2.309
0-3.752-2.5-2.598-4.5L9.4 3.003zM12 8.25a.75.75
0 01.75.75v3.75a.75.75 0 01-1.5 0V9a.75.75 0
01.75-.75zm0 8.25a.75.75 0 100-1.5.75.75 0 000 1.5z"
clip-rule="evenodd"
></path>
</svg>
<strong class="block font-medium">Service Unavailable</strong>
</div>
<div class="flex">
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="size-6 text-subtext0 hover:cursor-pointer"
@click="showError503=false"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
></path>
</svg>
</div>
</div>
<p class="mt-2 text-sm text-red">
The service is currently available. It could be down for maintenance.
Please try again later.
</p>
</div>
</div>
}

View File

@@ -41,11 +41,12 @@ templ Global() {
<script src="https://unpkg.com/alpinejs" defer></script>
<script>
// uncomment this line to enable logging of htmx events
// htmx.logAll();
htmx.logAll();
</script>
<script>
const bodyData = {
showError: false,
showError500: false,
showError503: false,
showConfirmPasswordModal: false,
handleHtmxBeforeOnLoad(event) {
const requestPath = event.detail.pathInfo.requestPath;
@@ -65,8 +66,13 @@ templ Global() {
// internal server error
if (errorCode.includes('Code 500')) {
this.showError = true;
setTimeout(() => this.showError = false, 6000);
this.showError500 = true;
setTimeout(() => this.showError500 = false, 6000);
}
// service not available error
if (errorCode.includes('Code 503')) {
this.showError503 = true;
setTimeout(() => this.showError503 = false, 6000);
}
// user is authorized but needs to refresh their login
@@ -83,7 +89,8 @@ templ Global() {
x-on:htmx:error="handleHtmxError($event)"
x-on:htmx:before-on-load="handleHtmxBeforeOnLoad($event)"
>
@popup.ErrorPopup()
@popup.Error500Popup()
@popup.Error503Popup()
@popup.ConfirmPasswordModal()
<div
id="main-content"

View File

@@ -1,11 +1,12 @@
package page
import "projectreshoot/view/layout"
import "strconv"
// Page template for Error pages. Error code should be a HTTP status code as
// a string, and err should be the corresponding response title.
// Message is a custom error message displayed below the code and error.
templ Error(code string, err string, message string) {
templ Error(code int, err string, message string) {
@layout.Global() {
<div
class="grid mt-24 left-0 right-0 top-0 bottom-0
@@ -14,7 +15,7 @@ templ Error(code string, err string, message string) {
<div class="text-center">
<h1
class="text-9xl text-text"
>{ code }</h1>
>{ strconv.Itoa(code) }</h1>
<p
class="text-2xl font-bold tracking-tight text-subtext1
sm:text-4xl"

View File

@@ -8,7 +8,6 @@ templ Index() {
<div class="text-center mt-24">
<div class="text-4xl lg:text-6xl">Project Reshoot</div>
<div>A better way to discover and rate films</div>
<div>If you're seeing this text, you're my favourite :)</div>
</div>
}
}