Skip to content

Commit c0d05ab

Browse files
authored
fix: capture launch errors in client-runtime (#19)
1 parent 1dd3f65 commit c0d05ab

File tree

4 files changed

+67
-60
lines changed

4 files changed

+67
-60
lines changed

crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,47 @@ impl McpClient for ClientRuntime {
8585
self.set_message_sender(sender).await;
8686

8787
let self_clone = Arc::clone(&self);
88-
self_clone.initialize_request().await?;
89-
9088
let self_clone_err = Arc::clone(&self);
9189

90+
let err_task = tokio::spawn(async move {
91+
let self_ref = &*self_clone_err;
92+
93+
if let IoStream::Readable(error_input) = error_io {
94+
let mut reader = BufReader::new(error_input).lines();
95+
loop {
96+
tokio::select! {
97+
should_break = self_ref.transport.is_shut_down() =>{
98+
if should_break {
99+
break;
100+
}
101+
}
102+
line = reader.next_line() =>{
103+
match line {
104+
Ok(Some(error_message)) => {
105+
self_ref
106+
.handler
107+
.handle_process_error(error_message, self_ref)
108+
.await?;
109+
}
110+
Ok(None) => {
111+
// end of input
112+
break;
113+
}
114+
Err(e) => {
115+
eprintln!("Error reading from std_err: {}", e);
116+
break;
117+
}
118+
}
119+
}
120+
}
121+
}
122+
}
123+
Ok::<(), McpSdkError>(())
124+
});
125+
126+
// send initialize request to the MCP server
127+
self_clone.initialize_request().await?;
128+
92129
let main_task = tokio::spawn(async move {
93130
let sender = self_clone.sender().await.read().await;
94131
let sender = sender.as_ref().ok_or(crate::error::McpSdkError::SdkError(
@@ -132,42 +169,6 @@ impl McpClient for ClientRuntime {
132169
Ok::<(), McpSdkError>(())
133170
});
134171

135-
let err_task = tokio::spawn(async move {
136-
let self_ref = &*self_clone_err;
137-
138-
if let IoStream::Readable(error_input) = error_io {
139-
let mut reader = BufReader::new(error_input).lines();
140-
loop {
141-
tokio::select! {
142-
should_break = self_ref.transport.is_shut_down() =>{
143-
if should_break {
144-
break;
145-
}
146-
}
147-
line = reader.next_line() =>{
148-
match line {
149-
Ok(Some(error_message)) => {
150-
self_ref
151-
.handler
152-
.handle_process_error(error_message, self_ref)
153-
.await?;
154-
}
155-
Ok(None) => {
156-
// end of input
157-
break;
158-
}
159-
Err(e) => {
160-
eprintln!("Error reading from std_err: {}", e);
161-
break;
162-
}
163-
}
164-
}
165-
}
166-
}
167-
}
168-
Ok::<(), McpSdkError>(())
169-
});
170-
171172
let mut lock = self.handlers.lock().await;
172173
lock.push(main_task);
173174
lock.push(err_task);

crates/rust-mcp-transport/src/mcp_stream.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ impl MCPStream {
3333
readable: Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>,
3434
writable: Mutex<Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>>,
3535
error_io: IoStream,
36+
pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
3637
timeout_msec: u64,
3738
shutdown_rx: Receiver<bool>,
3839
) -> (
@@ -44,7 +45,6 @@ impl MCPStream {
4445
R: RPCMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
4546
{
4647
let (tx, rx) = tokio::sync::broadcast::channel::<R>(CHANNEL_CAPACITY);
47-
let pending_requests = Arc::new(Mutex::new(HashMap::new()));
4848

4949
#[allow(clippy::let_underscore_future)]
5050
let _ = Self::spawn_reader(readable, tx, pending_requests.clone(), shutdown_rx);

crates/rust-mcp-transport/src/message_dispatcher.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use async_trait::async_trait;
22
use rust_mcp_schema::schema_utils::{
3-
ClientMessage, FromMessage, MCPMessage, MessageFromClient, MessageFromServer, ServerMessage,
3+
self, ClientMessage, FromMessage, MCPMessage, MessageFromClient, MessageFromServer,
4+
ServerMessage,
45
};
56
use rust_mcp_schema::{RequestId, RpcError};
67
use std::collections::HashMap;
@@ -12,7 +13,7 @@ use tokio::io::AsyncWriteExt;
1213
use tokio::sync::oneshot;
1314
use tokio::sync::Mutex;
1415

15-
use crate::error::TransportResult;
16+
use crate::error::{TransportError, TransportResult};
1617
use crate::utils::await_timeout;
1718
use crate::McpDispatch;
1819

@@ -146,9 +147,15 @@ impl McpDispatch<ServerMessage, MessageFromClient> for MessageDispatcher<ServerM
146147
writable_std.flush().await?;
147148

148149
if let Some(rx) = rx_response {
150+
// Wait for the response with timeout
149151
match await_timeout(rx, Duration::from_millis(self.timeout_msec)).await {
150152
Ok(response) => Ok(Some(response)),
151-
Err(error) => Err(error),
153+
Err(error) => match error {
154+
TransportError::OneshotRecvError(_) => {
155+
Err(schema_utils::SdkError::connection_closed().into())
156+
}
157+
_ => Err(error),
158+
},
152159
}
153160
} else {
154161
Ok(None)

crates/rust-mcp-transport/src/stdio.rs

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use async_trait::async_trait;
22
use futures::Stream;
33
use rust_mcp_schema::schema_utils::{MCPMessage, RPCMessage};
4+
use rust_mcp_schema::RequestId;
45
use std::collections::HashMap;
56
use std::pin::Pin;
6-
use tokio::process::{Child, Command};
7+
use std::sync::Arc;
8+
use tokio::process::Command;
79
use tokio::sync::watch::Sender;
810
use tokio::sync::{watch, Mutex};
911

@@ -24,7 +26,6 @@ pub struct StdioTransport {
2426
command: Option<String>,
2527
args: Option<Vec<String>>,
2628
env: Option<HashMap<String, String>>,
27-
process: Mutex<Option<Child>>,
2829
options: TransportOptions,
2930
shutdown_tx: tokio::sync::RwLock<Option<Sender<bool>>>,
3031
is_shut_down: Mutex<bool>,
@@ -49,7 +50,6 @@ impl StdioTransport {
4950
args: None,
5051
command: None,
5152
env: None,
52-
process: Mutex::new(None),
5353
options,
5454
shutdown_tx: tokio::sync::RwLock::new(None),
5555
is_shut_down: Mutex::new(false),
@@ -81,20 +81,12 @@ impl StdioTransport {
8181
args: Some(args),
8282
command: Some(command.into()),
8383
env,
84-
process: Mutex::new(None),
8584
options,
8685
shutdown_tx: tokio::sync::RwLock::new(None),
8786
is_shut_down: Mutex::new(false),
8887
})
8988
}
9089

91-
/// Sets the subprocess handle for the transport.
92-
async fn set_process(&self, value: Child) -> TransportResult<()> {
93-
let mut process = self.process.lock().await;
94-
*process = Some(value);
95-
Ok(())
96-
}
97-
9890
/// Retrieves the command and arguments for launching the subprocess.
9991
///
10092
/// Adjusts the command based on the platform: on Windows, wraps it with `cmd.exe /c`.
@@ -188,22 +180,35 @@ where
188180
.take()
189181
.ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?;
190182

191-
self.set_process(process).await.unwrap();
183+
let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
184+
Arc::new(Mutex::new(HashMap::new()));
185+
let pending_requests_clone = Arc::clone(&pending_requests);
186+
187+
tokio::spawn(async move {
188+
let _ = process.wait().await;
189+
// clean up pending requests to cancel waiting tasks
190+
let mut pending_requests = pending_requests.lock().await;
191+
pending_requests.clear();
192+
});
192193

193194
let (stream, sender, error_stream) = MCPStream::create(
194195
Box::pin(stdout),
195196
Mutex::new(Box::pin(stdin)),
196197
IoStream::Readable(Box::pin(stderr)),
198+
pending_requests_clone,
197199
self.options.timeout,
198200
shutdown_rx,
199201
);
200202

201203
Ok((stream, sender, error_stream))
202204
} else {
205+
let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
206+
Arc::new(Mutex::new(HashMap::new()));
203207
let (stream, sender, error_stream) = MCPStream::create(
204208
Box::pin(tokio::io::stdin()),
205209
Mutex::new(Box::pin(tokio::io::stdout())),
206210
IoStream::Writable(Box::pin(tokio::io::stderr())),
211+
pending_requests,
207212
self.options.timeout,
208213
shutdown_rx,
209214
);
@@ -234,12 +239,6 @@ where
234239
let mut lock = self.is_shut_down.lock().await;
235240
*lock = true
236241
}
237-
238-
let mut process = self.process.lock().await;
239-
if let Some(p) = process.as_mut() {
240-
p.kill().await?;
241-
p.wait().await?;
242-
}
243242
Ok(())
244243
}
245244
}

0 commit comments

Comments
 (0)