Skip to content

Commit 1a8c2b2

Browse files
authored
fix: better fastapi middleware (#505)
* fix: better fastapi middleware * fix: changelog and version * fix: comments
1 parent a66f6cc commit 1a8c2b2

File tree

4 files changed

+69
-29
lines changed

4 files changed

+69
-29
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
## [0.20.2] - 2024-05-17
12+
13+
- Improves FastAPI middleware performance using recommended ASGI middleware implementation.
14+
1115
## [0.20.1] - 2024-05-10
1216

1317
- Fixes parameter mismatch in generating fake email

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383

8484
setup(
8585
name="supertokens_python",
86-
version="0.20.1",
86+
version="0.20.2",
8787
author="SuperTokens",
8888
license="Apache 2.0",
8989
author_email="[email protected]",

supertokens_python/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
SUPPORTED_CDI_VERSIONS = ["3.0"]
17-
VERSION = "0.20.1"
17+
VERSION = "0.20.2"
1818
TELEMETRY = "/telemetry"
1919
USER_COUNT = "/users/count"
2020
USER_DELETE = "/user/remove"

supertokens_python/framework/fastapi/fastapi_middleware.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,64 +11,100 @@
1111
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
14-
from __future__ import annotations
14+
from typing import Union
1515

16-
from typing import TYPE_CHECKING, Union
1716

18-
from supertokens_python.framework import BaseResponse
17+
def get_middleware():
18+
from supertokens_python import Supertokens
19+
from supertokens_python.utils import default_user_context
20+
from supertokens_python.exceptions import SuperTokensError
21+
from supertokens_python.framework import BaseResponse
22+
from supertokens_python.recipe.session import SessionContainer
23+
from supertokens_python.supertokens import manage_session_post_response
1924

20-
if TYPE_CHECKING:
21-
from fastapi import Request
25+
from starlette.requests import Request
26+
from starlette.responses import Response
27+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
2228

29+
from supertokens_python.framework.fastapi.fastapi_request import (
30+
FastApiRequest,
31+
)
32+
from supertokens_python.framework.fastapi.fastapi_response import (
33+
FastApiResponse,
34+
)
2335

24-
def get_middleware():
25-
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
26-
from supertokens_python.utils import default_user_context
36+
class ASGIMiddleware:
37+
def __init__(self, app: ASGIApp) -> None:
38+
self.app = app
2739

28-
class Middleware(BaseHTTPMiddleware):
29-
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
30-
from supertokens_python import Supertokens
31-
from supertokens_python.exceptions import SuperTokensError
32-
from supertokens_python.framework.fastapi.fastapi_request import (
33-
FastApiRequest,
34-
)
35-
from supertokens_python.framework.fastapi.fastapi_response import (
36-
FastApiResponse,
37-
)
38-
from supertokens_python.recipe.session import SessionContainer
39-
from supertokens_python.supertokens import manage_session_post_response
40+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
41+
if scope["type"] != "http": # we pass through the non-http requests, if any
42+
await self.app(scope, receive, send)
43+
return
4044

4145
st = Supertokens.get_instance()
42-
from fastapi.responses import Response
4346

47+
request = Request(scope, receive=receive)
4448
custom_request = FastApiRequest(request)
45-
response = FastApiResponse(Response())
4649
user_context = default_user_context(custom_request)
4750

4851
try:
52+
response = FastApiResponse(Response())
4953
result: Union[BaseResponse, None] = await st.middleware(
5054
custom_request, response, user_context
5155
)
5256
if result is None:
53-
response = await call_next(request)
54-
result = FastApiResponse(response)
57+
# This means that the supertokens middleware did not handle the request,
58+
# however, we may need to handle the header changes in the response,
59+
# based on response mutators used by the session.
60+
async def send_wrapper(message: Message):
61+
if message["type"] == "http.response.start":
62+
# Start message has the headers, so we update the headers here
63+
# by using `manage_session_post_response` function, which will
64+
# apply all the Response Mutators. In the end, we just replace
65+
# the updated headers in the message.
66+
if hasattr(request.state, "supertokens") and isinstance(
67+
request.state.supertokens, SessionContainer
68+
):
69+
fapi_response = Response()
70+
fapi_response.raw_headers = message["headers"]
71+
response = FastApiResponse(fapi_response)
72+
manage_session_post_response(
73+
request.state.supertokens, response, user_context
74+
)
75+
message["headers"] = fapi_response.raw_headers
5576

77+
# For `http.response.start` message, we might have the headers updated,
78+
# otherwise, we just send all the messages as is
79+
await send(message)
80+
81+
await self.app(scope, receive, send_wrapper)
82+
return
83+
84+
# This means that the request was handled by the supertokens middleware
85+
# and hence we respond using the response object returned by the middleware.
5686
if hasattr(request.state, "supertokens") and isinstance(
5787
request.state.supertokens, SessionContainer
5888
):
5989
manage_session_post_response(
6090
request.state.supertokens, result, user_context
6191
)
92+
6293
if isinstance(result, FastApiResponse):
63-
return result.response
94+
await result.response(scope, receive, send)
95+
return
96+
97+
return
98+
6499
except SuperTokensError as e:
65100
response = FastApiResponse(Response())
66101
result: Union[BaseResponse, None] = await st.handle_supertokens_error(
67102
FastApiRequest(request), e, response, user_context
68103
)
69104
if isinstance(result, FastApiResponse):
70-
return result.response
105+
await result.response(scope, receive, send)
106+
return
71107

72108
raise Exception("Should never come here")
73109

74-
return Middleware
110+
return ASGIMiddleware

0 commit comments

Comments
 (0)