diff --git a/docs/conf.py b/docs/conf.py index 28ac1c71..55f6c781 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -151,7 +151,7 @@ Middleware asyncio.events.AbstractEventLoop graphql.execution.collect_fields.FieldsAndPatches -graphql.execution.map_async_iterable.MapAsyncIterable +graphql.execution.map_async_iterable.map_async_iterable graphql.execution.Middleware graphql.execution.execute.DeferredFragmentRecord graphql.execution.execute.ExperimentalExecuteMultipleResults diff --git a/docs/modules/execution.rst b/docs/modules/execution.rst index 82147930..5b33b03d 100644 --- a/docs/modules/execution.rst +++ b/docs/modules/execution.rst @@ -57,8 +57,6 @@ Execution .. autofunction:: create_source_event_stream -.. autoclass:: MapAsyncIterable - .. autoclass:: Middleware .. autoclass:: MiddlewareManager diff --git a/pyproject.toml b/pyproject.toml index c8b6ba18..f9e7e86c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "graphql-core" version = "3.3.0a2" -description = """ +description = """\ GraphQL-core is a Python port of GraphQL.js,\ the JavaScript reference implementation for GraphQL.""" license = "MIT" diff --git a/src/graphql/__init__.py b/src/graphql/__init__.py index cb946aba..714d4215 100644 --- a/src/graphql/__init__.py +++ b/src/graphql/__init__.py @@ -439,7 +439,7 @@ # Subscription subscribe, create_source_event_stream, - MapAsyncIterable, + map_async_iterable, # Middleware Middleware, MiddlewareManager, @@ -707,7 +707,7 @@ "MiddlewareManager", "subscribe", "create_source_event_stream", - "MapAsyncIterable", + "map_async_iterable", "validate", "ValidationContext", "ValidationRule", diff --git a/src/graphql/execution/__init__.py b/src/graphql/execution/__init__.py index 6487c33d..9bb89131 100644 --- a/src/graphql/execution/__init__.py +++ b/src/graphql/execution/__init__.py @@ -32,7 +32,7 @@ FormattedIncrementalResult, Middleware, ) -from .map_async_iterable import MapAsyncIterable +from .iterators import map_async_iterable from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values @@ -62,7 +62,7 @@ "FormattedIncrementalDeferResult", "FormattedIncrementalStreamResult", "FormattedIncrementalResult", - "MapAsyncIterable", + "map_async_iterable", "Middleware", "MiddlewareManager", "get_argument_values", diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index dd69658f..a8b52e4a 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -69,8 +69,7 @@ is_object_type, ) from .collect_fields import FieldsAndPatches, collect_fields, collect_subfields -from .flatten_async_iterable import flatten_async_iterable -from .map_async_iterable import MapAsyncIterable +from .iterators import flatten_async_iterable, map_async_iterable from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values @@ -1654,7 +1653,7 @@ async def callback(payload: Any) -> AsyncGenerator: await result if isawaitable(result) else result # type: ignore ) - return flatten_async_iterable(MapAsyncIterable(result_or_stream, callback)) + return flatten_async_iterable(map_async_iterable(result_or_stream, callback)) def execute_deferred_fragment( self, @@ -2319,18 +2318,20 @@ def subscribe( if isinstance(result, ExecutionResult): return result if isinstance(result, AsyncIterable): - return MapAsyncIterable(result, ensure_single_execution_result) + return map_async_iterable(result, ensure_single_execution_result) async def await_result() -> Union[AsyncIterator[ExecutionResult], ExecutionResult]: result_or_iterable = await result # type: ignore if isinstance(result_or_iterable, AsyncIterable): - return MapAsyncIterable(result_or_iterable, ensure_single_execution_result) + return map_async_iterable( + result_or_iterable, ensure_single_execution_result + ) return result_or_iterable return await_result() -def ensure_single_execution_result( +async def ensure_single_execution_result( result: Union[ ExecutionResult, InitialIncrementalExecutionResult, diff --git a/src/graphql/execution/flatten_async_iterable.py b/src/graphql/execution/iterators.py similarity index 50% rename from src/graphql/execution/flatten_async_iterable.py rename to src/graphql/execution/iterators.py index 7c0e0721..c3479175 100644 --- a/src/graphql/execution/flatten_async_iterable.py +++ b/src/graphql/execution/iterators.py @@ -1,4 +1,14 @@ -from typing import AsyncGenerator, AsyncIterable, TypeVar, Union +from __future__ import annotations # Python < 3.10 + +from typing import ( + Any, + AsyncGenerator, + AsyncIterable, + Awaitable, + Callable, + TypeVar, + Union, +) try: @@ -15,10 +25,11 @@ async def aclosing(thing): T = TypeVar("T") +V = TypeVar("V") AsyncIterableOrGenerator = Union[AsyncGenerator[T, None], AsyncIterable[T]] -__all__ = ["flatten_async_iterable"] +__all__ = ["flatten_async_iterable", "map_async_iterable"] async def flatten_async_iterable( @@ -34,3 +45,23 @@ async def flatten_async_iterable( async with aclosing(sub_iterator) as items: # type: ignore async for item in items: yield item + + +async def map_async_iterable( + iterable: AsyncIterable[T], callback: Callable[[T], Awaitable[V]] +) -> AsyncGenerator[V, None]: + """Map an AsyncIterable over a callback function. + + Given an AsyncIterable and an async callback callable, return an AsyncGenerator + which produces values mapped via calling the callback. + If the inner iterator supports an `aclose()` method, it will be called when + the generator finishes or closes. + """ + + aiter = iterable.__aiter__() + try: + async for element in aiter: + yield await callback(element) + finally: + if hasattr(aiter, "aclose"): + await aiter.aclose() diff --git a/src/graphql/execution/map_async_iterable.py b/src/graphql/execution/map_async_iterable.py deleted file mode 100644 index 84bd3f4a..00000000 --- a/src/graphql/execution/map_async_iterable.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations # Python < 3.10 - -from asyncio import CancelledError, Event, Task, ensure_future, wait -from concurrent.futures import FIRST_COMPLETED -from inspect import isasyncgen, isawaitable -from types import TracebackType -from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union, cast - - -__all__ = ["MapAsyncIterable"] - - -# noinspection PyAttributeOutsideInit -class MapAsyncIterable: - """Map an AsyncIterable over a callback function. - - Given an AsyncIterable and a callback function, return an AsyncIterator which - produces values mapped via calling the callback function. - - When the resulting AsyncIterator is closed, the underlying AsyncIterable will also - be closed. - """ - - def __init__(self, iterable: AsyncIterable, callback: Callable) -> None: - self.iterator = iterable.__aiter__() - self.callback = callback - self._close_event = Event() - - def __aiter__(self) -> MapAsyncIterable: - """Get the iterator object.""" - return self - - async def __anext__(self) -> Any: - """Get the next value of the iterator.""" - if self.is_closed: - if not isasyncgen(self.iterator): - raise StopAsyncIteration - value = await self.iterator.__anext__() - else: - aclose = ensure_future(self._close_event.wait()) - anext = ensure_future(self.iterator.__anext__()) - - try: - pending: Set[Task] = ( - await wait([aclose, anext], return_when=FIRST_COMPLETED) - )[1] - except CancelledError: - # cancel underlying tasks and close - aclose.cancel() - anext.cancel() - await self.aclose() - raise # re-raise the cancellation - - for task in pending: - task.cancel() - - if aclose.done(): - raise StopAsyncIteration - - error = anext.exception() - if error: - raise error - - value = anext.result() - - result = self.callback(value) - - return await result if isawaitable(result) else result - - async def athrow( - self, - type_: Union[BaseException, Type[BaseException]], - value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, - ) -> None: - """Throw an exception into the asynchronous iterator.""" - if self.is_closed: - return - athrow = getattr(self.iterator, "athrow", None) - if athrow: - await athrow(type_, value, traceback) - else: - await self.aclose() - if value is None: - if traceback is None: - raise type_ - value = ( - type_ - if isinstance(value, BaseException) - else cast(Type[BaseException], type_)() - ) - if traceback is not None: - value = value.with_traceback(traceback) - raise value - - async def aclose(self) -> None: - """Close the iterator.""" - if not self.is_closed: - aclose = getattr(self.iterator, "aclose", None) - if aclose: - try: - await aclose() - except RuntimeError: - pass - self.is_closed = True - - @property - def is_closed(self) -> bool: - """Check whether the iterator is closed.""" - return self._close_event.is_set() - - @is_closed.setter - def is_closed(self, value: bool) -> None: - """Mark the iterator as closed.""" - if value: - self._close_event.set() - else: - self._close_event.clear() diff --git a/tests/execution/test_customize.py b/tests/execution/test_customize.py index 3dbc6d00..5b839fc8 100644 --- a/tests/execution/test_customize.py +++ b/tests/execution/test_customize.py @@ -1,6 +1,8 @@ +from inspect import isasyncgen + from pytest import mark -from graphql.execution import ExecutionContext, MapAsyncIterable, execute, subscribe +from graphql.execution import ExecutionContext, execute, subscribe from graphql.language import parse from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString @@ -77,7 +79,7 @@ async def custom_foo(): root_value=Root(), subscribe_field_resolver=lambda root, _info: root.custom_foo(), ) - assert isinstance(subscription, MapAsyncIterable) + assert isasyncgen(subscription) assert await anext(subscription) == ( {"foo": "FooValue"}, @@ -121,6 +123,6 @@ def resolve_foo(message, _info): context_value={}, execution_context_class=TestExecutionContext, ) - assert isinstance(subscription, MapAsyncIterable) + assert isasyncgen(subscription) assert await anext(subscription) == ({"foo": "bar"}, None) diff --git a/tests/execution/test_flatten_async_iterable.py b/tests/execution/test_flatten_async_iterable.py index de9c5499..49ead410 100644 --- a/tests/execution/test_flatten_async_iterable.py +++ b/tests/execution/test_flatten_async_iterable.py @@ -2,7 +2,7 @@ from pytest import mark, raises -from graphql.execution.flatten_async_iterable import flatten_async_iterable +from graphql.execution.iterators import flatten_async_iterable try: # pragma: no cover diff --git a/tests/execution/test_map_async_iterable.py b/tests/execution/test_map_async_iterable.py index 6406f7dd..1462645a 100644 --- a/tests/execution/test_map_async_iterable.py +++ b/tests/execution/test_map_async_iterable.py @@ -1,500 +1,93 @@ -import platform -import sys -from asyncio import CancelledError, Event, ensure_future, sleep - from pytest import mark, raises -from graphql.execution import MapAsyncIterable - +from graphql.execution import map_async_iterable -is_pypy = platform.python_implementation() == "PyPy" -try: # pragma: no cover - anext -except NameError: # pragma: no cover (Python < 3.10) - # noinspection PyShadowingBuiltins - async def anext(iterator): - """Return the next item from an async iterator.""" - return await iterator.__anext__() +async def map_doubles(x): + return x + x def describe_map_async_iterable(): @mark.asyncio - async def maps_over_async_generator(): - async def source(): - yield 1 - yield 2 - yield 3 - - doubles = MapAsyncIterable(source(), lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - assert await anext(doubles) == 6 - with raises(StopAsyncIteration): - assert await anext(doubles) - - @mark.asyncio - async def maps_over_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __aiter__(self): - return self - - async def __anext__(self): - try: - return items.pop(0) - except IndexError: - raise StopAsyncIteration - - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) - - values = [value async for value in doubles] - - assert not items - assert values == [2, 4, 6] - - @mark.asyncio - async def compatible_with_async_for(): - async def source(): - yield 1 - yield 2 - yield 3 - - doubles = MapAsyncIterable(source(), lambda x: x + x) - - values = [value async for value in doubles] - - assert values == [2, 4, 6] - - @mark.asyncio - async def maps_over_async_values_with_async_function(): - async def source(): - yield 1 - yield 2 - yield 3 - - async def double(x): - return x + x - - doubles = MapAsyncIterable(source(), double) - - values = [value async for value in doubles] - - assert values == [2, 4, 6] - - @mark.asyncio - async def allows_returning_early_from_mapped_async_generator(): - async def source(): - yield 1 - yield 2 - yield 3 # pragma: no cover - - doubles = MapAsyncIterable(source(), lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Early return - await doubles.aclose() - - # Subsequent next calls - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) - - @mark.asyncio - async def allows_returning_early_from_mapped_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __aiter__(self): - return self - - async def __anext__(self): - try: - return items.pop(0) - except IndexError: # pragma: no cover - raise StopAsyncIteration - - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Early return - await doubles.aclose() - - # Subsequent next calls - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) - - @mark.asyncio - async def passes_through_early_return_from_async_values(): - async def source(): - try: - yield 1 - yield 2 - yield 3 # pragma: no cover - finally: - yield "Done" - yield "Last" - - doubles = MapAsyncIterable(source(), lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Early return - await doubles.aclose() - - # Subsequent next calls may yield from finally block - assert await anext(doubles) == "LastLast" - with raises(GeneratorExit): - assert await anext(doubles) - - @mark.asyncio - async def allows_throwing_errors_through_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __aiter__(self): - return self - - async def __anext__(self): - try: - return items.pop(0) - except IndexError: # pragma: no cover - raise StopAsyncIteration + async def test_inner_close_called(): + """ + Test that a custom iterator with aclose() gets an aclose() call + when outer is closed + """ - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Throw error - message = "allows throwing errors when mapping async iterable" - with raises(RuntimeError) as exc_info: - await doubles.athrow(RuntimeError(message)) - - assert str(exc_info.value) == message - - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) - - @mark.asyncio - async def allows_throwing_errors_with_values_through_async_iterables(): - class Iterable: - def __aiter__(self): - return self - - async def __anext__(self): - return 1 - - one = MapAsyncIterable(Iterable(), lambda x: x) - - assert await anext(one) == 1 - - # Throw error with value passed separately - try: - raise RuntimeError("Ouch") - except RuntimeError as error: - with raises(RuntimeError, match="Ouch") as exc_info: - await one.athrow(error.__class__, error) - - assert exc_info.value is error - assert exc_info.tb is error.__traceback__ + class Inner: + def __init__(self): + self.closed = False - with raises(StopAsyncIteration): - await anext(one) + async def aclose(self): + self.closed = True - @mark.asyncio - async def allows_throwing_errors_with_traceback_through_async_iterables(): - class Iterable: def __aiter__(self): return self async def __anext__(self): return 1 - one = MapAsyncIterable(Iterable(), lambda x: x) - - assert await anext(one) == 1 - - # Throw error with traceback passed separately - try: - raise RuntimeError("Ouch") - except RuntimeError as error: - with raises(RuntimeError) as exc_info: - await one.athrow(error.__class__, None, error.__traceback__) - - assert exc_info.tb and error.__traceback__ - assert exc_info.tb.tb_frame is error.__traceback__.tb_frame - - with raises(StopAsyncIteration): - await anext(one) + inner = Inner() + outer = map_async_iterable(inner, map_doubles) + it = outer.__aiter__() + assert await it.__anext__() == 2 + assert not inner.closed + await outer.aclose() + assert inner.closed @mark.asyncio - async def passes_through_caught_errors_through_async_generators(): - async def source(): - try: - yield 1 - yield 2 - yield 3 # pragma: no cover - except Exception as e: - yield e - - doubles = MapAsyncIterable(source(), lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Throw error - await doubles.athrow(RuntimeError("ouch")) - - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) - - @mark.asyncio - async def does_not_normally_map_over_thrown_errors(): - async def source(): - yield "Hello" - raise RuntimeError("Goodbye") - - doubles = MapAsyncIterable(source(), lambda x: x + x) - - assert await anext(doubles) == "HelloHello" - - with raises(RuntimeError) as exc_info: - await anext(doubles) - - assert str(exc_info.value) == "Goodbye" - - @mark.asyncio - async def does_not_normally_map_over_externally_thrown_errors(): - async def source(): - yield "Hello" - - doubles = MapAsyncIterable(source(), lambda x: x + x) - - assert await anext(doubles) == "HelloHello" + async def test_inner_close_called_on_callback_err(): + """ + Test that a custom iterator with aclose() gets an aclose() call + when the callback errors and the outer iterator aborts. + """ - with raises(RuntimeError) as exc_info: - await doubles.athrow(RuntimeError("Goodbye")) - - assert str(exc_info.value) == "Goodbye" - - @mark.asyncio - async def can_use_simple_iterable_instead_of_generator(): - async def source(): - yield 1 - yield 2 - yield 3 - - class Source: + class Inner: def __init__(self): - self.counter = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - self.counter += 1 - if self.counter > 3: - raise StopAsyncIteration - return self.counter - - def double(x): - return x + x - - for iterable in source, Source: - doubles = MapAsyncIterable(iterable(), double) - - await doubles.aclose() - - with raises(StopAsyncIteration): - await anext(doubles) - - doubles = MapAsyncIterable(iterable(), double) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - assert await anext(doubles) == 6 - - with raises(StopAsyncIteration): - await anext(doubles) - - doubles = MapAsyncIterable(iterable(), double) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 + self.closed = False - # Throw error - with raises(RuntimeError) as exc_info: - await doubles.athrow(RuntimeError("ouch")) - - assert str(exc_info.value) == "ouch" - - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) - - # no more exceptions should be thrown - if is_pypy: - # need to investigate why this is needed with PyPy - await doubles.aclose() # pragma: no cover - await doubles.athrow(RuntimeError("no more ouch")) - - with raises(StopAsyncIteration): - await anext(doubles) - - await doubles.aclose() - - doubles = MapAsyncIterable(iterable(), double) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - try: - raise ValueError("bad") - except ValueError: - tb = sys.exc_info()[2] - - # Throw error - with raises(ValueError): - await doubles.athrow(ValueError, None, tb) - - await sleep(0) - - @mark.asyncio - async def stops_async_iteration_on_close(): - async def source(): - yield 1 - await Event().wait() # Block forever - yield 2 # pragma: no cover - yield 3 # pragma: no cover - - singles = source() - doubles = MapAsyncIterable(singles, lambda x: x * 2) - - result = await anext(doubles) - assert result == 2 - - # Make sure it is blocked - doubles_future = ensure_future(anext(doubles)) - await sleep(0.05) - assert not doubles_future.done() - - # Unblock and watch StopAsyncIteration propagate - await doubles.aclose() - await sleep(0.05) - assert doubles_future.done() - assert isinstance(doubles_future.exception(), StopAsyncIteration) - - with raises(StopAsyncIteration): - await anext(singles) - - @mark.asyncio - async def can_unset_closed_state_of_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __init__(self): - self.is_closed = False + async def aclose(self): + self.closed = True def __aiter__(self): return self async def __anext__(self): - if self.is_closed: - raise StopAsyncIteration - try: - return items.pop(0) - except IndexError: - raise StopAsyncIteration - - async def aclose(self): - self.is_closed = True - - iterable = Iterable() - doubles = MapAsyncIterable(iterable, lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - assert not iterable.is_closed - await doubles.aclose() - assert iterable.is_closed - with raises(StopAsyncIteration): - await anext(iterable) - with raises(StopAsyncIteration): - await anext(doubles) - assert doubles.is_closed + return 1 - iterable.is_closed = False - doubles.is_closed = False - assert not doubles.is_closed + async def callback(v): + raise RuntimeError() - assert await anext(doubles) == 6 - assert not doubles.is_closed - assert not iterable.is_closed - with raises(StopAsyncIteration): - await anext(iterable) - with raises(StopAsyncIteration): - await anext(doubles) - assert not doubles.is_closed - assert not iterable.is_closed + inner = Inner() + outer = map_async_iterable(inner, callback) + with raises(RuntimeError): + async for _ in outer: + pass + assert inner.closed @mark.asyncio - async def can_cancel_async_iterable_while_waiting(): - class Iterable: - def __init__(self): - self.is_closed = False - self.value = 1 + async def test_inner_exit_on_callback_err(): + """ + Test that a custom iterator with aclose() gets an aclose() call + when the callback errors and the outer iterator aborts. + """ - def __aiter__(self): - return self - - async def __anext__(self): - try: - await sleep(0.5) - return self.value # pragma: no cover - except CancelledError: - self.value = -1 - raise + inner_exit = False - async def aclose(self): - self.is_closed = True - - iterable = Iterable() - doubles = MapAsyncIterable(iterable, lambda x: x + x) # pragma: no cover exit - cancelled = False - - async def iterator_task(): - nonlocal cancelled + async def inner(): + nonlocal inner_exit try: - async for _ in doubles: - assert False # pragma: no cover - except CancelledError: - cancelled = True - - task = ensure_future(iterator_task()) - await sleep(0.05) - assert not cancelled - assert not doubles.is_closed - assert iterable.value == 1 - assert not iterable.is_closed - task.cancel() - await sleep(0.05) - assert cancelled - assert iterable.value == -1 - assert doubles.is_closed - assert iterable.is_closed + while True: + yield 1 + except GeneratorExit: + inner_exit = True + + async def callback(v): + raise RuntimeError + + outer = map_async_iterable(inner(), callback) + with raises(RuntimeError): + async for _ in outer: + pass + assert inner_exit diff --git a/tests/execution/test_subscribe.py b/tests/execution/test_subscribe.py index d10edce6..b3364c9a 100644 --- a/tests/execution/test_subscribe.py +++ b/tests/execution/test_subscribe.py @@ -1187,3 +1187,52 @@ async def resolve_message(message, _info): assert isinstance(subscription, AsyncIterator) assert await anext(subscription) == ({"newMessage": "Hello"}, None) + + @mark.asyncio + async def custom_async_iterator(): + class CustomAsyncIterator: + def __init__(self, events): + self.events = events + + def __aiter__(self): + return self + + async def __anext__(self): + await asyncio.sleep(0) + if not self.events: + raise StopAsyncIteration + return self.events.pop(0) + + def generate_messages(_obj, _info): + return CustomAsyncIterator(["Hello", "Dolly"]) + + async def resolve_message(message, _info): + await asyncio.sleep(0) + return message + + schema = GraphQLSchema( + query=QueryType, + subscription=GraphQLObjectType( + "Subscription", + { + "newMessage": GraphQLField( + GraphQLString, + resolve=resolve_message, + subscribe=generate_messages, + ) + }, + ), + ) + + document = parse("subscription { newMessage }") + subscription = subscribe(schema, document) + assert isinstance(subscription, AsyncIterator) + + msgs = [] + async for result in subscription: + assert result.errors is None + assert result.data is not None + msgs.append(result.data["newMessage"]) + assert msgs == ["Hello", "Dolly"] + if hasattr(subscription, "aclose"): + await subscription.aclose()