Skip to content

Prevent stdio connection hang for missing server path. #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,28 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

command = _get_executable_command(server.command)

# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=(
{**get_default_environment(), **server.env}
if server.env is not None
else get_default_environment()
),
errlog=errlog,
cwd=server.cwd,
)
try:
command = _get_executable_command(server.command)

# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=(
{**get_default_environment(), **server.env}
if server.env is not None
else get_default_environment()
),
errlog=errlog,
cwd=server.cwd,
)
except OSError:
# Clean up streams if process creation fails
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
raise

async def stdout_reader():
assert process.stdout, "Opened process is missing stdout"
Expand Down Expand Up @@ -177,12 +185,18 @@ async def stdin_writer():
yield read_stream, write_stream
finally:
# Clean up process to prevent any dangling orphaned processes
if sys.platform == "win32":
await terminate_windows_process(process)
else:
process.terminate()
try:
if sys.platform == "win32":
await terminate_windows_process(process)
else:
process.terminate()
except ProcessLookupError:
# Process already exited, which is fine
pass
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()


def _get_executable_command(command: str) -> str:
Expand Down
50 changes: 48 additions & 2 deletions tests/client/test_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

import pytest

from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.session import ClientSession
from mcp.client.stdio import (
StdioServerParameters,
stdio_client,
)
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse

tee: str = shutil.which("tee") # type: ignore
python: str = shutil.which("python") # type: ignore


@pytest.mark.anyio
Expand Down Expand Up @@ -50,3 +56,43 @@ async def test_stdio_client():
assert read_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)


@pytest.mark.anyio
async def test_stdio_client_bad_path():
"""Check that the connection doesn't hang if process errors."""
server_params = StdioServerParameters(
command="python", args=["-c", "non-existent-file.py"]
)
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
# The session should raise an error when the connection closes
with pytest.raises(McpError) as exc_info:
await session.initialize()

# Check that we got a connection closed error
assert exc_info.value.error.code == CONNECTION_CLOSED
assert "Connection closed" in exc_info.value.error.message


@pytest.mark.anyio
async def test_stdio_client_nonexistent_command():
"""Test that stdio_client raises an error for non-existent commands."""
# Create a server with a non-existent command
server_params = StdioServerParameters(
command="/path/to/nonexistent/command",
args=["--help"],
)

# Should raise an error when trying to start the process
with pytest.raises(Exception) as exc_info:
async with stdio_client(server_params) as (_, _):
pass

# The error should indicate the command was not found
error_message = str(exc_info.value)
assert (
"nonexistent" in error_message
or "not found" in error_message.lower()
or "cannot find the file" in error_message.lower() # Windows error message
)
Loading