Added test capability and tests for authentication middleware
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"projectreshoot/logging"
|
||||
|
||||
@@ -16,9 +17,9 @@ type Config struct {
|
||||
Port string // Port to listen on
|
||||
TrustedHost string // Domain/Hostname to accept as trusted
|
||||
SSL bool // Flag for SSL Mode
|
||||
ReadHeaderTimeout int // Timeout for reading request headers in seconds
|
||||
WriteTimeout int // Timeout for writing requests in seconds
|
||||
IdleTimeout int // Timeout for idle connections in seconds
|
||||
ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
|
||||
WriteTimeout time.Duration // Timeout for writing requests in seconds
|
||||
IdleTimeout time.Duration // Timeout for idle connections in seconds
|
||||
TursoDBName string // DB Name for Turso DB/Branch
|
||||
TursoToken string // Bearer token for Turso DB/Branch
|
||||
SecretKey string // Secret key for signing tokens
|
||||
@@ -84,9 +85,9 @@ func GetConfig(args map[string]string) (*Config, error) {
|
||||
Port: port,
|
||||
TrustedHost: os.Getenv("TRUSTED_HOST"),
|
||||
SSL: GetEnvBool("SSL_MODE", false),
|
||||
ReadHeaderTimeout: GetEnvInt("READ_HEADER_TIMEOUT", 2),
|
||||
WriteTimeout: GetEnvInt("WRITE_TIMEOUT", 10),
|
||||
IdleTimeout: GetEnvInt("IDLE_TIMEOUT", 120),
|
||||
ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
|
||||
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
|
||||
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
|
||||
TursoDBName: os.Getenv("TURSO_DB_NAME"),
|
||||
TursoToken: os.Getenv("TURSO_AUTH_TOKEN"),
|
||||
SecretKey: os.Getenv("SECRET_KEY"),
|
||||
|
||||
@@ -21,12 +21,12 @@ func GetEnvDefault(key string, defaultValue string) string {
|
||||
func GetEnvDur(key string, defaultValue time.Duration) time.Duration {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
return time.Duration(defaultValue)
|
||||
}
|
||||
|
||||
intVal, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
return time.Duration(defaultValue)
|
||||
}
|
||||
return time.Duration(intVal)
|
||||
|
||||
|
||||
5
go.mod
5
go.mod
@@ -7,8 +7,10 @@ require (
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d
|
||||
golang.org/x/crypto v0.33.0
|
||||
)
|
||||
@@ -16,8 +18,11 @@ require (
|
||||
require (
|
||||
github.com/antlr4-go/antlr/v4 v4.13.0 // indirect
|
||||
github.com/coder/websocket v1.8.12 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
12
go.sum
12
go.sum
@@ -5,6 +5,8 @@ github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9
|
||||
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
|
||||
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
@@ -21,11 +23,17 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
|
||||
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d h1:dOMI4+zEbDI37KGb0TI44GUAwxHF9cMsIoDTJ7UmgfU=
|
||||
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d/go.mod h1:l8xTsYB90uaVdMHXMCxKKLSgw5wLYBwBKKefNIUnm9s=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
@@ -37,3 +45,7 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -75,9 +75,12 @@ func ParseAccessToken(
|
||||
}
|
||||
|
||||
valid, err := CheckTokenNotRevoked(conn, token)
|
||||
if err != nil || !valid {
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) {
|
||||
jti := t.GetJTI()
|
||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||
rows, err := conn.Query(query, jti)
|
||||
defer rows.Close()
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "conn.Exec")
|
||||
}
|
||||
|
||||
67
main_test.go
Normal file
67
main_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_main(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
t.Cleanup(cancel)
|
||||
args := map[string]string{}
|
||||
go run(ctx, os.Stdout, args)
|
||||
|
||||
// wait for the server to become available
|
||||
waitForReady(ctx, 10*time.Second, "http://localhost:3333/healthz")
|
||||
|
||||
// do tests
|
||||
fmt.Println("Tests starting")
|
||||
}
|
||||
|
||||
func waitForReady(
|
||||
ctx context.Context,
|
||||
timeout time.Duration,
|
||||
endpoint string,
|
||||
) error {
|
||||
client := http.Client{}
|
||||
startTime := time.Now()
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
endpoint,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("Error making request: %s\n", err.Error())
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
fmt.Println("Endpoint is ready!")
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
if time.Since(startTime) >= timeout {
|
||||
return fmt.Errorf("timeout reached while waiting for endpoint")
|
||||
}
|
||||
// wait a little while between checks
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
132
middleware/authentication_test.go
Normal file
132
middleware/authentication_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"projectreshoot/contexts"
|
||||
"projectreshoot/db"
|
||||
"projectreshoot/jwt"
|
||||
"projectreshoot/tests"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAuthenticationMiddleware(t *testing.T) {
|
||||
// Basic setup
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
conn, err := tests.SetupTestDB()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer tests.DeleteTestDB()
|
||||
|
||||
// Handler to check outcome of Authentication middleware
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := contexts.GetUser(r.Context())
|
||||
if user == nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(strconv.Itoa(0)))
|
||||
return
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(strconv.Itoa(user.ID)))
|
||||
}
|
||||
})
|
||||
|
||||
// Add the middleware and create the server
|
||||
authHandler := Authentication(logger, cfg, conn, testHandler)
|
||||
require.NoError(t, err)
|
||||
server := httptest.NewServer(authHandler)
|
||||
defer server.Close()
|
||||
|
||||
// Setup the user and tokens to test with
|
||||
user, err := db.GetUserFromID(conn, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Good tokens
|
||||
atStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false)
|
||||
require.NoError(t, err)
|
||||
rtStr, _, err := jwt.GenerateRefreshToken(cfg, &user, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a token and revoke it for testing
|
||||
expStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false)
|
||||
require.NoError(t, err)
|
||||
expT, err := jwt.ParseAccessToken(cfg, conn, expStr)
|
||||
require.NoError(t, err)
|
||||
err = jwt.RevokeToken(conn, expT)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure it actually got revoked
|
||||
expT, err = jwt.ParseAccessToken(cfg, conn, expStr)
|
||||
require.Error(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
accessToken string
|
||||
refreshToken string
|
||||
expectedCode int
|
||||
}{
|
||||
{
|
||||
name: "Valid Access Token",
|
||||
id: 1,
|
||||
accessToken: atStr,
|
||||
refreshToken: "",
|
||||
expectedCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Valid Refresh Token (Triggers Refresh)",
|
||||
id: 1,
|
||||
accessToken: expStr,
|
||||
refreshToken: rtStr,
|
||||
expectedCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Refresh token revoked (after refresh)",
|
||||
accessToken: expStr,
|
||||
refreshToken: rtStr,
|
||||
expectedCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "Invalid Tokens",
|
||||
accessToken: expStr,
|
||||
refreshToken: expStr,
|
||||
expectedCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "No Tokens",
|
||||
accessToken: "",
|
||||
refreshToken: "",
|
||||
expectedCode: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client := &http.Client{}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
|
||||
|
||||
// Add cookies if provided
|
||||
if tt.accessToken != "" {
|
||||
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
|
||||
}
|
||||
if tt.refreshToken != "" {
|
||||
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedCode, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, strconv.Itoa(tt.id), string(body))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,9 @@ func addRoutes(
|
||||
config *config.Config,
|
||||
conn *sql.DB,
|
||||
) {
|
||||
// Health check
|
||||
mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {})
|
||||
|
||||
// Static files
|
||||
mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic()))
|
||||
|
||||
|
||||
20
tests/config.go
Normal file
20
tests/config.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"projectreshoot/config"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestConfig() (*config.Config, error) {
|
||||
os.Setenv("TRUSTED_HOST", "127.0.0.1")
|
||||
os.Setenv("TURSO_DB_NAME", ".")
|
||||
os.Setenv("TURSO_AUTH_TOKEN", ".")
|
||||
os.Setenv("SECRET_KEY", ".")
|
||||
cfg, err := config.GetConfig(map[string]string{})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "config.GetConfig")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
52
tests/database.go
Normal file
52
tests/database.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// SetupTestDB initializes a test SQLite database with mock data
|
||||
// Make sure to call DeleteTestDB when finished to cleanup
|
||||
func SetupTestDB() (*sql.DB, error) {
|
||||
conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
// Create the test database
|
||||
_, err = conn.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
password_hash TEXT,
|
||||
created_at INTEGER DEFAULT (unixepoch())
|
||||
);
|
||||
INSERT INTO users VALUES(1,'testuser','hashedpassword',1738995274);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS jwtblacklist (
|
||||
jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'),
|
||||
exp INTEGER NOT NULL
|
||||
) STRICT;
|
||||
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "conn.Exec")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Deletes the test database from disk
|
||||
func DeleteTestDB() error {
|
||||
fileName := ".projectreshoot-test-database.db"
|
||||
|
||||
// Attempt to remove the file
|
||||
err := os.Remove(fileName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "os.Remove")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
29
tests/logger.go
Normal file
29
tests/logger.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type TLogWriter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
// Write implements the io.Writer interface for TLogWriter.
|
||||
func (w *TLogWriter) Write(p []byte) (n int, err error) {
|
||||
w.t.Logf("%s", p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Return a fake logger to satisfy functions that expect one
|
||||
func NilLogger() *zerolog.Logger {
|
||||
logger := zerolog.New(nil)
|
||||
return &logger
|
||||
}
|
||||
|
||||
// Return a logger that makes use of the T.Log method to enable debugging tests
|
||||
func DebugLogger(t *testing.T) *zerolog.Logger {
|
||||
logger := zerolog.New(&TLogWriter{t: t})
|
||||
return &logger
|
||||
}
|
||||
Reference in New Issue
Block a user