Skip to content

Commit ef18b26

Browse files
committed
Don't use JSONTransport in ProtocolServerMCP
1 parent 56c23f6 commit ef18b26

File tree

3 files changed

+65
-40
lines changed

3 files changed

+65
-40
lines changed

lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ using namespace llvm;
2222

2323
LLDB_PLUGIN_DEFINE(ProtocolServerMCP)
2424

25+
static constexpr size_t kChunkSize = 1024;
26+
2527
ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
2628
: ProtocolServer(), m_debugger(debugger) {
2729
AddRequestHandler("initialize",
@@ -91,21 +93,52 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
9193
m_clients.size() + 1);
9294

9395
lldb::IOObjectSP io_sp = std::move(socket);
94-
auto transport_sp = std::make_shared<JSONRPCTransport>(io_sp, io_sp);
96+
auto client_up = std::make_unique<Client>();
97+
client_up->io_sp = io_sp;
98+
Client *client = client_up.get();
9599

96100
Status status;
97101
auto read_handle_up = m_loop.RegisterReadObject(
98102
io_sp,
99-
[=](MainLoopBase &) {
100-
if (llvm::Error err = HandleData(*transport_sp)) {
101-
LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(err), "{0}");
103+
[this, client](MainLoopBase &loop) {
104+
if (Error error = ReadCallback(*client)) {
105+
LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}");
106+
client->read_handle_up.reset();
102107
}
103108
},
104109
status);
105110
if (status.Fail())
106111
return;
107112

108-
m_clients.emplace_back(io_sp, std::move(read_handle_up));
113+
client_up->read_handle_up = std::move(read_handle_up);
114+
m_clients.emplace_back(std::move(client_up));
115+
}
116+
117+
llvm::Error ProtocolServerMCP::ReadCallback(Client &client) {
118+
char chunk[kChunkSize];
119+
size_t bytes_read = sizeof(chunk);
120+
if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail())
121+
return status.takeError();
122+
client.buffer.append(chunk, bytes_read);
123+
124+
for (std::string::size_type pos;
125+
(pos = client.buffer.find('\n')) != std::string::npos;) {
126+
llvm::Expected<std::optional<protocol::Message>> message =
127+
HandleData(StringRef(client.buffer.data(), pos));
128+
client.buffer = client.buffer.erase(0, pos + 1);
129+
if (!message)
130+
return message.takeError();
131+
132+
if (*message) {
133+
std::string Output;
134+
llvm::raw_string_ostream OS(Output);
135+
OS << llvm::formatv("{0}", toJSON(**message)) << '\n';
136+
size_t num_bytes = Output.size();
137+
return client.io_sp->Write(Output.data(), num_bytes).takeError();
138+
}
139+
}
140+
141+
return llvm::Error::success();
109142
}
110143

111144
llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
@@ -165,58 +198,43 @@ llvm::Error ProtocolServerMCP::Stop() {
165198
return llvm::Error::success();
166199
}
167200

168-
llvm::Error ProtocolServerMCP::HandleData(JSONTransport &transport) {
169-
llvm::Expected<protocol::Message> maybe_message =
170-
transport.Read<protocol::Message>(std::chrono::seconds(1));
171-
if (maybe_message.errorIsA<TransportEOFError>() ||
172-
maybe_message.errorIsA<TransportInvalidError>() ||
173-
maybe_message.errorIsA<TransportTimeoutError>()) {
174-
consumeError(maybe_message.takeError());
175-
return llvm::Error::success();
176-
}
177-
178-
if (llvm::Error err = maybe_message.takeError())
179-
return err;
201+
llvm::Expected<std::optional<protocol::Message>>
202+
ProtocolServerMCP::HandleData(llvm::StringRef data) {
203+
auto message = llvm::json::parse<protocol::Message>(/*JSON=*/data);
204+
if (!message)
205+
return message.takeError();
180206

181-
protocol::Message &message = *maybe_message;
182207
if (const protocol::Request *request =
183-
std::get_if<protocol::Request>(&message)) {
184-
llvm::Expected<protocol::Response> maybe_response = Handle(*request);
208+
std::get_if<protocol::Request>(&(*message))) {
209+
llvm::Expected<protocol::Response> response = Handle(*request);
185210

186-
// Handle failures.
187-
if (!maybe_response) {
211+
// Handle failures by converting them into an Error message.
212+
if (!response) {
188213
protocol::Error protocol_error;
189214
llvm::handleAllErrors(
190-
maybe_response.takeError(),
215+
response.takeError(),
191216
[&](const MCPError &err) { protocol_error = err.toProtcolError(); },
192217
[&](const llvm::ErrorInfoBase &err) {
193218
protocol_error.error.code = -1;
194219
protocol_error.error.message = err.message();
195220
});
196221
protocol_error.id = request->id;
197-
if (llvm::Error err = transport.Write(protocol_error))
198-
return err;
199-
200-
return llvm::Error::success();
222+
return protocol_error;
201223
}
202224

203-
// Handle success.
204-
if (llvm::Error err = transport.Write(*maybe_response))
205-
return err;
206-
207-
return llvm::Error::success();
225+
return *response;
208226
}
209227

210228
if (const protocol::Notification *notification =
211-
std::get_if<protocol::Notification>(&message)) {
229+
std::get_if<protocol::Notification>(&(*message))) {
212230
Handle(*notification);
213-
return llvm::Error::success();
231+
return std::nullopt;
214232
}
215233

216-
if (std::get_if<protocol::Error>(&message))
234+
if (std::get_if<protocol::Error>(&(*message)))
217235
return llvm::createStringError("unexpected MCP message: error");
218236

219-
if (std::get_if<protocol::Response>(&message))
237+
if (std::get_if<protocol::Response>(&(*message)))
220238
return llvm::createStringError("unexpected MCP message: response");
221239

222240
llvm_unreachable("all message types handled");

lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "Protocol.h"
1313
#include "Tool.h"
1414
#include "lldb/Core/ProtocolServer.h"
15-
#include "lldb/Host/JSONTransport.h"
1615
#include "lldb/Host/MainLoop.h"
1716
#include "lldb/Host/Socket.h"
1817
#include "llvm/ADT/StringMap.h"
@@ -54,7 +53,8 @@ class ProtocolServerMCP : public ProtocolServer {
5453
private:
5554
void AcceptCallback(std::unique_ptr<Socket> socket);
5655

57-
llvm::Error HandleData(JSONTransport &transport);
56+
llvm::Expected<std::optional<protocol::Message>>
57+
HandleData(llvm::StringRef data);
5858

5959
llvm::Expected<protocol::Response> Handle(protocol::Request request);
6060
void Handle(protocol::Notification notification);
@@ -80,8 +80,14 @@ class ProtocolServerMCP : public ProtocolServer {
8080

8181
std::unique_ptr<Socket> m_listener;
8282
std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers;
83-
std::vector<std::pair<lldb::IOObjectSP, MainLoopBase::ReadHandleUP>>
84-
m_clients;
83+
84+
struct Client {
85+
lldb::IOObjectSP io_sp;
86+
MainLoopBase::ReadHandleUP read_handle_up;
87+
std::string buffer;
88+
};
89+
llvm::Error ReadCallback(Client &client);
90+
std::vector<std::unique_ptr<Client>> m_clients;
8591

8692
std::mutex m_server_mutex;
8793
llvm::StringMap<std::unique_ptr<Tool>> m_tools;

lldb/unittests/Protocol/ProtocolMCPServerTest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "lldb/Core/ProtocolServer.h"
1414
#include "lldb/Host/FileSystem.h"
1515
#include "lldb/Host/HostInfo.h"
16+
#include "lldb/Host/JSONTransport.h"
1617
#include "lldb/Host/Socket.h"
1718
#include "llvm/Testing/Support/Error.h"
1819
#include "gtest/gtest.h"

0 commit comments

Comments
 (0)