-
Notifications
You must be signed in to change notification settings - Fork 1.7k
StreamableHttp client transport #573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2b95598
3d790f8
27bc01e
3c4cf10
bce74b3
2011579
2cebf08
6c9c320
ede8cde
2a3bed8
0456b1b
97ca48d
f738cbf
92d4287
aa9f6e5
2fba7f3
45723ea
6b7a616
b1be691
201ec99
46ec72d
1902abb
da1df74
c2be5af
9b096dc
bbe79c2
a0a9c5b
a5ac2e0
2e615f3
110526d
7ffd5ba
029ec56
cae32e2
58745c7
1387929
bccff75
dd007d7
6482120
9a6da2e
b957fad
e087283
08247c4
0484dfb
ff70bd6
179fbc8
a979864
11b7dd9
684af52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
""" | ||
StreamableHTTP Client Transport Module | ||
|
||
This module implements the StreamableHTTP transport for MCP clients, | ||
providing support for HTTP POST requests with optional SSE streaming responses | ||
and session management. | ||
""" | ||
|
||
import logging | ||
from contextlib import asynccontextmanager | ||
from datetime import timedelta | ||
from typing import Any | ||
|
||
import anyio | ||
import httpx | ||
from httpx_sse import EventSource, aconnect_sse | ||
|
||
from mcp.types import ( | ||
ErrorData, | ||
JSONRPCError, | ||
JSONRPCMessage, | ||
JSONRPCNotification, | ||
JSONRPCRequest, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Header names | ||
MCP_SESSION_ID_HEADER = "mcp-session-id" | ||
LAST_EVENT_ID_HEADER = "last-event-id" | ||
|
||
# Content types | ||
CONTENT_TYPE_JSON = "application/json" | ||
CONTENT_TYPE_SSE = "text/event-stream" | ||
|
||
|
||
@asynccontextmanager | ||
async def streamablehttp_client( | ||
url: str, | ||
headers: dict[str, Any] | None = None, | ||
timeout: timedelta = timedelta(seconds=30), | ||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), | ||
Comment on lines
+41
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❤️ |
||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should have a return value. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe i missed this, but all the other transport implementation don't return a terminate_callback. Should we unify this so that the interface is the same? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. streamable http is quite unique here as we have a specific DELETE request to close the session, none of the other transports have it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding return value in #595 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for termination, unlike other transports, streamable http does not close session when client disconnects, it requires explicit termination by DELETE request. Session has no idea what delete request is, hence needed to have a callback to transport. This has it's drawbacks, like users need to know that they need to terminate the session. Alternative can be that we pass a parameter terminate_session_on_exit which defaults to true. In this way if someone wants to have benefits of resuming a long running session later, they can, they just need to set a parameter to False, something like
|
||
""" | ||
Client transport for StreamableHTTP. | ||
|
||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new | ||
event before disconnecting. All other HTTP operations are controlled by `timeout`. | ||
|
||
Yields: | ||
Tuple of (read_stream, write_stream, terminate_callback) | ||
""" | ||
|
||
read_stream_writer, read_stream = anyio.create_memory_object_stream[ | ||
JSONRPCMessage | Exception | ||
](0) | ||
write_stream, write_stream_reader = anyio.create_memory_object_stream[ | ||
JSONRPCMessage | ||
](0) | ||
|
||
async def get_stream(): | ||
""" | ||
Optional GET stream for server-initiated messages | ||
""" | ||
nonlocal session_id | ||
try: | ||
# Only attempt GET if we have a session ID | ||
if not session_id: | ||
return | ||
|
||
get_headers = request_headers.copy() | ||
get_headers[MCP_SESSION_ID_HEADER] = session_id | ||
|
||
async with aconnect_sse( | ||
client, | ||
"GET", | ||
url, | ||
headers=get_headers, | ||
timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), | ||
) as event_source: | ||
event_source.response.raise_for_status() | ||
logger.debug("GET SSE connection established") | ||
|
||
async for sse in event_source.aiter_sse(): | ||
if sse.event == "message": | ||
try: | ||
message = JSONRPCMessage.model_validate_json(sse.data) | ||
logger.debug(f"GET message: {message}") | ||
await read_stream_writer.send(message) | ||
except Exception as exc: | ||
logger.error(f"Error parsing GET message: {exc}") | ||
await read_stream_writer.send(exc) | ||
else: | ||
logger.warning(f"Unknown SSE event from GET: {sse.event}") | ||
except Exception as exc: | ||
# GET stream is optional, so don't propagate errors | ||
logger.debug(f"GET stream error (non-fatal): {exc}") | ||
|
||
async def post_writer(client: httpx.AsyncClient): | ||
nonlocal session_id | ||
try: | ||
async with write_stream_reader: | ||
async for message in write_stream_reader: | ||
# Add session ID to headers if we have one | ||
post_headers = request_headers.copy() | ||
if session_id: | ||
post_headers[MCP_SESSION_ID_HEADER] = session_id | ||
|
||
logger.debug(f"Sending client message: {message}") | ||
|
||
# Handle initial initialization request | ||
is_initialization = ( | ||
isinstance(message.root, JSONRPCRequest) | ||
and message.root.method == "initialize" | ||
) | ||
if ( | ||
isinstance(message.root, JSONRPCNotification) | ||
and message.root.method == "notifications/initialized" | ||
): | ||
tg.start_soon(get_stream) | ||
|
||
async with client.stream( | ||
"POST", | ||
url, | ||
json=message.model_dump( | ||
by_alias=True, mode="json", exclude_none=True | ||
), | ||
headers=post_headers, | ||
) as response: | ||
if response.status_code == 202: | ||
logger.debug("Received 202 Accepted") | ||
continue | ||
# Check for 404 (session expired/invalid) | ||
if response.status_code == 404: | ||
if isinstance(message.root, JSONRPCRequest): | ||
jsonrpc_error = JSONRPCError( | ||
jsonrpc="2.0", | ||
id=message.root.id, | ||
error=ErrorData( | ||
code=32600, | ||
message="Session terminated", | ||
), | ||
) | ||
await read_stream_writer.send( | ||
JSONRPCMessage(jsonrpc_error) | ||
) | ||
continue | ||
response.raise_for_status() | ||
|
||
# Extract session ID from response headers | ||
if is_initialization: | ||
new_session_id = response.headers.get(MCP_SESSION_ID_HEADER) | ||
if new_session_id: | ||
session_id = new_session_id | ||
logger.info(f"Received session ID: {session_id}") | ||
|
||
# Handle different response types | ||
content_type = response.headers.get("content-type", "").lower() | ||
|
||
if content_type.startswith(CONTENT_TYPE_JSON): | ||
try: | ||
content = await response.aread() | ||
json_message = JSONRPCMessage.model_validate_json( | ||
content | ||
) | ||
await read_stream_writer.send(json_message) | ||
except Exception as exc: | ||
logger.error(f"Error parsing JSON response: {exc}") | ||
await read_stream_writer.send(exc) | ||
|
||
elif content_type.startswith(CONTENT_TYPE_SSE): | ||
# Parse SSE events from the response | ||
try: | ||
event_source = EventSource(response) | ||
async for sse in event_source.aiter_sse(): | ||
if sse.event == "message": | ||
try: | ||
await read_stream_writer.send( | ||
JSONRPCMessage.model_validate_json( | ||
sse.data | ||
) | ||
) | ||
except Exception as exc: | ||
logger.exception("Error parsing message") | ||
await read_stream_writer.send(exc) | ||
else: | ||
logger.warning(f"Unknown event: {sse.event}") | ||
|
||
except Exception as e: | ||
logger.exception("Error reading SSE stream:") | ||
await read_stream_writer.send(e) | ||
Comment on lines
+160
to
+191
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would feel more intuitive for me if these branches wouldn't be inlined but separate handler functions for non-streamed vs streamed responses. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, sorry, refactoring on top of the stack |
||
|
||
else: | ||
# For 202 Accepted with no body | ||
if response.status_code == 202: | ||
logger.debug("Received 202 Accepted") | ||
continue | ||
|
||
error_msg = f"Unexpected content type: {content_type}" | ||
logger.error(error_msg) | ||
await read_stream_writer.send(ValueError(error_msg)) | ||
|
||
except Exception as exc: | ||
logger.error(f"Error in post_writer: {exc}") | ||
finally: | ||
await read_stream_writer.aclose() | ||
await write_stream.aclose() | ||
|
||
async def terminate_session(): | ||
""" | ||
Terminate the session by sending a DELETE request. | ||
""" | ||
nonlocal session_id | ||
if not session_id: | ||
return # No session to terminate | ||
|
||
try: | ||
delete_headers = request_headers.copy() | ||
delete_headers[MCP_SESSION_ID_HEADER] = session_id | ||
|
||
response = await client.delete( | ||
url, | ||
headers=delete_headers, | ||
) | ||
|
||
if response.status_code == 405: | ||
# Server doesn't allow client-initiated termination | ||
logger.debug("Server does not allow session termination") | ||
elif response.status_code != 200: | ||
logger.warning(f"Session termination failed: {response.status_code}") | ||
except Exception as exc: | ||
logger.warning(f"Session termination failed: {exc}") | ||
|
||
async with anyio.create_task_group() as tg: | ||
try: | ||
logger.info(f"Connecting to StreamableHTTP endpoint: {url}") | ||
# Set up headers with required Accept header | ||
request_headers = { | ||
"Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}", | ||
"Content-Type": CONTENT_TYPE_JSON, | ||
**(headers or {}), | ||
} | ||
# Track session ID if provided by server | ||
session_id: str | None = None | ||
|
||
async with httpx.AsyncClient( | ||
headers=request_headers, | ||
timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), | ||
follow_redirects=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are constructing AsyncClient everywhere in the codebase but with various different options. I think it makes sense if we just have factory function that creates an async cleint with correct default values that we want everywhere, like follow_redirects. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added to follow ups |
||
) as client: | ||
tg.start_soon(post_writer, client) | ||
try: | ||
yield read_stream, write_stream, terminate_session | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure i understand why we want terminate session as a callback. We are using a context manager, shouldn't we always be able to just terminate_session if it still exists after we yielded? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need this so that we can have a method on a client to delete/terminate the mcp-session |
||
finally: | ||
tg.cancel_scope.cancel() | ||
finally: | ||
await read_stream_writer.aclose() | ||
await write_stream.aclose() |
Uh oh!
There was an error while loading. Please reload this page.