Skip to content

feat: implement MCP-Protocol-Version header requirement for HTTP transport #898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
GetSessionIdCallback = Callable[[], str | None]

MCP_SESSION_ID = "mcp-session-id"
MCP_PROTOCOL_VERSION = "MCP-Protocol-Version"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

capitalization on these headers is inconsistent across the codebase, should we align on kebab-case-everywhere for headers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MCP_PROTOCOL_VERSION = "MCP-Protocol-Version"
MCP_PROTOCOL_VERSION = "mcp-protocol-version"

LAST_EVENT_ID = "last-event-id"
CONTENT_TYPE = "content-type"
ACCEPT = "Accept"
Expand Down Expand Up @@ -100,19 +101,20 @@ def __init__(
self.sse_read_timeout = sse_read_timeout
self.auth = auth
self.session_id: str | None = None
self.protocol_version: str | None = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
**self.headers,
}

def _update_headers_with_session(
self, base_headers: dict[str, str]
) -> dict[str, str]:
"""Update headers with session ID if available."""
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID and protocol version if available."""
headers = base_headers.copy()
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
if self.protocol_version:
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
return headers

def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
Expand All @@ -139,19 +141,36 @@ def _maybe_extract_session_id_from_response(
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")

def _maybe_extract_protocol_version_from_message(
self,
message: JSONRPCMessage,
) -> None:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result:
# Check if result has protocolVersion field
result = message.root.result
if "protocolVersion" in result:
self.protocol_version = result["protocolVersion"]
logger.info(f"Negotiated protocol version: {self.protocol_version}")

Comment on lines +144 to +155
Copy link
Author

@felixweinberger felixweinberger Jun 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inspecting the response in this way feels a bit wrong on the transport, but it looks like we're "peeking" into the response in various parts of the transport already

it felt a lot cleaner than doing something like creating a callback, passing that back to the session and have the session jam the version back into the transport, but keen for thoughts here

async def _handle_sse_event(
self,
sse: ServerSentEvent,
read_stream_writer: StreamWriter,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
is_initialization: bool = False,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"SSE message: {message}")

# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)

# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(
message.root, JSONRPCResponse | JSONRPCError
Expand Down Expand Up @@ -187,7 +206,7 @@ async def handle_get_stream(
if not self.session_id:
return

headers = self._update_headers_with_session(self.request_headers)
headers = self._prepare_request_headers(self.request_headers)

async with aconnect_sse(
client,
Expand All @@ -209,7 +228,7 @@ async def handle_get_stream(

async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers)
headers = self._prepare_request_headers(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
Expand Down Expand Up @@ -244,7 +263,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
headers = self._prepare_request_headers(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)

Expand Down Expand Up @@ -273,9 +292,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
content_type = response.headers.get(CONTENT_TYPE, "").lower()

if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer)
await self._handle_json_response(
response, ctx.read_stream_writer, is_initialization
)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx)
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type(
content_type,
Expand All @@ -286,19 +307,28 @@ async def _handle_json_response(
self,
response: httpx.Response,
read_stream_writer: StreamWriter,
is_initialization: bool = False,
) -> None:
"""Handle JSON response from the server."""
try:
content = await response.aread()
message = JSONRPCMessage.model_validate_json(content)

# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc:
logger.error(f"Error parsing JSON response: {exc}")
await read_stream_writer.send(exc)

async def _handle_sse_response(
self, response: httpx.Response, ctx: RequestContext
self,
response: httpx.Response,
ctx: RequestContext,
is_initialization: bool = False,
) -> None:
"""Handle SSE response from the server."""
try:
Expand All @@ -312,6 +342,7 @@ async def _handle_sse_response(
if ctx.metadata
else None
),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
Expand Down Expand Up @@ -408,7 +439,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
return

try:
headers = self._update_headers_with_session(self.request_headers)
headers = self._prepare_request_headers(self.request_headers)
response = await client.delete(self.url, headers=headers)

if response.status_code == 405:
Expand Down
42 changes: 38 additions & 4 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from starlette.types import Receive, Scope, Send

from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
INTERNAL_ERROR,
INVALID_PARAMS,
Expand All @@ -45,6 +46,7 @@

# Header names
MCP_SESSION_ID_HEADER = "mcp-session-id"
MCP_PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version"
LAST_EVENT_ID_HEADER = "last-event-id"

# Content types
Expand Down Expand Up @@ -383,8 +385,7 @@ async def _handle_post_request(
)
await response(scope, receive, send)
return
# For non-initialization requests, validate the session
elif not await self._validate_session(request, send):
elif not await self._validate_request_headers(request, send):
return

# For notifications and responses only, return 202 Accepted
Expand Down Expand Up @@ -559,8 +560,9 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
await response(request.scope, request.receive, send)
return

if not await self._validate_session(request, send):
if not await self._validate_request_headers(request, send):
return

# Handle resumability: check for Last-Event-ID header
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
await self._replay_events(last_event_id, request, send)
Expand Down Expand Up @@ -643,7 +645,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
await response(request.scope, request.receive, send)
return

if not await self._validate_session(request, send):
if not await self._validate_request_headers(request, send):
return

await self._terminate_session()
Expand Down Expand Up @@ -703,6 +705,13 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non
)
await response(request.scope, request.receive, send)

async def _validate_request_headers(self, request: Request, send: Send) -> bool:
if not await self._validate_session(request, send):
return False
if not await self._validate_protocol_version(request, send):
return False
return True

async def _validate_session(self, request: Request, send: Send) -> bool:
"""Validate the session ID in the request."""
if not self.mcp_session_id:
Expand Down Expand Up @@ -732,6 +741,31 @@ async def _validate_session(self, request: Request, send: Send) -> bool:

return True

async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
"""Validate the protocol version header in the request."""
# Get the protocol version from the request headers
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)

# If no protocol version provided, return error
if not protocol_version:
response = self._create_error_response(
"Bad Request: Missing MCP-Protocol-Version header",
HTTPStatus.BAD_REQUEST,
)
await response(request.scope, request.receive, send)
return False

# Check if the protocol version is supported
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
response = self._create_error_response(
f"Bad Request: Unsupported protocol version: {protocol_version}",
HTTPStatus.BAD_REQUEST,
)
await response(request.scope, request.receive, send)
return False

return True

async def _replay_events(
self, last_event_id: str, request: Request, send: Send
) -> None:
Expand Down
Loading
Loading