Skip to content

Commit 397dcdf

Browse files
feat: implement MCP-Protocol-Version header validation in server
- Add MCP_PROTOCOL_VERSION_HEADER constant - Add _validate_protocol_version method to check header presence and validity - Validate protocol version for all non-initialization requests (POST, GET, DELETE) - Return 400 Bad Request for missing or invalid protocol versions - Update tests to include MCP-Protocol-Version header in requests - Fix test_streamablehttp_client_resumption to pass protocol version when resuming This implements the server-side validation required by the spec change that mandates clients include the negotiated protocol version in all subsequent HTTP requests after initialization. Github-Issue: #548
1 parent e000c7b commit 397dcdf

File tree

2 files changed

+98
-11
lines changed

2 files changed

+98
-11
lines changed

src/mcp/server/streamable_http.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from starlette.types import Receive, Scope, Send
2626

2727
from mcp.shared.message import ServerMessageMetadata, SessionMessage
28+
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
2829
from mcp.types import (
2930
INTERNAL_ERROR,
3031
INVALID_PARAMS,
@@ -45,6 +46,7 @@
4546

4647
# Header names
4748
MCP_SESSION_ID_HEADER = "mcp-session-id"
49+
MCP_PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version"
4850
LAST_EVENT_ID_HEADER = "last-event-id"
4951

5052
# Content types
@@ -383,9 +385,10 @@ async def _handle_post_request(
383385
)
384386
await response(scope, receive, send)
385387
return
386-
# For non-initialization requests, validate the session
387388
elif not await self._validate_session(request, send):
388389
return
390+
elif not await self._validate_protocol_version(request, send):
391+
return
389392

390393
# For notifications and responses only, return 202 Accepted
391394
if not isinstance(message.root, JSONRPCRequest):
@@ -561,6 +564,9 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
561564

562565
if not await self._validate_session(request, send):
563566
return
567+
if not await self._validate_protocol_version(request, send):
568+
return
569+
564570
# Handle resumability: check for Last-Event-ID header
565571
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
566572
await self._replay_events(last_event_id, request, send)
@@ -645,6 +651,8 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
645651

646652
if not await self._validate_session(request, send):
647653
return
654+
if not await self._validate_protocol_version(request, send):
655+
return
648656

649657
await self._terminate_session()
650658

@@ -732,6 +740,31 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
732740

733741
return True
734742

743+
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
744+
"""Validate the protocol version header in the request."""
745+
# Get the protocol version from the request headers
746+
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
747+
748+
# If no protocol version provided, return error
749+
if not protocol_version:
750+
response = self._create_error_response(
751+
"Bad Request: Missing MCP-Protocol-Version header",
752+
HTTPStatus.BAD_REQUEST,
753+
)
754+
await response(request.scope, request.receive, send)
755+
return False
756+
757+
# Check if the protocol version is supported
758+
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
759+
response = self._create_error_response(
760+
f"Bad Request: Unsupported protocol version: {protocol_version}",
761+
HTTPStatus.BAD_REQUEST,
762+
)
763+
await response(request.scope, request.receive, send)
764+
return False
765+
766+
return True
767+
735768
async def _replay_events(
736769
self, last_event_id: str, request: Request, send: Send
737770
) -> None:

tests/shared/test_streamable_http.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from mcp.client.streamable_http import streamablehttp_client
2727
from mcp.server import Server
2828
from mcp.server.streamable_http import (
29+
MCP_PROTOCOL_VERSION_HEADER,
2930
MCP_SESSION_ID_HEADER,
3031
SESSION_ID_PATTERN,
3132
EventCallback,
@@ -576,11 +577,24 @@ def test_session_termination(basic_server, basic_server_url):
576577
)
577578
assert response.status_code == 200
578579

580+
# Extract negotiated protocol version from SSE response
581+
init_data = None
582+
assert response.headers.get("Content-Type") == "text/event-stream"
583+
for line in response.text.splitlines():
584+
if line.startswith("data: "):
585+
init_data = json.loads(line[6:])
586+
break
587+
assert init_data is not None
588+
negotiated_version = init_data["result"]["protocolVersion"]
589+
579590
# Now terminate the session
580591
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
581592
response = requests.delete(
582593
f"{basic_server_url}/mcp",
583-
headers={MCP_SESSION_ID_HEADER: session_id},
594+
headers={
595+
MCP_SESSION_ID_HEADER: session_id,
596+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
597+
},
584598
)
585599
assert response.status_code == 200
586600

@@ -611,16 +625,27 @@ def test_response(basic_server, basic_server_url):
611625
)
612626
assert response.status_code == 200
613627

614-
# Now terminate the session
628+
# Extract negotiated protocol version from SSE response
629+
init_data = None
630+
assert response.headers.get("Content-Type") == "text/event-stream"
631+
for line in response.text.splitlines():
632+
if line.startswith("data: "):
633+
init_data = json.loads(line[6:])
634+
break
635+
assert init_data is not None
636+
negotiated_version = init_data["result"]["protocolVersion"]
637+
638+
# Now get the session ID
615639
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
616640

617-
# Try to use the terminated session
641+
# Try to use the session with proper headers
618642
tools_response = requests.post(
619643
mcp_url,
620644
headers={
621645
"Accept": "application/json, text/event-stream",
622646
"Content-Type": "application/json",
623647
MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier
648+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
624649
},
625650
json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"},
626651
stream=True,
@@ -662,12 +687,23 @@ def test_get_sse_stream(basic_server, basic_server_url):
662687
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
663688
assert session_id is not None
664689

690+
# Extract negotiated protocol version from SSE response
691+
init_data = None
692+
assert init_response.headers.get("Content-Type") == "text/event-stream"
693+
for line in init_response.text.splitlines():
694+
if line.startswith("data: "):
695+
init_data = json.loads(line[6:])
696+
break
697+
assert init_data is not None
698+
negotiated_version = init_data["result"]["protocolVersion"]
699+
665700
# Now attempt to establish an SSE stream via GET
666701
get_response = requests.get(
667702
mcp_url,
668703
headers={
669704
"Accept": "text/event-stream",
670705
MCP_SESSION_ID_HEADER: session_id,
706+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
671707
},
672708
stream=True,
673709
)
@@ -682,6 +718,7 @@ def test_get_sse_stream(basic_server, basic_server_url):
682718
headers={
683719
"Accept": "text/event-stream",
684720
MCP_SESSION_ID_HEADER: session_id,
721+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
685722
},
686723
stream=True,
687724
)
@@ -710,11 +747,22 @@ def test_get_validation(basic_server, basic_server_url):
710747
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
711748
assert session_id is not None
712749

750+
# Extract negotiated protocol version from SSE response
751+
init_data = None
752+
assert init_response.headers.get("Content-Type") == "text/event-stream"
753+
for line in init_response.text.splitlines():
754+
if line.startswith("data: "):
755+
init_data = json.loads(line[6:])
756+
break
757+
assert init_data is not None
758+
negotiated_version = init_data["result"]["protocolVersion"]
759+
713760
# Test without Accept header
714761
response = requests.get(
715762
mcp_url,
716763
headers={
717764
MCP_SESSION_ID_HEADER: session_id,
765+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
718766
},
719767
stream=True,
720768
)
@@ -727,6 +775,7 @@ def test_get_validation(basic_server, basic_server_url):
727775
headers={
728776
"Accept": "application/json",
729777
MCP_SESSION_ID_HEADER: session_id,
778+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
730779
},
731780
)
732781
assert response.status_code == 406
@@ -1038,6 +1087,7 @@ async def test_streamablehttp_client_resumption(event_server):
10381087
captured_resumption_token = None
10391088
captured_notifications = []
10401089
tool_started = False
1090+
captured_protocol_version = None
10411091

10421092
async def message_handler(
10431093
message: RequestResponder[types.ServerRequest, types.ClientResult]
@@ -1070,6 +1120,8 @@ async def on_resumption_token_update(token: str) -> None:
10701120
assert isinstance(result, InitializeResult)
10711121
captured_session_id = get_session_id()
10721122
assert captured_session_id is not None
1123+
# Capture the negotiated protocol version
1124+
captured_protocol_version = result.protocolVersion
10731125

10741126
# Start a long-running tool in a task
10751127
async with anyio.create_task_group() as tg:
@@ -1104,10 +1156,12 @@ async def run_tool():
11041156
captured_notifications_pre = captured_notifications.copy()
11051157
captured_notifications = []
11061158

1107-
# Now resume the session with the same mcp-session-id
1159+
# Now resume the session with the same mcp-session-id and protocol version
11081160
headers = {}
11091161
if captured_session_id:
11101162
headers[MCP_SESSION_ID_HEADER] = captured_session_id
1163+
if captured_protocol_version:
1164+
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
11111165

11121166
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
11131167
read_stream,
@@ -1481,7 +1535,7 @@ async def test_server_validates_protocol_version_header(basic_server, basic_serv
14811535
)
14821536
assert response.status_code == 400
14831537
assert (
1484-
"MCP-Protocol-Version" in response.text
1538+
MCP_PROTOCOL_VERSION_HEADER in response.text
14851539
or "protocol version" in response.text.lower()
14861540
)
14871541

@@ -1492,13 +1546,13 @@ async def test_server_validates_protocol_version_header(basic_server, basic_serv
14921546
"Accept": "application/json, text/event-stream",
14931547
"Content-Type": "application/json",
14941548
MCP_SESSION_ID_HEADER: session_id,
1495-
"MCP-Protocol-Version": "invalid-version",
1549+
MCP_PROTOCOL_VERSION_HEADER: "invalid-version",
14961550
},
14971551
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"},
14981552
)
14991553
assert response.status_code == 400
15001554
assert (
1501-
"MCP-Protocol-Version" in response.text
1555+
MCP_PROTOCOL_VERSION_HEADER in response.text
15021556
or "protocol version" in response.text.lower()
15031557
)
15041558

@@ -1509,13 +1563,13 @@ async def test_server_validates_protocol_version_header(basic_server, basic_serv
15091563
"Accept": "application/json, text/event-stream",
15101564
"Content-Type": "application/json",
15111565
MCP_SESSION_ID_HEADER: session_id,
1512-
"MCP-Protocol-Version": "1999-01-01", # Very old unsupported version
1566+
MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version
15131567
},
15141568
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"},
15151569
)
15161570
assert response.status_code == 400
15171571
assert (
1518-
"MCP-Protocol-Version" in response.text
1572+
MCP_PROTOCOL_VERSION_HEADER in response.text
15191573
or "protocol version" in response.text.lower()
15201574
)
15211575

@@ -1536,7 +1590,7 @@ async def test_server_validates_protocol_version_header(basic_server, basic_serv
15361590
"Accept": "application/json, text/event-stream",
15371591
"Content-Type": "application/json",
15381592
MCP_SESSION_ID_HEADER: session_id,
1539-
"MCP-Protocol-Version": negotiated_version,
1593+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
15401594
},
15411595
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"},
15421596
)

0 commit comments

Comments
 (0)