|
11 | 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
12 | 12 | # License for the specific language governing permissions and limitations
|
13 | 13 | # under the License.
|
14 |
| -from __future__ import annotations |
| 14 | +from typing import Union |
15 | 15 |
|
16 |
| -from typing import TYPE_CHECKING, Union |
17 | 16 |
|
18 |
| -from supertokens_python.framework import BaseResponse |
| 17 | +def get_middleware(): |
| 18 | + from supertokens_python import Supertokens |
| 19 | + from supertokens_python.utils import default_user_context |
| 20 | + from supertokens_python.exceptions import SuperTokensError |
| 21 | + from supertokens_python.framework import BaseResponse |
| 22 | + from supertokens_python.recipe.session import SessionContainer |
| 23 | + from supertokens_python.supertokens import manage_session_post_response |
19 | 24 |
|
20 |
| -if TYPE_CHECKING: |
21 |
| - from fastapi import Request |
| 25 | + from starlette.requests import Request |
| 26 | + from starlette.responses import Response |
| 27 | + from starlette.types import ASGIApp, Message, Receive, Scope, Send |
22 | 28 |
|
| 29 | + from supertokens_python.framework.fastapi.fastapi_request import ( |
| 30 | + FastApiRequest, |
| 31 | + ) |
| 32 | + from supertokens_python.framework.fastapi.fastapi_response import ( |
| 33 | + FastApiResponse, |
| 34 | + ) |
23 | 35 |
|
24 |
| -def get_middleware(): |
25 |
| - from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint |
26 |
| - from supertokens_python.utils import default_user_context |
| 36 | + class ASGIMiddleware: |
| 37 | + def __init__(self, app: ASGIApp) -> None: |
| 38 | + self.app = app |
27 | 39 |
|
28 |
| - class Middleware(BaseHTTPMiddleware): |
29 |
| - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): |
30 |
| - from supertokens_python import Supertokens |
31 |
| - from supertokens_python.exceptions import SuperTokensError |
32 |
| - from supertokens_python.framework.fastapi.fastapi_request import ( |
33 |
| - FastApiRequest, |
34 |
| - ) |
35 |
| - from supertokens_python.framework.fastapi.fastapi_response import ( |
36 |
| - FastApiResponse, |
37 |
| - ) |
38 |
| - from supertokens_python.recipe.session import SessionContainer |
39 |
| - from supertokens_python.supertokens import manage_session_post_response |
| 40 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 41 | + if scope["type"] != "http": # we pass through the non-http requests, if any |
| 42 | + await self.app(scope, receive, send) |
| 43 | + return |
40 | 44 |
|
41 | 45 | st = Supertokens.get_instance()
|
42 |
| - from fastapi.responses import Response |
43 | 46 |
|
| 47 | + request = Request(scope, receive=receive) |
44 | 48 | custom_request = FastApiRequest(request)
|
45 |
| - response = FastApiResponse(Response()) |
46 | 49 | user_context = default_user_context(custom_request)
|
47 | 50 |
|
48 | 51 | try:
|
| 52 | + response = FastApiResponse(Response()) |
49 | 53 | result: Union[BaseResponse, None] = await st.middleware(
|
50 | 54 | custom_request, response, user_context
|
51 | 55 | )
|
52 | 56 | if result is None:
|
53 |
| - response = await call_next(request) |
54 |
| - result = FastApiResponse(response) |
| 57 | + # This means that the supertokens middleware did not handle the request, |
| 58 | + # however, we may need to handle the header changes in the response, |
| 59 | + # based on response mutators used by the session. |
| 60 | + async def send_wrapper(message: Message): |
| 61 | + if message["type"] == "http.response.start": |
| 62 | + # Start message has the headers, so we update the headers here |
| 63 | + # by using `manage_session_post_response` function, which will |
| 64 | + # apply all the Response Mutators. In the end, we just replace |
| 65 | + # the updated headers in the message. |
| 66 | + if hasattr(request.state, "supertokens") and isinstance( |
| 67 | + request.state.supertokens, SessionContainer |
| 68 | + ): |
| 69 | + fapi_response = Response() |
| 70 | + fapi_response.raw_headers = message["headers"] |
| 71 | + response = FastApiResponse(fapi_response) |
| 72 | + manage_session_post_response( |
| 73 | + request.state.supertokens, response, user_context |
| 74 | + ) |
| 75 | + message["headers"] = fapi_response.raw_headers |
55 | 76 |
|
| 77 | + # For `http.response.start` message, we might have the headers updated, |
| 78 | + # otherwise, we just send all the messages as is |
| 79 | + await send(message) |
| 80 | + |
| 81 | + await self.app(scope, receive, send_wrapper) |
| 82 | + return |
| 83 | + |
| 84 | + # This means that the request was handled by the supertokens middleware |
| 85 | + # and hence we respond using the response object returned by the middleware. |
56 | 86 | if hasattr(request.state, "supertokens") and isinstance(
|
57 | 87 | request.state.supertokens, SessionContainer
|
58 | 88 | ):
|
59 | 89 | manage_session_post_response(
|
60 | 90 | request.state.supertokens, result, user_context
|
61 | 91 | )
|
| 92 | + |
62 | 93 | if isinstance(result, FastApiResponse):
|
63 |
| - return result.response |
| 94 | + await result.response(scope, receive, send) |
| 95 | + return |
| 96 | + |
| 97 | + return |
| 98 | + |
64 | 99 | except SuperTokensError as e:
|
65 | 100 | response = FastApiResponse(Response())
|
66 | 101 | result: Union[BaseResponse, None] = await st.handle_supertokens_error(
|
67 | 102 | FastApiRequest(request), e, response, user_context
|
68 | 103 | )
|
69 | 104 | if isinstance(result, FastApiResponse):
|
70 |
| - return result.response |
| 105 | + await result.response(scope, receive, send) |
| 106 | + return |
71 | 107 |
|
72 | 108 | raise Exception("Should never come here")
|
73 | 109 |
|
74 |
| - return Middleware |
| 110 | + return ASGIMiddleware |
0 commit comments