Skip to content

Commit 97d3902

Browse files
committed
test progress context manager
1 parent 198f96a commit 97d3902

File tree

1 file changed

+148
-1
lines changed

1 file changed

+148
-1
lines changed

tests/shared/test_progress_notifications.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, cast
2+
13
import anyio
24
import pytest
35

@@ -7,7 +9,13 @@
79
from mcp.server.lowlevel import NotificationOptions
810
from mcp.server.models import InitializationOptions
911
from mcp.server.session import ServerSession
10-
from mcp.shared.session import RequestResponder, SessionMessage
12+
from mcp.shared.context import RequestContext
13+
from mcp.shared.progress import progress
14+
from mcp.shared.session import (
15+
BaseSession,
16+
RequestResponder,
17+
SessionMessage,
18+
)
1119

1220

1321
@pytest.mark.anyio
@@ -209,3 +217,142 @@ async def handle_client_message(
209217
assert server_progress_updates[0]["progress"] == 0.33
210218
assert server_progress_updates[0]["message"] == "Client progress 33%"
211219
assert server_progress_updates[2]["progress"] == 1.0
220+
221+
222+
@pytest.mark.anyio
223+
async def test_progress_context_manager():
224+
"""Test client using progress context manager for sending progress notifications."""
225+
# Create memory streams for client/server
226+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
227+
SessionMessage
228+
](5)
229+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
230+
SessionMessage
231+
](5)
232+
233+
# Track progress updates
234+
server_progress_updates = []
235+
236+
server = Server(name="ProgressContextTestServer")
237+
238+
# Register progress handler
239+
@server.progress_notification()
240+
async def handle_progress(
241+
progress_token: str | int,
242+
progress: float,
243+
total: float | None,
244+
message: str | None,
245+
):
246+
server_progress_updates.append(
247+
{
248+
"token": progress_token,
249+
"progress": progress,
250+
"total": total,
251+
"message": message,
252+
}
253+
)
254+
255+
# Run server session to receive progress updates
256+
async def run_server():
257+
# Create a server session
258+
async with ServerSession(
259+
client_to_server_receive,
260+
server_to_client_send,
261+
InitializationOptions(
262+
server_name="ProgressContextTestServer",
263+
server_version="0.1.0",
264+
capabilities=server.get_capabilities(NotificationOptions(), {}),
265+
),
266+
) as server_session:
267+
async for message in server_session.incoming_messages:
268+
try:
269+
await server._handle_message(message, server_session, ())
270+
except Exception as e:
271+
raise e
272+
273+
# Client message handler
274+
async def handle_client_message(
275+
message: RequestResponder[types.ServerRequest, types.ClientResult]
276+
| types.ServerNotification
277+
| Exception,
278+
) -> None:
279+
if isinstance(message, Exception):
280+
raise message
281+
282+
# run client session
283+
async with (
284+
ClientSession(
285+
server_to_client_receive,
286+
client_to_server_send,
287+
message_handler=handle_client_message,
288+
) as client_session,
289+
anyio.create_task_group() as tg,
290+
):
291+
tg.start_soon(run_server)
292+
293+
await client_session.initialize()
294+
295+
progress_token = "client_token_456"
296+
297+
# Create request context
298+
meta = types.RequestParams.Meta(progressToken=progress_token)
299+
request_context = RequestContext(
300+
request_id="test-request",
301+
session=client_session,
302+
meta=meta,
303+
lifespan_context=None,
304+
)
305+
306+
# cast for type checker
307+
typed_context = cast(
308+
RequestContext[
309+
BaseSession[Any, Any, Any, Any, Any],
310+
Any,
311+
],
312+
request_context,
313+
)
314+
315+
# Utilize progress context manager
316+
with progress(typed_context, total=100) as p:
317+
await p.progress(10, message="Loading configuration...")
318+
await anyio.sleep(0.1)
319+
320+
await p.progress(30, message="Connecting to database...")
321+
await anyio.sleep(0.1)
322+
323+
await p.progress(40, message="Fetching data...")
324+
await anyio.sleep(0.1)
325+
326+
await p.progress(20, message="Processing results...")
327+
await anyio.sleep(0.1)
328+
329+
# Wait for all messages to be processed
330+
await anyio.sleep(0.5)
331+
tg.cancel_scope.cancel()
332+
333+
# Verify progress updates were received by server
334+
assert len(server_progress_updates) == 4
335+
336+
# first update
337+
assert server_progress_updates[0]["token"] == progress_token
338+
assert server_progress_updates[0]["progress"] == 10
339+
assert server_progress_updates[0]["total"] == 100
340+
assert server_progress_updates[0]["message"] == "Loading configuration..."
341+
342+
# second update
343+
assert server_progress_updates[1]["token"] == progress_token
344+
assert server_progress_updates[1]["progress"] == 40
345+
assert server_progress_updates[1]["total"] == 100
346+
assert server_progress_updates[1]["message"] == "Connecting to database..."
347+
348+
# third update
349+
assert server_progress_updates[2]["token"] == progress_token
350+
assert server_progress_updates[2]["progress"] == 80
351+
assert server_progress_updates[2]["total"] == 100
352+
assert server_progress_updates[2]["message"] == "Fetching data..."
353+
354+
# final update
355+
assert server_progress_updates[3]["token"] == progress_token
356+
assert server_progress_updates[3]["progress"] == 100
357+
assert server_progress_updates[3]["total"] == 100
358+
assert server_progress_updates[3]["message"] == "Processing results..."

0 commit comments

Comments
 (0)