Skip to content

Merge aiohttp-graphql #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions graphql_server/aiohttp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .graphqlview import GraphQLView

__all__ = ["GraphQLView"]
217 changes: 217 additions & 0 deletions graphql_server/aiohttp/graphqlview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import copy
from collections.abc import MutableMapping
from functools import partial

from aiohttp import web
from graphql import GraphQLError
from graphql.type.schema import GraphQLSchema

from graphql_server import (
HttpQueryError,
encode_execution_results,
format_error_default,
json_encode,
load_json_body,
run_http_query,
)

from .render_graphiql import render_graphiql


class GraphQLView:
schema = None
root_value = None
context = None
pretty = False
graphiql = False
graphiql_version = None
graphiql_template = None
middleware = None
batch = False
jinja_env = None
max_age = 86400
enable_async = False
subscriptions = None

accepted_methods = ["GET", "POST", "PUT", "DELETE"]

format_error = staticmethod(format_error_default)
encode = staticmethod(json_encode)

def __init__(self, **kwargs):
super(GraphQLView, self).__init__()
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)

assert isinstance(
self.schema, GraphQLSchema
), "A Schema is required to be provided to GraphQLView."

def get_root_value(self):
return self.root_value

def get_context(self, request):
context = (
copy.copy(self.context)
if self.context and isinstance(self.context, MutableMapping)
else {}
)
if isinstance(context, MutableMapping) and "request" not in context:
context.update({"request": request})
return context

def get_middleware(self):
return self.middleware

# This method can be static
async def parse_body(self, request):
content_type = request.content_type
# request.text() is the aiohttp equivalent to
# request.body.decode("utf8")
if content_type == "application/graphql":
r_text = await request.text()
return {"query": r_text}

if content_type == "application/json":
text = await request.text()
return load_json_body(text)

if content_type in (
"application/x-www-form-urlencoded",
"multipart/form-data",
):
# TODO: seems like a multidict would be more appropriate
# than casting it and de-duping variables. Alas, it's what
# graphql-python wants.
return dict(await request.post())

return {}

def render_graphiql(self, params, result):
return render_graphiql(
jinja_env=self.jinja_env,
params=params,
result=result,
graphiql_version=self.graphiql_version,
graphiql_template=self.graphiql_template,
subscriptions=self.subscriptions,
)

# TODO:
# use this method to replace flask and sanic
# checks as this is equivalent to `should_display_graphiql` and
# `request_wants_html` methods.
def is_graphiql(self, request):
return all(
[
self.graphiql,
request.method.lower() == "get",
"raw" not in request.query,
any(
[
"text/html" in request.headers.get("accept", {}),
"*/*" in request.headers.get("accept", {}),
]
),
]
)

# TODO: Same stuff as above method.
def is_pretty(self, request):
return any(
[self.pretty, self.is_graphiql(request), request.query.get("pretty")]
)

async def __call__(self, request):
try:
data = await self.parse_body(request)
request_method = request.method.lower()
is_graphiql = self.is_graphiql(request)
is_pretty = self.is_pretty(request)

# TODO: way better than if-else so better
# implement this too on flask and sanic
if request_method == "options":
return self.process_preflight(request)

execution_results, all_params = run_http_query(
self.schema,
request_method,
data,
query_data=request.query,
batch_enabled=self.batch,
catch=is_graphiql,
# Execute options
run_sync=not self.enable_async,
root_value=self.get_root_value(),
context_value=self.get_context(request),
middleware=self.get_middleware(),
)

exec_res = (
[await ex for ex in execution_results]
if self.enable_async
else execution_results
)
result, status_code = encode_execution_results(
exec_res,
is_batch=isinstance(data, list),
format_error=self.format_error,
encode=partial(self.encode, pretty=is_pretty), # noqa: ignore
)

if is_graphiql:
return await self.render_graphiql(params=all_params[0], result=result)

return web.Response(
text=result, status=status_code, content_type="application/json",
)

except HttpQueryError as err:
parsed_error = GraphQLError(err.message)
return web.Response(
body=self.encode(dict(errors=[self.format_error(parsed_error)])),
status=err.status_code,
headers=err.headers,
content_type="application/json",
)

def process_preflight(self, request):
"""
Preflight request support for apollo-client
https://www.w3.org/TR/cors/#resource-preflight-requests
"""
headers = request.headers
origin = headers.get("Origin", "")
method = headers.get("Access-Control-Request-Method", "").upper()

if method and method in self.accepted_methods:
return web.Response(
status=200,
headers={
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": ", ".join(self.accepted_methods),
"Access-Control-Max-Age": str(self.max_age),
},
)
return web.Response(status=400)

@classmethod
def attach(cls, app, *, route_path="/graphql", route_name="graphql", **kwargs):
view = cls(**kwargs)
app.router.add_route("*", route_path, _asyncify(view), name=route_name)


def _asyncify(handler):
"""Return an async version of the given handler.

This is mainly here because ``aiohttp`` can't infer the async definition of
:py:meth:`.GraphQLView.__call__` and raises a :py:class:`DeprecationWarning`
in tests. Wrapping it into an async function avoids the noisy warning.
"""

async def _dispatch(request):
return await handler(request)

return _dispatch
Loading