diff --git a/docs/api-guide/views.md b/docs/api-guide/views.md index 878a291b22..07f775f71a 100644 --- a/docs/api-guide/views.md +++ b/docs/api-guide/views.md @@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation. def view(request): return Response({"message": "Will not appear in schema!"}) +# Async Views + +When using Django 4.1 and above, REST framework allows you to work with async class and function based views. + +For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async. + +For example: + + class AsyncView(APIView): + async def get(self, request): + return Response({"message": "This is an async class based view."}) + + + @api_view(['GET']) + async def async_view(request): + return Response({"message": "This is an async function based view."}) [cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html [cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html @@ -224,4 +240,3 @@ You may pass `None` in order to exclude the view from schema generation. [throttling]: throttling.md [schemas]: schemas.md [classy-drf]: http://www.cdrf.co - diff --git a/rest_framework/compat.py b/rest_framework/compat.py index ac5cbc572a..9cb5b76f33 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -41,6 +41,14 @@ def distinct(queryset, base): uritemplate = None +# async_to_sync is required for async view support +if django.VERSION >= (4, 1): + from asgiref.sync import async_to_sync, sync_to_async +else: + async_to_sync = None + sync_to_async = None + + # coreschema is optional try: import coreschema diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 3b572c09ef..1a56f7fa26 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -6,6 +6,7 @@ based views, as well as the `@action` decorator, which is used to annotate methods on viewsets that should be included by routers. """ +import asyncio import types from django.forms.utils import pretty_name @@ -46,8 +47,14 @@ def decorator(func): allowed_methods = set(http_method_names) | {'options'} WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] - def handler(self, *args, **kwargs): - return func(*args, **kwargs) + view_is_async = asyncio.iscoroutinefunction(func) + + if view_is_async: + async def handler(self, *args, **kwargs): + return await func(*args, **kwargs) + else: + def handler(self, *args, **kwargs): + return func(*args, **kwargs) for method in http_method_names: setattr(WrappedAPIView, method.lower(), handler) diff --git a/rest_framework/views.py b/rest_framework/views.py index 5b06220691..c8afa2b802 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -11,6 +11,7 @@ from django.views.decorators.csrf import csrf_exempt from django.views.generic import View +from rest_framework.compat import sync_to_async from rest_framework import exceptions, status from rest_framework.request import Request from rest_framework.response import Response @@ -482,9 +483,9 @@ def raise_uncaught_exception(self, exc): # Note: Views are made CSRF exempt from within `as_view` as to prevent # accidental removal of this exemption in cases where `dispatch` needs to # be overridden. - def dispatch(self, request, *args, **kwargs): + def sync_dispatch(self, request, *args, **kwargs): """ - `.dispatch()` is pretty much the same as Django's regular dispatch, + `.sync_dispatch()` is pretty much the same as Django's regular dispatch, but with extra hooks for startup, finalize, and exception handling. """ self.args = args @@ -511,11 +512,60 @@ def dispatch(self, request, *args, **kwargs): self.response = self.finalize_response(request, response, *args, **kwargs) return self.response + async def async_dispatch(self, request, *args, **kwargs): + """ + `.async_dispatch()` is pretty much the same as Django's regular dispatch, + except for awaiting the handler function and with extra hooks for startup, + finalize, and exception handling. + """ + self.args = args + self.kwargs = kwargs + request = self.initialize_request(request, *args, **kwargs) + self.request = request + self.headers = self.default_response_headers # deprecate? + + try: + await sync_to_async(self.initial)(request, *args, **kwargs) + + # Get the appropriate handler method + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), + self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed + + response = await handler(request, *args, **kwargs) + + except Exception as exc: + response = self.handle_exception(exc) + + self.response = self.finalize_response(request, response, *args, **kwargs) + return self.response + + def dispatch(self, request, *args, **kwargs): + """ + Dispatch checks if the view is async or not and uses the respective + async or sync dispatch method. + """ + if getattr(self, 'view_is_async', False): + return self.async_dispatch(request, *args, **kwargs) + else: + return self.sync_dispatch(request, *args, **kwargs) + def options(self, request, *args, **kwargs): """ Handler method for HTTP 'OPTIONS' request. """ - if self.metadata_class is None: - return self.http_method_not_allowed(request, *args, **kwargs) - data = self.metadata_class().determine_metadata(request, self) - return Response(data, status=status.HTTP_200_OK) + def func(): + if self.metadata_class is None: + return self.http_method_not_allowed(request, *args, **kwargs) + data = self.metadata_class().determine_metadata(request, self) + return Response(data, status=status.HTTP_200_OK) + + if getattr(self, 'view_is_async', False): + async def handler(): + return await sync_to_async(func)() + else: + def handler(): + return func() + return handler() diff --git a/tests/test_views.py b/tests/test_views.py index 2648c9fb38..49fdbe4760 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,8 +1,12 @@ import copy +import django +import pytest from django.test import TestCase +from django.contrib.auth.models import User from rest_framework import status +from rest_framework.compat import async_to_sync from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import APISettings, api_settings @@ -22,16 +26,36 @@ def post(self, request, *args, **kwargs): return Response({'method': 'POST', 'data': request.data}) +class BasicAsyncView(APIView): + async def get(self, request, *args, **kwargs): + return Response({'method': 'GET'}) + + async def post(self, request, *args, **kwargs): + return Response({'method': 'POST', 'data': request.data}) + + @api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': - return {'method': 'GET'} + return Response({'method': 'GET'}) elif request.method == 'POST': - return {'method': 'POST', 'data': request.data} + return Response({'method': 'POST', 'data': request.data}) elif request.method == 'PUT': - return {'method': 'PUT', 'data': request.data} + return Response({'method': 'PUT', 'data': request.data}) elif request.method == 'PATCH': - return {'method': 'PATCH', 'data': request.data} + return Response({'method': 'PATCH', 'data': request.data}) + + +@api_view(['GET', 'POST', 'PUT', 'PATCH']) +async def basic_async_view(request): + if request.method == 'GET': + return Response({'method': 'GET'}) + elif request.method == 'POST': + return Response({'method': 'POST', 'data': request.data}) + elif request.method == 'PUT': + return Response({'method': 'PUT', 'data': request.data}) + elif request.method == 'PATCH': + return Response({'method': 'PATCH', 'data': request.data}) class ErrorView(APIView): @@ -72,6 +96,36 @@ class ClassBasedViewIntegrationTests(TestCase): def setUp(self): self.view = BasicView.as_view() + def test_get_succeeds(self): + request = factory.get('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = self.view(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -86,6 +140,36 @@ class FunctionBasedViewIntegrationTests(TestCase): def setUp(self): self.view = basic_view + def test_get_succeeds(self): + request = factory.get('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = self.view(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -96,6 +180,102 @@ def test_400_parse_error(self): assert sanitise_json_error(response.data) == expected +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class ClassBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = BasicAsyncView.as_view() + + def test_get_succeeds(self): + request = factory.get('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = async_to_sync(self.view)(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = async_to_sync(self.view)(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class FunctionBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = basic_async_view + + def test_get_succeeds(self): + request = factory.get('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = async_to_sync(self.view)(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = async_to_sync(self.view)(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + class TestCustomExceptionHandler(TestCase): def setUp(self): self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER