From 20126134466c7235dfcff2e5db6a4c285699123f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 21 Apr 2023 14:29:28 +0000 Subject: [PATCH 1/7] Rewrite MapAsyncIterable using async generator semantics iterate over and aclose() the iterator Mapping method must be async turn MapAsyncIterable into an AsyncGenerator --- docs/conf.py | 2 +- docs/modules/execution.rst | 2 - src/graphql/__init__.py | 4 +- src/graphql/execution/__init__.py | 4 +- src/graphql/execution/execute.py | 12 +- src/graphql/execution/map_async_iterable.py | 124 +++------------- tests/execution/test_customize.py | 8 +- tests/execution/test_map_async_iterable.py | 152 ++++++++------------ 8 files changed, 92 insertions(+), 216 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 28ac1c71..55f6c781 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -151,7 +151,7 @@ Middleware asyncio.events.AbstractEventLoop graphql.execution.collect_fields.FieldsAndPatches -graphql.execution.map_async_iterable.MapAsyncIterable +graphql.execution.map_async_iterable.map_async_iterable graphql.execution.Middleware graphql.execution.execute.DeferredFragmentRecord graphql.execution.execute.ExperimentalExecuteMultipleResults diff --git a/docs/modules/execution.rst b/docs/modules/execution.rst index 82147930..5b33b03d 100644 --- a/docs/modules/execution.rst +++ b/docs/modules/execution.rst @@ -57,8 +57,6 @@ Execution .. autofunction:: create_source_event_stream -.. autoclass:: MapAsyncIterable - .. autoclass:: Middleware .. autoclass:: MiddlewareManager diff --git a/src/graphql/__init__.py b/src/graphql/__init__.py index cb946aba..714d4215 100644 --- a/src/graphql/__init__.py +++ b/src/graphql/__init__.py @@ -439,7 +439,7 @@ # Subscription subscribe, create_source_event_stream, - MapAsyncIterable, + map_async_iterable, # Middleware Middleware, MiddlewareManager, @@ -707,7 +707,7 @@ "MiddlewareManager", "subscribe", "create_source_event_stream", - "MapAsyncIterable", + "map_async_iterable", "validate", "ValidationContext", "ValidationRule", diff --git a/src/graphql/execution/__init__.py b/src/graphql/execution/__init__.py index 6487c33d..54fb8b5d 100644 --- a/src/graphql/execution/__init__.py +++ b/src/graphql/execution/__init__.py @@ -32,7 +32,7 @@ FormattedIncrementalResult, Middleware, ) -from .map_async_iterable import MapAsyncIterable +from .map_async_iterable import map_async_iterable from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values @@ -62,7 +62,7 @@ "FormattedIncrementalDeferResult", "FormattedIncrementalStreamResult", "FormattedIncrementalResult", - "MapAsyncIterable", + "map_async_iterable", "Middleware", "MiddlewareManager", "get_argument_values", diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index dd69658f..4766287c 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -70,7 +70,7 @@ ) from .collect_fields import FieldsAndPatches, collect_fields, collect_subfields from .flatten_async_iterable import flatten_async_iterable -from .map_async_iterable import MapAsyncIterable +from .map_async_iterable import map_async_iterable from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values @@ -1654,7 +1654,7 @@ async def callback(payload: Any) -> AsyncGenerator: await result if isawaitable(result) else result # type: ignore ) - return flatten_async_iterable(MapAsyncIterable(result_or_stream, callback)) + return flatten_async_iterable(map_async_iterable(result_or_stream, callback)) def execute_deferred_fragment( self, @@ -2319,18 +2319,20 @@ def subscribe( if isinstance(result, ExecutionResult): return result if isinstance(result, AsyncIterable): - return MapAsyncIterable(result, ensure_single_execution_result) + return map_async_iterable(result, ensure_single_execution_result) async def await_result() -> Union[AsyncIterator[ExecutionResult], ExecutionResult]: result_or_iterable = await result # type: ignore if isinstance(result_or_iterable, AsyncIterable): - return MapAsyncIterable(result_or_iterable, ensure_single_execution_result) + return map_async_iterable( + result_or_iterable, ensure_single_execution_result + ) return result_or_iterable return await_result() -def ensure_single_execution_result( +async def ensure_single_execution_result( result: Union[ ExecutionResult, InitialIncrementalExecutionResult, diff --git a/src/graphql/execution/map_async_iterable.py b/src/graphql/execution/map_async_iterable.py index 84bd3f4a..786154b6 100644 --- a/src/graphql/execution/map_async_iterable.py +++ b/src/graphql/execution/map_async_iterable.py @@ -1,118 +1,26 @@ from __future__ import annotations # Python < 3.10 -from asyncio import CancelledError, Event, Task, ensure_future, wait -from concurrent.futures import FIRST_COMPLETED -from inspect import isasyncgen, isawaitable -from types import TracebackType -from typing import Any, AsyncIterable, Callable, Optional, Set, Type, Union, cast +from typing import Any, AsyncIterable, Awaitable, Callable -__all__ = ["MapAsyncIterable"] +__all__ = ["map_async_iterable"] -# noinspection PyAttributeOutsideInit -class MapAsyncIterable: +async def map_async_iterable( + iterable: AsyncIterable[Any], callback: Callable[[Any], Awaitable[Any]] +) -> None: """Map an AsyncIterable over a callback function. - Given an AsyncIterable and a callback function, return an AsyncIterator which - produces values mapped via calling the callback function. - - When the resulting AsyncIterator is closed, the underlying AsyncIterable will also - be closed. + Given an AsyncIterable and an async callback callable, return an AsyncGenerator + which produces values mapped via calling the callback. + If the inner iterator supports an `aclose()` method, it will be called when + the generator finishes or closes. """ - def __init__(self, iterable: AsyncIterable, callback: Callable) -> None: - self.iterator = iterable.__aiter__() - self.callback = callback - self._close_event = Event() - - def __aiter__(self) -> MapAsyncIterable: - """Get the iterator object.""" - return self - - async def __anext__(self) -> Any: - """Get the next value of the iterator.""" - if self.is_closed: - if not isasyncgen(self.iterator): - raise StopAsyncIteration - value = await self.iterator.__anext__() - else: - aclose = ensure_future(self._close_event.wait()) - anext = ensure_future(self.iterator.__anext__()) - - try: - pending: Set[Task] = ( - await wait([aclose, anext], return_when=FIRST_COMPLETED) - )[1] - except CancelledError: - # cancel underlying tasks and close - aclose.cancel() - anext.cancel() - await self.aclose() - raise # re-raise the cancellation - - for task in pending: - task.cancel() - - if aclose.done(): - raise StopAsyncIteration - - error = anext.exception() - if error: - raise error - - value = anext.result() - - result = self.callback(value) - - return await result if isawaitable(result) else result - - async def athrow( - self, - type_: Union[BaseException, Type[BaseException]], - value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, - ) -> None: - """Throw an exception into the asynchronous iterator.""" - if self.is_closed: - return - athrow = getattr(self.iterator, "athrow", None) - if athrow: - await athrow(type_, value, traceback) - else: - await self.aclose() - if value is None: - if traceback is None: - raise type_ - value = ( - type_ - if isinstance(value, BaseException) - else cast(Type[BaseException], type_)() - ) - if traceback is not None: - value = value.with_traceback(traceback) - raise value - - async def aclose(self) -> None: - """Close the iterator.""" - if not self.is_closed: - aclose = getattr(self.iterator, "aclose", None) - if aclose: - try: - await aclose() - except RuntimeError: - pass - self.is_closed = True - - @property - def is_closed(self) -> bool: - """Check whether the iterator is closed.""" - return self._close_event.is_set() - - @is_closed.setter - def is_closed(self, value: bool) -> None: - """Mark the iterator as closed.""" - if value: - self._close_event.set() - else: - self._close_event.clear() + aiter = iterable.__aiter__() + try: + async for element in aiter: + yield await callback(element) + finally: + if hasattr(aiter, "aclose"): + await aiter.aclose() diff --git a/tests/execution/test_customize.py b/tests/execution/test_customize.py index 3dbc6d00..5b839fc8 100644 --- a/tests/execution/test_customize.py +++ b/tests/execution/test_customize.py @@ -1,6 +1,8 @@ +from inspect import isasyncgen + from pytest import mark -from graphql.execution import ExecutionContext, MapAsyncIterable, execute, subscribe +from graphql.execution import ExecutionContext, execute, subscribe from graphql.language import parse from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString @@ -77,7 +79,7 @@ async def custom_foo(): root_value=Root(), subscribe_field_resolver=lambda root, _info: root.custom_foo(), ) - assert isinstance(subscription, MapAsyncIterable) + assert isasyncgen(subscription) assert await anext(subscription) == ( {"foo": "FooValue"}, @@ -121,6 +123,6 @@ def resolve_foo(message, _info): context_value={}, execution_context_class=TestExecutionContext, ) - assert isinstance(subscription, MapAsyncIterable) + assert isasyncgen(subscription) assert await anext(subscription) == ({"foo": "bar"}, None) diff --git a/tests/execution/test_map_async_iterable.py b/tests/execution/test_map_async_iterable.py index 6406f7dd..b465f5f5 100644 --- a/tests/execution/test_map_async_iterable.py +++ b/tests/execution/test_map_async_iterable.py @@ -4,7 +4,7 @@ from pytest import mark, raises -from graphql.execution import MapAsyncIterable +from graphql.execution import map_async_iterable is_pypy = platform.python_implementation() == "PyPy" @@ -18,6 +18,14 @@ async def anext(iterator): return await iterator.__anext__() +async def map_single(x): + return x + + +async def map_doubles(x): + return x + x + + def describe_map_async_iterable(): @mark.asyncio async def maps_over_async_generator(): @@ -26,7 +34,7 @@ async def source(): yield 2 yield 3 - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = map_async_iterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -48,7 +56,7 @@ async def __anext__(self): except IndexError: raise StopAsyncIteration - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) + doubles = map_async_iterable(Iterable(), map_doubles) values = [value async for value in doubles] @@ -62,7 +70,7 @@ async def source(): yield 2 yield 3 - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = map_async_iterable(source(), map_doubles) values = [value async for value in doubles] @@ -78,7 +86,7 @@ async def source(): async def double(x): return x + x - doubles = MapAsyncIterable(source(), double) + doubles = map_async_iterable(source(), double) values = [value async for value in doubles] @@ -91,7 +99,7 @@ async def source(): yield 2 yield 3 # pragma: no cover - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = map_async_iterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -119,7 +127,7 @@ async def __anext__(self): except IndexError: # pragma: no cover raise StopAsyncIteration - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) + doubles = map_async_iterable(Iterable(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -133,8 +141,9 @@ async def __anext__(self): with raises(StopAsyncIteration): await anext(doubles) + # async iterators must not yield after aclose() is called @mark.asyncio - async def passes_through_early_return_from_async_values(): + async def ignored_generator_exit(): async def source(): try: yield 1 @@ -142,20 +151,16 @@ async def source(): yield 3 # pragma: no cover finally: yield "Done" - yield "Last" + yield "Last" # pragma: no cover - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = map_async_iterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 - # Early return - await doubles.aclose() - - # Subsequent next calls may yield from finally block - assert await anext(doubles) == "LastLast" - with raises(GeneratorExit): - assert await anext(doubles) + with raises(RuntimeError) as exc_info: + await doubles.aclose() + assert str(exc_info.value) == "async generator ignored GeneratorExit" @mark.asyncio async def allows_throwing_errors_through_async_iterable(): @@ -171,7 +176,7 @@ async def __anext__(self): except IndexError: # pragma: no cover raise StopAsyncIteration - doubles = MapAsyncIterable(Iterable(), lambda x: x + x) + doubles = map_async_iterable(Iterable(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -197,7 +202,7 @@ def __aiter__(self): async def __anext__(self): return 1 - one = MapAsyncIterable(Iterable(), lambda x: x) + one = map_async_iterable(Iterable(), map_single) assert await anext(one) == 1 @@ -223,7 +228,7 @@ def __aiter__(self): async def __anext__(self): return 1 - one = MapAsyncIterable(Iterable(), lambda x: x) + one = map_async_iterable(Iterable(), map_single) assert await anext(one) == 1 @@ -250,18 +255,14 @@ async def source(): except Exception as e: yield e - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = map_async_iterable(source(), map_doubles) assert await anext(doubles) == 2 assert await anext(doubles) == 4 # Throw error - await doubles.athrow(RuntimeError("ouch")) - - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) + with raises(RuntimeError): + await doubles.athrow(RuntimeError("ouch")) @mark.asyncio async def does_not_normally_map_over_thrown_errors(): @@ -269,7 +270,7 @@ async def source(): yield "Hello" raise RuntimeError("Goodbye") - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = map_async_iterable(source(), map_doubles) assert await anext(doubles) == "HelloHello" @@ -283,7 +284,7 @@ async def does_not_normally_map_over_externally_thrown_errors(): async def source(): yield "Hello" - doubles = MapAsyncIterable(source(), lambda x: x + x) + doubles = map_async_iterable(source(), map_doubles) assert await anext(doubles) == "HelloHello" @@ -312,18 +313,18 @@ async def __anext__(self): raise StopAsyncIteration return self.counter - def double(x): + async def double(x): return x + x for iterable in source, Source: - doubles = MapAsyncIterable(iterable(), double) + doubles = map_async_iterable(iterable(), double) await doubles.aclose() with raises(StopAsyncIteration): await anext(doubles) - doubles = MapAsyncIterable(iterable(), double) + doubles = map_async_iterable(iterable(), double) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -332,7 +333,7 @@ def double(x): with raises(StopAsyncIteration): await anext(doubles) - doubles = MapAsyncIterable(iterable(), double) + doubles = map_async_iterable(iterable(), double) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -359,7 +360,7 @@ def double(x): await doubles.aclose() - doubles = MapAsyncIterable(iterable(), double) + doubles = map_async_iterable(iterable(), double) assert await anext(doubles) == 2 assert await anext(doubles) == 4 @@ -384,7 +385,11 @@ async def source(): yield 3 # pragma: no cover singles = source() - doubles = MapAsyncIterable(singles, lambda x: x * 2) + + async def double(x): + return x * 2 + + doubles = map_async_iterable(singles, double) result = await anext(doubles) assert result == 2 @@ -394,65 +399,28 @@ async def source(): await sleep(0.05) assert not doubles_future.done() - # Unblock and watch StopAsyncIteration propagate - await doubles.aclose() - await sleep(0.05) - assert doubles_future.done() - assert isinstance(doubles_future.exception(), StopAsyncIteration) + # with python 3.8 and higher, close() cannot be used to unblock a generator. + # instead, the task should be killed. AsyncGenerators are not re-entrant. + if sys.version_info[:2] >= (3, 8): + with raises(RuntimeError): + await doubles.aclose() + doubles_future.cancel() + await sleep(0.05) + assert doubles_future.done() + with raises(CancelledError): + doubles_future.exception() + + else: + # old behaviour, where aclose() could unblock a Task + # Unblock and watch StopAsyncIteration propagate + await doubles.aclose() + await sleep(0.05) + assert doubles_future.done() + assert isinstance(doubles_future.exception(), StopAsyncIteration) with raises(StopAsyncIteration): await anext(singles) - @mark.asyncio - async def can_unset_closed_state_of_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __init__(self): - self.is_closed = False - - def __aiter__(self): - return self - - async def __anext__(self): - if self.is_closed: - raise StopAsyncIteration - try: - return items.pop(0) - except IndexError: - raise StopAsyncIteration - - async def aclose(self): - self.is_closed = True - - iterable = Iterable() - doubles = MapAsyncIterable(iterable, lambda x: x + x) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - assert not iterable.is_closed - await doubles.aclose() - assert iterable.is_closed - with raises(StopAsyncIteration): - await anext(iterable) - with raises(StopAsyncIteration): - await anext(doubles) - assert doubles.is_closed - - iterable.is_closed = False - doubles.is_closed = False - assert not doubles.is_closed - - assert await anext(doubles) == 6 - assert not doubles.is_closed - assert not iterable.is_closed - with raises(StopAsyncIteration): - await anext(iterable) - with raises(StopAsyncIteration): - await anext(doubles) - assert not doubles.is_closed - assert not iterable.is_closed - @mark.asyncio async def can_cancel_async_iterable_while_waiting(): class Iterable: @@ -475,7 +443,7 @@ async def aclose(self): self.is_closed = True iterable = Iterable() - doubles = MapAsyncIterable(iterable, lambda x: x + x) # pragma: no cover exit + doubles = map_async_iterable(iterable, map_doubles) # pragma: no cover exit cancelled = False async def iterator_task(): @@ -489,12 +457,10 @@ async def iterator_task(): task = ensure_future(iterator_task()) await sleep(0.05) assert not cancelled - assert not doubles.is_closed assert iterable.value == 1 assert not iterable.is_closed task.cancel() await sleep(0.05) assert cancelled assert iterable.value == -1 - assert doubles.is_closed assert iterable.is_closed From 48eb14aa90c233938570ee23516f9a2b9bc9fc42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 12 May 2023 08:30:40 +0000 Subject: [PATCH 2/7] update tests --- tests/execution/test_map_async_iterable.py | 470 ++------------------- 1 file changed, 36 insertions(+), 434 deletions(-) diff --git a/tests/execution/test_map_async_iterable.py b/tests/execution/test_map_async_iterable.py index b465f5f5..a3d9688a 100644 --- a/tests/execution/test_map_async_iterable.py +++ b/tests/execution/test_map_async_iterable.py @@ -1,466 +1,68 @@ -import platform -import sys -from asyncio import CancelledError, Event, ensure_future, sleep - from pytest import mark, raises from graphql.execution import map_async_iterable -is_pypy = platform.python_implementation() == "PyPy" - -try: # pragma: no cover - anext -except NameError: # pragma: no cover (Python < 3.10) - # noinspection PyShadowingBuiltins - async def anext(iterator): - """Return the next item from an async iterator.""" - return await iterator.__anext__() - - -async def map_single(x): - return x - - async def map_doubles(x): return x + x def describe_map_async_iterable(): @mark.asyncio - async def maps_over_async_generator(): - async def source(): - yield 1 - yield 2 - yield 3 - - doubles = map_async_iterable(source(), map_doubles) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - assert await anext(doubles) == 6 - with raises(StopAsyncIteration): - assert await anext(doubles) - - @mark.asyncio - async def maps_over_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __aiter__(self): - return self - - async def __anext__(self): - try: - return items.pop(0) - except IndexError: - raise StopAsyncIteration - - doubles = map_async_iterable(Iterable(), map_doubles) - - values = [value async for value in doubles] - - assert not items - assert values == [2, 4, 6] - - @mark.asyncio - async def compatible_with_async_for(): - async def source(): - yield 1 - yield 2 - yield 3 - - doubles = map_async_iterable(source(), map_doubles) - - values = [value async for value in doubles] - - assert values == [2, 4, 6] - - @mark.asyncio - async def maps_over_async_values_with_async_function(): - async def source(): - yield 1 - yield 2 - yield 3 - - async def double(x): - return x + x - - doubles = map_async_iterable(source(), double) - - values = [value async for value in doubles] - - assert values == [2, 4, 6] - - @mark.asyncio - async def allows_returning_early_from_mapped_async_generator(): - async def source(): - yield 1 - yield 2 - yield 3 # pragma: no cover - - doubles = map_async_iterable(source(), map_doubles) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Early return - await doubles.aclose() - - # Subsequent next calls - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) - - @mark.asyncio - async def allows_returning_early_from_mapped_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __aiter__(self): - return self - - async def __anext__(self): - try: - return items.pop(0) - except IndexError: # pragma: no cover - raise StopAsyncIteration - - doubles = map_async_iterable(Iterable(), map_doubles) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Early return - await doubles.aclose() - - # Subsequent next calls - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) - - # async iterators must not yield after aclose() is called - @mark.asyncio - async def ignored_generator_exit(): - async def source(): - try: - yield 1 - yield 2 - yield 3 # pragma: no cover - finally: - yield "Done" - yield "Last" # pragma: no cover - - doubles = map_async_iterable(source(), map_doubles) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - with raises(RuntimeError) as exc_info: - await doubles.aclose() - assert str(exc_info.value) == "async generator ignored GeneratorExit" - - @mark.asyncio - async def allows_throwing_errors_through_async_iterable(): - items = [1, 2, 3] - - class Iterable: - def __aiter__(self): - return self - - async def __anext__(self): - try: - return items.pop(0) - except IndexError: # pragma: no cover - raise StopAsyncIteration - - doubles = map_async_iterable(Iterable(), map_doubles) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Throw error - message = "allows throwing errors when mapping async iterable" - with raises(RuntimeError) as exc_info: - await doubles.athrow(RuntimeError(message)) - - assert str(exc_info.value) == message - - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) - - @mark.asyncio - async def allows_throwing_errors_with_values_through_async_iterables(): - class Iterable: - def __aiter__(self): - return self - - async def __anext__(self): - return 1 - - one = map_async_iterable(Iterable(), map_single) + async def test_inner_close_called(): + """ + Test that a custom iterator with aclose() gets an aclose() call + when outer is closed + """ - assert await anext(one) == 1 - - # Throw error with value passed separately - try: - raise RuntimeError("Ouch") - except RuntimeError as error: - with raises(RuntimeError, match="Ouch") as exc_info: - await one.athrow(error.__class__, error) - - assert exc_info.value is error - assert exc_info.tb is error.__traceback__ + class Inner: + def __init__(self): + self.closed = False - with raises(StopAsyncIteration): - await anext(one) + async def aclose(self): + self.closed = True - @mark.asyncio - async def allows_throwing_errors_with_traceback_through_async_iterables(): - class Iterable: def __aiter__(self): return self async def __anext__(self): return 1 - one = map_async_iterable(Iterable(), map_single) - - assert await anext(one) == 1 - - # Throw error with traceback passed separately - try: - raise RuntimeError("Ouch") - except RuntimeError as error: - with raises(RuntimeError) as exc_info: - await one.athrow(error.__class__, None, error.__traceback__) - - assert exc_info.tb and error.__traceback__ - assert exc_info.tb.tb_frame is error.__traceback__.tb_frame - - with raises(StopAsyncIteration): - await anext(one) + inner = Inner() + outer = map_async_iterable(inner, map_doubles) + it = outer.__aiter__() + assert await it.__anext__() == 2 + assert not inner.closed + await outer.aclose() + assert inner.closed @mark.asyncio - async def passes_through_caught_errors_through_async_generators(): - async def source(): - try: - yield 1 - yield 2 - yield 3 # pragma: no cover - except Exception as e: - yield e - - doubles = map_async_iterable(source(), map_doubles) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Throw error - with raises(RuntimeError): - await doubles.athrow(RuntimeError("ouch")) - - @mark.asyncio - async def does_not_normally_map_over_thrown_errors(): - async def source(): - yield "Hello" - raise RuntimeError("Goodbye") - - doubles = map_async_iterable(source(), map_doubles) - - assert await anext(doubles) == "HelloHello" - - with raises(RuntimeError) as exc_info: - await anext(doubles) + async def test_inner_close_called_on_callback_err(): + """ + Test that a custom iterator with aclose() gets an aclose() call + when the callback errors and the outer iterator aborts. + """ - assert str(exc_info.value) == "Goodbye" - - @mark.asyncio - async def does_not_normally_map_over_externally_thrown_errors(): - async def source(): - yield "Hello" - - doubles = map_async_iterable(source(), map_doubles) - - assert await anext(doubles) == "HelloHello" - - with raises(RuntimeError) as exc_info: - await doubles.athrow(RuntimeError("Goodbye")) - - assert str(exc_info.value) == "Goodbye" - - @mark.asyncio - async def can_use_simple_iterable_instead_of_generator(): - async def source(): - yield 1 - yield 2 - yield 3 - - class Source: + class Inner: def __init__(self): - self.counter = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - self.counter += 1 - if self.counter > 3: - raise StopAsyncIteration - return self.counter - - async def double(x): - return x + x - - for iterable in source, Source: - doubles = map_async_iterable(iterable(), double) - - await doubles.aclose() + self.closed = False - with raises(StopAsyncIteration): - await anext(doubles) - - doubles = map_async_iterable(iterable(), double) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - assert await anext(doubles) == 6 - - with raises(StopAsyncIteration): - await anext(doubles) - - doubles = map_async_iterable(iterable(), double) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - # Throw error - with raises(RuntimeError) as exc_info: - await doubles.athrow(RuntimeError("ouch")) - - assert str(exc_info.value) == "ouch" - - with raises(StopAsyncIteration): - await anext(doubles) - with raises(StopAsyncIteration): - await anext(doubles) - - # no more exceptions should be thrown - if is_pypy: - # need to investigate why this is needed with PyPy - await doubles.aclose() # pragma: no cover - await doubles.athrow(RuntimeError("no more ouch")) - - with raises(StopAsyncIteration): - await anext(doubles) - - await doubles.aclose() - - doubles = map_async_iterable(iterable(), double) - - assert await anext(doubles) == 2 - assert await anext(doubles) == 4 - - try: - raise ValueError("bad") - except ValueError: - tb = sys.exc_info()[2] - - # Throw error - with raises(ValueError): - await doubles.athrow(ValueError, None, tb) - - await sleep(0) - - @mark.asyncio - async def stops_async_iteration_on_close(): - async def source(): - yield 1 - await Event().wait() # Block forever - yield 2 # pragma: no cover - yield 3 # pragma: no cover - - singles = source() - - async def double(x): - return x * 2 - - doubles = map_async_iterable(singles, double) - - result = await anext(doubles) - assert result == 2 - - # Make sure it is blocked - doubles_future = ensure_future(anext(doubles)) - await sleep(0.05) - assert not doubles_future.done() - - # with python 3.8 and higher, close() cannot be used to unblock a generator. - # instead, the task should be killed. AsyncGenerators are not re-entrant. - if sys.version_info[:2] >= (3, 8): - with raises(RuntimeError): - await doubles.aclose() - doubles_future.cancel() - await sleep(0.05) - assert doubles_future.done() - with raises(CancelledError): - doubles_future.exception() - - else: - # old behaviour, where aclose() could unblock a Task - # Unblock and watch StopAsyncIteration propagate - await doubles.aclose() - await sleep(0.05) - assert doubles_future.done() - assert isinstance(doubles_future.exception(), StopAsyncIteration) - - with raises(StopAsyncIteration): - await anext(singles) - - @mark.asyncio - async def can_cancel_async_iterable_while_waiting(): - class Iterable: - def __init__(self): - self.is_closed = False - self.value = 1 + async def aclose(self): + self.closed = True def __aiter__(self): return self async def __anext__(self): - try: - await sleep(0.5) - return self.value # pragma: no cover - except CancelledError: - self.value = -1 - raise - - async def aclose(self): - self.is_closed = True - - iterable = Iterable() - doubles = map_async_iterable(iterable, map_doubles) # pragma: no cover exit - cancelled = False + return 1 - async def iterator_task(): - nonlocal cancelled - try: - async for _ in doubles: - assert False # pragma: no cover - except CancelledError: - cancelled = True + async def callback(v): + raise RuntimeError - task = ensure_future(iterator_task()) - await sleep(0.05) - assert not cancelled - assert iterable.value == 1 - assert not iterable.is_closed - task.cancel() - await sleep(0.05) - assert cancelled - assert iterable.value == -1 - assert iterable.is_closed + inner = Inner() + outer = map_async_iterable(inner, callback) + it = outer.__aiter__() + assert not inner.closed + with raises(RuntimeError): + await it.__anext__() + assert inner.closed From 422644676c354d92e70a17e921dfe0df1d5c932e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 12 May 2023 08:45:15 +0000 Subject: [PATCH 3/7] test inner async generator aclose --- tests/execution/test_map_async_iterable.py | 33 +++++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/tests/execution/test_map_async_iterable.py b/tests/execution/test_map_async_iterable.py index a3d9688a..1462645a 100644 --- a/tests/execution/test_map_async_iterable.py +++ b/tests/execution/test_map_async_iterable.py @@ -57,12 +57,37 @@ async def __anext__(self): return 1 async def callback(v): - raise RuntimeError + raise RuntimeError() inner = Inner() outer = map_async_iterable(inner, callback) - it = outer.__aiter__() - assert not inner.closed with raises(RuntimeError): - await it.__anext__() + async for _ in outer: + pass assert inner.closed + + @mark.asyncio + async def test_inner_exit_on_callback_err(): + """ + Test that a custom iterator with aclose() gets an aclose() call + when the callback errors and the outer iterator aborts. + """ + + inner_exit = False + + async def inner(): + nonlocal inner_exit + try: + while True: + yield 1 + except GeneratorExit: + inner_exit = True + + async def callback(v): + raise RuntimeError + + outer = map_async_iterable(inner(), callback) + with raises(RuntimeError): + async for _ in outer: + pass + assert inner_exit From ac9d08f589eb1a475c2aa2f28e7209a9ce41f2de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 28 May 2023 14:07:27 +0000 Subject: [PATCH 4/7] Fix PyProject.toml on windows, description may not start with \r\n --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c8b6ba18..f9e7e86c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "graphql-core" version = "3.3.0a2" -description = """ +description = """\ GraphQL-core is a Python port of GraphQL.js,\ the JavaScript reference implementation for GraphQL.""" license = "MIT" From fe16e5a39100c867da66a3ccd89fc71ae12743aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 28 May 2023 14:41:09 +0000 Subject: [PATCH 5/7] Move both async generatrs into common iterators module --- src/graphql/execution/__init__.py | 2 +- src/graphql/execution/execute.py | 3 +- ...flatten_async_iterable.py => iterators.py} | 34 +++++++++++++++++-- src/graphql/execution/map_async_iterable.py | 26 -------------- .../execution/test_flatten_async_iterable.py | 2 +- 5 files changed, 35 insertions(+), 32 deletions(-) rename src/graphql/execution/{flatten_async_iterable.py => iterators.py} (51%) delete mode 100644 src/graphql/execution/map_async_iterable.py diff --git a/src/graphql/execution/__init__.py b/src/graphql/execution/__init__.py index 54fb8b5d..9bb89131 100644 --- a/src/graphql/execution/__init__.py +++ b/src/graphql/execution/__init__.py @@ -32,7 +32,7 @@ FormattedIncrementalResult, Middleware, ) -from .map_async_iterable import map_async_iterable +from .iterators import map_async_iterable from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index 4766287c..a8b52e4a 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -69,8 +69,7 @@ is_object_type, ) from .collect_fields import FieldsAndPatches, collect_fields, collect_subfields -from .flatten_async_iterable import flatten_async_iterable -from .map_async_iterable import map_async_iterable +from .iterators import flatten_async_iterable, map_async_iterable from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values diff --git a/src/graphql/execution/flatten_async_iterable.py b/src/graphql/execution/iterators.py similarity index 51% rename from src/graphql/execution/flatten_async_iterable.py rename to src/graphql/execution/iterators.py index 7c0e0721..9a47c2e7 100644 --- a/src/graphql/execution/flatten_async_iterable.py +++ b/src/graphql/execution/iterators.py @@ -1,4 +1,14 @@ -from typing import AsyncGenerator, AsyncIterable, TypeVar, Union +from __future__ import annotations # Python < 3.10 + +from typing import ( + Any, + AsyncGenerator, + AsyncIterable, + Awaitable, + Callable, + TypeVar, + Union, +) try: @@ -18,7 +28,7 @@ async def aclosing(thing): AsyncIterableOrGenerator = Union[AsyncGenerator[T, None], AsyncIterable[T]] -__all__ = ["flatten_async_iterable"] +__all__ = ["flatten_async_iterable", "map_async_iterable"] async def flatten_async_iterable( @@ -34,3 +44,23 @@ async def flatten_async_iterable( async with aclosing(sub_iterator) as items: # type: ignore async for item in items: yield item + + +async def map_async_iterable( + iterable: AsyncIterable[T], callback: Callable[[T], Awaitable[Any]] +) -> None: + """Map an AsyncIterable over a callback function. + + Given an AsyncIterable and an async callback callable, return an AsyncGenerator + which produces values mapped via calling the callback. + If the inner iterator supports an `aclose()` method, it will be called when + the generator finishes or closes. + """ + + aiter = iterable.__aiter__() + try: + async for element in aiter: + yield await callback(element) + finally: + if hasattr(aiter, "aclose"): + await aiter.aclose() diff --git a/src/graphql/execution/map_async_iterable.py b/src/graphql/execution/map_async_iterable.py deleted file mode 100644 index 786154b6..00000000 --- a/src/graphql/execution/map_async_iterable.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations # Python < 3.10 - -from typing import Any, AsyncIterable, Awaitable, Callable - - -__all__ = ["map_async_iterable"] - - -async def map_async_iterable( - iterable: AsyncIterable[Any], callback: Callable[[Any], Awaitable[Any]] -) -> None: - """Map an AsyncIterable over a callback function. - - Given an AsyncIterable and an async callback callable, return an AsyncGenerator - which produces values mapped via calling the callback. - If the inner iterator supports an `aclose()` method, it will be called when - the generator finishes or closes. - """ - - aiter = iterable.__aiter__() - try: - async for element in aiter: - yield await callback(element) - finally: - if hasattr(aiter, "aclose"): - await aiter.aclose() diff --git a/tests/execution/test_flatten_async_iterable.py b/tests/execution/test_flatten_async_iterable.py index de9c5499..49ead410 100644 --- a/tests/execution/test_flatten_async_iterable.py +++ b/tests/execution/test_flatten_async_iterable.py @@ -2,7 +2,7 @@ from pytest import mark, raises -from graphql.execution.flatten_async_iterable import flatten_async_iterable +from graphql.execution.iterators import flatten_async_iterable try: # pragma: no cover From 0b3208054ff5438bc195d96c5a8600d8f63d2bef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 29 May 2023 15:34:09 +0000 Subject: [PATCH 6/7] Add unit tests for manual async iterator in subscriptions --- tests/execution/test_subscribe.py | 48 +++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/execution/test_subscribe.py b/tests/execution/test_subscribe.py index d10edce6..619fce7c 100644 --- a/tests/execution/test_subscribe.py +++ b/tests/execution/test_subscribe.py @@ -1187,3 +1187,51 @@ async def resolve_message(message, _info): assert isinstance(subscription, AsyncIterator) assert await anext(subscription) == ({"newMessage": "Hello"}, None) + + @mark.asyncio + async def custom_async_iterator(): + class CustomAsyncIterator: + def __init__(self, events): + self.events = events + + def __aiter__(self): + return self + + async def __anext__(self): + await asyncio.sleep(0) + if not self.events: + raise StopAsyncIteration + return self.events.pop(0) + + def generate_messages(_obj, _info): + return CustomAsyncIterator(["Hello", "Dolly"]) + + async def resolve_message(message, _info): + await asyncio.sleep(0) + return message + + schema = GraphQLSchema( + query=QueryType, + subscription=GraphQLObjectType( + "Subscription", + { + "newMessage": GraphQLField( + GraphQLString, + resolve=resolve_message, + subscribe=generate_messages, + ) + }, + ), + ) + + document = parse("subscription { newMessage }") + subscription = subscribe(schema, document) + assert isinstance(subscription, AsyncIterator) + + msgs = [] + async for msg in subscription: + a, b = msg + assert b is None + msgs.append(a["newMessage"]) + assert msgs == ["Hello", "Dolly"] + await subscription.aclose() From 7231bf4528235175c18bda99f77161ec180a28f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 29 May 2023 15:44:04 +0000 Subject: [PATCH 7/7] Fix typing --- src/graphql/execution/iterators.py | 5 +++-- tests/execution/test_subscribe.py | 11 ++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/graphql/execution/iterators.py b/src/graphql/execution/iterators.py index 9a47c2e7..c3479175 100644 --- a/src/graphql/execution/iterators.py +++ b/src/graphql/execution/iterators.py @@ -25,6 +25,7 @@ async def aclosing(thing): T = TypeVar("T") +V = TypeVar("V") AsyncIterableOrGenerator = Union[AsyncGenerator[T, None], AsyncIterable[T]] @@ -47,8 +48,8 @@ async def flatten_async_iterable( async def map_async_iterable( - iterable: AsyncIterable[T], callback: Callable[[T], Awaitable[Any]] -) -> None: + iterable: AsyncIterable[T], callback: Callable[[T], Awaitable[V]] +) -> AsyncGenerator[V, None]: """Map an AsyncIterable over a callback function. Given an AsyncIterable and an async callback callable, return an AsyncGenerator diff --git a/tests/execution/test_subscribe.py b/tests/execution/test_subscribe.py index 619fce7c..b3364c9a 100644 --- a/tests/execution/test_subscribe.py +++ b/tests/execution/test_subscribe.py @@ -1229,9 +1229,10 @@ async def resolve_message(message, _info): assert isinstance(subscription, AsyncIterator) msgs = [] - async for msg in subscription: - a, b = msg - assert b is None - msgs.append(a["newMessage"]) + async for result in subscription: + assert result.errors is None + assert result.data is not None + msgs.append(result.data["newMessage"]) assert msgs == ["Hello", "Dolly"] - await subscription.aclose() + if hasattr(subscription, "aclose"): + await subscription.aclose()