Skip to content

Doc update + validation in SseServerTransport + existing test fixes: addresses Issue: #827 #900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,37 @@ class SseServerTransport:
def __init__(self, endpoint: str) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.
messages to the relative path given.

Args:
endpoint: A relative path where messages should be posted
(e.g., "/messages/").

Note:
We use relative paths instead of full URLs for several reasons:
1. Security: Prevents cross-origin requests by ensuring clients only connect
to the same origin they established the SSE connection with
2. Flexibility: The server can be mounted at any path without needing to
know its full URL
3. Portability: The same endpoint configuration works across different
environments (development, staging, production)

Raises:
ValueError: If the endpoint is a full URL instead of a relative path
"""

super().__init__()

# Validate that endpoint is a relative path and not a full URL
if "://" in endpoint or endpoint.startswith("//"):
raise ValueError(
"Endpoint must be a relative path (e.g., '/messages/'), not a full URL."
)

# Ensure endpoint starts with a forward slash
if not endpoint.startswith("/"):
endpoint = "/" + endpoint

self._endpoint = endpoint
self._read_stream_writers = {}
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
Expand Down
5 changes: 1 addition & 4 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import httpx
import pytest
from inline_snapshot import snapshot
from pydantic import AnyHttpUrl

from mcp.client.auth import OAuthClientProvider
Expand Down Expand Up @@ -968,8 +967,7 @@ def test_build_metadata(
revocation_options=RevocationOptions(enabled=True),
)

assert metadata == snapshot(
OAuthMetadata(
assert metadata == OAuthMetadata(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the error you were having?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when I run uv run pytest I get the following error:
image

As a fix I tried running:
uv run pytest --inline-snapshot=fix tests/client/test_auth.py but got the same error again, alternatively I tried using Is for parameterised variables, but that did not seem like fully correct either.

issuer=AnyHttpUrl(issuer_url),
authorization_endpoint=AnyHttpUrl(authorization_endpoint),
token_endpoint=AnyHttpUrl(token_endpoint),
Expand All @@ -982,4 +980,3 @@ def test_build_metadata(
revocation_endpoint_auth_methods_supported=["client_secret_post"],
code_challenge_methods_supported=["S256"],
)
)
23 changes: 16 additions & 7 deletions tests/issues/test_188_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,38 @@
@pytest.mark.anyio
async def test_messages_are_executed_concurrently():
server = FastMCP("test")

call_timestamps = []

@server.tool("sleep")
async def sleep_tool():
call_timestamps.append(("tool_start_time", anyio.current_time()))
await anyio.sleep(_sleep_time_seconds)
call_timestamps.append(("tool_end_time", anyio.current_time()))
return "done"

@server.resource(_resource_name)
async def slow_resource():
call_timestamps.append(("resource_start_time", anyio.current_time()))
await anyio.sleep(_sleep_time_seconds)
call_timestamps.append(("resource_end_time", anyio.current_time()))
return "slow"

async with create_session(server._mcp_server) as client_session:
start_time = anyio.current_time()
async with anyio.create_task_group() as tg:
for _ in range(10):
tg.start_soon(client_session.call_tool, "sleep")
tg.start_soon(client_session.read_resource, AnyUrl(_resource_name))

end_time = anyio.current_time()

duration = end_time - start_time
assert duration < 6 * _sleep_time_seconds
print(duration)
active_calls = 0
max_concurrent_calls = 0
for call_type, _ in sorted(call_timestamps, key=lambda x: x[1]):
if "start" in call_type:
active_calls += 1
max_concurrent_calls = max(max_concurrent_calls, active_calls)
else:
active_calls -= 1
print(f"Max concurrent calls: {max_concurrent_calls}")
assert max_concurrent_calls > 1, "No concurrent calls were executed"


def main():
Expand Down
Loading