|
| 1 | +from typing import Any, cast |
| 2 | + |
1 | 3 | import anyio
|
2 | 4 | import pytest
|
3 | 5 |
|
|
7 | 9 | from mcp.server.lowlevel import NotificationOptions
|
8 | 10 | from mcp.server.models import InitializationOptions
|
9 | 11 | from mcp.server.session import ServerSession
|
10 |
| -from mcp.shared.session import RequestResponder, SessionMessage |
| 12 | +from mcp.shared.context import RequestContext |
| 13 | +from mcp.shared.progress import progress |
| 14 | +from mcp.shared.session import ( |
| 15 | + BaseSession, |
| 16 | + RequestResponder, |
| 17 | + SessionMessage, |
| 18 | +) |
11 | 19 |
|
12 | 20 |
|
13 | 21 | @pytest.mark.anyio
|
@@ -209,3 +217,142 @@ async def handle_client_message(
|
209 | 217 | assert server_progress_updates[0]["progress"] == 0.33
|
210 | 218 | assert server_progress_updates[0]["message"] == "Client progress 33%"
|
211 | 219 | assert server_progress_updates[2]["progress"] == 1.0
|
| 220 | + |
| 221 | + |
| 222 | +@pytest.mark.anyio |
| 223 | +async def test_progress_context_manager(): |
| 224 | + """Test client using progress context manager for sending progress notifications.""" |
| 225 | + # Create memory streams for client/server |
| 226 | + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ |
| 227 | + SessionMessage |
| 228 | + ](5) |
| 229 | + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ |
| 230 | + SessionMessage |
| 231 | + ](5) |
| 232 | + |
| 233 | + # Track progress updates |
| 234 | + server_progress_updates = [] |
| 235 | + |
| 236 | + server = Server(name="ProgressContextTestServer") |
| 237 | + |
| 238 | + # Register progress handler |
| 239 | + @server.progress_notification() |
| 240 | + async def handle_progress( |
| 241 | + progress_token: str | int, |
| 242 | + progress: float, |
| 243 | + total: float | None, |
| 244 | + message: str | None, |
| 245 | + ): |
| 246 | + server_progress_updates.append( |
| 247 | + { |
| 248 | + "token": progress_token, |
| 249 | + "progress": progress, |
| 250 | + "total": total, |
| 251 | + "message": message, |
| 252 | + } |
| 253 | + ) |
| 254 | + |
| 255 | + # Run server session to receive progress updates |
| 256 | + async def run_server(): |
| 257 | + # Create a server session |
| 258 | + async with ServerSession( |
| 259 | + client_to_server_receive, |
| 260 | + server_to_client_send, |
| 261 | + InitializationOptions( |
| 262 | + server_name="ProgressContextTestServer", |
| 263 | + server_version="0.1.0", |
| 264 | + capabilities=server.get_capabilities(NotificationOptions(), {}), |
| 265 | + ), |
| 266 | + ) as server_session: |
| 267 | + async for message in server_session.incoming_messages: |
| 268 | + try: |
| 269 | + await server._handle_message(message, server_session, ()) |
| 270 | + except Exception as e: |
| 271 | + raise e |
| 272 | + |
| 273 | + # Client message handler |
| 274 | + async def handle_client_message( |
| 275 | + message: RequestResponder[types.ServerRequest, types.ClientResult] |
| 276 | + | types.ServerNotification |
| 277 | + | Exception, |
| 278 | + ) -> None: |
| 279 | + if isinstance(message, Exception): |
| 280 | + raise message |
| 281 | + |
| 282 | + # run client session |
| 283 | + async with ( |
| 284 | + ClientSession( |
| 285 | + server_to_client_receive, |
| 286 | + client_to_server_send, |
| 287 | + message_handler=handle_client_message, |
| 288 | + ) as client_session, |
| 289 | + anyio.create_task_group() as tg, |
| 290 | + ): |
| 291 | + tg.start_soon(run_server) |
| 292 | + |
| 293 | + await client_session.initialize() |
| 294 | + |
| 295 | + progress_token = "client_token_456" |
| 296 | + |
| 297 | + # Create request context |
| 298 | + meta = types.RequestParams.Meta(progressToken=progress_token) |
| 299 | + request_context = RequestContext( |
| 300 | + request_id="test-request", |
| 301 | + session=client_session, |
| 302 | + meta=meta, |
| 303 | + lifespan_context=None, |
| 304 | + ) |
| 305 | + |
| 306 | + # cast for type checker |
| 307 | + typed_context = cast( |
| 308 | + RequestContext[ |
| 309 | + BaseSession[Any, Any, Any, Any, Any], |
| 310 | + Any, |
| 311 | + ], |
| 312 | + request_context, |
| 313 | + ) |
| 314 | + |
| 315 | + # Utilize progress context manager |
| 316 | + with progress(typed_context, total=100) as p: |
| 317 | + await p.progress(10, message="Loading configuration...") |
| 318 | + await anyio.sleep(0.1) |
| 319 | + |
| 320 | + await p.progress(30, message="Connecting to database...") |
| 321 | + await anyio.sleep(0.1) |
| 322 | + |
| 323 | + await p.progress(40, message="Fetching data...") |
| 324 | + await anyio.sleep(0.1) |
| 325 | + |
| 326 | + await p.progress(20, message="Processing results...") |
| 327 | + await anyio.sleep(0.1) |
| 328 | + |
| 329 | + # Wait for all messages to be processed |
| 330 | + await anyio.sleep(0.5) |
| 331 | + tg.cancel_scope.cancel() |
| 332 | + |
| 333 | + # Verify progress updates were received by server |
| 334 | + assert len(server_progress_updates) == 4 |
| 335 | + |
| 336 | + # first update |
| 337 | + assert server_progress_updates[0]["token"] == progress_token |
| 338 | + assert server_progress_updates[0]["progress"] == 10 |
| 339 | + assert server_progress_updates[0]["total"] == 100 |
| 340 | + assert server_progress_updates[0]["message"] == "Loading configuration..." |
| 341 | + |
| 342 | + # second update |
| 343 | + assert server_progress_updates[1]["token"] == progress_token |
| 344 | + assert server_progress_updates[1]["progress"] == 40 |
| 345 | + assert server_progress_updates[1]["total"] == 100 |
| 346 | + assert server_progress_updates[1]["message"] == "Connecting to database..." |
| 347 | + |
| 348 | + # third update |
| 349 | + assert server_progress_updates[2]["token"] == progress_token |
| 350 | + assert server_progress_updates[2]["progress"] == 80 |
| 351 | + assert server_progress_updates[2]["total"] == 100 |
| 352 | + assert server_progress_updates[2]["message"] == "Fetching data..." |
| 353 | + |
| 354 | + # final update |
| 355 | + assert server_progress_updates[3]["token"] == progress_token |
| 356 | + assert server_progress_updates[3]["progress"] == 100 |
| 357 | + assert server_progress_updates[3]["total"] == 100 |
| 358 | + assert server_progress_updates[3]["message"] == "Processing results..." |
0 commit comments