package jwt import ( "context" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) { db, mock, err := sqlmock.New() require.NoError(t, err) gen, err := CreateGenerator( 15, 60, 5, "example.com", "supersecret", db, ) require.NoError(t, err) return gen, mock, func() { db.Close() } } func expectNotRevoked(mock sqlmock.Sqlmock, jti any) { mock.ExpectBegin() mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`). WithArgs(jti). WillReturnRows(sqlmock.NewRows([]string{})) mock.ExpectCommit() } func TestValidateAccess_Success(t *testing.T) { gen, mock, cleanup := newGeneratorWithMockDB(t) defer cleanup() tokenStr, _, err := gen.NewAccess(42, true, false) require.NoError(t, err) // We don't know the JTI beforehand; match any arg expectNotRevoked(mock, sqlmock.AnyArg()) token, err := gen.ValidateAccess(context.Background(), tokenStr) require.NoError(t, err) require.Equal(t, 42, token.SUB) require.Equal(t, "access", token.Scope) } func TestValidateAccess_NoDB(t *testing.T) { gen := newGeneratorWithNoDB(t) tokenStr, _, err := gen.NewAccess(42, true, false) require.NoError(t, err) token, err := gen.ValidateAccess(context.Background(), tokenStr) require.NoError(t, err) require.Equal(t, 42, token.SUB) require.Equal(t, "access", token.Scope) } func TestValidateRefresh_Success(t *testing.T) { gen, mock, cleanup := newGeneratorWithMockDB(t) defer cleanup() tokenStr, _, err := gen.NewRefresh(42, false) require.NoError(t, err) expectNotRevoked(mock, sqlmock.AnyArg()) token, err := gen.ValidateRefresh(context.Background(), tokenStr) require.NoError(t, err) require.Equal(t, 42, token.SUB) require.Equal(t, "refresh", token.Scope) } func TestValidateRefresh_NoDB(t *testing.T) { gen := newGeneratorWithNoDB(t) tokenStr, _, err := gen.NewRefresh(42, false) require.NoError(t, err) token, err := gen.ValidateRefresh(context.Background(), tokenStr) require.NoError(t, err) require.Equal(t, 42, token.SUB) require.Equal(t, "refresh", token.Scope) } func TestValidateAccess_EmptyToken(t *testing.T) { gen := newTestGenerator(t) _, err := gen.ValidateAccess(context.Background(), "") require.Error(t, err) } func TestValidateRefresh_WrongScope(t *testing.T) { gen := newTestGenerator(t) // Create access token but validate as refresh tokenStr, _, err := gen.NewAccess(1, false, false) require.NoError(t, err) _, err = gen.ValidateRefresh(context.Background(), tokenStr) require.Error(t, err) }