diff --git a/config/config.go b/config/config.go index e7a0a8e..4d38be7 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "os" + "time" "projectreshoot/logging" @@ -16,9 +17,9 @@ type Config struct { Port string // Port to listen on TrustedHost string // Domain/Hostname to accept as trusted SSL bool // Flag for SSL Mode - ReadHeaderTimeout int // Timeout for reading request headers in seconds - WriteTimeout int // Timeout for writing requests in seconds - IdleTimeout int // Timeout for idle connections in seconds + ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds + WriteTimeout time.Duration // Timeout for writing requests in seconds + IdleTimeout time.Duration // Timeout for idle connections in seconds TursoDBName string // DB Name for Turso DB/Branch TursoToken string // Bearer token for Turso DB/Branch SecretKey string // Secret key for signing tokens @@ -84,9 +85,9 @@ func GetConfig(args map[string]string) (*Config, error) { Port: port, TrustedHost: os.Getenv("TRUSTED_HOST"), SSL: GetEnvBool("SSL_MODE", false), - ReadHeaderTimeout: GetEnvInt("READ_HEADER_TIMEOUT", 2), - WriteTimeout: GetEnvInt("WRITE_TIMEOUT", 10), - IdleTimeout: GetEnvInt("IDLE_TIMEOUT", 120), + ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2), + WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10), + IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120), TursoDBName: os.Getenv("TURSO_DB_NAME"), TursoToken: os.Getenv("TURSO_AUTH_TOKEN"), SecretKey: os.Getenv("SECRET_KEY"), diff --git a/config/environment.go b/config/environment.go index 2875b7b..2c4c815 100644 --- a/config/environment.go +++ b/config/environment.go @@ -21,12 +21,12 @@ func GetEnvDefault(key string, defaultValue string) string { func GetEnvDur(key string, defaultValue time.Duration) time.Duration { val, exists := os.LookupEnv(key) if !exists { - return defaultValue + return time.Duration(defaultValue) } intVal, err := strconv.Atoi(val) if err != nil { - return defaultValue + return time.Duration(defaultValue) } return time.Duration(intVal) diff --git a/go.mod b/go.mod index 380a9c2..11426db 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,10 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 + github.com/mattn/go-sqlite3 v1.14.24 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.33.0 + github.com/stretchr/testify v1.10.0 github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d golang.org/x/crypto v0.33.0 ) @@ -16,8 +18,11 @@ require ( require ( github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/coder/websocket v1.8.12 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect golang.org/x/sys v0.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 1ee3159..ffca76c 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9 github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= @@ -21,11 +23,17 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d h1:dOMI4+zEbDI37KGb0TI44GUAwxHF9cMsIoDTJ7UmgfU= github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d/go.mod h1:l8xTsYB90uaVdMHXMCxKKLSgw5wLYBwBKKefNIUnm9s= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= @@ -37,3 +45,7 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/jwt/parse.go b/jwt/parse.go index 8a49127..741cc59 100644 --- a/jwt/parse.go +++ b/jwt/parse.go @@ -75,9 +75,12 @@ func ParseAccessToken( } valid, err := CheckTokenNotRevoked(conn, token) - if err != nil || !valid { + if err != nil { return nil, errors.Wrap(err, "CheckTokenNotRevoked") } + if !valid { + return nil, errors.New("Token has been revoked") + } return token, nil } diff --git a/jwt/revoke.go b/jwt/revoke.go index 66465a7..ed2ec63 100644 --- a/jwt/revoke.go +++ b/jwt/revoke.go @@ -23,6 +23,7 @@ func CheckTokenNotRevoked(conn *sql.DB, t Token) (bool, error) { jti := t.GetJTI() query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1` rows, err := conn.Query(query, jti) + defer rows.Close() if err != nil { return false, errors.Wrap(err, "conn.Exec") } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..ccb1d48 --- /dev/null +++ b/main_test.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "testing" + "time" +) + +func Test_main(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + args := map[string]string{} + go run(ctx, os.Stdout, args) + + // wait for the server to become available + waitForReady(ctx, 10*time.Second, "http://localhost:3333/healthz") + + // do tests + fmt.Println("Tests starting") +} + +func waitForReady( + ctx context.Context, + timeout time.Duration, + endpoint string, +) error { + client := http.Client{} + startTime := time.Now() + for { + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + endpoint, + nil, + ) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + fmt.Printf("Error making request: %s\n", err.Error()) + continue + } + if resp.StatusCode == http.StatusOK { + fmt.Println("Endpoint is ready!") + resp.Body.Close() + return nil + } + resp.Body.Close() + + select { + case <-ctx.Done(): + return ctx.Err() + default: + if time.Since(startTime) >= timeout { + return fmt.Errorf("timeout reached while waiting for endpoint") + } + // wait a little while between checks + time.Sleep(250 * time.Millisecond) + } + } +} diff --git a/middleware/authentication_test.go b/middleware/authentication_test.go new file mode 100644 index 0000000..c608127 --- /dev/null +++ b/middleware/authentication_test.go @@ -0,0 +1,132 @@ +package middleware + +import ( + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "projectreshoot/contexts" + "projectreshoot/db" + "projectreshoot/jwt" + "projectreshoot/tests" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthenticationMiddleware(t *testing.T) { + // Basic setup + cfg, err := tests.TestConfig() + require.NoError(t, err) + logger := tests.NilLogger() + conn, err := tests.SetupTestDB() + require.NoError(t, err) + require.NotNil(t, conn) + defer tests.DeleteTestDB() + + // Handler to check outcome of Authentication middleware + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := contexts.GetUser(r.Context()) + if user == nil { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(strconv.Itoa(0))) + return + } else { + w.WriteHeader(http.StatusOK) + w.Write([]byte(strconv.Itoa(user.ID))) + } + }) + + // Add the middleware and create the server + authHandler := Authentication(logger, cfg, conn, testHandler) + require.NoError(t, err) + server := httptest.NewServer(authHandler) + defer server.Close() + + // Setup the user and tokens to test with + user, err := db.GetUserFromID(conn, 1) + require.NoError(t, err) + + // Good tokens + atStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false) + require.NoError(t, err) + rtStr, _, err := jwt.GenerateRefreshToken(cfg, &user, false) + require.NoError(t, err) + + // Create a token and revoke it for testing + expStr, _, err := jwt.GenerateAccessToken(cfg, &user, false, false) + require.NoError(t, err) + expT, err := jwt.ParseAccessToken(cfg, conn, expStr) + require.NoError(t, err) + err = jwt.RevokeToken(conn, expT) + require.NoError(t, err) + + // Make sure it actually got revoked + expT, err = jwt.ParseAccessToken(cfg, conn, expStr) + require.Error(t, err) + + tests := []struct { + name string + id int + accessToken string + refreshToken string + expectedCode int + }{ + { + name: "Valid Access Token", + id: 1, + accessToken: atStr, + refreshToken: "", + expectedCode: http.StatusOK, + }, + { + name: "Valid Refresh Token (Triggers Refresh)", + id: 1, + accessToken: expStr, + refreshToken: rtStr, + expectedCode: http.StatusOK, + }, + { + name: "Refresh token revoked (after refresh)", + accessToken: expStr, + refreshToken: rtStr, + expectedCode: http.StatusUnauthorized, + }, + { + name: "Invalid Tokens", + accessToken: expStr, + refreshToken: expStr, + expectedCode: http.StatusUnauthorized, + }, + { + name: "No Tokens", + accessToken: "", + refreshToken: "", + expectedCode: http.StatusUnauthorized, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &http.Client{} + + req, _ := http.NewRequest(http.MethodGet, server.URL, nil) + + // Add cookies if provided + if tt.accessToken != "" { + req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken}) + } + if tt.refreshToken != "" { + req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken}) + } + + resp, err := client.Do(req) + assert.NoError(t, err) + assert.Equal(t, tt.expectedCode, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, strconv.Itoa(tt.id), string(body)) + }) + } +} diff --git a/server/routes.go b/server/routes.go index 7261b68..866b0e7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -18,6 +18,9 @@ func addRoutes( config *config.Config, conn *sql.DB, ) { + // Health check + mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {}) + // Static files mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic())) diff --git a/tests/config.go b/tests/config.go new file mode 100644 index 0000000..f436b4a --- /dev/null +++ b/tests/config.go @@ -0,0 +1,20 @@ +package tests + +import ( + "os" + "projectreshoot/config" + + "github.com/pkg/errors" +) + +func TestConfig() (*config.Config, error) { + os.Setenv("TRUSTED_HOST", "127.0.0.1") + os.Setenv("TURSO_DB_NAME", ".") + os.Setenv("TURSO_AUTH_TOKEN", ".") + os.Setenv("SECRET_KEY", ".") + cfg, err := config.GetConfig(map[string]string{}) + if err != nil { + return nil, errors.Wrap(err, "config.GetConfig") + } + return cfg, nil +} diff --git a/tests/database.go b/tests/database.go new file mode 100644 index 0000000..49b1b98 --- /dev/null +++ b/tests/database.go @@ -0,0 +1,52 @@ +package tests + +import ( + "database/sql" + "os" + + "github.com/pkg/errors" + + _ "github.com/mattn/go-sqlite3" +) + +// SetupTestDB initializes a test SQLite database with mock data +// Make sure to call DeleteTestDB when finished to cleanup +func SetupTestDB() (*sql.DB, error) { + conn, err := sql.Open("sqlite3", "file:.projectreshoot-test-database.db") + if err != nil { + return nil, errors.Wrap(err, "sql.Open") + } + // Create the test database + _, err = conn.Exec(` +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL, + password_hash TEXT, + created_at INTEGER DEFAULT (unixepoch()) +); +INSERT INTO users VALUES(1,'testuser','hashedpassword',1738995274); + +CREATE TABLE IF NOT EXISTS jwtblacklist ( + jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'), + exp INTEGER NOT NULL +) STRICT; + + `) + if err != nil { + return nil, errors.Wrap(err, "conn.Exec") + } + return conn, nil +} + +// Deletes the test database from disk +func DeleteTestDB() error { + fileName := ".projectreshoot-test-database.db" + + // Attempt to remove the file + err := os.Remove(fileName) + if err != nil { + return errors.Wrap(err, "os.Remove") + } + + return nil +} diff --git a/tests/logger.go b/tests/logger.go new file mode 100644 index 0000000..d8a0dd9 --- /dev/null +++ b/tests/logger.go @@ -0,0 +1,29 @@ +package tests + +import ( + "testing" + + "github.com/rs/zerolog" +) + +type TLogWriter struct { + t *testing.T +} + +// Write implements the io.Writer interface for TLogWriter. +func (w *TLogWriter) Write(p []byte) (n int, err error) { + w.t.Logf("%s", p) + return len(p), nil +} + +// Return a fake logger to satisfy functions that expect one +func NilLogger() *zerolog.Logger { + logger := zerolog.New(nil) + return &logger +} + +// Return a logger that makes use of the T.Log method to enable debugging tests +func DebugLogger(t *testing.T) *zerolog.Logger { + logger := zerolog.New(&TLogWriter{t: t}) + return &logger +}