diff --git a/Makefile b/Makefile index a2f26e2..68d12a7 100644 --- a/Makefile +++ b/Makefile @@ -21,8 +21,8 @@ tester: test: rm -f **/.projectreshoot-test-database.db go mod tidy && \ - go test . -v - go test ./middleware -v + go test . + go test ./middleware clean: go clean diff --git a/middleware/pageprotection.go b/middleware/pageprotection.go index 64ef4da..f5537b2 100644 --- a/middleware/pageprotection.go +++ b/middleware/pageprotection.go @@ -11,6 +11,7 @@ func RequiresLogin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := contexts.GetUser(r.Context()) if user == nil { + w.WriteHeader(http.StatusUnauthorized) page.Error( "401", "Unauthorized", diff --git a/middleware/pageprotection_test.go b/middleware/pageprotection_test.go new file mode 100644 index 0000000..de79975 --- /dev/null +++ b/middleware/pageprotection_test.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "projectreshoot/tests" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPageLoginRequired(t *testing.T) { + // Basic setup + cfg, err := tests.TestConfig() + require.NoError(t, err) + logger := tests.NilLogger() + conn, err := tests.SetupTestDB() + require.NoError(t, err) + require.NotNil(t, conn) + defer tests.DeleteTestDB() + + // Handler to check outcome of Authentication middleware + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Add the middleware and create the server + loginRequiredHandler := RequiresLogin(testHandler) + authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + server := httptest.NewServer(authHandler) + defer server.Close() + + tokens := getTokens() + + tests := []struct { + name string + accessToken string + refreshToken string + expectedCode int + }{ + { + name: "Valid Login", + accessToken: tokens["accessFresh"], + refreshToken: "", + expectedCode: http.StatusOK, + }, + { + name: "Expired login", + accessToken: tokens["accessExpired"], + refreshToken: tokens["refreshExpired"], + expectedCode: http.StatusUnauthorized, + }, + { + name: "No login", + accessToken: "", + refreshToken: "", + expectedCode: http.StatusUnauthorized, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &http.Client{} + + req, _ := http.NewRequest(http.MethodGet, server.URL, nil) + + // Add cookies if provided + if tt.accessToken != "" { + req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken}) + } + if tt.refreshToken != "" { + req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken}) + } + + resp, err := client.Do(req) + assert.NoError(t, err) + assert.Equal(t, tt.expectedCode, resp.StatusCode) + }) + } +} diff --git a/middleware/reauthentication_test.go b/middleware/reauthentication_test.go new file mode 100644 index 0000000..63017cb --- /dev/null +++ b/middleware/reauthentication_test.go @@ -0,0 +1,88 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "projectreshoot/tests" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestActionReauthRequired(t *testing.T) { + // Basic setup + cfg, err := tests.TestConfig() + require.NoError(t, err) + logger := tests.NilLogger() + conn, err := tests.SetupTestDB() + require.NoError(t, err) + require.NotNil(t, conn) + defer tests.DeleteTestDB() + + // Handler to check outcome of Authentication middleware + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Add the middleware and create the server + reauthRequiredHandler := RequiresFresh(testHandler) + loginRequiredHandler := RequiresLogin(reauthRequiredHandler) + authHandler := Authentication(logger, cfg, conn, loginRequiredHandler) + server := httptest.NewServer(authHandler) + defer server.Close() + + tokens := getTokens() + + tests := []struct { + name string + accessToken string + refreshToken string + expectedCode int + }{ + { + name: "Fresh Login", + accessToken: tokens["accessFresh"], + refreshToken: "", + expectedCode: http.StatusOK, + }, + { + name: "Unfresh Login", + accessToken: tokens["accessUnfresh"], + refreshToken: "", + expectedCode: 444, + }, + { + name: "Expired login", + accessToken: tokens["accessExpired"], + refreshToken: tokens["refreshExpired"], + expectedCode: http.StatusUnauthorized, + }, + { + name: "No login", + accessToken: "", + refreshToken: "", + expectedCode: http.StatusUnauthorized, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &http.Client{} + + req, _ := http.NewRequest(http.MethodGet, server.URL, nil) + + // Add cookies if provided + if tt.accessToken != "" { + req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken}) + } + if tt.refreshToken != "" { + req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken}) + } + + resp, err := client.Do(req) + assert.NoError(t, err) + assert.Equal(t, tt.expectedCode, resp.StatusCode) + }) + } +}