Skip to content

Commit a0f6562

Browse files
refactor: extract protocol version parsing to helper function
- Add extract_protocol_version_from_sse helper function to reduce code duplication - Replace repeated protocol version extraction logic in 5 test functions - Fix line length issues in docstrings to comply with 88 char limit This improves test maintainability by centralizing the SSE response parsing logic.
1 parent 397dcdf commit a0f6562

File tree

1 file changed

+14
-25
lines changed

1 file changed

+14
-25
lines changed

tests/shared/test_streamable_http.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@
6565
}
6666

6767

68+
# Helper functions
69+
def extract_protocol_version_from_sse(response: requests.Response) -> str:
70+
"""Extract the negotiated protocol version from an SSE initialization response."""
71+
assert response.headers.get("Content-Type") == "text/event-stream"
72+
for line in response.text.splitlines():
73+
if line.startswith("data: "):
74+
init_data = json.loads(line[6:])
75+
return init_data["result"]["protocolVersion"]
76+
raise ValueError("Could not extract protocol version from SSE response")
77+
78+
6879
# Simple in-memory event store for testing
6980
class SimpleEventStore(EventStore):
7081
"""Simple in-memory event store for testing."""
@@ -578,14 +589,7 @@ def test_session_termination(basic_server, basic_server_url):
578589
assert response.status_code == 200
579590

580591
# Extract negotiated protocol version from SSE response
581-
init_data = None
582-
assert response.headers.get("Content-Type") == "text/event-stream"
583-
for line in response.text.splitlines():
584-
if line.startswith("data: "):
585-
init_data = json.loads(line[6:])
586-
break
587-
assert init_data is not None
588-
negotiated_version = init_data["result"]["protocolVersion"]
592+
negotiated_version = extract_protocol_version_from_sse(response)
589593

590594
# Now terminate the session
591595
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
@@ -626,14 +630,7 @@ def test_response(basic_server, basic_server_url):
626630
assert response.status_code == 200
627631

628632
# Extract negotiated protocol version from SSE response
629-
init_data = None
630-
assert response.headers.get("Content-Type") == "text/event-stream"
631-
for line in response.text.splitlines():
632-
if line.startswith("data: "):
633-
init_data = json.loads(line[6:])
634-
break
635-
assert init_data is not None
636-
negotiated_version = init_data["result"]["protocolVersion"]
633+
negotiated_version = extract_protocol_version_from_sse(response)
637634

638635
# Now get the session ID
639636
session_id = response.headers.get(MCP_SESSION_ID_HEADER)
@@ -1574,15 +1571,7 @@ async def test_server_validates_protocol_version_header(basic_server, basic_serv
15741571
)
15751572

15761573
# Test request with valid protocol version (should succeed)
1577-
init_data = None
1578-
assert init_response.headers.get("Content-Type") == "text/event-stream"
1579-
for line in init_response.text.splitlines():
1580-
if line.startswith("data: "):
1581-
init_data = json.loads(line[6:])
1582-
break
1583-
1584-
assert init_data is not None
1585-
negotiated_version = init_data["result"]["protocolVersion"]
1574+
negotiated_version = extract_protocol_version_from_sse(init_response)
15861575

15871576
response = requests.post(
15881577
f"{basic_server_url}/mcp",

0 commit comments

Comments
 (0)