updated hwsauth: uses new hws version
This commit is contained in:
440
hwsauth/hwsauth_test.go
Normal file
440
hwsauth/hwsauth_test.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package hwsauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type TestModel struct {
|
||||
ID int
|
||||
}
|
||||
|
||||
func (tm TestModel) GetID() int {
|
||||
return tm.ID
|
||||
}
|
||||
|
||||
type TestTransaction struct {
|
||||
}
|
||||
|
||||
func (tt *TestTransaction) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (tt *TestTransaction) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (tt *TestTransaction) Commit() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tt *TestTransaction) Rollback() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type TestErrorPage struct{}
|
||||
|
||||
func (tep TestErrorPage) Render(ctx context.Context, w io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetNil(t *testing.T) {
|
||||
var zero TestModel
|
||||
result := getNil[TestModel]()
|
||||
assert.Equal(t, zero, result)
|
||||
}
|
||||
|
||||
func TestSetAndGetAuthenticatedModel(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
model := TestModel{ID: 123}
|
||||
authModel := authenticatedModel[TestModel]{
|
||||
model: model,
|
||||
fresh: 1234567890,
|
||||
}
|
||||
|
||||
newCtx := setAuthenticatedModel(ctx, authModel)
|
||||
|
||||
retrieved, ok := getAuthorizedModel[TestModel](newCtx)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, model, retrieved.model)
|
||||
assert.Equal(t, int64(1234567890), retrieved.fresh)
|
||||
}
|
||||
|
||||
func TestGetAuthorizedModel_NotSet(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
retrieved, ok := getAuthorizedModel[TestModel](ctx)
|
||||
assert.False(t, ok)
|
||||
var zero TestModel
|
||||
assert.Equal(t, zero, retrieved.model)
|
||||
assert.Equal(t, int64(0), retrieved.fresh)
|
||||
}
|
||||
|
||||
func TestCurrentModel(t *testing.T) {
|
||||
auth := &Authenticator[TestModel, DBTransaction]{}
|
||||
|
||||
t.Run("nil context", func(t *testing.T) {
|
||||
var nilContext context.Context = nil
|
||||
result := auth.CurrentModel(nilContext)
|
||||
var zero TestModel
|
||||
assert.Equal(t, zero, result)
|
||||
})
|
||||
|
||||
t.Run("context without authenticated model", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
result := auth.CurrentModel(ctx)
|
||||
var zero TestModel
|
||||
assert.Equal(t, zero, result)
|
||||
})
|
||||
|
||||
t.Run("context with authenticated model", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
model := TestModel{ID: 456}
|
||||
authModel := authenticatedModel[TestModel]{
|
||||
model: model,
|
||||
fresh: 1234567890,
|
||||
}
|
||||
ctx = setAuthenticatedModel(ctx, authModel)
|
||||
|
||||
result := auth.CurrentModel(ctx)
|
||||
assert.Equal(t, model, result)
|
||||
assert.Equal(t, 456, result.GetID())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigFromEnv_MissingSecretKey(t *testing.T) {
|
||||
// Clear environment variables
|
||||
originalSecret := os.Getenv("HWSAUTH_SECRET_KEY")
|
||||
os.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
defer os.Setenv("HWSAUTH_SECRET_KEY", originalSecret)
|
||||
|
||||
_, err := ConfigFromEnv()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Envar not set: HWSAUTH_SECRET_KEY")
|
||||
}
|
||||
|
||||
func TestConfigFromEnv_SSLWithoutTrustedHost(t *testing.T) {
|
||||
// Clear environment variables
|
||||
t.Setenv("HWSAUTH_SECRET_KEY", "test-secret")
|
||||
t.Setenv("HWSAUTH_SSL", "true")
|
||||
t.Setenv("HWSAUTH_TRUSTED_HOST", "")
|
||||
defer func() {
|
||||
t.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
t.Setenv("HWSAUTH_SSL", "")
|
||||
t.Setenv("HWSAUTH_TRUSTED_HOST", "")
|
||||
}()
|
||||
|
||||
_, err := ConfigFromEnv()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "SSL is enabled and no HWS_TRUSTED_HOST set")
|
||||
}
|
||||
|
||||
func TestConfigFromEnv_ValidMinimalConfig(t *testing.T) {
|
||||
// Set environment variables
|
||||
t.Setenv("HWSAUTH_SECRET_KEY", "test-secret-key")
|
||||
defer t.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
|
||||
cfg, err := ConfigFromEnv()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test-secret-key", cfg.SecretKey)
|
||||
assert.Equal(t, false, cfg.SSL)
|
||||
assert.Equal(t, int64(5), cfg.AccessTokenExpiry)
|
||||
assert.Equal(t, int64(1440), cfg.RefreshTokenExpiry)
|
||||
assert.Equal(t, int64(5), cfg.TokenFreshTime)
|
||||
assert.Equal(t, "/profile", cfg.LandingPage)
|
||||
assert.Equal(t, "postgres", cfg.DatabaseType)
|
||||
assert.Equal(t, "15", cfg.DatabaseVersion)
|
||||
assert.Equal(t, "jwtblacklist", cfg.JWTTableName)
|
||||
}
|
||||
|
||||
func TestConfigFromEnv_ValidFullConfig(t *testing.T) {
|
||||
// Set environment variables
|
||||
t.Setenv("HWSAUTH_SECRET_KEY", "custom-secret")
|
||||
t.Setenv("HWSAUTH_SSL", "true")
|
||||
t.Setenv("HWSAUTH_TRUSTED_HOST", "example.com")
|
||||
t.Setenv("HWSAUTH_ACCESS_TOKEN_EXPIRY", "15")
|
||||
t.Setenv("HWSAUTH_REFRESH_TOKEN_EXPIRY", "2880")
|
||||
t.Setenv("HWSAUTH_TOKEN_FRESH_TIME", "10")
|
||||
t.Setenv("HWSAUTH_LANDING_PAGE", "/dashboard")
|
||||
t.Setenv("HWSAUTH_DATABASE_TYPE", "mysql")
|
||||
t.Setenv("HWSAUTH_DATABASE_VERSION", "8.0")
|
||||
t.Setenv("HWSAUTH_JWT_TABLE_NAME", "custom_tokens")
|
||||
defer func() {
|
||||
t.Setenv("HWSAUTH_SECRET_KEY", "")
|
||||
t.Setenv("HWSAUTH_SSL", "")
|
||||
t.Setenv("HWSAUTH_TRUSTED_HOST", "")
|
||||
t.Setenv("HWSAUTH_ACCESS_TOKEN_EXPIRY", "")
|
||||
t.Setenv("HWSAUTH_REFRESH_TOKEN_EXPIRY", "")
|
||||
t.Setenv("HWSAUTH_TOKEN_FRESH_TIME", "")
|
||||
t.Setenv("HWSAUTH_LANDING_PAGE", "")
|
||||
t.Setenv("HWSAUTH_DATABASE_TYPE", "")
|
||||
t.Setenv("HWSAUTH_DATABASE_VERSION", "")
|
||||
t.Setenv("HWSAUTH_JWT_TABLE_NAME", "")
|
||||
}()
|
||||
|
||||
cfg, err := ConfigFromEnv()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "custom-secret", cfg.SecretKey)
|
||||
assert.Equal(t, true, cfg.SSL)
|
||||
assert.Equal(t, "example.com", cfg.TrustedHost)
|
||||
assert.Equal(t, int64(15), cfg.AccessTokenExpiry)
|
||||
assert.Equal(t, int64(2880), cfg.RefreshTokenExpiry)
|
||||
assert.Equal(t, int64(10), cfg.TokenFreshTime)
|
||||
assert.Equal(t, "/dashboard", cfg.LandingPage)
|
||||
assert.Equal(t, "mysql", cfg.DatabaseType)
|
||||
assert.Equal(t, "8.0", cfg.DatabaseVersion)
|
||||
assert.Equal(t, "custom_tokens", cfg.JWTTableName)
|
||||
}
|
||||
|
||||
func TestNewAuthenticator_NilConfig(t *testing.T) {
|
||||
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
||||
return TestModel{ID: id}, nil
|
||||
}
|
||||
server := &hws.Server{}
|
||||
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
||||
return &TestTransaction{}, nil
|
||||
}
|
||||
logger := &hlog.Logger{}
|
||||
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return TestErrorPage{}, nil
|
||||
}
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
nil,
|
||||
load,
|
||||
server,
|
||||
beginTx,
|
||||
logger,
|
||||
errorPage,
|
||||
)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, auth)
|
||||
assert.Contains(t, err.Error(), "Config is required")
|
||||
}
|
||||
|
||||
func TestNewAuthenticator_MissingSecretKey(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SecretKey: "",
|
||||
}
|
||||
|
||||
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
||||
return TestModel{ID: id}, nil
|
||||
}
|
||||
server := &hws.Server{}
|
||||
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
||||
return &TestTransaction{}, nil
|
||||
}
|
||||
logger := &hlog.Logger{}
|
||||
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return TestErrorPage{}, nil
|
||||
}
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
load,
|
||||
server,
|
||||
beginTx,
|
||||
logger,
|
||||
errorPage,
|
||||
)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, auth)
|
||||
assert.Contains(t, err.Error(), "SecretKey is required")
|
||||
}
|
||||
|
||||
func TestNewAuthenticator_NilLoadFunction(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SecretKey: "test-secret",
|
||||
}
|
||||
|
||||
server := &hws.Server{}
|
||||
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
||||
return &TestTransaction{}, nil
|
||||
}
|
||||
logger := &hlog.Logger{}
|
||||
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return TestErrorPage{}, nil
|
||||
}
|
||||
|
||||
auth, err := NewAuthenticator[TestModel, DBTransaction](
|
||||
cfg,
|
||||
nil,
|
||||
server,
|
||||
beginTx,
|
||||
logger,
|
||||
errorPage,
|
||||
)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, auth)
|
||||
assert.Contains(t, err.Error(), "No function to load model supplied")
|
||||
}
|
||||
|
||||
func TestNewAuthenticator_SSLWithoutTrustedHost(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SecretKey: "test-secret",
|
||||
SSL: true,
|
||||
}
|
||||
|
||||
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
||||
return TestModel{ID: id}, nil
|
||||
}
|
||||
server := &hws.Server{}
|
||||
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
||||
return &TestTransaction{}, nil
|
||||
}
|
||||
logger := &hlog.Logger{}
|
||||
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return TestErrorPage{}, nil
|
||||
}
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
load,
|
||||
server,
|
||||
beginTx,
|
||||
logger,
|
||||
errorPage,
|
||||
)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, auth)
|
||||
assert.Contains(t, err.Error(), "TrustedHost is required when SSL is enabled")
|
||||
}
|
||||
|
||||
func TestNewAuthenticator_ValidMinimalConfig(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SecretKey: "test-secret",
|
||||
TrustedHost: "example.com",
|
||||
}
|
||||
|
||||
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
||||
return TestModel{ID: id}, nil
|
||||
}
|
||||
server := &hws.Server{}
|
||||
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
||||
return &TestTransaction{}, nil
|
||||
}
|
||||
logger := &hlog.Logger{}
|
||||
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return TestErrorPage{}, nil
|
||||
}
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
load,
|
||||
server,
|
||||
beginTx,
|
||||
logger,
|
||||
errorPage,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, auth)
|
||||
|
||||
assert.Equal(t, false, auth.SSL)
|
||||
assert.Equal(t, "/profile", auth.LandingPage)
|
||||
}
|
||||
|
||||
func TestModelInterface(t *testing.T) {
|
||||
t.Run("TestModel implements Model interface", func(t *testing.T) {
|
||||
var _ Model = TestModel{}
|
||||
})
|
||||
|
||||
t.Run("GetID method", func(t *testing.T) {
|
||||
model := TestModel{ID: 789}
|
||||
assert.Equal(t, 789, model.GetID())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAuthenticatedUser_NoTokens(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SecretKey: "test-secret",
|
||||
TrustedHost: "example.com",
|
||||
}
|
||||
|
||||
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
||||
return TestModel{ID: id}, nil
|
||||
}
|
||||
server := &hws.Server{}
|
||||
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
||||
return &TestTransaction{}, nil
|
||||
}
|
||||
logger := &hlog.Logger{}
|
||||
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return TestErrorPage{}, nil
|
||||
}
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
load,
|
||||
server,
|
||||
beginTx,
|
||||
logger,
|
||||
errorPage,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
tx := &TestTransaction{}
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
model, err := auth.getAuthenticatedUser(tx, w, r)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "No token strings provided")
|
||||
var zero TestModel
|
||||
assert.Equal(t, zero, model.model)
|
||||
}
|
||||
|
||||
func TestLogin_BasicFunctionality(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SecretKey: "test-secret",
|
||||
TrustedHost: "example.com",
|
||||
}
|
||||
|
||||
load := func(ctx context.Context, tx DBTransaction, id int) (TestModel, error) {
|
||||
return TestModel{ID: id}, nil
|
||||
}
|
||||
server := &hws.Server{}
|
||||
beginTx := func(ctx context.Context) (DBTransaction, error) {
|
||||
return &TestTransaction{}, nil
|
||||
}
|
||||
logger := &hlog.Logger{}
|
||||
errorPage := func(error hws.HWSError) (hws.ErrorPage, error) {
|
||||
return TestErrorPage{}, nil
|
||||
}
|
||||
|
||||
auth, err := NewAuthenticator(
|
||||
cfg,
|
||||
load,
|
||||
server,
|
||||
beginTx,
|
||||
logger,
|
||||
errorPage,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
user := TestModel{ID: 123}
|
||||
rememberMe := true
|
||||
|
||||
// This test mainly checks that the function doesn't panic and has right call signature
|
||||
// The actual JWT functionality is tested in jwt package itself
|
||||
assert.NotPanics(t, func() {
|
||||
auth.Login(w, r, user, rememberMe)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user