Skip to content

Fix SSE server transport to support relative and absolute endpoints #633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ async def handle_sse(request):
import logging
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import quote
from uuid import UUID, uuid4

import anyio
Expand Down Expand Up @@ -100,7 +99,7 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
session_uri = f"{self._endpoint}?session_id={session_id.hex}"
self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")

Expand Down
277 changes: 267 additions & 10 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,54 @@ def run_server(server_port: int) -> None:
time.sleep(0.5)


def make_server_app_with_endpoint(endpoint: str) -> Starlette:
"""Create test Starlette app with SSE transport using the specified endpoint"""
sse = SseServerTransport(endpoint)
server = ServerTest()

async def handle_sse(request: Request) -> Response:
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
return Response()

# For absolute URLs, we route all paths
if endpoint.startswith(("http://", "https://")):
route_path = "/sse"
mount_path = "/"
else:
route_path = "/sse"
mount_path = endpoint

app = Starlette(
routes=[
Route(route_path, endpoint=handle_sse),
Mount(mount_path, app=sse.handle_post_message),
]
)

return app


def run_server_with_endpoint(server_port: int, endpoint: str) -> None:
app = make_server_app_with_endpoint(endpoint)
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting server on {server_port} with endpoint {endpoint}")
server.run()

# Give server time to start
while not server.started:
print("waiting for server to start")
time.sleep(0.5)


@pytest.fixture()
def server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process(
Expand Down Expand Up @@ -159,6 +207,129 @@ async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, N
yield client


@pytest.fixture()
def server_with_relative_endpoint(server_port: int) -> Generator[None, None, None]:
"""Setup a server with a relative endpoint path"""
proc = multiprocessing.Process(
target=run_server_with_endpoint,
kwargs={"server_port": server_port, "endpoint": "/messages/"},
daemon=True,
)
print("starting process with relative endpoint")
proc.start()

# Wait for server to be running
max_attempts = 20
attempt = 0
print("waiting for server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")

yield

print("killing server")
# Signal the server to stop
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("server process failed to terminate")


@pytest.fixture()
def server_with_absolute_endpoint(
server_port: int, server_url: str
) -> Generator[None, None, None]:
"""Setup a server with an absolute endpoint URL"""
absolute_endpoint = f"{server_url}/messages/"
proc = multiprocessing.Process(
target=run_server_with_endpoint,
kwargs={"server_port": server_port, "endpoint": absolute_endpoint},
daemon=True,
)
print(f"starting process with absolute endpoint: {absolute_endpoint}")
proc.start()

# Wait for server to be running
max_attempts = 20
attempt = 0
print("waiting for server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")

yield

print("killing server")
# Signal the server to stop
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("server process failed to terminate")


@pytest.fixture()
async def http_client_with_relative_endpoint(
server_with_relative_endpoint, server_url
) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create test client for server with relative endpoint"""
async with httpx.AsyncClient(base_url=server_url) as client:
yield client


@pytest.fixture()
async def http_client_with_absolute_endpoint(
server_with_absolute_endpoint, server_url
) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create test client for server with absolute endpoint"""
async with httpx.AsyncClient(base_url=server_url) as client:
yield client


@pytest.fixture
async def initialized_sse_client_session(
server, server_url: str
) -> AsyncGenerator[ClientSession, None]:
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session


@pytest.fixture
async def initialized_sse_client_session_with_relative_endpoint(
server_with_relative_endpoint, server_url: str
) -> AsyncGenerator[ClientSession, None]:
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session


@pytest.fixture
async def initialized_sse_client_session_with_absolute_endpoint(
server_with_absolute_endpoint, server_url: str
) -> AsyncGenerator[ClientSession, None]:
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session


# Tests
@pytest.mark.anyio
async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
Expand Down Expand Up @@ -202,16 +373,6 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
assert isinstance(ping_result, EmptyResult)


@pytest.fixture
async def initialized_sse_client_session(
server, server_url: str
) -> AsyncGenerator[ClientSession, None]:
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session


@pytest.mark.anyio
async def test_sse_client_happy_request_and_response(
initialized_sse_client_session: ClientSession,
Expand Down Expand Up @@ -252,3 +413,99 @@ async def test_sse_client_timeout(
return

pytest.fail("the client should have timed out and returned an error already")


@pytest.mark.anyio
async def test_raw_sse_connection_with_relative_endpoint(
http_client_with_relative_endpoint: httpx.AsyncClient,
) -> None:
"""Test the SSE connection establishment with a relative endpoint URL."""
async with anyio.create_task_group():

async def connection_test() -> None:
async with http_client_with_relative_endpoint.stream(
"GET", "/sse"
) as response:
assert response.status_code == 200
assert (
response.headers["content-type"]
== "text/event-stream; charset=utf-8"
)

line_number = 0
async for line in response.aiter_lines():
if line_number == 0:
assert line == "event: endpoint"
elif line_number == 1:
assert line.startswith("data: /messages/?session_id=")
# Verify it's a relative URL
endpoint_data = line.removeprefix("data: ")
assert not endpoint_data.startswith(("http://", "https://"))
assert endpoint_data.startswith("/messages/?session_id=")
else:
return
line_number += 1

# Add timeout to prevent test from hanging if it fails
with anyio.fail_after(3):
await connection_test()


@pytest.mark.anyio
async def test_raw_sse_connection_with_absolute_endpoint(
http_client_with_absolute_endpoint: httpx.AsyncClient,
) -> None:
"""Test the SSE connection establishment with an absolute endpoint URL."""
async with anyio.create_task_group():

async def connection_test() -> None:
async with http_client_with_absolute_endpoint.stream(
"GET", "/sse"
) as response:
assert response.status_code == 200
assert (
response.headers["content-type"]
== "text/event-stream; charset=utf-8"
)

line_number = 0
async for line in response.aiter_lines():
if line_number == 0:
assert line == "event: endpoint"
elif line_number == 1:
# Verify it's an absolute URL
assert line.startswith("data: http://")
assert "/messages/?session_id=" in line
else:
return
line_number += 1

# Add timeout to prevent test from hanging if it fails
with anyio.fail_after(3):
await connection_test()


@pytest.mark.anyio
async def test_sse_client_with_relative_endpoint(
initialized_sse_client_session_with_relative_endpoint: ClientSession,
) -> None:
"""Test that a client session works properly with a relative endpoint."""
session = initialized_sse_client_session_with_relative_endpoint
# Test basic functionality
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
assert len(response.contents) == 1
assert isinstance(response.contents[0], TextResourceContents)
assert response.contents[0].text == "Read should-work"


@pytest.mark.anyio
async def test_sse_client_with_absolute_endpoint(
initialized_sse_client_session_with_absolute_endpoint: ClientSession,
) -> None:
"""Test that a client session works properly with an absolute endpoint."""
session = initialized_sse_client_session_with_absolute_endpoint
# Test basic functionality
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
assert len(response.contents) == 1
assert isinstance(response.contents[0], TextResourceContents)
assert response.contents[0].text == "Read should-work"
Loading