Skip to content

Commit 68462e4

Browse files
committed
Async view implementation
1 parent 4f7e9ed commit 68462e4

File tree

9 files changed

+1030
-70
lines changed

9 files changed

+1030
-70
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ jobs:
1212
runs-on: ubuntu-20.04
1313

1414
strategy:
15+
fail-fast: false
1516
matrix:
1617
python-version:
1718
- '3.6'

docs/api-guide/views.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation.
217217
def view(request):
218218
return Response({"message": "Will not appear in schema!"})
219219

220+
# Async Views
221+
222+
When using Django 4.1 and above, REST framework allows you to work with async class and function based views.
223+
224+
For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async.
225+
226+
For example:
227+
228+
class AsyncView(APIView):
229+
async def get(self, request):
230+
return Response({"message": "This is an async class based view."})
231+
232+
233+
@api_view(['GET'])
234+
async def async_view(request):
235+
return Response({"message": "This is an async function based view."})
220236

221237
[cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
222238
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html

rest_framework/compat.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def distinct(queryset, base):
4141
uritemplate = None
4242

4343

44+
# async_to_sync is required for async view support
45+
if django.VERSION >= (4, 1):
46+
from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
47+
else:
48+
async_to_sync = None
49+
sync_to_async = None
50+
51+
def iscoroutinefunction(func):
52+
return False
53+
54+
4455
# coreschema is optional
4556
try:
4657
import coreschema

rest_framework/decorators.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from django.forms.utils import pretty_name
1212

13+
from rest_framework.compat import iscoroutinefunction
1314
from rest_framework.views import APIView
1415

1516

@@ -46,8 +47,12 @@ def decorator(func):
4647
allowed_methods = set(http_method_names) | {'options'}
4748
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
4849

49-
def handler(self, *args, **kwargs):
50-
return func(*args, **kwargs)
50+
if iscoroutinefunction(func):
51+
async def handler(self, *args, **kwargs):
52+
return await func(*args, **kwargs)
53+
else:
54+
def handler(self, *args, **kwargs):
55+
return func(*args, **kwargs)
5156

5257
for method in http_method_names:
5358
setattr(WrappedAPIView, method.lower(), handler)

rest_framework/test.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from django.test.client import Client as DjangoClient
1212
from django.test.client import ClientHandler
1313
from django.test.client import RequestFactory as DjangoRequestFactory
14+
15+
if django.VERSION >= (4, 1):
16+
from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory
17+
1418
from django.utils.encoding import force_bytes
1519
from django.utils.http import urlencode
1620

@@ -240,6 +244,111 @@ def request(self, **kwargs):
240244
return request
241245

242246

247+
if django.VERSION >= (4, 1):
248+
class APIAsyncRequestFactory(DjangoAsyncRequestFactory):
249+
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
250+
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
251+
252+
def __init__(self, enforce_csrf_checks=False, **defaults):
253+
self.enforce_csrf_checks = enforce_csrf_checks
254+
self.renderer_classes = {}
255+
for cls in self.renderer_classes_list:
256+
self.renderer_classes[cls.format] = cls
257+
super().__init__(**defaults)
258+
259+
def _encode_data(self, data, format=None, content_type=None):
260+
"""
261+
Encode the data returning a two tuple of (bytes, content_type)
262+
"""
263+
264+
if data is None:
265+
return ('', content_type)
266+
267+
assert format is None or content_type is None, (
268+
'You may not set both `format` and `content_type`.'
269+
)
270+
271+
if content_type:
272+
# Content type specified explicitly, treat data as a raw bytestring
273+
ret = force_bytes(data, settings.DEFAULT_CHARSET)
274+
275+
else:
276+
format = format or self.default_format
277+
278+
assert format in self.renderer_classes, (
279+
"Invalid format '{}'. Available formats are {}. "
280+
"Set TEST_REQUEST_RENDERER_CLASSES to enable "
281+
"extra request formats.".format(
282+
format,
283+
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
284+
)
285+
)
286+
287+
# Use format and render the data into a bytestring
288+
renderer = self.renderer_classes[format]()
289+
ret = renderer.render(data)
290+
291+
# Determine the content-type header from the renderer
292+
content_type = renderer.media_type
293+
if renderer.charset:
294+
content_type = "{}; charset={}".format(
295+
content_type, renderer.charset
296+
)
297+
298+
# Coerce text to bytes if required.
299+
if isinstance(ret, str):
300+
ret = ret.encode(renderer.charset)
301+
302+
return ret, content_type
303+
304+
def get(self, path, data=None, **extra):
305+
r = {
306+
'QUERY_STRING': urlencode(data or {}, doseq=True),
307+
}
308+
if not data and '?' in path:
309+
# Fix to support old behavior where you have the arguments in the
310+
# url. See #1461.
311+
query_string = force_bytes(path.split('?')[1])
312+
query_string = query_string.decode('iso-8859-1')
313+
r['QUERY_STRING'] = query_string
314+
r.update(extra)
315+
return self.generic('GET', path, **r)
316+
317+
def post(self, path, data=None, format=None, content_type=None, **extra):
318+
data, content_type = self._encode_data(data, format, content_type)
319+
return self.generic('POST', path, data, content_type, **extra)
320+
321+
def put(self, path, data=None, format=None, content_type=None, **extra):
322+
data, content_type = self._encode_data(data, format, content_type)
323+
return self.generic('PUT', path, data, content_type, **extra)
324+
325+
def patch(self, path, data=None, format=None, content_type=None, **extra):
326+
data, content_type = self._encode_data(data, format, content_type)
327+
return self.generic('PATCH', path, data, content_type, **extra)
328+
329+
def delete(self, path, data=None, format=None, content_type=None, **extra):
330+
data, content_type = self._encode_data(data, format, content_type)
331+
return self.generic('DELETE', path, data, content_type, **extra)
332+
333+
def options(self, path, data=None, format=None, content_type=None, **extra):
334+
data, content_type = self._encode_data(data, format, content_type)
335+
return self.generic('OPTIONS', path, data, content_type, **extra)
336+
337+
def generic(self, method, path, data='',
338+
content_type='application/octet-stream', secure=False, **extra):
339+
# Include the CONTENT_TYPE, regardless of whether or not data is empty.
340+
if content_type is not None:
341+
extra['CONTENT_TYPE'] = str(content_type)
342+
343+
return super().generic(
344+
method, path, data, content_type, secure, **extra)
345+
346+
def request(self, **kwargs):
347+
request = super().request(**kwargs)
348+
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
349+
return request
350+
351+
243352
class ForceAuthClientHandler(ClientHandler):
244353
"""
245354
A patched version of ClientHandler that can enforce authentication

rest_framework/throttling.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from django.core.cache import cache as default_cache
77
from django.core.exceptions import ImproperlyConfigured
88

9+
from rest_framework.compat import (
10+
async_to_sync, iscoroutinefunction, sync_to_async
11+
)
912
from rest_framework.settings import api_settings
1013

1114

@@ -64,6 +67,8 @@ class SimpleRateThrottle(BaseThrottle):
6467
cache_format = 'throttle_%(scope)s_%(ident)s'
6568
scope = None
6669
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
70+
sync_capable = True
71+
async_capable = True
6772

6873
def __init__(self):
6974
if not getattr(self, 'rate', None):
@@ -113,23 +118,52 @@ def allow_request(self, request, view):
113118
On success calls `throttle_success`.
114119
On failure calls `throttle_failure`.
115120
"""
116-
if self.rate is None:
117-
return True
118-
119-
self.key = self.get_cache_key(request, view)
120-
if self.key is None:
121-
return True
122-
123-
self.history = self.cache.get(self.key, [])
124-
self.now = self.timer()
125-
126-
# Drop any requests from the history which have now passed the
127-
# throttle duration
128-
while self.history and self.history[-1] <= self.now - self.duration:
129-
self.history.pop()
130-
if len(self.history) >= self.num_requests:
131-
return self.throttle_failure()
132-
return self.throttle_success()
121+
if getattr(view, 'view_is_async', False):
122+
123+
async def func():
124+
if self.rate is None:
125+
return True
126+
127+
self.key = self.get_cache_key(request, view)
128+
if self.key is None:
129+
return True
130+
131+
self.history = self.cache.get(self.key, [])
132+
if iscoroutinefunction(self.timer):
133+
self.now = await self.timer()
134+
else:
135+
self.now = await sync_to_async(self.timer)()
136+
137+
# Drop any requests from the history which have now passed the
138+
# throttle duration
139+
while self.history and self.history[-1] <= self.now - self.duration:
140+
self.history.pop()
141+
if len(self.history) >= self.num_requests:
142+
return self.throttle_failure()
143+
return self.throttle_success()
144+
145+
return func()
146+
else:
147+
if self.rate is None:
148+
return True
149+
150+
self.key = self.get_cache_key(request, view)
151+
if self.key is None:
152+
return True
153+
154+
self.history = self.cache.get(self.key, [])
155+
if iscoroutinefunction(self.timer):
156+
self.now = async_to_sync(self.timer)()
157+
else:
158+
self.now = self.timer()
159+
160+
# Drop any requests from the history which have now passed the
161+
# throttle duration
162+
while self.history and self.history[-1] <= self.now - self.duration:
163+
self.history.pop()
164+
if len(self.history) >= self.num_requests:
165+
return self.throttle_failure()
166+
return self.throttle_success()
133167

134168
def throttle_success(self):
135169
"""
@@ -210,6 +244,8 @@ class ScopedRateThrottle(SimpleRateThrottle):
210244
user id of the request, and the scope of the view being accessed.
211245
"""
212246
scope_attr = 'throttle_scope'
247+
sync_capable = True
248+
async_capable = True
213249

214250
def __init__(self):
215251
# Override the usual SimpleRateThrottle, because we can't determine
@@ -220,17 +256,34 @@ def allow_request(self, request, view):
220256
# We can only determine the scope once we're called by the view.
221257
self.scope = getattr(view, self.scope_attr, None)
222258

223-
# If a view does not have a `throttle_scope` always allow the request
224-
if not self.scope:
225-
return True
259+
if getattr(view, 'view_is_async', False):
226260

227-
# Determine the allowed request rate as we normally would during
228-
# the `__init__` call.
229-
self.rate = self.get_rate()
230-
self.num_requests, self.duration = self.parse_rate(self.rate)
261+
async def func(allow_request):
262+
# If a view does not have a `throttle_scope` always allow the request
263+
if not self.scope:
264+
return True
265+
266+
# Determine the allowed request rate as we normally would during
267+
# the `__init__` call.
268+
self.rate = self.get_rate()
269+
self.num_requests, self.duration = self.parse_rate(self.rate)
270+
271+
# We can now proceed as normal.
272+
return await allow_request(request, view)
273+
274+
return func(super().allow_request)
275+
else:
276+
# If a view does not have a `throttle_scope` always allow the request
277+
if not self.scope:
278+
return True
279+
280+
# Determine the allowed request rate as we normally would during
281+
# the `__init__` call.
282+
self.rate = self.get_rate()
283+
self.num_requests, self.duration = self.parse_rate(self.rate)
231284

232-
# We can now proceed as normal.
233-
return super().allow_request(request, view)
285+
# We can now proceed as normal.
286+
return super().allow_request(request, view)
234287

235288
def get_cache_key(self, request, view):
236289
"""

0 commit comments

Comments
 (0)