@@ -39,67 +39,103 @@ <h1 class="title">Module <code>supertokens_python.framework.fastapi.fastapi_midd
39
39
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
40
40
# License for the specific language governing permissions and limitations
41
41
# under the License.
42
- from __future__ import annotations
42
+ from typing import Union
43
43
44
- from typing import TYPE_CHECKING, Union
45
44
46
- from supertokens_python.framework import BaseResponse
45
+ def get_middleware():
46
+ from supertokens_python import Supertokens
47
+ from supertokens_python.utils import default_user_context
48
+ from supertokens_python.exceptions import SuperTokensError
49
+ from supertokens_python.framework import BaseResponse
50
+ from supertokens_python.recipe.session import SessionContainer
51
+ from supertokens_python.supertokens import manage_session_post_response
47
52
48
- if TYPE_CHECKING:
49
- from fastapi import Request
53
+ from starlette.requests import Request
54
+ from starlette.responses import Response
55
+ from starlette.types import ASGIApp, Message, Receive, Scope, Send
50
56
57
+ from supertokens_python.framework.fastapi.fastapi_request import (
58
+ FastApiRequest,
59
+ )
60
+ from supertokens_python.framework.fastapi.fastapi_response import (
61
+ FastApiResponse,
62
+ )
51
63
52
- def get_middleware() :
53
- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
54
- from supertokens_python.utils import default_user_context
64
+ class ASGIMiddleware :
65
+ def __init__(self, app: ASGIApp) -> None:
66
+ self.app = app
55
67
56
- class Middleware(BaseHTTPMiddleware):
57
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
58
- from supertokens_python import Supertokens
59
- from supertokens_python.exceptions import SuperTokensError
60
- from supertokens_python.framework.fastapi.fastapi_request import (
61
- FastApiRequest,
62
- )
63
- from supertokens_python.framework.fastapi.fastapi_response import (
64
- FastApiResponse,
65
- )
66
- from supertokens_python.recipe.session import SessionContainer
67
- from supertokens_python.supertokens import manage_session_post_response
68
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
69
+ if scope["type"] != "http": # we pass through the non-http requests, if any
70
+ await self.app(scope, receive, send)
71
+ return
68
72
69
73
st = Supertokens.get_instance()
70
- from fastapi.responses import Response
71
74
75
+ request = Request(scope, receive=receive)
72
76
custom_request = FastApiRequest(request)
73
- response = FastApiResponse(Response())
74
77
user_context = default_user_context(custom_request)
75
78
76
79
try:
80
+ response = FastApiResponse(Response())
77
81
result: Union[BaseResponse, None] = await st.middleware(
78
82
custom_request, response, user_context
79
83
)
80
84
if result is None:
81
- response = await call_next(request)
82
- result = FastApiResponse(response)
85
+ # This means that the supertokens middleware did not handle the request,
86
+ # however, we may need to handle the header changes in the response,
87
+ # based on response mutators used by the session.
88
+ async def send_wrapper(message: Message):
89
+ if message["type"] == "http.response.start":
90
+ # Start message has the headers, so we update the headers here
91
+ # by using `manage_session_post_response` function, which will
92
+ # apply all the Response Mutators. In the end, we just replace
93
+ # the updated headers in the message.
94
+ if hasattr(request.state, "supertokens") and isinstance(
95
+ request.state.supertokens, SessionContainer
96
+ ):
97
+ fapi_response = Response()
98
+ fapi_response.raw_headers = message["headers"]
99
+ response = FastApiResponse(fapi_response)
100
+ manage_session_post_response(
101
+ request.state.supertokens, response, user_context
102
+ )
103
+ message["headers"] = fapi_response.raw_headers
83
104
105
+ # For `http.response.start` message, we might have the headers updated,
106
+ # otherwise, we just send all the messages as is
107
+ await send(message)
108
+
109
+ await self.app(scope, receive, send_wrapper)
110
+ return
111
+
112
+ # This means that the request was handled by the supertokens middleware
113
+ # and hence we respond using the response object returned by the middleware.
84
114
if hasattr(request.state, "supertokens") and isinstance(
85
115
request.state.supertokens, SessionContainer
86
116
):
87
117
manage_session_post_response(
88
118
request.state.supertokens, result, user_context
89
119
)
120
+
90
121
if isinstance(result, FastApiResponse):
91
- return result.response
122
+ await result.response(scope, receive, send)
123
+ return
124
+
125
+ return
126
+
92
127
except SuperTokensError as e:
93
128
response = FastApiResponse(Response())
94
129
result: Union[BaseResponse, None] = await st.handle_supertokens_error(
95
130
FastApiRequest(request), e, response, user_context
96
131
)
97
132
if isinstance(result, FastApiResponse):
98
- return result.response
133
+ await result.response(scope, receive, send)
134
+ return
99
135
100
136
raise Exception("Should never come here")
101
137
102
- return Middleware </ code > </ pre >
138
+ return ASGIMiddleware </ code > </ pre >
103
139
</ details >
104
140
</ section >
105
141
< section >
@@ -119,56 +155,99 @@ <h2 class="section-title" id="header-functions">Functions</h2>
119
155
< span > Expand source code</ span >
120
156
</ summary >
121
157
< pre > < code class ="python "> def get_middleware():
122
- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
158
+ from supertokens_python import Supertokens
123
159
from supertokens_python.utils import default_user_context
160
+ from supertokens_python.exceptions import SuperTokensError
161
+ from supertokens_python.framework import BaseResponse
162
+ from supertokens_python.recipe.session import SessionContainer
163
+ from supertokens_python.supertokens import manage_session_post_response
124
164
125
- class Middleware(BaseHTTPMiddleware):
126
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
127
- from supertokens_python import Supertokens
128
- from supertokens_python.exceptions import SuperTokensError
129
- from supertokens_python.framework.fastapi.fastapi_request import (
130
- FastApiRequest,
131
- )
132
- from supertokens_python.framework.fastapi.fastapi_response import (
133
- FastApiResponse,
134
- )
135
- from supertokens_python.recipe.session import SessionContainer
136
- from supertokens_python.supertokens import manage_session_post_response
165
+ from starlette.requests import Request
166
+ from starlette.responses import Response
167
+ from starlette.types import ASGIApp, Message, Receive, Scope, Send
168
+
169
+ from supertokens_python.framework.fastapi.fastapi_request import (
170
+ FastApiRequest,
171
+ )
172
+ from supertokens_python.framework.fastapi.fastapi_response import (
173
+ FastApiResponse,
174
+ )
175
+
176
+ class ASGIMiddleware:
177
+ def __init__(self, app: ASGIApp) -> None:
178
+ self.app = app
179
+
180
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
181
+ if scope["type"] != "http": # we pass through the non-http requests, if any
182
+ await self.app(scope, receive, send)
183
+ return
137
184
138
185
st = Supertokens.get_instance()
139
- from fastapi.responses import Response
140
186
187
+ request = Request(scope, receive=receive)
141
188
custom_request = FastApiRequest(request)
142
- response = FastApiResponse(Response())
143
189
user_context = default_user_context(custom_request)
144
190
145
191
try:
192
+ response = FastApiResponse(Response())
146
193
result: Union[BaseResponse, None] = await st.middleware(
147
194
custom_request, response, user_context
148
195
)
149
196
if result is None:
150
- response = await call_next(request)
151
- result = FastApiResponse(response)
197
+ # This means that the supertokens middleware did not handle the request,
198
+ # however, we may need to handle the header changes in the response,
199
+ # based on response mutators used by the session.
200
+ async def send_wrapper(message: Message):
201
+ if message["type"] == "http.response.start":
202
+ # Start message has the headers, so we update the headers here
203
+ # by using `manage_session_post_response` function, which will
204
+ # apply all the Response Mutators. In the end, we just replace
205
+ # the updated headers in the message.
206
+ if hasattr(request.state, "supertokens") and isinstance(
207
+ request.state.supertokens, SessionContainer
208
+ ):
209
+ fapi_response = Response()
210
+ fapi_response.raw_headers = message["headers"]
211
+ response = FastApiResponse(fapi_response)
212
+ manage_session_post_response(
213
+ request.state.supertokens, response, user_context
214
+ )
215
+ message["headers"] = fapi_response.raw_headers
216
+
217
+ # For `http.response.start` message, we might have the headers updated,
218
+ # otherwise, we just send all the messages as is
219
+ await send(message)
152
220
221
+ await self.app(scope, receive, send_wrapper)
222
+ return
223
+
224
+ # This means that the request was handled by the supertokens middleware
225
+ # and hence we respond using the response object returned by the middleware.
153
226
if hasattr(request.state, "supertokens") and isinstance(
154
227
request.state.supertokens, SessionContainer
155
228
):
156
229
manage_session_post_response(
157
230
request.state.supertokens, result, user_context
158
231
)
232
+
159
233
if isinstance(result, FastApiResponse):
160
- return result.response
234
+ await result.response(scope, receive, send)
235
+ return
236
+
237
+ return
238
+
161
239
except SuperTokensError as e:
162
240
response = FastApiResponse(Response())
163
241
result: Union[BaseResponse, None] = await st.handle_supertokens_error(
164
242
FastApiRequest(request), e, response, user_context
165
243
)
166
244
if isinstance(result, FastApiResponse):
167
- return result.response
245
+ await result.response(scope, receive, send)
246
+ return
168
247
169
248
raise Exception("Should never come here")
170
249
171
- return Middleware </ code > </ pre >
250
+ return ASGIMiddleware </ code > </ pre >
172
251
</ details >
173
252
</ dd >
174
253
</ dl >
0 commit comments