Skip to content

Commit 1dfc074

Browse files
committed
Async view implementation
1 parent 99e8b40 commit 1dfc074

File tree

5 files changed

+395
-43
lines changed

5 files changed

+395
-43
lines changed

docs/api-guide/views.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation.
217217
def view(request):
218218
return Response({"message": "Will not appear in schema!"})
219219

220+
# Async Views
221+
222+
When using Django 4.1 and above, REST framework allows you to work with async class and function based views.
223+
224+
For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async.
225+
226+
For example:
227+
228+
class AsyncView(APIView):
229+
async def get(self, request):
230+
return Response({"message": "This is an async class based view."})
231+
232+
233+
@api_view(['GET'])
234+
async def async_view(request):
235+
return Response({"message": "This is an async function based view."})
220236

221237
[cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
222238
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html

rest_framework/compat.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def distinct(queryset, base):
4141
uritemplate = None
4242

4343

44+
# async_to_sync is required for async view support
45+
if django.VERSION >= (4, 1):
46+
from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
47+
else:
48+
async_to_sync = None
49+
sync_to_async = None
50+
51+
def iscoroutinefunction(func):
52+
return False
53+
54+
4455
# coreschema is optional
4556
try:
4657
import coreschema

rest_framework/decorators.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from django.forms.utils import pretty_name
1212

13+
from rest_framework.compat import iscoroutinefunction
1314
from rest_framework.views import APIView
1415

1516

@@ -46,8 +47,12 @@ def decorator(func):
4647
allowed_methods = set(http_method_names) | {'options'}
4748
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
4849

49-
def handler(self, *args, **kwargs):
50-
return func(*args, **kwargs)
50+
if iscoroutinefunction(func):
51+
async def handler(self, *args, **kwargs):
52+
return await func(*args, **kwargs)
53+
else:
54+
def handler(self, *args, **kwargs):
55+
return func(*args, **kwargs)
5156

5257
for method in http_method_names:
5358
setattr(WrappedAPIView, method.lower(), handler)

rest_framework/views.py

Lines changed: 177 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
Provides an APIView class that is the base of all views in REST framework.
33
"""
4+
import asyncio
5+
46
from django.conf import settings
57
from django.core.exceptions import PermissionDenied
68
from django.db import connections, models
@@ -12,6 +14,9 @@
1214
from django.views.generic import View
1315

1416
from rest_framework import exceptions, status
17+
from rest_framework.compat import (
18+
async_to_sync, iscoroutinefunction, sync_to_async
19+
)
1520
from rest_framework.request import Request
1621
from rest_framework.response import Response
1722
from rest_framework.schemas import DefaultSchema
@@ -328,13 +333,52 @@ def check_permissions(self, request):
328333
Check if the request should be permitted.
329334
Raises an appropriate exception if the request is not permitted.
330335
"""
336+
async_permissions, sync_permissions = [], []
331337
for permission in self.get_permissions():
332-
if not permission.has_permission(request, self):
333-
self.permission_denied(
334-
request,
335-
message=getattr(permission, 'message', None),
336-
code=getattr(permission, 'code', None)
337-
)
338+
if iscoroutinefunction(permission.has_permission):
339+
async_permissions.append(permission)
340+
else:
341+
sync_permissions.append(permission)
342+
343+
async def check_async():
344+
results = await asyncio.gather(
345+
*(permission.has_permission(request, self) for permission in
346+
async_permissions), return_exceptions=True
347+
)
348+
349+
for idx in range(len(async_permissions)):
350+
if isinstance(results[idx], Exception):
351+
raise results[idx]
352+
elif not results[idx]:
353+
self.permission_denied(
354+
request,
355+
message=getattr(async_permissions[idx], "message", None),
356+
code=getattr(async_permissions[idx], "code", None),
357+
)
358+
359+
def check_sync():
360+
for permission in sync_permissions:
361+
if not permission.has_permission(request, self):
362+
self.permission_denied(
363+
request,
364+
message=getattr(permission, 'message', None),
365+
code=getattr(permission, 'code', None)
366+
)
367+
368+
if getattr(self, 'view_is_async', False):
369+
370+
async def func():
371+
if async_permissions:
372+
await check_async()
373+
if sync_permissions:
374+
await sync_to_async(check_sync)()
375+
376+
return func()
377+
else:
378+
if sync_permissions:
379+
check_sync()
380+
if async_permissions:
381+
async_to_sync(check_async)
338382

339383
def check_object_permissions(self, request, obj):
340384
"""
@@ -354,21 +398,65 @@ def check_throttles(self, request):
354398
Check if request should be throttled.
355399
Raises an appropriate exception if the request is throttled.
356400
"""
357-
throttle_durations = []
401+
async_throttle_durations, sync_throttle_durations = [], []
358402
for throttle in self.get_throttles():
359-
if not throttle.allow_request(request, self):
360-
throttle_durations.append(throttle.wait())
403+
if iscoroutinefunction(throttle.allow_request):
404+
async_throttle_durations.append(throttle)
405+
else:
406+
sync_throttle_durations.append(throttle)
407+
408+
async def async_throttles():
409+
for throttle in async_throttle_durations:
410+
if not await throttle.allow_request(request, self):
411+
yield throttle.wait()
412+
413+
def sync_throttles():
414+
for throttle in sync_throttle_durations:
415+
if not throttle.allow_request(request, self):
416+
yield throttle.wait()
417+
418+
if getattr(self, 'view_is_async', False):
361419

362-
if throttle_durations:
363-
# Filter out `None` values which may happen in case of config / rate
364-
# changes, see #1438
365-
durations = [
366-
duration for duration in throttle_durations
367-
if duration is not None
368-
]
420+
async def func():
421+
throttle_durations = []
369422

370-
duration = max(durations, default=None)
371-
self.throttled(request, duration)
423+
if async_throttle_durations:
424+
throttle_durations.extend(duration async for duration in async_throttles())
425+
426+
if sync_throttle_durations:
427+
throttle_durations.extend(duration async for duration in await sync_to_async(sync_throttles)())
428+
429+
if throttle_durations:
430+
# Filter out `None` values which may happen in case of config / rate
431+
# changes, see #1438
432+
durations = [
433+
duration for duration in throttle_durations
434+
if duration is not None
435+
]
436+
437+
duration = max(durations, default=None)
438+
self.throttled(request, duration)
439+
440+
return func()
441+
else:
442+
throttle_durations = []
443+
444+
if sync_throttle_durations:
445+
throttle_durations.extend(sync_throttles())
446+
447+
if async_throttle_durations:
448+
throttle_durations.extend(async_to_sync(async_throttles)())
449+
450+
if throttle_durations:
451+
# Filter out `None` values which may happen in case of config / rate
452+
# changes, see #1438
453+
durations = [
454+
duration for duration in throttle_durations
455+
if duration is not None
456+
]
457+
458+
duration = max(durations, default=None)
459+
self.throttled(request, duration)
372460

373461
def determine_version(self, request, *args, **kwargs):
374462
"""
@@ -410,10 +498,20 @@ def initial(self, request, *args, **kwargs):
410498
version, scheme = self.determine_version(request, *args, **kwargs)
411499
request.version, request.versioning_scheme = version, scheme
412500

413-
# Ensure that the incoming request is permitted
414-
self.perform_authentication(request)
415-
self.check_permissions(request)
416-
self.check_throttles(request)
501+
if getattr(self, 'view_is_async', False):
502+
503+
async def func():
504+
# Ensure that the incoming request is permitted
505+
await sync_to_async(self.perform_authentication)(request)
506+
await self.check_permissions(request)
507+
await self.check_throttles(request)
508+
509+
return func()
510+
else:
511+
# Ensure that the incoming request is permitted
512+
self.perform_authentication(request)
513+
self.check_permissions(request)
514+
self.check_throttles(request)
417515

418516
def finalize_response(self, request, response, *args, **kwargs):
419517
"""
@@ -469,7 +567,15 @@ def handle_exception(self, exc):
469567
self.raise_uncaught_exception(exc)
470568

471569
response.exception = True
472-
return response
570+
571+
if getattr(self, 'view_is_async', False):
572+
573+
async def func():
574+
return response
575+
576+
return func()
577+
else:
578+
return response
473579

474580
def raise_uncaught_exception(self, exc):
475581
if settings.DEBUG:
@@ -493,23 +599,49 @@ def dispatch(self, request, *args, **kwargs):
493599
self.request = request
494600
self.headers = self.default_response_headers # deprecate?
495601

496-
try:
497-
self.initial(request, *args, **kwargs)
602+
if getattr(self, 'view_is_async', False):
498603

499-
# Get the appropriate handler method
500-
if request.method.lower() in self.http_method_names:
501-
handler = getattr(self, request.method.lower(),
502-
self.http_method_not_allowed)
503-
else:
504-
handler = self.http_method_not_allowed
604+
async def func():
605+
606+
try:
607+
await self.initial(request, *args, **kwargs)
608+
609+
# Get the appropriate handler method
610+
if request.method.lower() in self.http_method_names:
611+
handler = getattr(self, request.method.lower(),
612+
self.http_method_not_allowed)
613+
else:
614+
handler = self.http_method_not_allowed
615+
616+
response = await handler(request, *args, **kwargs)
505617

506-
response = handler(request, *args, **kwargs)
618+
except Exception as exc:
619+
response = await self.handle_exception(exc)
507620

508-
except Exception as exc:
509-
response = self.handle_exception(exc)
621+
return self.finalize_response(request, response, *args, **kwargs)
510622

511-
self.response = self.finalize_response(request, response, *args, **kwargs)
512-
return self.response
623+
self.response = func()
624+
625+
return self.response
626+
else:
627+
try:
628+
self.initial(request, *args, **kwargs)
629+
630+
# Get the appropriate handler method
631+
if request.method.lower() in self.http_method_names:
632+
handler = getattr(self, request.method.lower(),
633+
self.http_method_not_allowed)
634+
else:
635+
handler = self.http_method_not_allowed
636+
637+
response = handler(request, *args, **kwargs)
638+
639+
except Exception as exc:
640+
response = self.handle_exception(exc)
641+
642+
self.response = self.finalize_response(request, response, *args, **kwargs)
643+
644+
return self.response
513645

514646
def options(self, request, *args, **kwargs):
515647
"""
@@ -518,4 +650,12 @@ def options(self, request, *args, **kwargs):
518650
if self.metadata_class is None:
519651
return self.http_method_not_allowed(request, *args, **kwargs)
520652
data = self.metadata_class().determine_metadata(request, self)
521-
return Response(data, status=status.HTTP_200_OK)
653+
654+
if getattr(self, 'view_is_async', False):
655+
656+
async def func():
657+
return Response(data, status=status.HTTP_200_OK)
658+
659+
return func()
660+
else:
661+
return Response(data, status=status.HTTP_200_OK)

0 commit comments

Comments
 (0)