Skip to content

chore: asynchronously gather execution results #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions graphql_server/aiohttp/graphqlview.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import copy
from collections.abc import MutableMapping
from functools import partial
from typing import List

from aiohttp import web
from graphql import ExecutionResult, GraphQLError, specified_rules
from graphql import GraphQLError, specified_rules
from graphql.pyutils import is_awaitable
from graphql.type.schema import GraphQLSchema

from graphql_server import (
Expand All @@ -22,6 +24,7 @@
GraphiQLOptions,
render_graphiql_async,
)
from graphql_server.utils import wrap_in_async


class GraphQLView:
Expand Down Expand Up @@ -166,10 +169,14 @@ async def __call__(self, request):
)

exec_res = (
[
ex if ex is None or isinstance(ex, ExecutionResult) else await ex
for ex in execution_results
]
await asyncio.gather(
*(
ex
if ex is not None and is_awaitable(ex)
else wrap_in_async(lambda: ex)()
for ex in execution_results
)
)
if self.enable_async
else execution_results
)
Expand Down
17 changes: 12 additions & 5 deletions graphql_server/quart/graphqlview.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import copy
from collections.abc import MutableMapping
from functools import partial
from typing import List

from graphql import ExecutionResult, specified_rules
from graphql import specified_rules
from graphql.error import GraphQLError
from graphql.pyutils import is_awaitable
from graphql.type.schema import GraphQLSchema
from quart import Response, render_template_string, request
from quart.views import View
Expand All @@ -24,6 +26,7 @@
GraphiQLOptions,
render_graphiql_sync,
)
from graphql_server.utils import wrap_in_async


class GraphQLView(View):
Expand Down Expand Up @@ -113,10 +116,14 @@ async def dispatch_request(self):
execution_context_class=self.get_execution_context_class(),
)
exec_res = (
[
ex if ex is None or isinstance(ex, ExecutionResult) else await ex
for ex in execution_results
]
await asyncio.gather(
*(
ex
if ex is not None and is_awaitable(ex)
else wrap_in_async(lambda: ex)()
for ex in execution_results
)
)
if self.enable_async
else execution_results
)
Expand Down
19 changes: 12 additions & 7 deletions graphql_server/sanic/graphqlview.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import copy
from cgi import parse_header
from collections.abc import MutableMapping
from functools import partial
from typing import List

from graphql import ExecutionResult, GraphQLError, specified_rules
from graphql import GraphQLError, specified_rules
from graphql.pyutils import is_awaitable
from graphql.type.schema import GraphQLSchema
from sanic.response import HTTPResponse, html
from sanic.views import HTTPMethodView
Expand All @@ -24,6 +26,7 @@
GraphiQLOptions,
render_graphiql_async,
)
from graphql_server.utils import wrap_in_async


class GraphQLView(HTTPMethodView):
Expand Down Expand Up @@ -119,12 +122,14 @@ async def __handle_request(self, request, *args, **kwargs):
execution_context_class=self.get_execution_context_class(),
)
exec_res = (
[
ex
if ex is None or isinstance(ex, ExecutionResult)
else await ex
for ex in execution_results
]
await asyncio.gather(
*(
ex
if ex is not None and is_awaitable(ex)
else wrap_in_async(lambda: ex)()
for ex in execution_results
)
)
if self.enable_async
else execution_results
)
Expand Down
25 changes: 25 additions & 0 deletions graphql_server/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import sys
from typing import Awaitable, Callable, TypeVar

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


__all__ = ["wrap_in_async"]

P = ParamSpec("P")
R = TypeVar("R")


def wrap_in_async(f: Callable[P, R]) -> Callable[P, Awaitable[R]]:
"""Convert a sync callable (normal def or lambda) to a coroutine (async def).

This is similar to asyncio.coroutine which was deprecated in Python 3.8.
"""

async def f_async(*args: P.args, **kwargs: P.kwargs) -> R:
return f(*args, **kwargs)

return f_async
4 changes: 2 additions & 2 deletions tests/quart/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from tests.quart.schema import Schema


def create_app(path="/graphql", **kwargs):
def create_app(path="/graphql", schema=Schema, **kwargs):
server = Quart(__name__)
server.debug = True
server.add_url_rule(
path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs)
path, view_func=GraphQLView.as_view("graphql", schema=schema, **kwargs)
)
return server

Expand Down
50 changes: 49 additions & 1 deletion tests/quart/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

from graphql.type.definition import (
GraphQLArgument,
GraphQLField,
Expand All @@ -12,6 +14,7 @@ def resolve_raises(*_):
raise Exception("Throws!")


# Sync schema
QueryRootType = GraphQLObjectType(
name="QueryRoot",
fields={
Expand All @@ -36,7 +39,7 @@ def resolve_raises(*_):
"test": GraphQLField(
type_=GraphQLString,
args={"who": GraphQLArgument(GraphQLString)},
resolve=lambda obj, info, who="World": "Hello %s" % who,
resolve=lambda obj, info, who="World": f"Hello {who}",
),
},
)
Expand All @@ -49,3 +52,48 @@ def resolve_raises(*_):
)

Schema = GraphQLSchema(QueryRootType, MutationRootType)


# Schema with async methods
async def resolver_field_async_1(_obj, info):
await asyncio.sleep(0.001)
return "hey"


async def resolver_field_async_2(_obj, info):
await asyncio.sleep(0.003)
return "hey2"


def resolver_field_sync(_obj, info):
return "hey3"


AsyncQueryType = GraphQLObjectType(
name="AsyncQueryType",
fields={
"a": GraphQLField(GraphQLString, resolve=resolver_field_async_1),
"b": GraphQLField(GraphQLString, resolve=resolver_field_async_2),
"c": GraphQLField(GraphQLString, resolve=resolver_field_sync),
},
)


def resolver_field_sync_1(_obj, info):
return "synced_one"


def resolver_field_sync_2(_obj, info):
return "synced_two"


SyncQueryType = GraphQLObjectType(
"SyncQueryType",
{
"a": GraphQLField(GraphQLString, resolve=resolver_field_sync_1),
"b": GraphQLField(GraphQLString, resolve=resolver_field_sync_2),
},
)

AsyncSchema = GraphQLSchema(AsyncQueryType)
SyncSchema = GraphQLSchema(SyncQueryType)
15 changes: 15 additions & 0 deletions tests/quart/test_graphqlview.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ..utils import RepeatExecutionContext
from .app import create_app
from .schema import AsyncSchema


@pytest.fixture
Expand Down Expand Up @@ -736,6 +737,20 @@ async def test_batch_allows_post_with_operation_name(
]


@pytest.mark.asyncio
@pytest.mark.parametrize("app", [create_app(schema=AsyncSchema, enable_async=True)])
async def test_async_schema(app, client):
response = await execute_client(
app,
client,
query="{a,b,c}",
)

assert response.status_code == 200
result = await response.get_data(as_text=True)
assert response_json(result) == {"data": {"a": "hey", "b": "hey2", "c": "hey3"}}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"app", [create_app(execution_context_class=RepeatExecutionContext)]
Expand Down