Skip to content

Fix: Prevent session manager shutdown on individual session crash #841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ async def sse_writer():
):
break
except Exception as e:
logger.exception(f"Error in SSE writer: {e}")
logger.warning(f"Error in SSE writer: {e}", exc_info=True)
finally:
logger.debug("Closing SSE writer")
await self._clean_up_memory_streams(request_id)
Expand Down Expand Up @@ -517,13 +517,13 @@ async def sse_writer():
session_message = SessionMessage(message, metadata=metadata)
await writer.send(session_message)
except Exception:
logger.exception("SSE response error")
logger.warning("SSE response error", exc_info=True)
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(request_id)

except Exception as err:
logger.exception("Error handling POST request")
logger.warning("Error handling POST request", exc_info=True)
response = self._create_error_response(
f"Error handling POST request: {err}",
HTTPStatus.INTERNAL_SERVER_ERROR,
Expand Down Expand Up @@ -610,7 +610,7 @@ async def standalone_sse_writer():
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
except Exception as e:
logger.exception(f"Error in standalone SSE writer: {e}")
logger.warning(f"Error in standalone SSE writer: {e}", exc_info=True)
finally:
logger.debug("Closing standalone SSE writer")
await self._clean_up_memory_streams(GET_STREAM_KEY)
Expand All @@ -626,7 +626,7 @@ async def standalone_sse_writer():
# This will send headers immediately and establish the SSE connection
await response(request.scope, request.receive, send)
except Exception as e:
logger.exception(f"Error in standalone SSE response: {e}")
logger.warning(f"Error in standalone SSE response: {e}", exc_info=True)
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(GET_STREAM_KEY)
Expand Down
53 changes: 40 additions & 13 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class StreamableHTTPSessionManager:
json_response: Whether to use JSON responses instead of SSE streams
stateless: If True, creates a completely fresh transport for each request
with no session tracking or state persistence between requests.
"""

def __init__(
Expand Down Expand Up @@ -171,12 +170,15 @@ async def run_stateless_server(
async with http_transport.connect() as streams:
read_stream, write_stream = streams
task_status.started()
await self.app.run(
read_stream,
write_stream,
self.app.create_initialization_options(),
stateless=True,
)
try:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we look into .run function to see how to handle the error there?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NAVNAV221 That should also happen, but there should be no scenarios in which a per-session error is allowed to destabilize the entire server until reboot. This is a catch-all to make sure that the server as a whole survives any errors where proper error handling was missed.

await self.app.run(
read_stream,
write_stream,
self.app.create_initialization_options(),
stateless=True,
)
except Exception as e:
logger.warning(f"Stateless session crashed: {e}", exc_info=True)

# Assert task group is not None for type checking
assert self._task_group is not None
Expand Down Expand Up @@ -235,12 +237,37 @@ async def run_server(
async with http_transport.connect() as streams:
read_stream, write_stream = streams
task_status.started()
await self.app.run(
read_stream,
write_stream,
self.app.create_initialization_options(),
stateless=False, # Stateful mode
)
try:
await self.app.run(
read_stream,
write_stream,
self.app.create_initialization_options(),
stateless=False, # Stateful mode
)
except Exception as e:
logger.warning(
f"Session {http_transport.mcp_session_id} crashed: {e}",
exc_info=True,
)
finally:
# Only remove from instances if not terminated
if (
http_transport.mcp_session_id
and http_transport.mcp_session_id
in self._server_instances
and not (
hasattr(http_transport, "_terminated")
and http_transport._terminated # pyright: ignore
)
):
logger.info(
"Cleaning up crashed session "
f"{http_transport.mcp_session_id} from "
"active instances."
)
del self._server_instances[
http_transport.mcp_session_id
]

# Assert task group is not None for type checking
assert self._task_group is not None
Expand Down
128 changes: 128 additions & 0 deletions tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Tests for StreamableHTTPSessionManager."""

from unittest.mock import AsyncMock

import anyio
import pytest

from mcp.server.lowlevel import Server
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager


Expand Down Expand Up @@ -79,3 +82,128 @@ async def send(message):
assert "Task group is not initialized. Make sure to use run()." in str(
excinfo.value
)


class TestException(Exception):
__test__ = False # Prevent pytest from collecting this as a test class
pass


@pytest.fixture
async def running_manager():
app = Server("test-cleanup-server")
# It's important that the app instance used by the manager is the one we can patch
manager = StreamableHTTPSessionManager(app=app)
async with manager.run():
# Patch app.run here if it's simpler, or patch it within the test
yield manager, app


@pytest.mark.anyio
async def test_stateful_session_cleanup_on_graceful_exit(running_manager):
manager, app = running_manager

mock_mcp_run = AsyncMock(return_value=None)
# This will be called by StreamableHTTPSessionManager's run_server -> self.app.run
app.run = mock_mcp_run

sent_messages = []

async def mock_send(message):
sent_messages.append(message)

scope = {"type": "http", "method": "POST", "path": "/mcp", "headers": []}

async def mock_receive():
return {"type": "http.request", "body": b"", "more_body": False}

# Trigger session creation
await manager.handle_request(scope, mock_receive, mock_send)

# Extract session ID from response headers
session_id = None
for msg in sent_messages:
if msg["type"] == "http.response.start":
for header_name, header_value in msg.get("headers", []):
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
session_id = header_value.decode()
break
if session_id: # Break outer loop if session_id is found
break

assert session_id is not None, "Session ID not found in response headers"

# Ensure MCPServer.run was called
mock_mcp_run.assert_called_once()

# At this point, mock_mcp_run has completed, and the finally block in
# StreamableHTTPSessionManager's run_server should have executed.

# To ensure the task spawned by handle_request finishes and cleanup occurs:
# Give other tasks a chance to run. This is important for the finally block.
await anyio.sleep(0.01)

assert (
session_id not in manager._server_instances
), "Session ID should be removed from _server_instances after graceful exit"
assert (
not manager._server_instances
), "No sessions should be tracked after the only session exits gracefully"


@pytest.mark.anyio
async def test_stateful_session_cleanup_on_exception(running_manager):
manager, app = running_manager

mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash"))
app.run = mock_mcp_run

sent_messages = []

async def mock_send(message):
sent_messages.append(message)
# If an exception occurs, the transport might try to send an error response
# For this test, we mostly care that the session is established enough
# to get an ID
if message["type"] == "http.response.start" and message["status"] >= 500:
pass # Expected if TestException propagates that far up the transport

scope = {"type": "http", "method": "POST", "path": "/mcp", "headers": []}

async def mock_receive():
return {"type": "http.request", "body": b"", "more_body": False}

# It's possible handle_request itself might raise an error if the TestException
# isn't caught by the transport layer before propagating.
# The key is that the session manager's internal task for MCPServer.run
# encounters the exception.
try:
await manager.handle_request(scope, mock_receive, mock_send)
except TestException:
# This might be caught here if not handled by StreamableHTTPServerTransport's
# error handling
pass

session_id = None
for msg in sent_messages:
if msg["type"] == "http.response.start":
for header_name, header_value in msg.get("headers", []):
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
session_id = header_value.decode()
break
if session_id: # Break outer loop if session_id is found
break

assert session_id is not None, "Session ID not found in response headers"

mock_mcp_run.assert_called_once()

# Give other tasks a chance to run to ensure the finally block executes
await anyio.sleep(0.01)

assert (
session_id not in manager._server_instances
), "Session ID should be removed from _server_instances after an exception"
assert (
not manager._server_instances
), "No sessions should be tracked after the only session crashes"
Loading