Description
I'm debugging the issue described in jlowin/fastmcp#393.
Essentially, attempting to use a callback for sampling in conjunction with a Streamable HTTP connection results in an indefinite hang.
I've successfully traced the issue to a specific moment in the low-level interaction. The server correctly sends the sampling request to the client, and the client successfully processes it via its sampling callback and sends a response to the server. However, the server never receives that response, with the result that it blocks forever waiting.
(Note that this also happens if the client doesn't even support sampling and the default sampling callback responds with an appropriate error -- the ultimate effect being that anytime the server requests a sample it immediately locks itself up)
Unfortunately I'm not knowledgeable enough on how the server streams work to suggest a solution, but I have created a complete unit test that I think demonstrates the issue. It is based on similar low-level code in https://github.com/modelcontextprotocol/python-sdk/blob/main/tests/shared/test_streamable_http.py and contains a single unit test that can be run with pytest
. If you add debug prints to the session classes you can see the server and client writing messages, with the last one being the client's attempt to respond to the server.
Apologies for the length of the test but wanted to be sure everything was covered and self-contained.
"""
Tests for the StreamableHTTP server and client transport.
Contains tests for both server and client sides of the StreamableHTTP transport.
"""
import multiprocessing
import socket
import time
from collections.abc import Generator
import httpx
import pytest
import uvicorn
from starlette.applications import Starlette
from starlette.routing import Mount
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.server import Server
from mcp.server.streamable_http import (
EventStore,
)
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.types import (
TextContent,
Tool,
)
# Test constants
SERVER_NAME = "test_streamable_http_server"
TEST_SESSION_ID = "test-session-id-12345"
INIT_REQUEST = {
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": "2025-03-26",
"capabilities": {},
},
"id": "init-1",
}
class ServerTest(Server):
def __init__(self):
super().__init__(SERVER_NAME)
@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="sample_tool",
description="A tool that uses sampling",
inputSchema={"type": "object", "properties": {}},
),
]
@self.call_tool()
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
ctx = self.request_context
response = await ctx.session.create_message(
messages=[
types.SamplingMessage(
role="user",
content=types.TextContent(type="text", text="Hello, world!"),
),
],
max_tokens=100,
)
return [TextContent(type="text", text=f"{response}")]
def create_app(
is_json_response_enabled=False, event_store: EventStore | None = None
) -> Starlette:
"""Create a Starlette application for testing using the session manager.
Args:
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
event_store: Optional event store for testing resumability.
"""
# Create server instance
server = ServerTest()
# Create the session manager
session_manager = StreamableHTTPSessionManager(
app=server,
event_store=event_store,
json_response=is_json_response_enabled,
)
# Create an ASGI application that uses the session manager
app = Starlette(
debug=True,
routes=[
Mount("/mcp", app=session_manager.handle_request),
],
lifespan=lambda app: session_manager.run(),
)
return app
def run_server(
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
) -> None:
"""Run the test server.
Args:
port: Port to listen on.
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
event_store: Optional event store for testing resumability.
"""
app = create_app()
# Configure server
config = uvicorn.Config(
app=app,
host="127.0.0.1",
port=port,
log_level="info",
limit_concurrency=10,
timeout_keep_alive=5,
access_log=False,
)
# Start the server
server = uvicorn.Server(config=config)
# This is important to catch exceptions and prevent test hangs
try:
server.run()
except Exception:
import traceback
traceback.print_exc()
# Test fixtures - using same approach as SSE tests
@pytest.fixture
def basic_server_port() -> int:
"""Find an available port for the basic server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
@pytest.fixture
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start a basic server."""
proc = multiprocessing.Process(
target=run_server, kwargs={"port": basic_server_port}, daemon=True
)
proc.start()
# Wait for server to be running
max_attempts = 20
attempt = 0
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", basic_server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
yield
# Clean up
proc.kill()
proc.join(timeout=2)
@pytest.fixture
def basic_server_url(basic_server_port: int) -> str:
"""Get the URL for the basic test server."""
return f"http://127.0.0.1:{basic_server_port}"
# Client-specific fixtures
@pytest.fixture
async def http_client(basic_server, basic_server_url):
"""Create test client matching the SSE test pattern."""
async with httpx.AsyncClient(base_url=basic_server_url) as client:
yield client
async def sample_callback(*args, **kwargs):
print("IN SAMPLE CALLBACK")
result = types.CreateMessageResult(
role="user",
content=types.TextContent(type="text", text="Hello, world!"),
model="hi",
)
return result
@pytest.fixture
async def initialized_client_session(basic_server, basic_server_url):
"""Create initialized StreamableHTTP client session."""
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(
read_stream,
write_stream,
sampling_callback=sample_callback,
) as session:
await session.initialize()
yield session
@pytest.mark.anyio
async def test_callback(initialized_client_session):
"""Test client tool invocation."""
# First list tools
tools = await initialized_client_session.list_tools()
assert len(tools.tools) == 1
assert tools.tools[0].name == "sample_tool"
# Call the tool
result = await initialized_client_session.call_tool("sample_tool", {})
assert len(result.content) == 1
assert result.content[0].type == "text"
assert "Hello, world!" in result.content[0].text