refactor: changed file structure

This commit is contained in:
2025-03-05 20:18:28 +11:00
parent 5c1089e0ce
commit 1d9af44d0a
137 changed files with 4986 additions and 581 deletions

View File

@@ -0,0 +1,144 @@
package middleware
import (
"context"
"net/http"
"sync/atomic"
"time"
"projectreshoot/internal/handler"
"projectreshoot/internal/models"
"projectreshoot/pkg/config"
"projectreshoot/pkg/contexts"
"projectreshoot/pkg/cookies"
"projectreshoot/pkg/db"
"projectreshoot/pkg/jwt"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
// Attempt to use a valid refresh token to generate a new token pair
func refreshAuthTokens(
config *config.Config,
ctx context.Context,
tx *db.SafeWTX,
w http.ResponseWriter,
req *http.Request,
ref *jwt.RefreshToken,
) (*models.User, error) {
user, err := ref.GetUser(ctx, tx)
if err != nil {
return nil, errors.Wrap(err, "ref.GetUser")
}
rememberMe := map[string]bool{
"session": false,
"exp": true,
}[ref.TTL]
// Set fresh to true because new tokens coming from refresh request
err = cookies.SetTokenCookies(w, req, config, user, false, rememberMe)
if err != nil {
return nil, errors.Wrap(err, "cookies.SetTokenCookies")
}
// New tokens sent, revoke the used refresh token
err = jwt.RevokeToken(ctx, tx, ref)
if err != nil {
return nil, errors.Wrap(err, "jwt.RevokeToken")
}
// Return the authorized user
return user, nil
}
// Check the cookies for token strings and attempt to authenticate them
func getAuthenticatedUser(
config *config.Config,
ctx context.Context,
tx *db.SafeWTX,
w http.ResponseWriter,
r *http.Request,
) (*contexts.AuthenticatedUser, error) {
// Get token strings from cookies
atStr, rtStr := cookies.GetTokenStrings(r)
// Attempt to parse the access token
aT, err := jwt.ParseAccessToken(config, ctx, tx, atStr)
if err != nil {
// Access token invalid, attempt to parse refresh token
rT, err := jwt.ParseRefreshToken(config, ctx, tx, rtStr)
if err != nil {
return nil, errors.Wrap(err, "jwt.ParseRefreshToken")
}
// Refresh token valid, attempt to get a new token pair
user, err := refreshAuthTokens(config, ctx, tx, w, r, rT)
if err != nil {
return nil, errors.Wrap(err, "refreshAuthTokens")
}
// New token pair sent, return the authorized user
authUser := contexts.AuthenticatedUser{
User: user,
Fresh: time.Now().Unix(),
}
return &authUser, nil
}
// Access token valid
user, err := aT.GetUser(ctx, tx)
if err != nil {
return nil, errors.Wrap(err, "aT.GetUser")
}
authUser := contexts.AuthenticatedUser{
User: user,
Fresh: aT.Fresh,
}
return &authUser, nil
}
// Attempt to authenticate the user and add their account details
// to the request context
func Authentication(
logger *zerolog.Logger,
config *config.Config,
conn *db.SafeConn,
next http.Handler,
maint *uint32,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/static/css/output.css" ||
r.URL.Path == "/static/favicon.ico" {
next.ServeHTTP(w, r)
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
if atomic.LoadUint32(maint) == 1 {
cancel()
}
// Start the transaction
tx, err := conn.Begin(ctx)
if err != nil {
// Failed to start transaction, skip auth
logger.Warn().Err(err).
Msg("Skipping Auth - unable to start a transaction")
handler.ErrorPage(http.StatusServiceUnavailable, w, r)
return
}
user, err := getAuthenticatedUser(config, ctx, tx, w, r)
if err != nil {
tx.Rollback()
// User auth failed, delete the cookies to avoid repeat requests
cookies.DeleteCookie(w, "access", "/")
cookies.DeleteCookie(w, "refresh", "/")
logger.Debug().
Str("remote_addr", r.RemoteAddr).
Err(err).
Msg("Failed to authenticate user")
next.ServeHTTP(w, r)
return
}
tx.Commit()
uctx := contexts.SetUser(r.Context(), user)
newReq := r.WithContext(uctx)
next.ServeHTTP(w, newReq)
})
}

View File

@@ -0,0 +1,148 @@
package middleware
import (
"io"
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
"projectreshoot/pkg/contexts"
"projectreshoot/pkg/db"
"projectreshoot/pkg/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthenticationMiddleware(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
wconn, rconn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := db.MakeSafe(wconn, rconn, logger)
defer sconn.Close()
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := contexts.GetUser(r.Context())
if user == nil {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(strconv.Itoa(0)))
return
} else {
w.WriteHeader(http.StatusOK)
w.Write([]byte(strconv.Itoa(user.ID)))
}
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
authHandler := Authentication(logger, cfg, sconn, testHandler, &maint)
require.NoError(t, err)
server := httptest.NewServer(authHandler)
defer server.Close()
tokens := getTokens()
tests := []struct {
name string
id int
accessToken string
refreshToken string
expectedCode int
}{
{
name: "Valid Access Token (Fresh)",
id: 1,
accessToken: tokens["accessFresh"],
refreshToken: "",
expectedCode: http.StatusOK,
},
{
name: "Valid Access Token (Unfresh)",
id: 1,
accessToken: tokens["accessUnfresh"],
refreshToken: tokens["refreshExpired"],
expectedCode: http.StatusOK,
},
{
name: "Valid Refresh Token (Triggers Refresh)",
id: 1,
accessToken: tokens["accessExpired"],
refreshToken: tokens["refreshValid"],
expectedCode: http.StatusOK,
},
{
name: "Both tokens expired",
accessToken: tokens["accessExpired"],
refreshToken: tokens["refreshExpired"],
expectedCode: http.StatusUnauthorized,
},
{
name: "Access token revoked",
accessToken: tokens["accessRevoked"],
refreshToken: "",
expectedCode: http.StatusUnauthorized,
},
{
name: "Refresh token revoked",
accessToken: "",
refreshToken: tokens["refreshRevoked"],
expectedCode: http.StatusUnauthorized,
},
{
name: "Invalid Tokens",
accessToken: tokens["invalid"],
refreshToken: tokens["invalid"],
expectedCode: http.StatusUnauthorized,
},
{
name: "No Tokens",
accessToken: "",
refreshToken: "",
expectedCode: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{}
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
// Add cookies if provided
if tt.accessToken != "" {
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
}
if tt.refreshToken != "" {
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
}
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, strconv.Itoa(tt.id), string(body))
})
}
}
// get the tokens to test with
func getTokens() map[string]string {
tokens := map[string]string{
"accessFresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzIyMTAsImZyZXNoIjo0ODk1NjcyMjEwLCJpYXQiOjE3Mzk2NzIyMTAsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImE4Njk2YWM4LTg3OWMtNDdkNC1iZWM2LTRlY2Y4MTRiZThiZiIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.6nAquDY0JBLPdaJ9q_sMpKj1ISG4Vt2U05J57aoPue8",
"accessUnfresh": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJmcmVzaCI6MTczOTY3NTY3MSwiaWF0IjoxNzM5Njc1NjcxLCJpc3MiOiIxMjcuMC4wLjEiLCJqdGkiOiJjOGNhZmFjNy0yODkzLTQzNzMtOTI4ZS03MGUwODJkYmM2MGIiLCJzY29wZSI6ImFjY2VzcyIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.plWQVFwHlhXUYI5utS7ny1JfXjJSFrigkq-PnTHD5VY",
"accessExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImZyZXNoIjoxNzM5NjcyMjQ4LCJpYXQiOjE3Mzk2NzIyNDgsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjgxYzA1YzBjLTJhOGItNGQ2MC04Yzc4LWY2ZTQxODYxZDFmNCIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.iI1f17kKTuFDEMEYltJRIwRYgYQ-_nF9Wsn0KR6x77Q",
"refreshValid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImlhdCI6MTczOTY3MTkyMiwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTUxMTY3ZWEtNDA3OS00ZTczLTkzZDQtNTgwZDMzODRjZDU4Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.tvtqQ8Z4WrYWHHb0MaEPdsU2FT2KLRE1zHOv3ipoFyc",
"refreshExpired": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3Mzk2NzIyNDgsImlhdCI6MTczOTY3MjI0OCwiaXNzIjoiMTI3LjAuMC4xIiwianRpIjoiZTg5YTc5MTYtZGEzYi00YmJhLWI3ZDMtOWI1N2ViNjRhMmU0Iiwic2NvcGUiOiJyZWZyZXNoIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.rH_fytC7Duxo598xacu820pQKF9ELbG8674h_bK_c4I",
"accessRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4OTU2NzE5MjIsImZyZXNoIjoxNzM5NjcxOTIyLCJpYXQiOjE3Mzk2NzE5MjIsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6IjBhNmIzMzhlLTkzMGEtNDNmZS04ZjcwLTFhNmRhZWQyNTZmYSIsInNjb3BlIjoiYWNjZXNzIiwic3ViIjoxLCJ0dGwiOiJzZXNzaW9uIn0.mZLuCp9amcm2_CqYvbHPlk86nfiuy_Or8TlntUCw4Qs",
"refreshRevoked": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjMzMjk5Njc1NjcxLCJpYXQiOjE3Mzk2NzU2NzEsImlzcyI6IjEyNy4wLjAuMSIsImp0aSI6ImI3ZmE1MWRjLTg1MzItNDJlMS04NzU2LTVkMjViZmIyMDAzYSIsInNjb3BlIjoicmVmcmVzaCIsInN1YiI6MSwidHRsIjoic2Vzc2lvbiJ9.5Q9yDZN5FubfCWHclUUZEkJPOUHcOEpVpgcUK-ameHo",
"invalid": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE0ODUxNDA5ODQsImlhdCI6MTQ4NTEzNzM4NCwiaXNzIjoiYWNtZS5jb20iLCJzdWIiOiIyOWFjMGMxOC0wYjRhLTQyY2YtODJmYy0wM2Q1NzAzMThhMWQiLCJhcHBsaWNhdGlvbklkIjoiNzkxMDM3MzQtOTdhYi00ZDFhLWFmMzctZTAwNmQwNWQyOTUyIiwicm9sZXMiOltdfQ.Mp0Pcwsz5VECK11Kf2ZZNF_SMKu5CgBeLN9ZOP04kZo",
}
return tokens
}

View File

@@ -0,0 +1,32 @@
package middleware
import (
"compress/gzip"
"io"
"net/http"
"strings"
)
func Gzip(next http.Handler, useGzip bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") ||
!useGzip {
next.ServeHTTP(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
next.ServeHTTP(gzw, r)
})
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

View File

@@ -0,0 +1,50 @@
package middleware
import (
"net/http"
"projectreshoot/internal/handler"
"projectreshoot/pkg/contexts"
"time"
"github.com/rs/zerolog"
)
// Wraps the http.ResponseWriter, adding a statusCode field
type wrappedWriter struct {
http.ResponseWriter
statusCode int
}
// Extends WriteHeader to the ResponseWriter to add the status code
func (w *wrappedWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
w.statusCode = statusCode
}
// Middleware to add logs to console with details of the request
func Logging(logger *zerolog.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/static/css/output.css" ||
r.URL.Path == "/static/favicon.ico" {
next.ServeHTTP(w, r)
return
}
start, err := contexts.GetStartTime(r.Context())
if err != nil {
handler.ErrorPage(http.StatusInternalServerError, w, r)
return
}
wrapped := &wrappedWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
next.ServeHTTP(wrapped, r)
logger.Info().
Int("status", wrapped.statusCode).
Str("method", r.Method).
Str("resource", r.URL.Path).
Dur("time_elapsed", time.Since(start)).
Str("remote_addr", r.Header.Get("X-Forwarded-For")).
Msg("Served")
})
}

View File

@@ -0,0 +1,32 @@
package middleware
import (
"net/http"
"projectreshoot/internal/handler"
"projectreshoot/pkg/contexts"
)
// Checks if the user is set in the context and shows 401 page if not logged in
func LoginReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := contexts.GetUser(r.Context())
if user == nil {
handler.ErrorPage(http.StatusUnauthorized, w, r)
return
}
next.ServeHTTP(w, r)
})
}
// Checks if the user is set in the context and redirects them to profile if
// they are logged in
func LogoutReq(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := contexts.GetUser(r.Context())
if user != nil {
http.Redirect(w, r, "/profile", http.StatusFound)
return
}
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,87 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
"projectreshoot/pkg/db"
"projectreshoot/pkg/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPageLoginRequired(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
wconn, rconn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := db.MakeSafe(wconn, rconn, logger)
defer sconn.Close()
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
loginRequiredHandler := LoginReq(testHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler)
defer server.Close()
tokens := getTokens()
tests := []struct {
name string
accessToken string
refreshToken string
expectedCode int
}{
{
name: "Valid Login",
accessToken: tokens["accessFresh"],
refreshToken: "",
expectedCode: http.StatusOK,
},
{
name: "Expired login",
accessToken: tokens["accessExpired"],
refreshToken: tokens["refreshExpired"],
expectedCode: http.StatusUnauthorized,
},
{
name: "No login",
accessToken: "",
refreshToken: "",
expectedCode: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{}
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
// Add cookies if provided
if tt.accessToken != "" {
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
}
if tt.refreshToken != "" {
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
}
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, resp.StatusCode)
})
}
}

View File

@@ -0,0 +1,21 @@
package middleware
import (
"net/http"
"projectreshoot/pkg/contexts"
"time"
)
func FreshReq(
next http.Handler,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := contexts.GetUser(r.Context())
isFresh := time.Now().Before(time.Unix(user.Fresh, 0))
if !isFresh {
w.WriteHeader(444)
return
}
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,94 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strconv"
"sync/atomic"
"testing"
"projectreshoot/pkg/db"
"projectreshoot/pkg/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReauthRequired(t *testing.T) {
cfg, err := tests.TestConfig()
require.NoError(t, err)
logger := tests.NilLogger()
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
require.NoError(t, err)
wconn, rconn, err := tests.SetupTestDB(ver)
require.NoError(t, err)
sconn := db.MakeSafe(wconn, rconn, logger)
defer sconn.Close()
// Handler to check outcome of Authentication middleware
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
var maint uint32
atomic.StoreUint32(&maint, 0)
// Add the middleware and create the server
reauthRequiredHandler := FreshReq(testHandler)
loginRequiredHandler := LoginReq(reauthRequiredHandler)
authHandler := Authentication(logger, cfg, sconn, loginRequiredHandler, &maint)
server := httptest.NewServer(authHandler)
defer server.Close()
tokens := getTokens()
tests := []struct {
name string
accessToken string
refreshToken string
expectedCode int
}{
{
name: "Fresh Login",
accessToken: tokens["accessFresh"],
refreshToken: "",
expectedCode: http.StatusOK,
},
{
name: "Unfresh Login",
accessToken: tokens["accessUnfresh"],
refreshToken: "",
expectedCode: 444,
},
{
name: "Expired login",
accessToken: tokens["accessExpired"],
refreshToken: tokens["refreshExpired"],
expectedCode: http.StatusUnauthorized,
},
{
name: "No login",
accessToken: "",
refreshToken: "",
expectedCode: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &http.Client{}
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
// Add cookies if provided
if tt.accessToken != "" {
req.AddCookie(&http.Cookie{Name: "access", Value: tt.accessToken})
}
if tt.refreshToken != "" {
req.AddCookie(&http.Cookie{Name: "refresh", Value: tt.refreshToken})
}
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, resp.StatusCode)
})
}
}

View File

@@ -0,0 +1,18 @@
package middleware
import (
"net/http"
"projectreshoot/pkg/contexts"
"time"
)
func StartTimer(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
ctx := contexts.SetStart(r.Context(), start)
newReq := r.WithContext(ctx)
next.ServeHTTP(w, newReq)
},
)
}