-
Notifications
You must be signed in to change notification settings - Fork 1.7k
feat: implement MCP-Protocol-Version header requirement for HTTP transport #898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2363398
a4793f1
3e27010
0552898
970a034
16c94ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ | |
GetSessionIdCallback = Callable[[], str | None] | ||
|
||
MCP_SESSION_ID = "mcp-session-id" | ||
MCP_PROTOCOL_VERSION = "MCP-Protocol-Version" | ||
LAST_EVENT_ID = "last-event-id" | ||
CONTENT_TYPE = "content-type" | ||
ACCEPT = "Accept" | ||
|
@@ -100,19 +101,20 @@ def __init__( | |
self.sse_read_timeout = sse_read_timeout | ||
self.auth = auth | ||
self.session_id: str | None = None | ||
self.protocol_version: str | None = None | ||
self.request_headers = { | ||
ACCEPT: f"{JSON}, {SSE}", | ||
CONTENT_TYPE: JSON, | ||
**self.headers, | ||
} | ||
|
||
def _update_headers_with_session( | ||
self, base_headers: dict[str, str] | ||
) -> dict[str, str]: | ||
"""Update headers with session ID if available.""" | ||
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]: | ||
"""Update headers with session ID and protocol version if available.""" | ||
headers = base_headers.copy() | ||
if self.session_id: | ||
headers[MCP_SESSION_ID] = self.session_id | ||
if self.protocol_version: | ||
headers[MCP_PROTOCOL_VERSION] = self.protocol_version | ||
return headers | ||
|
||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool: | ||
|
@@ -139,19 +141,36 @@ def _maybe_extract_session_id_from_response( | |
self.session_id = new_session_id | ||
logger.info(f"Received session ID: {self.session_id}") | ||
|
||
def _maybe_extract_protocol_version_from_message( | ||
self, | ||
message: JSONRPCMessage, | ||
) -> None: | ||
"""Extract protocol version from initialization response message.""" | ||
if isinstance(message.root, JSONRPCResponse) and message.root.result: | ||
# Check if result has protocolVersion field | ||
result = message.root.result | ||
if "protocolVersion" in result: | ||
self.protocol_version = result["protocolVersion"] | ||
logger.info(f"Negotiated protocol version: {self.protocol_version}") | ||
|
||
Comment on lines
+144
to
+155
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inspecting the response in this way feels a bit wrong on the transport, but it looks like we're "peeking" into the response in various parts of the transport already it felt a lot cleaner than doing something like creating a callback, passing that back to the session and have the session jam the version back into the transport, but keen for thoughts here |
||
async def _handle_sse_event( | ||
self, | ||
sse: ServerSentEvent, | ||
read_stream_writer: StreamWriter, | ||
original_request_id: RequestId | None = None, | ||
resumption_callback: Callable[[str], Awaitable[None]] | None = None, | ||
is_initialization: bool = False, | ||
) -> bool: | ||
"""Handle an SSE event, returning True if the response is complete.""" | ||
if sse.event == "message": | ||
try: | ||
message = JSONRPCMessage.model_validate_json(sse.data) | ||
logger.debug(f"SSE message: {message}") | ||
|
||
# Extract protocol version from initialization response | ||
if is_initialization: | ||
self._maybe_extract_protocol_version_from_message(message) | ||
|
||
# If this is a response and we have original_request_id, replace it | ||
if original_request_id is not None and isinstance( | ||
message.root, JSONRPCResponse | JSONRPCError | ||
|
@@ -187,7 +206,7 @@ async def handle_get_stream( | |
if not self.session_id: | ||
return | ||
|
||
headers = self._update_headers_with_session(self.request_headers) | ||
headers = self._prepare_request_headers(self.request_headers) | ||
|
||
async with aconnect_sse( | ||
client, | ||
|
@@ -209,7 +228,7 @@ async def handle_get_stream( | |
|
||
async def _handle_resumption_request(self, ctx: RequestContext) -> None: | ||
"""Handle a resumption request using GET with SSE.""" | ||
headers = self._update_headers_with_session(ctx.headers) | ||
headers = self._prepare_request_headers(ctx.headers) | ||
if ctx.metadata and ctx.metadata.resumption_token: | ||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token | ||
else: | ||
|
@@ -244,7 +263,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: | |
|
||
async def _handle_post_request(self, ctx: RequestContext) -> None: | ||
"""Handle a POST request with response processing.""" | ||
headers = self._update_headers_with_session(ctx.headers) | ||
headers = self._prepare_request_headers(ctx.headers) | ||
message = ctx.session_message.message | ||
is_initialization = self._is_initialization_request(message) | ||
|
||
|
@@ -273,9 +292,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: | |
content_type = response.headers.get(CONTENT_TYPE, "").lower() | ||
|
||
if content_type.startswith(JSON): | ||
await self._handle_json_response(response, ctx.read_stream_writer) | ||
await self._handle_json_response( | ||
response, ctx.read_stream_writer, is_initialization | ||
) | ||
elif content_type.startswith(SSE): | ||
await self._handle_sse_response(response, ctx) | ||
await self._handle_sse_response(response, ctx, is_initialization) | ||
else: | ||
await self._handle_unexpected_content_type( | ||
content_type, | ||
|
@@ -286,19 +307,28 @@ async def _handle_json_response( | |
self, | ||
response: httpx.Response, | ||
read_stream_writer: StreamWriter, | ||
is_initialization: bool = False, | ||
) -> None: | ||
"""Handle JSON response from the server.""" | ||
try: | ||
content = await response.aread() | ||
message = JSONRPCMessage.model_validate_json(content) | ||
|
||
# Extract protocol version from initialization response | ||
if is_initialization: | ||
self._maybe_extract_protocol_version_from_message(message) | ||
|
||
session_message = SessionMessage(message) | ||
await read_stream_writer.send(session_message) | ||
except Exception as exc: | ||
logger.error(f"Error parsing JSON response: {exc}") | ||
await read_stream_writer.send(exc) | ||
|
||
async def _handle_sse_response( | ||
self, response: httpx.Response, ctx: RequestContext | ||
self, | ||
response: httpx.Response, | ||
ctx: RequestContext, | ||
is_initialization: bool = False, | ||
) -> None: | ||
"""Handle SSE response from the server.""" | ||
try: | ||
|
@@ -312,6 +342,7 @@ async def _handle_sse_response( | |
if ctx.metadata | ||
else None | ||
), | ||
is_initialization=is_initialization, | ||
) | ||
# If the SSE event indicates completion, like returning respose/error | ||
# break the loop | ||
|
@@ -408,7 +439,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: | |
return | ||
|
||
try: | ||
headers = self._update_headers_with_session(self.request_headers) | ||
headers = self._prepare_request_headers(self.request_headers) | ||
response = await client.delete(self.url, headers=headers) | ||
|
||
if response.status_code == 405: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
capitalization on these headers is inconsistent across the codebase, should we align on
kebab-case-everywhere
for headers?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.