Skip to content

Fix streamable http sampling #693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mcp/cli/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_claude_config_path() -> Path | None:
return path
return None


def get_uv_path() -> str:
"""Get the full path to the uv executable."""
uv_path = shutil.which("uv")
Expand All @@ -42,6 +43,7 @@ def get_uv_path() -> str:
return "uv" # Fall back to just "uv" if not found
return uv_path


def update_claude_config(
file_spec: str,
server_name: str,
Expand Down
24 changes: 19 additions & 5 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse

Expand Down Expand Up @@ -239,7 +240,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
break

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
Expand Down Expand Up @@ -300,7 +301,7 @@ async def _handle_sse_response(
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse():
await self._handle_sse_event(
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(
Expand All @@ -309,6 +310,10 @@ async def _handle_sse_response(
else None
),
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why break here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_handle_sse_event returns true if there is a response/error. this means we are done and need to stop the stream

except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
Expand Down Expand Up @@ -344,6 +349,7 @@ async def post_writer(
read_stream_writer: StreamWriter,
write_stream: MemoryObjectSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
) -> None:
"""Handle writing requests to the server."""
try:
Expand Down Expand Up @@ -375,10 +381,17 @@ async def post_writer(
sse_read_timeout=self.sse_read_timeout,
)

if is_resumption:
await self._handle_resumption_request(ctx)
async def handle_request_async():
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)

# If this is a request, start a new task to handle it
if isinstance(message.root, JSONRPCRequest):
tg.start_soon(handle_request_async)
else:
await self._handle_post_request(ctx)
await handle_request_async()

except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
Expand Down Expand Up @@ -466,6 +479,7 @@ def start_get_stream() -> None:
read_stream_writer,
write_stream,
start_get_stream,
tg,
)

try:
Expand Down
10 changes: 7 additions & 3 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:

import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import (
BaseSession,
RequestResponder,
Expand Down Expand Up @@ -230,10 +230,11 @@ async def create_message(
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
related_request_id: types.RequestId | None = None,
) -> types.CreateMessageResult:
"""Send a sampling/create_message request."""
return await self.send_request(
types.ServerRequest(
request=types.ServerRequest(
types.CreateMessageRequest(
method="sampling/createMessage",
params=types.CreateMessageRequestParams(
Expand All @@ -248,7 +249,10 @@ async def create_message(
),
)
),
types.CreateMessageResult,
result_type=types.CreateMessageResult,
metadata=ServerMessageMetadata(
related_request_id=related_request_id,
),
)

async def list_roots(self) -> types.ListRootsResult:
Expand Down
21 changes: 14 additions & 7 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
Expand Down Expand Up @@ -849,9 +848,15 @@ async def message_router():
# Determine which request stream(s) should receive this message
message = session_message.message
target_request_id = None
if isinstance(
message.root, JSONRPCNotification | JSONRPCRequest
):
# Check if this is a response
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
response_id = str(message.root.id)
# If this response is for an existing request stream,
# send it there
if response_id in self._request_streams:
target_request_id = response_id

else:
# Extract related_request_id from meta if it exists
if (
session_message.metadata is not None
Expand All @@ -865,10 +870,12 @@ async def message_router():
target_request_id = str(
session_message.metadata.related_request_id
)
else:
target_request_id = str(message.root.id)

request_stream_id = target_request_id or GET_STREAM_KEY
request_stream_id = (
target_request_id
if target_request_id is not None
else GET_STREAM_KEY
)

# Store the event if we have an event store,
# regardless of whether a client is connected
Expand Down
1 change: 0 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ async def send_request(
Do not use this method to emit notifications! Use send_notification()
instead.
"""

request_id = self._request_id
self._request_id = request_id + 1

Expand Down
6 changes: 3 additions & 3 deletions tests/client/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_absolute_uv_path(mock_config_path: Path):
"""Test that the absolute path to uv is used when available."""
# Mock the shutil.which function to return a fake path
mock_uv_path = "/usr/local/bin/uv"

with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path):
# Setup
server_name = "test_server"
Expand All @@ -71,5 +71,5 @@ def test_absolute_uv_path(mock_config_path: Path):
# Verify the command is the absolute path
server_config = config["mcpServers"][server_name]
command = server_config["command"]
assert command == mock_uv_path

assert command == mock_uv_path
111 changes: 107 additions & 4 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import socket
import time
from collections.abc import Generator
from typing import Any

import anyio
import httpx
Expand All @@ -33,6 +34,7 @@
StreamId,
)
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import (
ClientMessageMetadata,
Expand Down Expand Up @@ -139,6 +141,11 @@ async def handle_list_tools() -> list[Tool]:
description="A long-running tool that sends periodic notifications",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="test_sampling_tool",
description="A tool that triggers server-side sampling",
inputSchema={"type": "object", "properties": {}},
),
]

@self.call_tool()
Expand Down Expand Up @@ -174,6 +181,34 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:

return [TextContent(type="text", text="Completed!")]

elif name == "test_sampling_tool":
# Test sampling by requesting the client to sample a message
sampling_result = await ctx.session.create_message(
messages=[
types.SamplingMessage(
role="user",
content=types.TextContent(
type="text", text="Server needs client sampling"
),
)
],
max_tokens=100,
related_request_id=ctx.request_id,
)

# Return the sampling result in the tool response
response = (
sampling_result.content.text
if sampling_result.content.type == "text"
else None
)
return [
TextContent(
type="text",
text=f"Response from sampling: {response}",
)
]

return [TextContent(type="text", text=f"Called {name}")]


Expand Down Expand Up @@ -754,7 +789,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
"""Test client tool invocation."""
# First list tools
tools = await initialized_client_session.list_tools()
assert len(tools.tools) == 3
assert len(tools.tools) == 4
assert tools.tools[0].name == "test_tool"

# Call the tool
Expand Down Expand Up @@ -795,7 +830,7 @@ async def test_streamablehttp_client_session_persistence(

# Make multiple requests to verify session persistence
tools = await session.list_tools()
assert len(tools.tools) == 3
assert len(tools.tools) == 4

# Read a resource
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
Expand Down Expand Up @@ -826,7 +861,7 @@ async def test_streamablehttp_client_json_response(

# Check tool listing
tools = await session.list_tools()
assert len(tools.tools) == 3
assert len(tools.tools) == 4

# Call a tool and verify JSON response handling
result = await session.call_tool("test_tool", {})
Expand Down Expand Up @@ -905,7 +940,7 @@ async def test_streamablehttp_client_session_termination(

# Make a request to confirm session is working
tools = await session.list_tools()
assert len(tools.tools) == 3
assert len(tools.tools) == 4

headers = {}
if captured_session_id:
Expand Down Expand Up @@ -1054,3 +1089,71 @@ async def run_tool():
assert not any(
n in captured_notifications_pre for n in captured_notifications
)


@pytest.mark.anyio
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
"""Test server-initiated sampling request through streamable HTTP transport."""
print("Testing server sampling...")
# Variable to track if sampling callback was invoked
sampling_callback_invoked = False
captured_message_params = None

# Define sampling callback that returns a mock response
async def sampling_callback(
context: RequestContext[ClientSession, Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult:
nonlocal sampling_callback_invoked, captured_message_params
sampling_callback_invoked = True
captured_message_params = params
message_received = (
params.messages[0].content.text
if params.messages[0].content.type == "text"
else None
)

return types.CreateMessageResult(
role="assistant",
content=types.TextContent(
type="text",
text=f"Received message from server: {message_received}",
),
model="test-model",
stopReason="endTurn",
)

# Create client with sampling callback
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(
read_stream,
write_stream,
sampling_callback=sampling_callback,
) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)

# Call the tool that triggers server-side sampling
tool_result = await session.call_tool("test_sampling_tool", {})

# Verify the tool result contains the expected content
assert len(tool_result.content) == 1
assert tool_result.content[0].type == "text"
assert (
"Response from sampling: Received message from server"
in tool_result.content[0].text
)

# Verify sampling callback was invoked
assert sampling_callback_invoked
assert captured_message_params is not None
assert len(captured_message_params.messages) == 1
assert (
captured_message_params.messages[0].content.text
== "Server needs client sampling"
)
Loading