|
13 | 13 | # under the License.
|
14 | 14 |
|
15 | 15 | import json
|
| 16 | +from unittest.mock import Mock |
16 | 17 | from urllib.parse import urlencode
|
17 | 18 | from datetime import datetime
|
18 | 19 | from inspect import isawaitable
|
19 | 20 | from base64 import b64encode
|
20 | 21 | from typing import Any, Dict, Union
|
21 | 22 |
|
22 | 23 | from django.http import HttpRequest, HttpResponse, JsonResponse
|
23 |
| -from django.test import RequestFactory, TestCase |
| 24 | +from django.test import RequestFactory, TestCase, override_settings |
24 | 25 |
|
25 | 26 | from supertokens_python import InputAppInfo, SupertokensConfig, init
|
26 | 27 | from supertokens_python.framework.django import middleware
|
| 28 | +from supertokens_python.framework.django.django_request import DjangoRequest |
27 | 29 | from supertokens_python.framework.django.django_response import (
|
28 | 30 | DjangoResponse as SuperTokensDjangoWrapper,
|
29 | 31 | )
|
@@ -1000,3 +1002,164 @@ def test_remove_header_works():
|
1000 | 1002 | assert st_response.get_header("foo") == "bar"
|
1001 | 1003 | st_response.remove_header("foo")
|
1002 | 1004 | assert st_response.get_header("foo") is None
|
| 1005 | + |
| 1006 | + |
| 1007 | +class DjangoRequestTest(TestCase): |
| 1008 | + def setUp(self): |
| 1009 | + self.factory = RequestFactory() |
| 1010 | + |
| 1011 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1012 | + def test_get_original_url(self): |
| 1013 | + request = self.factory.get("/some/url/path") |
| 1014 | + custom_request = DjangoRequest(request) |
| 1015 | + assert custom_request.get_original_url() == "http://testserver/some/url/path" |
| 1016 | + |
| 1017 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1018 | + def test_get_query_param(self): |
| 1019 | + request = self.factory.get( |
| 1020 | + "/some/url/path", data={"key1": "value1", "key2": "value2"} |
| 1021 | + ) |
| 1022 | + custom_request = DjangoRequest(request) |
| 1023 | + assert custom_request.get_query_param("key1") == "value1" |
| 1024 | + assert ( |
| 1025 | + custom_request.get_query_param("key3", default="default_value") |
| 1026 | + == "default_value" |
| 1027 | + ) |
| 1028 | + |
| 1029 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1030 | + def test_get_query_params(self): |
| 1031 | + request = self.factory.get( |
| 1032 | + "/some/url/path", data={"key1": "value1", "key2": "value2"} |
| 1033 | + ) |
| 1034 | + custom_request = DjangoRequest(request) |
| 1035 | + assert custom_request.get_query_params() == {"key1": "value1", "key2": "value2"} |
| 1036 | + |
| 1037 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1038 | + async def test_json(self): |
| 1039 | + request = self.factory.post( |
| 1040 | + "/some/url/path", |
| 1041 | + data=json.dumps({"key": "value"}), |
| 1042 | + content_type="application/json", |
| 1043 | + ) |
| 1044 | + custom_request = DjangoRequest(request) |
| 1045 | + assert await custom_request.json() == {"key": "value"} |
| 1046 | + |
| 1047 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1048 | + def test_method(self): |
| 1049 | + request = self.factory.get("/some/url/path") |
| 1050 | + custom_request = DjangoRequest(request) |
| 1051 | + assert custom_request.method() == "GET" |
| 1052 | + |
| 1053 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1054 | + def test_get_cookie(self): |
| 1055 | + request = self.factory.get("/some/url/path") |
| 1056 | + request.COOKIES["cookie_key"] = "cookie_value" |
| 1057 | + custom_request = DjangoRequest(request) |
| 1058 | + assert custom_request.get_cookie("cookie_key") == "cookie_value" |
| 1059 | + |
| 1060 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1061 | + def test_get_header(self): |
| 1062 | + request = self.factory.get("/some/url/path", HTTP_CUSTOM_HEADER="header_value") |
| 1063 | + custom_request = DjangoRequest(request) |
| 1064 | + assert custom_request.get_header("custom-header") == "header_value" |
| 1065 | + |
| 1066 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1067 | + def test_get_session(self): |
| 1068 | + request = self.factory.get("/some/url/path") |
| 1069 | + session = Mock() |
| 1070 | + request.supertokens = session # type: ignore |
| 1071 | + custom_request = DjangoRequest(request) |
| 1072 | + assert custom_request.get_session() == session |
| 1073 | + |
| 1074 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1075 | + def test_set_session(self): |
| 1076 | + request = self.factory.get("/some/url/path") |
| 1077 | + custom_request = DjangoRequest(request) |
| 1078 | + session = Mock() |
| 1079 | + custom_request.set_session(session) |
| 1080 | + assert custom_request.get_session() == session |
| 1081 | + |
| 1082 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1083 | + def test_set_session_as_none(self): |
| 1084 | + request = self.factory.get("/some/url/path") |
| 1085 | + custom_request = DjangoRequest(request) |
| 1086 | + session = Mock() |
| 1087 | + custom_request.set_session(session) |
| 1088 | + custom_request.set_session_as_none() |
| 1089 | + assert custom_request.get_session() is None |
| 1090 | + |
| 1091 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1092 | + def test_get_path(self): |
| 1093 | + request = self.factory.get("/some/url/path") |
| 1094 | + custom_request = DjangoRequest(request) |
| 1095 | + assert custom_request.get_path() == "/some/url/path" |
| 1096 | + |
| 1097 | + @override_settings(ALLOWED_HOSTS=["testserver"]) |
| 1098 | + async def test_form_data(self): |
| 1099 | + data = {"key": "value"} |
| 1100 | + request = self.factory.post( |
| 1101 | + "/some/url/path", |
| 1102 | + data=urlencode(data), |
| 1103 | + content_type="application/x-www-form-urlencoded", |
| 1104 | + ) |
| 1105 | + custom_request = DjangoRequest(request) |
| 1106 | + assert await custom_request.form_data() == data |
| 1107 | + |
| 1108 | + |
| 1109 | +class DjangoResponseTest(TestCase): |
| 1110 | + def setUp(self): |
| 1111 | + self.factory = RequestFactory() |
| 1112 | + |
| 1113 | + def test_set_html_content(self): |
| 1114 | + response = HttpResponse() |
| 1115 | + custom_response = SuperTokensDjangoWrapper(response) |
| 1116 | + custom_response.set_html_content("<html><body>Hello, World!</body></html>") |
| 1117 | + |
| 1118 | + self.assertEqual(response["Content-Type"], "text/html") |
| 1119 | + self.assertEqual(response.content, b"<html><body>Hello, World!</body></html>") |
| 1120 | + |
| 1121 | + def test_set_cookie(self): |
| 1122 | + response = HttpResponse() |
| 1123 | + custom_response = SuperTokensDjangoWrapper(response) |
| 1124 | + custom_response.set_cookie("cookie_key", "cookie_value", expires=1000) |
| 1125 | + |
| 1126 | + self.assertIn("cookie_key", response.cookies) |
| 1127 | + self.assertEqual(response.cookies["cookie_key"].value, "cookie_value") |
| 1128 | + self.assertIsNotNone(response.cookies["cookie_key"]["expires"]) |
| 1129 | + |
| 1130 | + def test_set_status_code(self): |
| 1131 | + response = HttpResponse() |
| 1132 | + custom_response = SuperTokensDjangoWrapper(response) |
| 1133 | + custom_response.set_status_code(404) |
| 1134 | + |
| 1135 | + self.assertEqual(response.status_code, 404) |
| 1136 | + |
| 1137 | + def test_set_header(self): |
| 1138 | + response = HttpResponse() |
| 1139 | + custom_response = SuperTokensDjangoWrapper(response) |
| 1140 | + custom_response.set_header("Custom-Header", "Custom-Value") |
| 1141 | + |
| 1142 | + self.assertEqual(response["Custom-Header"], "Custom-Value") |
| 1143 | + |
| 1144 | + def test_get_header(self): |
| 1145 | + response = HttpResponse() |
| 1146 | + response["Custom-Header"] = "Custom-Value" |
| 1147 | + custom_response = SuperTokensDjangoWrapper(response) |
| 1148 | + |
| 1149 | + self.assertEqual(custom_response.get_header("Custom-Header"), "Custom-Value") |
| 1150 | + |
| 1151 | + def test_remove_header(self): |
| 1152 | + response = HttpResponse() |
| 1153 | + response["Custom-Header"] = "Custom-Value" |
| 1154 | + custom_response = SuperTokensDjangoWrapper(response) |
| 1155 | + |
| 1156 | + custom_response.remove_header("Custom-Header") |
| 1157 | + self.assertNotIn("Custom-Header", response) |
| 1158 | + |
| 1159 | + def test_set_json_content(self): |
| 1160 | + response = HttpResponse() |
| 1161 | + custom_response = SuperTokensDjangoWrapper(response) |
| 1162 | + custom_response.set_json_content({"key": "value"}) |
| 1163 | + |
| 1164 | + self.assertEqual(response["Content-Type"], "application/json; charset=utf-8") |
| 1165 | + self.assertEqual(response.content, b'{"key":"value"}') |
0 commit comments