482 lines
12 KiB
Go
482 lines
12 KiB
Go
package hwsauth
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"io"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
|
|
"git.haelnorr.com/h/golib/hlog"
|
|
"git.haelnorr.com/h/golib/hws"
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type TestModel struct {
|
|
ID int
|
|
}
|
|
|
|
func (tm TestModel) GetID() int {
|
|
return tm.ID
|
|
}
|
|
|
|
type TestTransaction struct {
|
|
}
|
|
|
|
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (tt *TestTransaction) Query(query string, args ...any) (*sql.Rows, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (tt *TestTransaction) Commit() error {
|
|
return nil
|
|
}
|
|
|
|
func (tt *TestTransaction) Rollback() error {
|
|
return nil
|
|
}
|
|
|
|
type TestErrorPage struct{}
|
|
|
|
func (tep TestErrorPage) Render(ctx context.Context, w io.Writer) error {
|
|
return nil
|
|
}
|
|
|
|
// createMockDB creates a mock SQL database for testing
|
|
func createMockDB() (*sql.DB, sqlmock.Sqlmock, error) {
|
|
db, mock, err := sqlmock.New()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Expect a ping to succeed for database connectivity test
|
|
mock.ExpectPing()
|
|
|
|
// Expect table existence check (returns a row = table exists)
|
|
mock.ExpectQuery(`SELECT 1 FROM information_schema\.tables WHERE table_schema = 'public' AND table_name = \$1`).
|
|
WithArgs("jwtblacklist").
|
|
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
|
|
|
// Expect cleanup function creation
|
|
mock.ExpectExec(`CREATE OR REPLACE FUNCTION cleanup_jwtblacklist\(\) RETURNS void AS \$\$ BEGIN DELETE FROM jwtblacklist WHERE exp < EXTRACT\(EPOCH FROM NOW\(\)\); END; \$\$ LANGUAGE plpgsql;`).
|
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
|
|
|
return db, mock, nil
|
|
}
|
|
|
|
func TestGetNil(t *testing.T) {
|
|
var zero TestModel
|
|
result := getNil[TestModel]()
|
|
assert.Equal(t, zero, result)
|
|
}
|
|
|
|
func TestSetAndGetAuthenticatedModel(t *testing.T) {
|
|
ctx := context.Background()
|
|
model := TestModel{ID: 123}
|
|
authModel := authenticatedModel[TestModel]{
|
|
model: model,
|
|
fresh: 1234567890,
|
|
}
|
|
|
|
newCtx := setAuthenticatedModel(ctx, authModel)
|
|
|
|
retrieved, ok := getAuthorizedModel[TestModel](newCtx)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, model, retrieved.model)
|
|
assert.Equal(t, int64(1234567890), retrieved.fresh)
|
|
}
|
|
|
|
func TestGetAuthorizedModel_NotSet(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
retrieved, ok := getAuthorizedModel[TestModel](ctx)
|
|
assert.False(t, ok)
|
|
var zero TestModel
|
|
assert.Equal(t, zero, retrieved.model)
|
|
assert.Equal(t, int64(0), retrieved.fresh)
|
|
}
|
|
|
|
func TestCurrentModel(t *testing.T) {
|
|
auth := &Authenticator[TestModel, DBTransaction]{}
|
|
|
|
t.Run("nil context", func(t *testing.T) {
|
|
var nilContext context.Context = nil
|
|
result := auth.CurrentModel(nilContext)
|
|
var zero TestModel
|
|
assert.Equal(t, zero, result)
|
|
})
|
|
|
|
t.Run("context without authenticated model", func(t *testing.T) {
|
|
ctx := context.Background()
|
|
result := auth.CurrentModel(ctx)
|
|
var zero TestModel
|
|
assert.Equal(t, zero, result)
|
|
})
|
|
|
|
t.Run("context with authenticated model", func(t *testing.T) {
|
|
ctx := context.Background()
|
|
model := TestModel{ID: 456}
|
|
authModel := authenticatedModel[TestModel]{
|
|
model: model,
|
|
fresh: 1234567890,
|
|
}
|
|
ctx = setAuthenticatedModel(ctx, authModel)
|
|
|
|
result := auth.CurrentModel(ctx)
|
|
assert.Equal(t, model, result)
|
|
assert.Equal(t, 456, result.GetID())
|
|
})
|
|
}
|
|
|
|
func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
|
|
// Clear environment variables
|
|
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
|
|
os.Setenv("HWSAUTH_SECRET_KEY", "")
|
|
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
|
|
|
_, err := ConfigFromEnv()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "Envar not set: HWSAUTH_SECRET_KEY")
|
|
}
|
|
|
|
func TestConfigFromEnv_SSLWithoutTrustedHost(t *testing.T) {
|
|
// Clear environment variables
|
|
t.Setenv("HWSAUTH_SECRET_KEY", "test-secret")
|
|
t.Setenv("HWSAUTH_SSL", "true")
|
|
t.Setenv("HWSAUTH_TRUSTED_HOST", "")
|
|
defer func() {
|
|
t.Setenv("HWSAUTH_SECRET_KEY", "")
|
|
t.Setenv("HWSAUTH_SSL", "")
|
|
t.Setenv("HWSAUTH_TRUSTED_HOST", "")
|
|
}()
|
|
|
|
_, err := ConfigFromEnv()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "SSL is enabled and no HWS_TRUSTED_HOST set")
|
|
}
|
|
|
|
func TestConfigFromEnv_ValidMinimalConfig(t *testing.T) {
|
|
// Set environment variables
|
|
t.Setenv("HWSAUTH_SECRET_KEY", "test-secret-key")
|
|
defer t.Setenv("HWSAUTH_SECRET_KEY", "")
|
|
|
|
cfg, err := ConfigFromEnv()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "test-secret-key", cfg.SecretKey)
|
|
assert.Equal(t, false, cfg.SSL)
|
|
assert.Equal(t, int64(5), cfg.AccessTokenExpiry)
|
|
assert.Equal(t, int64(1440), cfg.RefreshTokenExpiry)
|
|
assert.Equal(t, int64(5), cfg.TokenFreshTime)
|
|
assert.Equal(t, "/profile", cfg.LandingPage)
|
|
assert.Equal(t, "postgres", cfg.DatabaseType)
|
|
assert.Equal(t, "15", cfg.DatabaseVersion)
|
|
assert.Equal(t, "jwtblacklist", cfg.JWTTableName)
|
|
}
|
|
|
|
func TestConfigFromEnv_ValidFullConfig(t *testing.T) {
|
|
// Set environment variables
|
|
t.Setenv("HWSAUTH_SECRET_KEY", "custom-secret")
|
|
t.Setenv("HWSAUTH_SSL", "true")
|
|
t.Setenv("HWSAUTH_TRUSTED_HOST", "example.com")
|
|
t.Setenv("HWSAUTH_ACCESS_TOKEN_EXPIRY", "15")
|
|
t.Setenv("HWSAUTH_REFRESH_TOKEN_EXPIRY", "2880")
|
|
t.Setenv("HWSAUTH_TOKEN_FRESH_TIME", "10")
|
|
t.Setenv("HWSAUTH_LANDING_PAGE", "/dashboard")
|
|
t.Setenv("HWSAUTH_DATABASE_TYPE", "mysql")
|
|
t.Setenv("HWSAUTH_DATABASE_VERSION", "8.0")
|
|
t.Setenv("HWSAUTH_JWT_TABLE_NAME", "custom_tokens")
|
|
defer func() {
|
|
t.Setenv("HWSAUTH_SECRET_KEY", "")
|
|
t.Setenv("HWSAUTH_SSL", "")
|
|
t.Setenv("HWSAUTH_TRUSTED_HOST", "")
|
|
t.Setenv("HWSAUTH_ACCESS_TOKEN_EXPIRY", "")
|
|
t.Setenv("HWSAUTH_REFRESH_TOKEN_EXPIRY", "")
|
|
t.Setenv("HWSAUTH_TOKEN_FRESH_TIME", "")
|
|
t.Setenv("HWSAUTH_LANDING_PAGE", "")
|
|
t.Setenv("HWSAUTH_DATABASE_TYPE", "")
|
|
t.Setenv("HWSAUTH_DATABASE_VERSION", "")
|
|
t.Setenv("HWSAUTH_JWT_TABLE_NAME", "")
|
|
}()
|
|
|
|
cfg, err := ConfigFromEnv()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "custom-secret", cfg.SecretKey)
|
|
assert.Equal(t, true, cfg.SSL)
|
|
assert.Equal(t, "example.com", cfg.TrustedHost)
|
|
assert.Equal(t, int64(15), cfg.AccessTokenExpiry)
|
|
assert.Equal(t, int64(2880), cfg.RefreshTokenExpiry)
|
|
assert.Equal(t, int64(10), cfg.TokenFreshTime)
|
|
assert.Equal(t, "/dashboard", cfg.LandingPage)
|
|
assert.Equal(t, "mysql", cfg.DatabaseType)
|
|
assert.Equal(t, "8.0", cfg.DatabaseVersion)
|
|
assert.Equal(t, "custom_tokens", cfg.JWTTableName)
|
|
}
|
|
|
|
func TestNewAuthenticator_NilConfig(t *testing.T) {
|
|
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
|
return TestModel{ID: id}, nil
|
|
}
|
|
server := &hws.Server{}
|
|
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
|
return &TestTransaction{}, nil
|
|
}
|
|
logger := &hlog.Logger{}
|
|
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
|
return TestErrorPage{}, nil
|
|
}
|
|
|
|
auth, err := NewAuthenticator(
|
|
nil, // cfg
|
|
load,
|
|
server,
|
|
beginTx,
|
|
logger,
|
|
errorPage,
|
|
nil, // db
|
|
)
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, auth)
|
|
assert.Contains(t, err.Error(), "Config is required")
|
|
}
|
|
|
|
func TestNewAuthenticator_MissingSecretKey(t *testing.T) {
|
|
cfg := &Config{
|
|
SecretKey: "",
|
|
}
|
|
|
|
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
|
return TestModel{ID: id}, nil
|
|
}
|
|
server := &hws.Server{}
|
|
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
|
return &TestTransaction{}, nil
|
|
}
|
|
logger := &hlog.Logger{}
|
|
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
|
return TestErrorPage{}, nil
|
|
}
|
|
|
|
auth, err := NewAuthenticator(
|
|
cfg,
|
|
load,
|
|
server,
|
|
beginTx,
|
|
logger,
|
|
errorPage,
|
|
nil, // db - will fail before db check since SecretKey is missing
|
|
)
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, auth)
|
|
assert.Contains(t, err.Error(), "SecretKey is required")
|
|
}
|
|
|
|
func TestNewAuthenticator_NilLoadFunction(t *testing.T) {
|
|
cfg := &Config{
|
|
SecretKey: "test-secret",
|
|
}
|
|
|
|
server := &hws.Server{}
|
|
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
|
return &TestTransaction{}, nil
|
|
}
|
|
logger := &hlog.Logger{}
|
|
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
|
return TestErrorPage{}, nil
|
|
}
|
|
|
|
auth, err := NewAuthenticator[TestModel, DBTransaction](
|
|
cfg,
|
|
nil,
|
|
server,
|
|
beginTx,
|
|
logger,
|
|
errorPage,
|
|
nil, // db
|
|
)
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, auth)
|
|
assert.Contains(t, err.Error(), "No function to load model supplied")
|
|
}
|
|
|
|
func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
|
|
cfg := &Config{
|
|
SecretKey: "test-secret",
|
|
SSL: true,
|
|
}
|
|
|
|
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
|
return TestModel{ID: id}, nil
|
|
}
|
|
server := &hws.Server{}
|
|
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
|
return &TestTransaction{}, nil
|
|
}
|
|
logger := &hlog.Logger{}
|
|
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
|
return TestErrorPage{}, nil
|
|
}
|
|
|
|
db, _, err := createMockDB()
|
|
require.NoError(t, err)
|
|
defer db.Close()
|
|
|
|
auth, err := NewAuthenticator(
|
|
cfg,
|
|
load,
|
|
server,
|
|
beginTx,
|
|
logger,
|
|
errorPage,
|
|
db,
|
|
)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, auth)
|
|
|
|
assert.Equal(t, false, auth.SSL)
|
|
assert.Equal(t, "/profile", auth.LandingPage)
|
|
}
|
|
|
|
func TestNewAuthenticator_NilDatabase(t *testing.T) {
|
|
cfg := &Config{
|
|
SecretKey: "test-secret",
|
|
}
|
|
|
|
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
|
return TestModel{ID: id}, nil
|
|
}
|
|
server := &hws.Server{}
|
|
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
|
return &TestTransaction{}, nil
|
|
}
|
|
logger := &hlog.Logger{}
|
|
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
|
return TestErrorPage{}, nil
|
|
}
|
|
|
|
auth, err := NewAuthenticator(
|
|
cfg,
|
|
load,
|
|
server,
|
|
beginTx,
|
|
logger,
|
|
errorPage,
|
|
nil, // db
|
|
)
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, auth)
|
|
assert.Contains(t, err.Error(), "No Database provided")
|
|
}
|
|
|
|
func TestModelInterface(t *testing.T) {
|
|
t.Run("TestModel implements Model interface", func(t *testing.T) {
|
|
var _ Model = TestModel{}
|
|
})
|
|
|
|
t.Run("GetID method", func(t *testing.T) {
|
|
model := TestModel{ID: 789}
|
|
assert.Equal(t, 789, model.GetID())
|
|
})
|
|
}
|
|
|
|
func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
|
|
cfg := &Config{
|
|
SecretKey: "test-secret",
|
|
TrustedHost: "example.com",
|
|
}
|
|
|
|
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
|
return TestModel{ID: id}, nil
|
|
}
|
|
server := &hws.Server{}
|
|
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
|
return &TestTransaction{}, nil
|
|
}
|
|
logger := &hlog.Logger{}
|
|
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
|
return TestErrorPage{}, nil
|
|
}
|
|
|
|
db, _, err := createMockDB()
|
|
require.NoError(t, err)
|
|
defer db.Close()
|
|
|
|
auth, err := NewAuthenticator(
|
|
cfg,
|
|
load,
|
|
server,
|
|
beginTx,
|
|
logger,
|
|
errorPage,
|
|
db,
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
tx := &TestTransaction{}
|
|
w := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/", nil)
|
|
|
|
model, err := auth.getAuthenticatedUser(tx, w, r)
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "No token strings provided")
|
|
var zero TestModel
|
|
assert.Equal(t, zero, model.model)
|
|
}
|
|
|
|
func TestLogin_BasicFunctionality(t *testing.T) {
|
|
cfg := &Config{
|
|
SecretKey: "test-secret",
|
|
TrustedHost: "example.com",
|
|
}
|
|
|
|
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
|
return TestModel{ID: id}, nil
|
|
}
|
|
server := &hws.Server{}
|
|
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
|
return &TestTransaction{}, nil
|
|
}
|
|
logger := &hlog.Logger{}
|
|
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
|
return TestErrorPage{}, nil
|
|
}
|
|
|
|
db, _, err := createMockDB()
|
|
require.NoError(t, err)
|
|
defer db.Close()
|
|
|
|
auth, err := NewAuthenticator(
|
|
cfg,
|
|
load,
|
|
server,
|
|
beginTx,
|
|
logger,
|
|
errorPage,
|
|
db,
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
w := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/", nil)
|
|
|
|
user := TestModel{ID: 123}
|
|
rememberMe := true
|
|
|
|
// This test mainly checks that the function doesn't panic and has right call signature
|
|
// The actual JWT functionality is tested in jwt package itself
|
|
assert.NotPanics(t, func() {
|
|
auth.Login(w, r, user, rememberMe)
|
|
})
|
|
}
|