Skip to content

Commit 979bb24

Browse files
committed
Async view implementation
1 parent 4f7e9ed commit 979bb24

File tree

9 files changed

+935
-71
lines changed

9 files changed

+935
-71
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ jobs:
1212
runs-on: ubuntu-20.04
1313

1414
strategy:
15+
fail-fast: false
1516
matrix:
1617
python-version:
1718
- '3.6'

docs/api-guide/views.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation.
217217
def view(request):
218218
return Response({"message": "Will not appear in schema!"})
219219

220+
# Async Views
221+
222+
When using Django 4.1 and above, REST framework allows you to work with async class and function based views.
223+
224+
For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async.
225+
226+
For example:
227+
228+
class AsyncView(APIView):
229+
async def get(self, request):
230+
return Response({"message": "This is an async class based view."})
231+
232+
233+
@api_view(['GET'])
234+
async def async_view(request):
235+
return Response({"message": "This is an async function based view."})
220236

221237
[cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
222238
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html

rest_framework/compat.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def distinct(queryset, base):
4141
uritemplate = None
4242

4343

44+
# async_to_sync is required for async view support
45+
if django.VERSION >= (4, 1):
46+
from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
47+
else:
48+
async_to_sync = None
49+
sync_to_async = None
50+
51+
def iscoroutinefunction(func):
52+
return False
53+
54+
4455
# coreschema is optional
4556
try:
4657
import coreschema

rest_framework/decorators.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from django.forms.utils import pretty_name
1212

13+
from rest_framework.compat import iscoroutinefunction
1314
from rest_framework.views import APIView
1415

1516

@@ -46,8 +47,12 @@ def decorator(func):
4647
allowed_methods = set(http_method_names) | {'options'}
4748
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
4849

49-
def handler(self, *args, **kwargs):
50-
return func(*args, **kwargs)
50+
if iscoroutinefunction(func):
51+
async def handler(self, *args, **kwargs):
52+
return await func(*args, **kwargs)
53+
else:
54+
def handler(self, *args, **kwargs):
55+
return func(*args, **kwargs)
5156

5257
for method in http_method_names:
5358
setattr(WrappedAPIView, method.lower(), handler)

rest_framework/test.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from django.test.client import Client as DjangoClient
1212
from django.test.client import ClientHandler
1313
from django.test.client import RequestFactory as DjangoRequestFactory
14+
15+
if django.VERSION >= (4, 1):
16+
from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory
17+
1418
from django.utils.encoding import force_bytes
1519
from django.utils.http import urlencode
1620

@@ -136,7 +140,7 @@ def CoreAPIClient(*args, **kwargs):
136140
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
137141

138142

139-
class APIRequestFactory(DjangoRequestFactory):
143+
class APIRequestFactoryMixin:
140144
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
141145
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
142146

@@ -240,6 +244,15 @@ def request(self, **kwargs):
240244
return request
241245

242246

247+
class APIRequestFactory(APIRequestFactoryMixin, DjangoRequestFactory):
248+
pass
249+
250+
251+
if django.VERSION >= (4, 1):
252+
class APIAsyncRequestFactory(APIRequestFactoryMixin, DjangoAsyncRequestFactory):
253+
pass
254+
255+
243256
class ForceAuthClientHandler(ClientHandler):
244257
"""
245258
A patched version of ClientHandler that can enforce authentication

rest_framework/throttling.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from django.core.cache import cache as default_cache
77
from django.core.exceptions import ImproperlyConfigured
88

9+
from rest_framework.compat import (
10+
async_to_sync, iscoroutinefunction, sync_to_async
11+
)
912
from rest_framework.settings import api_settings
1013

1114

@@ -64,6 +67,8 @@ class SimpleRateThrottle(BaseThrottle):
6467
cache_format = 'throttle_%(scope)s_%(ident)s'
6568
scope = None
6669
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
70+
sync_capable = True
71+
async_capable = True
6772

6873
def __init__(self):
6974
if not getattr(self, 'rate', None):
@@ -113,23 +118,52 @@ def allow_request(self, request, view):
113118
On success calls `throttle_success`.
114119
On failure calls `throttle_failure`.
115120
"""
116-
if self.rate is None:
117-
return True
118-
119-
self.key = self.get_cache_key(request, view)
120-
if self.key is None:
121-
return True
122-
123-
self.history = self.cache.get(self.key, [])
124-
self.now = self.timer()
125-
126-
# Drop any requests from the history which have now passed the
127-
# throttle duration
128-
while self.history and self.history[-1] <= self.now - self.duration:
129-
self.history.pop()
130-
if len(self.history) >= self.num_requests:
131-
return self.throttle_failure()
132-
return self.throttle_success()
121+
if getattr(view, 'view_is_async', False):
122+
123+
async def func():
124+
if self.rate is None:
125+
return True
126+
127+
self.key = self.get_cache_key(request, view)
128+
if self.key is None:
129+
return True
130+
131+
self.history = self.cache.get(self.key, [])
132+
if iscoroutinefunction(self.timer):
133+
self.now = await self.timer()
134+
else:
135+
self.now = await sync_to_async(self.timer)()
136+
137+
# Drop any requests from the history which have now passed the
138+
# throttle duration
139+
while self.history and self.history[-1] <= self.now - self.duration:
140+
self.history.pop()
141+
if len(self.history) >= self.num_requests:
142+
return self.throttle_failure()
143+
return self.throttle_success()
144+
145+
return func()
146+
else:
147+
if self.rate is None:
148+
return True
149+
150+
self.key = self.get_cache_key(request, view)
151+
if self.key is None:
152+
return True
153+
154+
self.history = self.cache.get(self.key, [])
155+
if iscoroutinefunction(self.timer):
156+
self.now = async_to_sync(self.timer)()
157+
else:
158+
self.now = self.timer()
159+
160+
# Drop any requests from the history which have now passed the
161+
# throttle duration
162+
while self.history and self.history[-1] <= self.now - self.duration:
163+
self.history.pop()
164+
if len(self.history) >= self.num_requests:
165+
return self.throttle_failure()
166+
return self.throttle_success()
133167

134168
def throttle_success(self):
135169
"""
@@ -210,6 +244,8 @@ class ScopedRateThrottle(SimpleRateThrottle):
210244
user id of the request, and the scope of the view being accessed.
211245
"""
212246
scope_attr = 'throttle_scope'
247+
sync_capable = True
248+
async_capable = True
213249

214250
def __init__(self):
215251
# Override the usual SimpleRateThrottle, because we can't determine
@@ -220,17 +256,34 @@ def allow_request(self, request, view):
220256
# We can only determine the scope once we're called by the view.
221257
self.scope = getattr(view, self.scope_attr, None)
222258

223-
# If a view does not have a `throttle_scope` always allow the request
224-
if not self.scope:
225-
return True
259+
if getattr(view, 'view_is_async', False):
226260

227-
# Determine the allowed request rate as we normally would during
228-
# the `__init__` call.
229-
self.rate = self.get_rate()
230-
self.num_requests, self.duration = self.parse_rate(self.rate)
261+
async def func(allow_request):
262+
# If a view does not have a `throttle_scope` always allow the request
263+
if not self.scope:
264+
return True
265+
266+
# Determine the allowed request rate as we normally would during
267+
# the `__init__` call.
268+
self.rate = self.get_rate()
269+
self.num_requests, self.duration = self.parse_rate(self.rate)
270+
271+
# We can now proceed as normal.
272+
return await allow_request(request, view)
273+
274+
return func(super().allow_request)
275+
else:
276+
# If a view does not have a `throttle_scope` always allow the request
277+
if not self.scope:
278+
return True
279+
280+
# Determine the allowed request rate as we normally would during
281+
# the `__init__` call.
282+
self.rate = self.get_rate()
283+
self.num_requests, self.duration = self.parse_rate(self.rate)
231284

232-
# We can now proceed as normal.
233-
return super().allow_request(request, view)
285+
# We can now proceed as normal.
286+
return super().allow_request(request, view)
234287

235288
def get_cache_key(self, request, view):
236289
"""

0 commit comments

Comments
 (0)