|
65 | 65 | }
|
66 | 66 |
|
67 | 67 |
|
| 68 | +# Helper functions |
| 69 | +def extract_protocol_version_from_sse(response: requests.Response) -> str: |
| 70 | + """Extract the negotiated protocol version from an SSE initialization response.""" |
| 71 | + assert response.headers.get("Content-Type") == "text/event-stream" |
| 72 | + for line in response.text.splitlines(): |
| 73 | + if line.startswith("data: "): |
| 74 | + init_data = json.loads(line[6:]) |
| 75 | + return init_data["result"]["protocolVersion"] |
| 76 | + raise ValueError("Could not extract protocol version from SSE response") |
| 77 | + |
| 78 | + |
68 | 79 | # Simple in-memory event store for testing
|
69 | 80 | class SimpleEventStore(EventStore):
|
70 | 81 | """Simple in-memory event store for testing."""
|
@@ -578,14 +589,7 @@ def test_session_termination(basic_server, basic_server_url):
|
578 | 589 | assert response.status_code == 200
|
579 | 590 |
|
580 | 591 | # Extract negotiated protocol version from SSE response
|
581 |
| - init_data = None |
582 |
| - assert response.headers.get("Content-Type") == "text/event-stream" |
583 |
| - for line in response.text.splitlines(): |
584 |
| - if line.startswith("data: "): |
585 |
| - init_data = json.loads(line[6:]) |
586 |
| - break |
587 |
| - assert init_data is not None |
588 |
| - negotiated_version = init_data["result"]["protocolVersion"] |
| 592 | + negotiated_version = extract_protocol_version_from_sse(response) |
589 | 593 |
|
590 | 594 | # Now terminate the session
|
591 | 595 | session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
@@ -626,14 +630,7 @@ def test_response(basic_server, basic_server_url):
|
626 | 630 | assert response.status_code == 200
|
627 | 631 |
|
628 | 632 | # Extract negotiated protocol version from SSE response
|
629 |
| - init_data = None |
630 |
| - assert response.headers.get("Content-Type") == "text/event-stream" |
631 |
| - for line in response.text.splitlines(): |
632 |
| - if line.startswith("data: "): |
633 |
| - init_data = json.loads(line[6:]) |
634 |
| - break |
635 |
| - assert init_data is not None |
636 |
| - negotiated_version = init_data["result"]["protocolVersion"] |
| 633 | + negotiated_version = extract_protocol_version_from_sse(response) |
637 | 634 |
|
638 | 635 | # Now get the session ID
|
639 | 636 | session_id = response.headers.get(MCP_SESSION_ID_HEADER)
|
@@ -1574,15 +1571,7 @@ async def test_server_validates_protocol_version_header(basic_server, basic_serv
|
1574 | 1571 | )
|
1575 | 1572 |
|
1576 | 1573 | # Test request with valid protocol version (should succeed)
|
1577 |
| - init_data = None |
1578 |
| - assert init_response.headers.get("Content-Type") == "text/event-stream" |
1579 |
| - for line in init_response.text.splitlines(): |
1580 |
| - if line.startswith("data: "): |
1581 |
| - init_data = json.loads(line[6:]) |
1582 |
| - break |
1583 |
| - |
1584 |
| - assert init_data is not None |
1585 |
| - negotiated_version = init_data["result"]["protocolVersion"] |
| 1574 | + negotiated_version = extract_protocol_version_from_sse(init_response) |
1586 | 1575 |
|
1587 | 1576 | response = requests.post(
|
1588 | 1577 | f"{basic_server_url}/mcp",
|
|
0 commit comments