Skip to content

Handle SSE Disconnects Properly #612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/servers/simple-prompt/mcp_simple_prompt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def get_prompt(
if transport == "sse":
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route

sse = SseServerTransport("/messages/")
Expand All @@ -101,6 +102,7 @@ async def handle_sse(request):
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
return Response()

starlette_app = Starlette(
debug=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def read_resource(uri: FileUrl) -> str | bytes:
if transport == "sse":
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route

sse = SseServerTransport("/messages/")
Expand All @@ -57,11 +58,12 @@ async def handle_sse(request):
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
return Response()

starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message),
],
)
Expand Down
4 changes: 3 additions & 1 deletion examples/servers/simple-tool/mcp_simple_tool/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def list_tools() -> list[types.Tool]:
if transport == "sse":
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route

sse = SseServerTransport("/messages/")
Expand All @@ -71,11 +72,12 @@ async def handle_sse(request):
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
return Response()

starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message),
],
)
Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send):
streams[1],
self._mcp_server.create_initialization_options(),
)
return Response()

# Create routes
routes: list[Route | Mount] = []
Expand Down
7 changes: 4 additions & 3 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def __init__(
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)

@property
def client_params(self) -> types.InitializeRequestParams | None:
Expand Down Expand Up @@ -137,6 +134,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:

return True

async def _receive_loop(self) -> None:
async with self._incoming_message_stream_writer:
await super()._receive_loop()

async def _received_request(
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
):
Expand Down
27 changes: 22 additions & 5 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# Create Starlette routes for SSE and message handling
routes = [
Route("/sse", endpoint=handle_sse),
Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message),
]

Expand All @@ -22,12 +22,18 @@ async def handle_sse(request):
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
# Return empty response to avoid NoneType error
return Response()

# Create and run Starlette app
starlette_app = Starlette(routes=routes)
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
```

Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType'
object is not callable" error when client disconnects. The example above returns
an empty Response() after the SSE connection ends to fix this.

See SseServerTransport class documentation for more details.
"""

Expand Down Expand Up @@ -121,11 +127,22 @@ async def sse_writer():
)

async with anyio.create_task_group() as tg:
response = EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)

async def response_wrapper(scope: Scope, receive: Receive, send: Send):
"""
The EventSourceResponse returning signals a client close / disconnect.
In this case we close our side of the streams to signal the client that
the connection has been closed.
"""
await EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)(scope, receive, send)
await read_stream_writer.aclose()
await write_stream_reader.aclose()
logging.debug(f"Client session disconnected {session_id}")

logger.debug("Starting SSE response task")
tg.start_soon(response, scope, receive, send)
tg.start_soon(response_wrapper, scope, receive, send)

logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
Expand Down
4 changes: 3 additions & 1 deletion tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount, Route

from mcp.client.session import ClientSession
Expand Down Expand Up @@ -83,13 +84,14 @@ def make_server_app() -> Starlette:
sse = SseServerTransport("/messages/")
server = ServerTest()

async def handle_sse(request: Request) -> None:
async def handle_sse(request: Request) -> Response:
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
return Response()

app = Starlette(
routes=[
Expand Down
Loading