Skip to content

fix: drf fix #492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions tests/frontendIntegration/drf_async/polls/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from rest_framework import status # type: ignore
from rest_framework.decorators import api_view as api_view_sync, renderer_classes # type: ignore
from adrf.decorators import api_view # type: ignore
from rest_framework.renderers import JSONRenderer, StaticHTMLRenderer, BaseRenderer # type: ignore
from rest_framework.renderers import StaticHTMLRenderer, BaseRenderer # type: ignore
from rest_framework.request import Request # type: ignore
from rest_framework.response import Response # type: ignore
from supertokens_python import get_all_cors_headers
Expand Down Expand Up @@ -81,6 +81,15 @@
last_set_enable_jwt = False


class JsonTextRenderer(BaseRenderer): # type: ignore
media_type = "application/json"

def render(self, data, media_type=None, renderer_context=None): # type: ignore
if isinstance(data, str):
return data.encode("utf-8") # type: ignore
return json.dumps(data).encode("utf-8")


def custom_decorator_for_test(): # type: ignore
def session_verify_custom_test(f): # type: ignore
@wraps(f) # type: ignore
Expand Down Expand Up @@ -221,7 +230,7 @@ async def wrapped_function(request: Request, *args, **kwargs): # type: ignore


@api_view_sync(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
def try_refresh_token(_): # type: ignore
return Response(
{"error": "try refresh token"},
Expand All @@ -230,7 +239,7 @@ def try_refresh_token(_): # type: ignore


@api_view_sync(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
def unauthorised(_): # type: ignore
return Response(
{"error": "unauthorised"},
Expand Down Expand Up @@ -416,15 +425,6 @@ def send_options_api_response(): # type: ignore
return Response("") # type: ignore


class JsonTextRenderer(BaseRenderer): # type: ignore
media_type = "application/json"

def render(self, data, media_type=None, renderer_context=None): # type: ignore
if isinstance(data, dict):
return json.dumps(data).encode("utf-8")
return data.encode("utf-8") # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JsonTextRenderer]) # type: ignore
async def login(request: Request): # type: ignore
Expand Down Expand Up @@ -480,7 +480,7 @@ async def login_218(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def before_each(request: Request): # type: ignore
config(True, False, None) # type: ignore
if request.method == "POST": # type: ignore
Expand All @@ -491,7 +491,7 @@ async def before_each(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def test_config(request: Request): # type: ignore
if request.method == "POST": # type: ignore
return Response("") # type: ignore
Expand All @@ -500,7 +500,7 @@ async def test_config(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def multiple_interceptors(request: Request): # type: ignore
if request.method == "POST": # type: ignore
result_bool = (
Expand Down Expand Up @@ -538,7 +538,7 @@ async def update_jwt(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
@custom_decorator_for_update_jwt_with_handle()
@verify_session()
async def update_jwt_with_handle(request: Request): # type: ignore
Expand All @@ -557,21 +557,21 @@ async def validate(self, payload: JSONObject, user_context: Dict[str, Any]):


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
@verify_session(override_global_claim_validators=gcv_for_session_claim_err) # type: ignore
async def session_claim_error_api(request: Request): # type: ignore
return Response({}) # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def without_body_403(request: Request): # type: ignore
if request.method == "POST": # type: ignore
return Response("", status=403) # type: ignore


@api_view(["GET", "POST", "PUT", "DELETE"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def testing(request: Request): # type: ignore
if request.method in ["GET", "PUT", "POST", "DELETE"]: # type: ignore
if "testing" in request.headers: # type: ignore
Expand All @@ -593,7 +593,7 @@ async def logout(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
@verify_session()
async def revoke_all(request: Request): # type: ignore
if request.method: # type: ignore
Expand All @@ -616,15 +616,15 @@ def refresh_attempted_time(request: Request): # type: ignore


@api_view(["GET", "PUT", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
@custom_decorator_for_test()
@verify_session()
async def refresh(request: Request): # type: ignore
return Response("refresh success") # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
def set_anti_csrf(request: Request): # type: ignore
global last_set_enable_anti_csrf
data = request.data # type: ignore
Expand Down Expand Up @@ -663,7 +663,7 @@ def set_enable_jwt(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
def feature_flags(request: Request): # type: ignore
global last_set_enable_jwt
return Response(
Expand All @@ -676,7 +676,7 @@ def feature_flags(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def reinitialize(request: Request): # type: ignore
global last_set_enable_jwt
global last_set_enable_anti_csrf
Expand Down Expand Up @@ -711,7 +711,7 @@ async def get_session_called_time(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def ping(request: Request): # type: ignore
if request.method == "GET": # type: ignore
return Response("success") # type: ignore
Expand All @@ -720,7 +720,7 @@ async def ping(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def test_header(request: Request): # type: ignore
if request.method == "GET": # type: ignore
success_info = request.headers.get("st-custom-header") # type: ignore
Expand All @@ -730,7 +730,7 @@ async def test_header(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def check_device_info(request: Request): # type: ignore
if request.method == "GET": # type: ignore
sdk_name = request.headers.get("supertokens-sdk-name") # type: ignore
Expand All @@ -745,14 +745,14 @@ async def check_device_info(request: Request): # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def check_rid(request: Request): # type: ignore
rid = request.headers.get("rid") # type: ignore
return Response("fail" if rid is None else "success") # type: ignore


@api_view(["GET", "POST"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def check_allow_credentials(request: Request): # type: ignore
if request.method == "GET": # type: ignore
return Response("allow-credentials" in request.headers) # type: ignore
Expand All @@ -761,7 +761,7 @@ async def check_allow_credentials(request: Request): # type: ignore


@api_view(["GET", "POST", "OPTIONS"])
@renderer_classes([JSONRenderer]) # type: ignore
@renderer_classes([JsonTextRenderer]) # type: ignore
async def test_error(request: Request): # type: ignore
if request.method == "OPTIONS": # type: ignore
return send_options_api_response() # type: ignore
Expand Down
Loading