Skip to content

Commit e8abdfc

Browse files
authored
[lldb] Make MCP server instance global (#145616)
Rather than having one MCP server per debugger, make the MCP server global and pass a debugger id along with tool invocations that require one. This PR also adds a second tool to list the available debuggers with their targets so the model can decide which debugger instance to use.
1 parent 2db0289 commit e8abdfc

File tree

13 files changed

+180
-136
lines changed

13 files changed

+180
-136
lines changed

lldb/include/lldb/Core/Debugger.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -602,10 +602,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>,
602602
void FlushProcessOutput(Process &process, bool flush_stdout,
603603
bool flush_stderr);
604604

605-
void AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp);
606-
void RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp);
607-
lldb::ProtocolServerSP GetProtocolServer(llvm::StringRef protocol) const;
608-
609605
SourceManager::SourceFileCache &GetSourceFileCache() {
610606
return m_source_file_cache;
611607
}
@@ -776,8 +772,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>,
776772
mutable std::mutex m_progress_reports_mutex;
777773
/// @}
778774

779-
llvm::SmallVector<lldb::ProtocolServerSP> m_protocol_servers;
780-
781775
std::mutex m_destroy_callback_mutex;
782776
lldb::callback_token_t m_destroy_callback_next_token = 0;
783777
struct DestroyCallbackInfo {

lldb/include/lldb/Core/ProtocolServer.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ class ProtocolServer : public PluginInterface {
2020
ProtocolServer() = default;
2121
virtual ~ProtocolServer() = default;
2222

23-
static lldb::ProtocolServerSP Create(llvm::StringRef name,
24-
Debugger &debugger);
23+
static ProtocolServer *GetOrCreate(llvm::StringRef name);
24+
25+
static std::vector<llvm::StringRef> GetSupportedProtocols();
2526

2627
struct Connection {
2728
Socket::SocketProtocol protocol;

lldb/include/lldb/lldb-forward.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ typedef std::shared_ptr<lldb_private::Platform> PlatformSP;
391391
typedef std::shared_ptr<lldb_private::Process> ProcessSP;
392392
typedef std::shared_ptr<lldb_private::ProcessAttachInfo> ProcessAttachInfoSP;
393393
typedef std::shared_ptr<lldb_private::ProcessLaunchInfo> ProcessLaunchInfoSP;
394-
typedef std::shared_ptr<lldb_private::ProtocolServer> ProtocolServerSP;
394+
typedef std::unique_ptr<lldb_private::ProtocolServer> ProtocolServerUP;
395395
typedef std::weak_ptr<lldb_private::Process> ProcessWP;
396396
typedef std::shared_ptr<lldb_private::RegisterCheckpoint> RegisterCheckpointSP;
397397
typedef std::shared_ptr<lldb_private::RegisterContext> RegisterContextSP;

lldb/include/lldb/lldb-private-interfaces.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force,
8181
typedef lldb::ProcessSP (*ProcessCreateInstance)(
8282
lldb::TargetSP target_sp, lldb::ListenerSP listener_sp,
8383
const FileSpec *crash_file_path, bool can_connect);
84-
typedef lldb::ProtocolServerSP (*ProtocolServerCreateInstance)(
85-
Debugger &debugger);
84+
typedef lldb::ProtocolServerUP (*ProtocolServerCreateInstance)();
8685
typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)(
8786
Target &target);
8887
typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)(

lldb/source/Commands/CommandObjectProtocolServer.cpp

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,6 @@ using namespace lldb_private;
2323
#define LLDB_OPTIONS_mcp
2424
#include "CommandOptions.inc"
2525

26-
static std::vector<llvm::StringRef> GetSupportedProtocols() {
27-
std::vector<llvm::StringRef> supported_protocols;
28-
size_t i = 0;
29-
30-
for (llvm::StringRef protocol_name =
31-
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
32-
!protocol_name.empty();
33-
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
34-
supported_protocols.push_back(protocol_name);
35-
}
36-
37-
return supported_protocols;
38-
}
39-
4026
class CommandObjectProtocolServerStart : public CommandObjectParsed {
4127
public:
4228
CommandObjectProtocolServerStart(CommandInterpreter &interpreter)
@@ -57,12 +43,11 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
5743
}
5844

5945
llvm::StringRef protocol = args.GetArgumentAtIndex(0);
60-
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
61-
if (llvm::find(supported_protocols, protocol) ==
62-
supported_protocols.end()) {
46+
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
47+
if (!server) {
6348
result.AppendErrorWithFormatv(
6449
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
65-
llvm::join(GetSupportedProtocols(), ", "));
50+
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
6651
return;
6752
}
6853

@@ -72,10 +57,6 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
7257
}
7358
llvm::StringRef connection_uri = args.GetArgumentAtIndex(1);
7459

75-
ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol);
76-
if (!server_sp)
77-
server_sp = ProtocolServer::Create(protocol, GetDebugger());
78-
7960
const char *connection_error =
8061
"unsupported connection specifier, expected 'accept:///path' or "
8162
"'listen://[host]:port', got '{0}'.";
@@ -98,14 +79,12 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
9879
formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname,
9980
uri->port.value_or(0));
10081

101-
if (llvm::Error error = server_sp->Start(connection)) {
82+
if (llvm::Error error = server->Start(connection)) {
10283
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
10384
return;
10485
}
10586

106-
GetDebugger().AddProtocolServer(server_sp);
107-
108-
if (Socket *socket = server_sp->GetSocket()) {
87+
if (Socket *socket = server->GetSocket()) {
10988
std::string address =
11089
llvm::join(socket->GetListeningConnectionURI(), ", ");
11190
result.AppendMessageWithFormatv(
@@ -134,30 +113,18 @@ class CommandObjectProtocolServerStop : public CommandObjectParsed {
134113
}
135114

136115
llvm::StringRef protocol = args.GetArgumentAtIndex(0);
137-
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
138-
if (llvm::find(supported_protocols, protocol) ==
139-
supported_protocols.end()) {
116+
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
117+
if (!server) {
140118
result.AppendErrorWithFormatv(
141119
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
142-
llvm::join(GetSupportedProtocols(), ", "));
120+
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
143121
return;
144122
}
145123

146-
Debugger &debugger = GetDebugger();
147-
148-
ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol);
149-
if (!server_sp) {
150-
result.AppendError(
151-
llvm::formatv("no {0} protocol server running", protocol).str());
152-
return;
153-
}
154-
155-
if (llvm::Error error = server_sp->Stop()) {
124+
if (llvm::Error error = server->Stop()) {
156125
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
157126
return;
158127
}
159-
160-
debugger.RemoveProtocolServer(server_sp);
161128
}
162129
};
163130

lldb/source/Core/Debugger.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,26 +2376,3 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() {
23762376
"Debugger::GetThreadPool called before Debugger::Initialize");
23772377
return *g_thread_pool;
23782378
}
2379-
2380-
void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
2381-
assert(protocol_server_sp &&
2382-
GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr);
2383-
m_protocol_servers.push_back(protocol_server_sp);
2384-
}
2385-
2386-
void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
2387-
auto it = llvm::find(m_protocol_servers, protocol_server_sp);
2388-
if (it != m_protocol_servers.end())
2389-
m_protocol_servers.erase(it);
2390-
}
2391-
2392-
lldb::ProtocolServerSP
2393-
Debugger::GetProtocolServer(llvm::StringRef protocol) const {
2394-
for (ProtocolServerSP protocol_server_sp : m_protocol_servers) {
2395-
if (!protocol_server_sp)
2396-
continue;
2397-
if (protocol_server_sp->GetPluginName() == protocol)
2398-
return protocol_server_sp;
2399-
}
2400-
return nullptr;
2401-
}

lldb/source/Core/ProtocolServer.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,36 @@
1212
using namespace lldb_private;
1313
using namespace lldb;
1414

15-
ProtocolServerSP ProtocolServer::Create(llvm::StringRef name,
16-
Debugger &debugger) {
15+
ProtocolServer *ProtocolServer::GetOrCreate(llvm::StringRef name) {
16+
static std::mutex g_mutex;
17+
static llvm::StringMap<ProtocolServerUP> g_protocol_server_instances;
18+
19+
std::lock_guard<std::mutex> guard(g_mutex);
20+
21+
auto it = g_protocol_server_instances.find(name);
22+
if (it != g_protocol_server_instances.end())
23+
return it->second.get();
24+
1725
if (ProtocolServerCreateInstance create_callback =
18-
PluginManager::GetProtocolCreateCallbackForPluginName(name))
19-
return create_callback(debugger);
26+
PluginManager::GetProtocolCreateCallbackForPluginName(name)) {
27+
auto pair =
28+
g_protocol_server_instances.try_emplace(name, create_callback());
29+
return pair.first->second.get();
30+
}
31+
2032
return nullptr;
2133
}
34+
35+
std::vector<llvm::StringRef> ProtocolServer::GetSupportedProtocols() {
36+
std::vector<llvm::StringRef> supported_protocols;
37+
size_t i = 0;
38+
39+
for (llvm::StringRef protocol_name =
40+
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
41+
!protocol_name.empty();
42+
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
43+
supported_protocols.push_back(protocol_name);
44+
}
45+
46+
return supported_protocols;
47+
}

lldb/source/Plugins/Protocol/MCP/Protocol.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ using Message = std::variant<Request, Response, Notification, Error>;
123123
bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path);
124124
llvm::json::Value toJSON(const Message &);
125125

126+
using ToolArguments = std::variant<std::monostate, llvm::json::Value>;
127+
126128
} // namespace lldb_private::mcp::protocol
127129

128130
#endif

lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ LLDB_PLUGIN_DEFINE(ProtocolServerMCP)
2424

2525
static constexpr size_t kChunkSize = 1024;
2626

27-
ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
28-
: ProtocolServer(), m_debugger(debugger) {
27+
ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {
2928
AddRequestHandler("initialize",
3029
std::bind(&ProtocolServerMCP::InitializeHandler, this,
3130
std::placeholders::_1));
@@ -39,8 +38,10 @@ ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
3938
"notifications/initialized", [](const protocol::Notification &) {
4039
LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete");
4140
});
42-
AddTool(std::make_unique<LLDBCommandTool>(
43-
"lldb_command", "Run an lldb command.", m_debugger));
41+
AddTool(
42+
std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));
43+
AddTool(std::make_unique<DebuggerListTool>(
44+
"lldb_debugger_list", "List debugger instances with their debugger_id."));
4445
}
4546

4647
ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); }
@@ -54,8 +55,8 @@ void ProtocolServerMCP::Terminate() {
5455
PluginManager::UnregisterPlugin(CreateInstance);
5556
}
5657

57-
lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) {
58-
return std::make_shared<ProtocolServerMCP>(debugger);
58+
lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() {
59+
return std::make_unique<ProtocolServerMCP>();
5960
}
6061

6162
llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
@@ -145,7 +146,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
145146
std::lock_guard<std::mutex> guard(m_server_mutex);
146147

147148
if (m_running)
148-
return llvm::createStringError("server already running");
149+
return llvm::createStringError("the MCP server is already running");
149150

150151
Status status;
151152
m_listener = Socket::Create(connection.protocol, status);
@@ -162,10 +163,10 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
162163
if (llvm::Error error = handles.takeError())
163164
return error;
164165

166+
m_running = true;
165167
m_listen_handlers = std::move(*handles);
166168
m_loop_thread = std::thread([=] {
167-
llvm::set_thread_name(
168-
llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID()));
169+
llvm::set_thread_name("protocol-server.mcp");
169170
m_loop.Run();
170171
});
171172

@@ -175,6 +176,8 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
175176
llvm::Error ProtocolServerMCP::Stop() {
176177
{
177178
std::lock_guard<std::mutex> guard(m_server_mutex);
179+
if (!m_running)
180+
return createStringError("the MCP sever is not running");
178181
m_running = false;
179182
}
180183

@@ -311,11 +314,12 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) {
311314
if (it == m_tools.end())
312315
return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name));
313316

314-
const json::Value *args = param_obj->get("arguments");
315-
if (!args)
316-
return llvm::createStringError("no tool arguments");
317+
protocol::ToolArguments tool_args;
318+
if (const json::Value *args = param_obj->get("arguments"))
319+
tool_args = *args;
317320

318-
llvm::Expected<protocol::TextResult> text_result = it->second->Call(*args);
321+
llvm::Expected<protocol::TextResult> text_result =
322+
it->second->Call(tool_args);
319323
if (!text_result)
320324
return text_result.takeError();
321325

lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace lldb_private::mcp {
2121

2222
class ProtocolServerMCP : public ProtocolServer {
2323
public:
24-
ProtocolServerMCP(Debugger &debugger);
24+
ProtocolServerMCP();
2525
virtual ~ProtocolServerMCP() override;
2626

2727
virtual llvm::Error Start(ProtocolServer::Connection connection) override;
@@ -33,7 +33,7 @@ class ProtocolServerMCP : public ProtocolServer {
3333
static llvm::StringRef GetPluginNameStatic() { return "MCP"; }
3434
static llvm::StringRef GetPluginDescriptionStatic();
3535

36-
static lldb::ProtocolServerSP CreateInstance(Debugger &debugger);
36+
static lldb::ProtocolServerUP CreateInstance();
3737

3838
llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); }
3939

@@ -71,8 +71,6 @@ class ProtocolServerMCP : public ProtocolServer {
7171
llvm::StringLiteral kName = "lldb-mcp";
7272
llvm::StringLiteral kVersion = "0.1.0";
7373

74-
Debugger &m_debugger;
75-
7674
bool m_running = false;
7775

7876
MainLoop m_loop;

0 commit comments

Comments
 (0)