Skip to content

Add progress notification callback for client #721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
if __name__ == "__main__":
# Click will handle CLI arguments
import sys

sys.exit(main()) # type: ignore[call-arg]
11 changes: 8 additions & 3 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
Expand Down Expand Up @@ -270,18 +270,23 @@ async def call_tool(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
) -> types.CallToolResult:
"""Send a tools/call request."""
"""Send a tools/call request with optional progress callback support."""

return await self.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name=name, arguments=arguments),
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
),
)
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
)

async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:
Expand Down
1 change: 0 additions & 1 deletion src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,6 @@ async def report_progress(
total: Optional total value e.g. 100
message: Optional message e.g. Starting render...
"""

progress_token = (
self.request_context.meta.progressToken
if self.request_context.meta
Expand Down
44 changes: 39 additions & 5 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import AsyncExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Protocol, TypeVar

import anyio
import httpx
Expand All @@ -24,6 +24,7 @@
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ProgressNotification,
RequestParams,
ServerNotification,
ServerRequest,
Expand All @@ -42,6 +43,14 @@
RequestId = str | int


class ProgressFnT(Protocol):
"""Protocol for progress notification callbacks."""

async def __call__(
self, progress: float, total: float | None, message: str | None
) -> None: ...


class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.

Expand Down Expand Up @@ -169,6 +178,7 @@ class BaseSession(
]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]

def __init__(
self,
Expand All @@ -187,6 +197,7 @@ def __init__(
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_callbacks = {}
self._exit_stack = AsyncExitStack()

async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -214,6 +225,7 @@ async def send_request(
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
Expand All @@ -231,15 +243,25 @@ async def send_request(
](1)
self._response_streams[request_id] = response_stream

# Set up progress token if progress callback is provided
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[properly non-blocking] wondering if we can set this in the pydantic model before dumping? I'm really not fussed either way (it would just give you validation etc)

if progress_callback is not None:
# Use request_id as progress token
if "params" not in request_data:
request_data["params"] = {}
if "_meta" not in request_data["params"]:
request_data["params"]["_meta"] = {}
request_data["params"]["_meta"]["progressToken"] = request_id
# Store the callback for this request
self._progress_callbacks[request_id] = progress_callback

try:
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
**request_data,
)

# TODO: Support progress callbacks

await self._write_stream.send(
SessionMessage(
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
Expand Down Expand Up @@ -275,6 +297,7 @@ async def send_request(

finally:
self._response_streams.pop(request_id, None)
self._progress_callbacks.pop(request_id, None)
await response_stream.aclose()
await response_stream_reader.aclose()

Expand Down Expand Up @@ -333,7 +356,6 @@ async def _receive_loop(self) -> None:
by_alias=True, mode="json", exclude_none=True
)
)

responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
Expand Down Expand Up @@ -363,6 +385,18 @@ async def _receive_loop(self) -> None:
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_list_methods_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

async def test_list_tools_cursor_parameter():
"""Test that the cursor parameter is accepted for list_tools.

Note: FastMCP doesn't currently implement pagination, so this test
only verifies that the cursor parameter is accepted by the client.
"""
Expand Down
Loading
Loading