Skip to content

Commit e2af695

Browse files
authored
fix: drf fix (#492)
1 parent 744e56b commit e2af695

File tree

2 files changed

+60
-60
lines changed

2 files changed

+60
-60
lines changed

tests/frontendIntegration/drf_async/polls/views.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from rest_framework import status # type: ignore
3333
from rest_framework.decorators import api_view as api_view_sync, renderer_classes # type: ignore
3434
from adrf.decorators import api_view # type: ignore
35-
from rest_framework.renderers import JSONRenderer, StaticHTMLRenderer, BaseRenderer # type: ignore
35+
from rest_framework.renderers import StaticHTMLRenderer, BaseRenderer # type: ignore
3636
from rest_framework.request import Request # type: ignore
3737
from rest_framework.response import Response # type: ignore
3838
from supertokens_python import get_all_cors_headers
@@ -81,6 +81,15 @@
8181
last_set_enable_jwt = False
8282

8383

84+
class JsonTextRenderer(BaseRenderer): # type: ignore
85+
media_type = "application/json"
86+
87+
def render(self, data, media_type=None, renderer_context=None): # type: ignore
88+
if isinstance(data, str):
89+
return data.encode("utf-8") # type: ignore
90+
return json.dumps(data).encode("utf-8")
91+
92+
8493
def custom_decorator_for_test(): # type: ignore
8594
def session_verify_custom_test(f): # type: ignore
8695
@wraps(f) # type: ignore
@@ -221,7 +230,7 @@ async def wrapped_function(request: Request, *args, **kwargs): # type: ignore
221230

222231

223232
@api_view_sync(["GET", "POST"])
224-
@renderer_classes([JSONRenderer]) # type: ignore
233+
@renderer_classes([JsonTextRenderer]) # type: ignore
225234
def try_refresh_token(_): # type: ignore
226235
return Response(
227236
{"error": "try refresh token"},
@@ -230,7 +239,7 @@ def try_refresh_token(_): # type: ignore
230239

231240

232241
@api_view_sync(["GET", "POST"])
233-
@renderer_classes([JSONRenderer]) # type: ignore
242+
@renderer_classes([JsonTextRenderer]) # type: ignore
234243
def unauthorised(_): # type: ignore
235244
return Response(
236245
{"error": "unauthorised"},
@@ -416,15 +425,6 @@ def send_options_api_response(): # type: ignore
416425
return Response("") # type: ignore
417426

418427

419-
class JsonTextRenderer(BaseRenderer): # type: ignore
420-
media_type = "application/json"
421-
422-
def render(self, data, media_type=None, renderer_context=None): # type: ignore
423-
if isinstance(data, dict):
424-
return json.dumps(data).encode("utf-8")
425-
return data.encode("utf-8") # type: ignore
426-
427-
428428
@api_view(["GET", "POST"])
429429
@renderer_classes([JsonTextRenderer]) # type: ignore
430430
async def login(request: Request): # type: ignore
@@ -480,7 +480,7 @@ async def login_218(request: Request): # type: ignore
480480

481481

482482
@api_view(["GET", "POST"])
483-
@renderer_classes([JSONRenderer]) # type: ignore
483+
@renderer_classes([JsonTextRenderer]) # type: ignore
484484
async def before_each(request: Request): # type: ignore
485485
config(True, False, None) # type: ignore
486486
if request.method == "POST": # type: ignore
@@ -491,7 +491,7 @@ async def before_each(request: Request): # type: ignore
491491

492492

493493
@api_view(["GET", "POST"])
494-
@renderer_classes([JSONRenderer]) # type: ignore
494+
@renderer_classes([JsonTextRenderer]) # type: ignore
495495
async def test_config(request: Request): # type: ignore
496496
if request.method == "POST": # type: ignore
497497
return Response("") # type: ignore
@@ -500,7 +500,7 @@ async def test_config(request: Request): # type: ignore
500500

501501

502502
@api_view(["GET", "POST"])
503-
@renderer_classes([JSONRenderer]) # type: ignore
503+
@renderer_classes([JsonTextRenderer]) # type: ignore
504504
async def multiple_interceptors(request: Request): # type: ignore
505505
if request.method == "POST": # type: ignore
506506
result_bool = (
@@ -538,7 +538,7 @@ async def update_jwt(request: Request): # type: ignore
538538

539539

540540
@api_view(["GET", "POST"])
541-
@renderer_classes([JSONRenderer]) # type: ignore
541+
@renderer_classes([JsonTextRenderer]) # type: ignore
542542
@custom_decorator_for_update_jwt_with_handle()
543543
@verify_session()
544544
async def update_jwt_with_handle(request: Request): # type: ignore
@@ -557,21 +557,21 @@ async def validate(self, payload: JSONObject, user_context: Dict[str, Any]):
557557

558558

559559
@api_view(["GET", "POST"])
560-
@renderer_classes([JSONRenderer]) # type: ignore
560+
@renderer_classes([JsonTextRenderer]) # type: ignore
561561
@verify_session(override_global_claim_validators=gcv_for_session_claim_err) # type: ignore
562562
async def session_claim_error_api(request: Request): # type: ignore
563563
return Response({}) # type: ignore
564564

565565

566566
@api_view(["GET", "POST"])
567-
@renderer_classes([JSONRenderer]) # type: ignore
567+
@renderer_classes([JsonTextRenderer]) # type: ignore
568568
async def without_body_403(request: Request): # type: ignore
569569
if request.method == "POST": # type: ignore
570570
return Response("", status=403) # type: ignore
571571

572572

573573
@api_view(["GET", "POST", "PUT", "DELETE"])
574-
@renderer_classes([JSONRenderer]) # type: ignore
574+
@renderer_classes([JsonTextRenderer]) # type: ignore
575575
async def testing(request: Request): # type: ignore
576576
if request.method in ["GET", "PUT", "POST", "DELETE"]: # type: ignore
577577
if "testing" in request.headers: # type: ignore
@@ -593,7 +593,7 @@ async def logout(request: Request): # type: ignore
593593

594594

595595
@api_view(["GET", "POST"])
596-
@renderer_classes([JSONRenderer]) # type: ignore
596+
@renderer_classes([JsonTextRenderer]) # type: ignore
597597
@verify_session()
598598
async def revoke_all(request: Request): # type: ignore
599599
if request.method: # type: ignore
@@ -616,15 +616,15 @@ def refresh_attempted_time(request: Request): # type: ignore
616616

617617

618618
@api_view(["GET", "PUT", "POST"])
619-
@renderer_classes([JSONRenderer]) # type: ignore
619+
@renderer_classes([JsonTextRenderer]) # type: ignore
620620
@custom_decorator_for_test()
621621
@verify_session()
622622
async def refresh(request: Request): # type: ignore
623623
return Response("refresh success") # type: ignore
624624

625625

626626
@api_view(["GET", "POST"])
627-
@renderer_classes([JSONRenderer]) # type: ignore
627+
@renderer_classes([JsonTextRenderer]) # type: ignore
628628
def set_anti_csrf(request: Request): # type: ignore
629629
global last_set_enable_anti_csrf
630630
data = request.data # type: ignore
@@ -663,7 +663,7 @@ def set_enable_jwt(request: Request): # type: ignore
663663

664664

665665
@api_view(["GET", "POST"])
666-
@renderer_classes([JSONRenderer]) # type: ignore
666+
@renderer_classes([JsonTextRenderer]) # type: ignore
667667
def feature_flags(request: Request): # type: ignore
668668
global last_set_enable_jwt
669669
return Response(
@@ -676,7 +676,7 @@ def feature_flags(request: Request): # type: ignore
676676

677677

678678
@api_view(["GET", "POST"])
679-
@renderer_classes([JSONRenderer]) # type: ignore
679+
@renderer_classes([JsonTextRenderer]) # type: ignore
680680
async def reinitialize(request: Request): # type: ignore
681681
global last_set_enable_jwt
682682
global last_set_enable_anti_csrf
@@ -711,7 +711,7 @@ async def get_session_called_time(request: Request): # type: ignore
711711

712712

713713
@api_view(["GET", "POST"])
714-
@renderer_classes([JSONRenderer]) # type: ignore
714+
@renderer_classes([JsonTextRenderer]) # type: ignore
715715
async def ping(request: Request): # type: ignore
716716
if request.method == "GET": # type: ignore
717717
return Response("success") # type: ignore
@@ -720,7 +720,7 @@ async def ping(request: Request): # type: ignore
720720

721721

722722
@api_view(["GET", "POST"])
723-
@renderer_classes([JSONRenderer]) # type: ignore
723+
@renderer_classes([JsonTextRenderer]) # type: ignore
724724
async def test_header(request: Request): # type: ignore
725725
if request.method == "GET": # type: ignore
726726
success_info = request.headers.get("st-custom-header") # type: ignore
@@ -730,7 +730,7 @@ async def test_header(request: Request): # type: ignore
730730

731731

732732
@api_view(["GET", "POST"])
733-
@renderer_classes([JSONRenderer]) # type: ignore
733+
@renderer_classes([JsonTextRenderer]) # type: ignore
734734
async def check_device_info(request: Request): # type: ignore
735735
if request.method == "GET": # type: ignore
736736
sdk_name = request.headers.get("supertokens-sdk-name") # type: ignore
@@ -745,14 +745,14 @@ async def check_device_info(request: Request): # type: ignore
745745

746746

747747
@api_view(["GET", "POST"])
748-
@renderer_classes([JSONRenderer]) # type: ignore
748+
@renderer_classes([JsonTextRenderer]) # type: ignore
749749
async def check_rid(request: Request): # type: ignore
750750
rid = request.headers.get("rid") # type: ignore
751751
return Response("fail" if rid is None else "success") # type: ignore
752752

753753

754754
@api_view(["GET", "POST"])
755-
@renderer_classes([JSONRenderer]) # type: ignore
755+
@renderer_classes([JsonTextRenderer]) # type: ignore
756756
async def check_allow_credentials(request: Request): # type: ignore
757757
if request.method == "GET": # type: ignore
758758
return Response("allow-credentials" in request.headers) # type: ignore
@@ -761,7 +761,7 @@ async def check_allow_credentials(request: Request): # type: ignore
761761

762762

763763
@api_view(["GET", "POST", "OPTIONS"])
764-
@renderer_classes([JSONRenderer]) # type: ignore
764+
@renderer_classes([JsonTextRenderer]) # type: ignore
765765
async def test_error(request: Request): # type: ignore
766766
if request.method == "OPTIONS": # type: ignore
767767
return send_options_api_response() # type: ignore

0 commit comments

Comments
 (0)