1
1
"""
2
2
Provides an APIView class that is the base of all views in REST framework.
3
3
"""
4
+ import asyncio
5
+
4
6
from django .conf import settings
5
7
from django .core .exceptions import PermissionDenied
6
8
from django .db import connections , models
12
14
from django .views .generic import View
13
15
14
16
from rest_framework import exceptions , status
17
+ from rest_framework .compat import (
18
+ async_to_sync , iscoroutinefunction , sync_to_async
19
+ )
15
20
from rest_framework .request import Request
16
21
from rest_framework .response import Response
17
22
from rest_framework .schemas import DefaultSchema
@@ -328,13 +333,52 @@ def check_permissions(self, request):
328
333
Check if the request should be permitted.
329
334
Raises an appropriate exception if the request is not permitted.
330
335
"""
336
+ async_permissions , sync_permissions = [], []
331
337
for permission in self .get_permissions ():
332
- if not permission .has_permission (request , self ):
333
- self .permission_denied (
334
- request ,
335
- message = getattr (permission , 'message' , None ),
336
- code = getattr (permission , 'code' , None )
337
- )
338
+ if iscoroutinefunction (permission .has_permission ):
339
+ async_permissions .append (permission )
340
+ else :
341
+ sync_permissions .append (permission )
342
+
343
+ async def check_async ():
344
+ results = await asyncio .gather (
345
+ * (permission .has_permission (request , self ) for permission in
346
+ async_permissions ), return_exceptions = True
347
+ )
348
+
349
+ for idx in range (len (async_permissions )):
350
+ if isinstance (results [idx ], Exception ):
351
+ raise results [idx ]
352
+ elif not results [idx ]:
353
+ self .permission_denied (
354
+ request ,
355
+ message = getattr (async_permissions [idx ], "message" , None ),
356
+ code = getattr (async_permissions [idx ], "code" , None ),
357
+ )
358
+
359
+ def check_sync ():
360
+ for permission in sync_permissions :
361
+ if not permission .has_permission (request , self ):
362
+ self .permission_denied (
363
+ request ,
364
+ message = getattr (permission , 'message' , None ),
365
+ code = getattr (permission , 'code' , None )
366
+ )
367
+
368
+ if getattr (self , 'view_is_async' , False ):
369
+
370
+ async def func ():
371
+ if async_permissions :
372
+ await check_async ()
373
+ if sync_permissions :
374
+ await sync_to_async (check_sync )()
375
+
376
+ return func ()
377
+ else :
378
+ if sync_permissions :
379
+ check_sync ()
380
+ if async_permissions :
381
+ async_to_sync (check_async )
338
382
339
383
def check_object_permissions (self , request , obj ):
340
384
"""
@@ -354,21 +398,65 @@ def check_throttles(self, request):
354
398
Check if request should be throttled.
355
399
Raises an appropriate exception if the request is throttled.
356
400
"""
357
- throttle_durations = []
401
+ async_throttle_durations , sync_throttle_durations = [], []
358
402
for throttle in self .get_throttles ():
359
- if not throttle .allow_request (request , self ):
360
- throttle_durations .append (throttle .wait ())
403
+ if iscoroutinefunction (throttle .allow_request ):
404
+ async_throttle_durations .append (throttle )
405
+ else :
406
+ sync_throttle_durations .append (throttle )
407
+
408
+ async def async_throttles ():
409
+ for throttle in async_throttle_durations :
410
+ if not await throttle .allow_request (request , self ):
411
+ yield throttle .wait ()
412
+
413
+ def sync_throttles ():
414
+ for throttle in sync_throttle_durations :
415
+ if not throttle .allow_request (request , self ):
416
+ yield throttle .wait ()
417
+
418
+ if getattr (self , 'view_is_async' , False ):
361
419
362
- if throttle_durations :
363
- # Filter out `None` values which may happen in case of config / rate
364
- # changes, see #1438
365
- durations = [
366
- duration for duration in throttle_durations
367
- if duration is not None
368
- ]
420
+ async def func ():
421
+ throttle_durations = []
369
422
370
- duration = max (durations , default = None )
371
- self .throttled (request , duration )
423
+ if async_throttle_durations :
424
+ throttle_durations .extend (duration async for duration in async_throttles ())
425
+
426
+ if sync_throttle_durations :
427
+ throttle_durations .extend (duration async for duration in await sync_to_async (sync_throttles )())
428
+
429
+ if throttle_durations :
430
+ # Filter out `None` values which may happen in case of config / rate
431
+ # changes, see #1438
432
+ durations = [
433
+ duration for duration in throttle_durations
434
+ if duration is not None
435
+ ]
436
+
437
+ duration = max (durations , default = None )
438
+ self .throttled (request , duration )
439
+
440
+ return func ()
441
+ else :
442
+ throttle_durations = []
443
+
444
+ if sync_throttle_durations :
445
+ throttle_durations .extend (sync_throttles ())
446
+
447
+ if async_throttle_durations :
448
+ throttle_durations .extend (async_to_sync (async_throttles )())
449
+
450
+ if throttle_durations :
451
+ # Filter out `None` values which may happen in case of config / rate
452
+ # changes, see #1438
453
+ durations = [
454
+ duration for duration in throttle_durations
455
+ if duration is not None
456
+ ]
457
+
458
+ duration = max (durations , default = None )
459
+ self .throttled (request , duration )
372
460
373
461
def determine_version (self , request , * args , ** kwargs ):
374
462
"""
@@ -410,10 +498,20 @@ def initial(self, request, *args, **kwargs):
410
498
version , scheme = self .determine_version (request , * args , ** kwargs )
411
499
request .version , request .versioning_scheme = version , scheme
412
500
413
- # Ensure that the incoming request is permitted
414
- self .perform_authentication (request )
415
- self .check_permissions (request )
416
- self .check_throttles (request )
501
+ if getattr (self , 'view_is_async' , False ):
502
+
503
+ async def func ():
504
+ # Ensure that the incoming request is permitted
505
+ await sync_to_async (self .perform_authentication )(request )
506
+ await self .check_permissions (request )
507
+ await self .check_throttles (request )
508
+
509
+ return func ()
510
+ else :
511
+ # Ensure that the incoming request is permitted
512
+ self .perform_authentication (request )
513
+ self .check_permissions (request )
514
+ self .check_throttles (request )
417
515
418
516
def finalize_response (self , request , response , * args , ** kwargs ):
419
517
"""
@@ -469,7 +567,15 @@ def handle_exception(self, exc):
469
567
self .raise_uncaught_exception (exc )
470
568
471
569
response .exception = True
472
- return response
570
+
571
+ if getattr (self , 'view_is_async' , False ):
572
+
573
+ async def func ():
574
+ return response
575
+
576
+ return func ()
577
+ else :
578
+ return response
473
579
474
580
def raise_uncaught_exception (self , exc ):
475
581
if settings .DEBUG :
@@ -493,23 +599,49 @@ def dispatch(self, request, *args, **kwargs):
493
599
self .request = request
494
600
self .headers = self .default_response_headers # deprecate?
495
601
496
- try :
497
- self .initial (request , * args , ** kwargs )
602
+ if getattr (self , 'view_is_async' , False ):
498
603
499
- # Get the appropriate handler method
500
- if request .method .lower () in self .http_method_names :
501
- handler = getattr (self , request .method .lower (),
502
- self .http_method_not_allowed )
503
- else :
504
- handler = self .http_method_not_allowed
604
+ async def func ():
605
+
606
+ try :
607
+ await self .initial (request , * args , ** kwargs )
608
+
609
+ # Get the appropriate handler method
610
+ if request .method .lower () in self .http_method_names :
611
+ handler = getattr (self , request .method .lower (),
612
+ self .http_method_not_allowed )
613
+ else :
614
+ handler = self .http_method_not_allowed
615
+
616
+ response = await handler (request , * args , ** kwargs )
505
617
506
- response = handler (request , * args , ** kwargs )
618
+ except Exception as exc :
619
+ response = await self .handle_exception (exc )
507
620
508
- except Exception as exc :
509
- response = self .handle_exception (exc )
621
+ return self .finalize_response (request , response , * args , ** kwargs )
510
622
511
- self .response = self .finalize_response (request , response , * args , ** kwargs )
512
- return self .response
623
+ self .response = func ()
624
+
625
+ return self .response
626
+ else :
627
+ try :
628
+ self .initial (request , * args , ** kwargs )
629
+
630
+ # Get the appropriate handler method
631
+ if request .method .lower () in self .http_method_names :
632
+ handler = getattr (self , request .method .lower (),
633
+ self .http_method_not_allowed )
634
+ else :
635
+ handler = self .http_method_not_allowed
636
+
637
+ response = handler (request , * args , ** kwargs )
638
+
639
+ except Exception as exc :
640
+ response = self .handle_exception (exc )
641
+
642
+ self .response = self .finalize_response (request , response , * args , ** kwargs )
643
+
644
+ return self .response
513
645
514
646
def options (self , request , * args , ** kwargs ):
515
647
"""
@@ -518,4 +650,12 @@ def options(self, request, *args, **kwargs):
518
650
if self .metadata_class is None :
519
651
return self .http_method_not_allowed (request , * args , ** kwargs )
520
652
data = self .metadata_class ().determine_metadata (request , self )
521
- return Response (data , status = status .HTTP_200_OK )
653
+
654
+ if getattr (self , 'view_is_async' , False ):
655
+
656
+ async def func ():
657
+ return Response (data , status = status .HTTP_200_OK )
658
+
659
+ return func ()
660
+ else :
661
+ return Response (data , status = status .HTTP_200_OK )
0 commit comments