9
9
"""
10
10
11
11
import contextlib
12
+ import logging
12
13
from collections .abc import Callable
13
14
from datetime import timedelta
15
+ from types import TracebackType
14
16
from typing import Any , TypeAlias
15
17
18
+ import anyio
16
19
from pydantic import BaseModel
20
+ from typing_extensions import Self
17
21
18
22
import mcp
19
23
from mcp import types
@@ -67,11 +71,18 @@ class ClientSessionGroup:
67
71
"""Client for managing connections to multiple MCP servers.
68
72
69
73
This class is responsible for encapsulating management of server connections.
70
- It it aggregates tools, resources, and prompts from all connected servers.
74
+ It aggregates tools, resources, and prompts from all connected servers.
71
75
72
76
For auxiliary handlers, such as resource subscription, this is delegated to
73
- the client and can be accessed via the session. For example:
74
- mcp_session_group.get_session("server_name").subscribe_to_resource(...)
77
+ the client and can be accessed via the session.
78
+
79
+ Example Usage:
80
+ name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
81
+ async with ClientSessionGroup(component_name_hook=name_fn) as group:
82
+ for server_params in server_params:
83
+ group.connect_to_server(server_param)
84
+ ...
85
+
75
86
"""
76
87
77
88
class _ComponentNames (BaseModel ):
@@ -90,6 +101,7 @@ class _ComponentNames(BaseModel):
90
101
_sessions : dict [mcp .ClientSession , _ComponentNames ]
91
102
_tool_to_session : dict [str , mcp .ClientSession ]
92
103
_exit_stack : contextlib .AsyncExitStack
104
+ _session_exit_stacks : dict [mcp .ClientSession , contextlib .AsyncExitStack ]
93
105
94
106
# Optional fn consuming (component_name, serverInfo) for custom names.
95
107
# This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +111,7 @@ class _ComponentNames(BaseModel):
99
111
100
112
def __init__ (
101
113
self ,
102
- exit_stack : contextlib .AsyncExitStack = contextlib . AsyncExitStack () ,
114
+ exit_stack : contextlib .AsyncExitStack | None = None ,
103
115
component_name_hook : _ComponentNameHook | None = None ,
104
116
) -> None :
105
117
"""Initializes the MCP client."""
@@ -110,9 +122,43 @@ def __init__(
110
122
111
123
self ._sessions = {}
112
124
self ._tool_to_session = {}
113
- self ._exit_stack = exit_stack
125
+ if exit_stack is None :
126
+ self ._exit_stack = contextlib .AsyncExitStack ()
127
+ self ._owns_exit_stack = True
128
+ else :
129
+ self ._exit_stack = exit_stack
130
+ self ._owns_exit_stack = False
131
+ self ._session_exit_stacks = {}
114
132
self ._component_name_hook = component_name_hook
115
133
134
+ async def __aenter__ (self ) -> Self :
135
+ # Enter the exit stack only if we created it ourselves
136
+ if self ._owns_exit_stack :
137
+ await self ._exit_stack .__aenter__ ()
138
+ return self
139
+
140
+ async def __aexit__ (
141
+ self ,
142
+ _exc_type : type [BaseException ] | None ,
143
+ _exc_val : BaseException | None ,
144
+ _exc_tb : TracebackType | None ,
145
+ ) -> bool | None :
146
+ """Closes session exit stacks and main exit stack upon completion."""
147
+
148
+ # Concurrently close session stacks.
149
+ async with anyio .create_task_group () as tg :
150
+ for exit_stack in self ._session_exit_stacks .values ():
151
+ tg .start_soon (exit_stack .aclose )
152
+
153
+ # Only close the main exit stack if we created it
154
+ if self ._owns_exit_stack :
155
+ await self ._exit_stack .aclose ()
156
+
157
+ @property
158
+ def sessions (self ) -> list [mcp .ClientSession ]:
159
+ """Returns the list of sessions being managed."""
160
+ return list (self ._sessions .keys ())
161
+
116
162
@property
117
163
def prompts (self ) -> dict [str , types .Prompt ]:
118
164
"""Returns the prompts as a dictionary of names to prompts."""
@@ -131,42 +177,113 @@ def tools(self) -> dict[str, types.Tool]:
131
177
async def call_tool (self , name : str , args : dict [str , Any ]) -> types .CallToolResult :
132
178
"""Executes a tool given its name and arguments."""
133
179
session = self ._tool_to_session [name ]
134
- return await session .call_tool (name , args )
180
+ session_tool_name = self .tools [name ].name
181
+ return await session .call_tool (session_tool_name , args )
135
182
136
- def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
183
+ async def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
137
184
"""Disconnects from a single MCP server."""
138
185
139
- if session not in self ._sessions :
186
+ session_known_for_components = session in self ._sessions
187
+ session_known_for_stack = session in self ._session_exit_stacks
188
+
189
+ if not session_known_for_components and not session_known_for_stack :
140
190
raise McpError (
141
191
types .ErrorData (
142
192
code = types .INVALID_PARAMS ,
143
- message = "Provided session is not being managed." ,
193
+ message = "Provided session is not managed or already disconnected ." ,
144
194
)
145
195
)
146
- component_names = self ._sessions [session ]
147
-
148
- # Remove prompts associated with the session.
149
- for name in component_names .prompts :
150
- del self ._prompts [name ]
151
196
152
- # Remove resources associated with the session.
153
- for name in component_names .resources :
154
- del self ._resources [name ]
155
-
156
- # Remove tools associated with the session.
157
- for name in component_names .tools :
158
- del self ._tools [name ]
159
-
160
- del self ._sessions [session ]
197
+ if session_known_for_components :
198
+ component_names = self ._sessions .pop (session ) # Pop from _sessions tracking
199
+
200
+ # Remove prompts associated with the session.
201
+ for name in component_names .prompts :
202
+ if name in self ._prompts :
203
+ del self ._prompts [name ]
204
+ # Remove resources associated with the session.
205
+ for name in component_names .resources :
206
+ if name in self ._resources :
207
+ del self ._resources [name ]
208
+ # Remove tools associated with the session.
209
+ for name in component_names .tools :
210
+ if name in self ._tools :
211
+ del self ._tools [name ]
212
+ if name in self ._tool_to_session :
213
+ del self ._tool_to_session [name ]
214
+
215
+ # Clean up the session's resources via its dedicated exit stack
216
+ if session_known_for_stack :
217
+ session_stack_to_close = self ._session_exit_stacks .pop (session )
218
+ await session_stack_to_close .aclose ()
219
+
220
+ async def connect_with_session (
221
+ self , server_info : types .Implementation , session : mcp .ClientSession
222
+ ) -> mcp .ClientSession :
223
+ """Connects to a single MCP server."""
224
+ await self ._aggregate_components (server_info , session )
225
+ return session
161
226
162
227
async def connect_to_server (
163
228
self ,
164
229
server_params : ServerParameters ,
165
230
) -> mcp .ClientSession :
166
231
"""Connects to a single MCP server."""
167
-
168
- # Establish server connection and create session.
169
232
server_info , session = await self ._establish_session (server_params )
233
+ return await self .connect_with_session (server_info , session )
234
+
235
+ async def _establish_session (
236
+ self , server_params : ServerParameters
237
+ ) -> tuple [types .Implementation , mcp .ClientSession ]:
238
+ """Establish a client session to an MCP server."""
239
+
240
+ session_stack = contextlib .AsyncExitStack ()
241
+ try :
242
+ # Create read and write streams that facilitate io with the server.
243
+ if isinstance (server_params , StdioServerParameters ):
244
+ client = mcp .stdio_client (server_params )
245
+ read , write = await session_stack .enter_async_context (client )
246
+ elif isinstance (server_params , SseServerParameters ):
247
+ client = sse_client (
248
+ url = server_params .url ,
249
+ headers = server_params .headers ,
250
+ timeout = server_params .timeout ,
251
+ sse_read_timeout = server_params .sse_read_timeout ,
252
+ )
253
+ read , write = await session_stack .enter_async_context (client )
254
+ else :
255
+ client = streamablehttp_client (
256
+ url = server_params .url ,
257
+ headers = server_params .headers ,
258
+ timeout = server_params .timeout ,
259
+ sse_read_timeout = server_params .sse_read_timeout ,
260
+ terminate_on_close = server_params .terminate_on_close ,
261
+ )
262
+ read , write , _ = await session_stack .enter_async_context (client )
263
+
264
+ session = await session_stack .enter_async_context (
265
+ mcp .ClientSession (read , write )
266
+ )
267
+ result = await session .initialize ()
268
+
269
+ # Session successfully initialized.
270
+ # Store its stack and register the stack with the main group stack.
271
+ self ._session_exit_stacks [session ] = session_stack
272
+ # session_stack itself becomes a resource managed by the
273
+ # main _exit_stack.
274
+ await self ._exit_stack .enter_async_context (session_stack )
275
+
276
+ return result .serverInfo , session
277
+ except Exception :
278
+ # If anything during this setup fails, ensure the session-specific
279
+ # stack is closed.
280
+ await session_stack .aclose ()
281
+ raise
282
+
283
+ async def _aggregate_components (
284
+ self , server_info : types .Implementation , session : mcp .ClientSession
285
+ ) -> None :
286
+ """Aggregates prompts, resources, and tools from a given session."""
170
287
171
288
# Create a reverse index so we can find all prompts, resources, and
172
289
# tools belonging to this session. Used for removing components from
@@ -181,47 +298,66 @@ async def connect_to_server(
181
298
tool_to_session_temp : dict [str , mcp .ClientSession ] = {}
182
299
183
300
# Query the server for its prompts and aggregate to list.
184
- prompts = (await session .list_prompts ()).prompts
185
- for prompt in prompts :
186
- name = self ._component_name (prompt .name , server_info )
187
- if name in self ._prompts :
188
- raise McpError (
189
- types .ErrorData (
190
- code = types .INVALID_PARAMS ,
191
- message = f"{ name } already exists in group prompts." ,
192
- )
193
- )
194
- prompts_temp [name ] = prompt
195
- component_names .prompts .add (name )
301
+ try :
302
+ prompts = (await session .list_prompts ()).prompts
303
+ for prompt in prompts :
304
+ name = self ._component_name (prompt .name , server_info )
305
+ prompts_temp [name ] = prompt
306
+ component_names .prompts .add (name )
307
+ except McpError as err :
308
+ logging .warning (f"Could not fetch prompts: { err } " )
196
309
197
310
# Query the server for its resources and aggregate to list.
198
- resources = (await session .list_resources ()).resources
199
- for resource in resources :
200
- name = self ._component_name (resource .name , server_info )
201
- if name in self ._resources :
202
- raise McpError (
203
- types .ErrorData (
204
- code = types .INVALID_PARAMS ,
205
- message = f"{ name } already exists in group resources." ,
206
- )
207
- )
208
- resources_temp [name ] = resource
209
- component_names .resources .add (name )
311
+ try :
312
+ resources = (await session .list_resources ()).resources
313
+ for resource in resources :
314
+ name = self ._component_name (resource .name , server_info )
315
+ resources_temp [name ] = resource
316
+ component_names .resources .add (name )
317
+ except McpError as err :
318
+ logging .warning (f"Could not fetch resources: { err } " )
210
319
211
320
# Query the server for its tools and aggregate to list.
212
- tools = (await session .list_tools ()).tools
213
- for tool in tools :
214
- name = self ._component_name (tool .name , server_info )
215
- if name in self ._tools :
216
- raise McpError (
217
- types .ErrorData (
218
- code = types .INVALID_PARAMS ,
219
- message = f"{ name } already exists in group tools." ,
220
- )
321
+ try :
322
+ tools = (await session .list_tools ()).tools
323
+ for tool in tools :
324
+ name = self ._component_name (tool .name , server_info )
325
+ tools_temp [name ] = tool
326
+ tool_to_session_temp [name ] = session
327
+ component_names .tools .add (name )
328
+ except McpError as err :
329
+ logging .warning (f"Could not fetch tools: { err } " )
330
+
331
+ # Clean up exit stack for session if we couldn't retrieve anything
332
+ # from the server.
333
+ if not any ((prompts_temp , resources_temp , tools_temp )):
334
+ del self ._session_exit_stacks [session ]
335
+
336
+ # Check for duplicates.
337
+ matching_prompts = prompts_temp .keys () & self ._prompts .keys ()
338
+ if matching_prompts :
339
+ raise McpError (
340
+ types .ErrorData (
341
+ code = types .INVALID_PARAMS ,
342
+ message = f"{ matching_prompts } already exist in group prompts." ,
343
+ )
344
+ )
345
+ matching_resources = resources_temp .keys () & self ._resources .keys ()
346
+ if matching_resources :
347
+ raise McpError (
348
+ types .ErrorData (
349
+ code = types .INVALID_PARAMS ,
350
+ message = f"{ matching_resources } already exist in group resources." ,
351
+ )
352
+ )
353
+ matching_tools = tools_temp .keys () & self ._tools .keys ()
354
+ if matching_tools :
355
+ raise McpError (
356
+ types .ErrorData (
357
+ code = types .INVALID_PARAMS ,
358
+ message = f"{ matching_tools } already exist in group tools." ,
221
359
)
222
- tools_temp [name ] = tool
223
- tool_to_session_temp [name ] = session
224
- component_names .tools .add (name )
360
+ )
225
361
226
362
# Aggregate components.
227
363
self ._sessions [session ] = component_names
@@ -230,41 +366,6 @@ async def connect_to_server(
230
366
self ._tools .update (tools_temp )
231
367
self ._tool_to_session .update (tool_to_session_temp )
232
368
233
- return session
234
-
235
- async def _establish_session (
236
- self , server_params : ServerParameters
237
- ) -> tuple [types .Implementation , mcp .ClientSession ]:
238
- """Establish a client session to an MCP server."""
239
-
240
- # Create read and write streams that facilitate io with the server.
241
- if isinstance (server_params , StdioServerParameters ):
242
- client = mcp .stdio_client (server_params )
243
- read , write = await self ._exit_stack .enter_async_context (client )
244
- elif isinstance (server_params , SseServerParameters ):
245
- client = sse_client (
246
- url = server_params .url ,
247
- headers = server_params .headers ,
248
- timeout = server_params .timeout ,
249
- sse_read_timeout = server_params .sse_read_timeout ,
250
- )
251
- read , write = await self ._exit_stack .enter_async_context (client )
252
- else :
253
- client = streamablehttp_client (
254
- url = server_params .url ,
255
- headers = server_params .headers ,
256
- timeout = server_params .timeout ,
257
- sse_read_timeout = server_params .sse_read_timeout ,
258
- terminate_on_close = server_params .terminate_on_close ,
259
- )
260
- read , write , _ = await self ._exit_stack .enter_async_context (client )
261
-
262
- session = await self ._exit_stack .enter_async_context (
263
- mcp .ClientSession (read , write )
264
- )
265
- result = await session .initialize ()
266
- return result .serverInfo , session
267
-
268
369
def _component_name (self , name : str , server_info : types .Implementation ) -> str :
269
370
if self ._component_name_hook :
270
371
return self ._component_name_hook (name , server_info )
0 commit comments