Added test capability and tests for authentication middleware

This commit is contained in:
2025-02-12 21:23:13 +11:00
parent ca92d573ba
commit 2d52084fa7
12 changed files with 334 additions and 9 deletions

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"time"
"projectreshoot/logging" "projectreshoot/logging"
@@ -16,9 +17,9 @@ type Config struct {
Port string // Port to listen on Port string // Port to listen on
TrustedHost string // Domain/Hostname to accept as trusted TrustedHost string // Domain/Hostname to accept as trusted
SSL bool // Flag for SSL Mode SSL bool // Flag for SSL Mode
ReadHeaderTimeout int // Timeout for reading request headers in seconds ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
WriteTimeout int // Timeout for writing requests in seconds WriteTimeout time.Duration // Timeout for writing requests in seconds
IdleTimeout int // Timeout for idle connections in seconds IdleTimeout time.Duration // Timeout for idle connections in seconds
TursoDBName string // DB Name for Turso DB/Branch TursoDBName string // DB Name for Turso DB/Branch
TursoToken string // Bearer token for Turso DB/Branch TursoToken string // Bearer token for Turso DB/Branch
SecretKey string // Secret key for signing tokens SecretKey string // Secret key for signing tokens
@@ -84,9 +85,9 @@ func GetConfig(args map[string]string) (*Config, error) {
Port: port, Port: port,
TrustedHost: os.Getenv("TRUSTED_HOST"), TrustedHost: os.Getenv("TRUSTED_HOST"),
SSL: GetEnvBool("SSL_MODE", false), SSL: GetEnvBool("SSL_MODE", false),
ReadHeaderTimeout: GetEnvInt("READ_HEADER_TIMEOUT", 2), ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
WriteTimeout: GetEnvInt("WRITE_TIMEOUT", 10), WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
IdleTimeout: GetEnvInt("IDLE_TIMEOUT", 120), IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
TursoDBName: os.Getenv("TURSO_DB_NAME"), TursoDBName: os.Getenv("TURSO_DB_NAME"),
TursoToken: os.Getenv("TURSO_AUTH_TOKEN"), TursoToken: os.Getenv("TURSO_AUTH_TOKEN"),
SecretKey: os.Getenv("SECRET_KEY"), SecretKey: os.Getenv("SECRET_KEY"),

View File

@@ -21,12 +21,12 @@ func GetEnvDefault(key string, defaultValue string) string {
func GetEnvDur(key string, defaultValue time.Duration) time.Duration { func GetEnvDur(key string, defaultValue time.Duration) time.Duration {
val, exists := os.LookupEnv(key) val, exists := os.LookupEnv(key)
if !exists { if !exists {
return defaultValue return time.Duration(defaultValue)
} }
intVal, err := strconv.Atoi(val) intVal, err := strconv.Atoi(val)
if err != nil { if err != nil {
return defaultValue return time.Duration(defaultValue)
} }
return time.Duration(intVal) return time.Duration(intVal)

5
go.mod
View File

@@ -7,8 +7,10 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/mattn/go-sqlite3 v1.14.24
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.33.0 github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.10.0
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d
golang.org/x/crypto v0.33.0 golang.org/x/crypto v0.33.0
) )
@@ -16,8 +18,11 @@ require (
require ( require (
github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect
github.com/coder/websocket v1.8.12 // indirect github.com/coder/websocket v1.8.12 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
golang.org/x/sys v0.30.0 // indirect golang.org/x/sys v0.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

12
go.sum
View File

@@ -5,6 +5,8 @@ github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
@@ -21,11 +23,17 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d h1:dOMI4+zEbDI37KGb0TI44GUAwxHF9cMsIoDTJ7UmgfU= github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d h1:dOMI4+zEbDI37KGb0TI44GUAwxHF9cMsIoDTJ7UmgfU=
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d/go.mod h1:l8xTsYB90uaVdMHXMCxKKLSgw5wLYBwBKKefNIUnm9s= github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d/go.mod h1:l8xTsYB90uaVdMHXMCxKKLSgw5wLYBwBKKefNIUnm9s=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
@@ -37,3 +45,7 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -75,9 +75,12 @@ func ParseAccessToken(
} }
valid, err := CheckTokenNotRevoked(conn, token) valid, err := CheckTokenNotRevoked(conn, token)
if err != nil || !valid { if err != nil {
return nil, errors.Wrap(err, "CheckTokenNotRevoked") return nil, errors.Wrap(err, "CheckTokenNotRevoked")
} }
if !valid {
return nil, errors.New("Token has been revoked")
}
return token, nil return token, nil
} }

View File

@@ -23,6 +23,7 @@ func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
jti := t.GetJTI() jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := conn.Query(query, jti) rows, err := conn.Query(query, jti)
defer rows.Close()
if err != nil { if err != nil {
return false, errors.Wrap(err, "conn.Exec") return false, errors.Wrap(err, "conn.Exec")
} }

67
main_test.go Normal file
View File

@@ -0,0 +1,67 @@
package main
import (
"context"
"fmt"
"net/http"
"os"
"testing"
"time"
)
func Test_main(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
args := map[string]string{}
go run(ctx, os.Stdout, args)
// wait for the server to become available
waitForReady(ctx, 10*time.Second, "http://localhost:3333/healthz")
// do tests
fmt.Println("Tests starting")
}
func waitForReady(
ctx context.Context,
timeout time.Duration,
endpoint string,
) error {
client := http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
endpoint,
nil,
)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
fmt.Printf("Error making request: %s\n", err.Error())
continue
}
if resp.StatusCode == http.StatusOK {
fmt.Println("Endpoint is ready!")
resp.Body.Close()
return nil
}
resp.Body.Close()
select {
case <-ctx.Done():
return ctx.Err()
default:
if time.Since(startTime) >= timeout {
return fmt.Errorf("timeout reached while waiting for endpoint")
}
// wait a little while between checks
time.Sleep(250 * time.Millisecond)
}
}
}

View File

@@ -0,0 +1,132 @@
package middleware
import (
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"projectreshoot/contexts"
"projectreshoot/db"
"projectreshoot/jwt"
"projectreshoot/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthenticationMiddleware(t *testing.T) {
// Basic setup
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
conn, err := tests.SetupTestDB()
require.NoError(t, err)
require.NotNil(t, conn)
defer tests.DeleteTestDB()
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := contexts.GetUser(r.Context())
if user == nil {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(strconv.Itoa(0)))
return
} else {
w.WriteHeader(http.StatusOK)
w.Write([]byte(strconv.Itoa(user.ID)))
}
})
// Add the middleware and create the server
authHandler := Authentication(logger, cfg, conn, testHandler)
require.NoError(t, err)
server := httptest.NewServer(authHandler)
defer server.Close()
// Setup the user and tokens to test with
user, err := db.GetUserFromID(conn, 1)
require.NoError(t, err)
// Good tokens
atStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false)
require.NoError(t, err)
rtStr, _, err := jwt.GenerateRefreshToken(cfg, &user, false)
require.NoError(t, err)
// Create a token and revoke it for testing
expStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false)
require.NoError(t, err)
expT, err := jwt.ParseAccessToken(cfg, conn, expStr)
require.NoError(t, err)
err = jwt.RevokeToken(conn, expT)
require.NoError(t, err)
// Make sure it actually got revoked
expT, err = jwt.ParseAccessToken(cfg, conn, expStr)
require.Error(t, err)
tests := []struct {
name string
id int
accessToken string
refreshToken string
expectedCode int
}{
{
name: "Valid Access Token",
id: 1,
accessToken: atStr,
refreshToken: "",
expectedCode: http.StatusOK,
},
{
name: "Valid Refresh Token (Triggers Refresh)",
id: 1,
accessToken: expStr,
refreshToken: rtStr,
expectedCode: http.StatusOK,
},
{
name: "Refresh token revoked (after refresh)",
accessToken: expStr,
refreshToken: rtStr,
expectedCode: http.StatusUnauthorized,
},
{
name: "Invalid Tokens",
accessToken: expStr,
refreshToken: expStr,
expectedCode: http.StatusUnauthorized,
},
{
name: "No Tokens",
accessToken: "",
refreshToken: "",
expectedCode: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{}
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
// Add cookies if provided
if tt.accessToken != "" {
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
}
if tt.refreshToken != "" {
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
}
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, strconv.Itoa(tt.id), string(body))
})
}
}

View File

@@ -18,6 +18,9 @@ func addRoutes(
config *config.Config, config *config.Config,
conn *sql.DB, conn *sql.DB,
) { ) {
// Health check
mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {})
// Static files // Static files
mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic())) mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic()))

20
tests/config.go Normal file
View File

@@ -0,0 +1,20 @@
package tests
import (
"os"
"projectreshoot/config"
"github.com/pkg/errors"
)
func TestConfig() (*config.Config, error) {
os.Setenv("TRUSTED_HOST", "127.0.0.1")
os.Setenv("TURSO_DB_NAME", ".")
os.Setenv("TURSO_AUTH_TOKEN", ".")
os.Setenv("SECRET_KEY", ".")
cfg, err := config.GetConfig(map[string]string{})
if err != nil {
return nil, errors.Wrap(err, "config.GetConfig")
}
return cfg, nil
}

52
tests/database.go Normal file
View File

@@ -0,0 +1,52 @@
package tests
import (
"database/sql"
"os"
"github.com/pkg/errors"
_ "github.com/mattn/go-sqlite3"
)
// SetupTestDB initializes a test SQLite database with mock data
// Make sure to call DeleteTestDB when finished to cleanup
func SetupTestDB() (*sql.DB, error) {
conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
// Create the test database
_, err = conn.Exec(`
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
password_hash TEXT,
created_at INTEGER DEFAULT (unixepoch())
);
INSERT INTO users VALUES(1,'testuser','hashedpassword',1738995274);
CREATE TABLE IF NOT EXISTS jwtblacklist (
jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'),
exp INTEGER NOT NULL
) STRICT;
`)
if err != nil {
return nil, errors.Wrap(err, "conn.Exec")
}
return conn, nil
}
// Deletes the test database from disk
func DeleteTestDB() error {
fileName := ".projectreshoot-test-database.db"
// Attempt to remove the file
err := os.Remove(fileName)
if err != nil {
return errors.Wrap(err, "os.Remove")
}
return nil
}

29
tests/logger.go Normal file
View File

@@ -0,0 +1,29 @@
package tests
import (
"testing"
"github.com/rs/zerolog"
)
type TLogWriter struct {
t *testing.T
}
// Write implements the io.Writer interface for TLogWriter.
func (w *TLogWriter) Write(p []byte) (n int, err error) {
w.t.Logf("%s", p)
return len(p), nil
}
// Return a fake logger to satisfy functions that expect one
func NilLogger() *zerolog.Logger {
logger := zerolog.New(nil)
return &logger
}
// Return a logger that makes use of the T.Log method to enable debugging tests
func DebugLogger(t *testing.T) *zerolog.Logger {
logger := zerolog.New(&TLogWriter{t: t})
return &logger
}