diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 2ac05880de86b..382f9a4731dd4 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -9,6 +9,8 @@ #ifndef LLDB_PROTOCOL_MCP_SERVER_H #define LLDB_PROTOCOL_MCP_SERVER_H +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" #include "lldb/Protocol/MCP/Tool.h" @@ -18,26 +20,52 @@ namespace lldb_protocol::mcp { -class Server { +class MCPTransport final + : public lldb_private::JSONRPCTransport { public: - Server(std::string name, std::string version); - virtual ~Server() = default; + using LogCallback = std::function; + + MCPTransport(lldb::IOObjectSP in, lldb::IOObjectSP out, + std::string client_name, LogCallback log_callback = {}) + : JSONRPCTransport(in, out), m_client_name(std::move(client_name)), + m_log_callback(log_callback) {} + virtual ~MCPTransport() = default; + + void Log(llvm::StringRef message) override { + if (m_log_callback) + m_log_callback(llvm::formatv("{0}: {1}", m_client_name, message).str()); + } + +private: + std::string m_client_name; + LogCallback m_log_callback; +}; + +class Server : public MCPTransport::MessageHandler { +public: + Server(std::string name, std::string version, + std::unique_ptr transport_up, + lldb_private::MainLoop &loop); + ~Server() = default; + + using NotificationHandler = std::function; void AddTool(std::unique_ptr tool); void AddResourceProvider(std::unique_ptr resource_provider); + void AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler); + + llvm::Error Run(); protected: - virtual Capabilities GetCapabilities() = 0; + Capabilities GetCapabilities(); using RequestHandler = std::function(const Request &)>; - using NotificationHandler = std::function; void AddRequestHandlers(); void AddRequestHandler(llvm::StringRef method, RequestHandler handler); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); llvm::Expected> HandleData(llvm::StringRef data); @@ -52,12 +80,23 @@ class Server { llvm::Expected ResourcesListHandler(const Request &); llvm::Expected ResourcesReadHandler(const Request &); + void Received(const Request &) override; + void Received(const Response &) override; + void Received(const Notification &) override; + void OnError(llvm::Error) override; + void OnClosed() override; + + void TerminateLoop(); + std::mutex m_mutex; private: const std::string m_name; const std::string m_version; + std::unique_ptr m_transport_up; + lldb_private::MainLoop &m_loop; + llvm::StringMap> m_tools; std::vector> m_resource_providers; diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index c359663239dcc..57132534cf680 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -26,24 +26,10 @@ using namespace llvm; LLDB_PLUGIN_DEFINE(ProtocolServerMCP) -static constexpr size_t kChunkSize = 1024; static constexpr llvm::StringLiteral kName = "lldb-mcp"; static constexpr llvm::StringLiteral kVersion = "0.1.0"; -ProtocolServerMCP::ProtocolServerMCP() - : ProtocolServer(), - lldb_protocol::mcp::Server(std::string(kName), std::string(kVersion)) { - AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); - - AddTool( - std::make_unique("lldb_command", "Run an lldb command.")); - - AddResourceProvider(std::make_unique()); -} +ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {} ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } @@ -64,57 +50,37 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { return "MCP Server."; } +void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { + server.AddNotificationHandler("notifications/initialized", + [](const lldb_protocol::mcp::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), + "MCP initialization complete"); + }); + server.AddTool( + std::make_unique("lldb_command", "Run an lldb command.")); + server.AddResourceProvider(std::make_unique()); +} + void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { - LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", - m_clients.size() + 1); + Log *log = GetLog(LLDBLog::Host); + std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1); + LLDB_LOG(log, "New MCP client connected: {0}", client_name); lldb::IOObjectSP io_sp = std::move(socket); - auto client_up = std::make_unique(); - client_up->io_sp = io_sp; - Client *client = client_up.get(); - - Status status; - auto read_handle_up = m_loop.RegisterReadObject( - io_sp, - [this, client](MainLoopBase &loop) { - if (llvm::Error error = ReadCallback(*client)) { - LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); - client->read_handle_up.reset(); - } - }, - status); - if (status.Fail()) + auto transport_up = std::make_unique( + io_sp, io_sp, std::move(client_name), [&](llvm::StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); + }); + auto instance_up = std::make_unique( + std::string(kName), std::string(kVersion), std::move(transport_up), + m_loop); + Extend(*instance_up); + llvm::Error error = instance_up->Run(); + if (error) { + LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}"); return; - - client_up->read_handle_up = std::move(read_handle_up); - m_clients.emplace_back(std::move(client_up)); -} - -llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { - char chunk[kChunkSize]; - size_t bytes_read = sizeof(chunk); - if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail()) - return status.takeError(); - client.buffer.append(chunk, bytes_read); - - for (std::string::size_type pos; - (pos = client.buffer.find('\n')) != std::string::npos;) { - llvm::Expected> message = - HandleData(StringRef(client.buffer.data(), pos)); - client.buffer = client.buffer.erase(0, pos + 1); - if (!message) - return message.takeError(); - - if (*message) { - std::string Output; - llvm::raw_string_ostream OS(Output); - OS << llvm::formatv("{0}", toJSON(**message)) << '\n'; - size_t num_bytes = Output.size(); - return client.io_sp->Write(Output.data(), num_bytes).takeError(); - } } - - return llvm::Error::success(); + m_instances.push_back(std::move(instance_up)); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { @@ -158,27 +124,11 @@ llvm::Error ProtocolServerMCP::Stop() { // Stop the main loop. m_loop.AddPendingCallback( - [](MainLoopBase &loop) { loop.RequestTermination(); }); + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); // Wait for the main loop to exit. if (m_loop_thread.joinable()) m_loop_thread.join(); - { - std::lock_guard guard(m_mutex); - m_listener.reset(); - m_listen_handlers.clear(); - m_clients.clear(); - } - return llvm::Error::success(); } - -lldb_protocol::mcp::Capabilities ProtocolServerMCP::GetCapabilities() { - lldb_protocol::mcp::Capabilities capabilities; - capabilities.tools.listChanged = true; - // FIXME: Support sending notifications when a debugger/target are - // added/removed. - capabilities.resources.listChanged = false; - return capabilities; -} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 7fe909a728b85..fc650ffe0dfa7 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -18,8 +18,7 @@ namespace lldb_private::mcp { -class ProtocolServerMCP : public ProtocolServer, - public lldb_protocol::mcp::Server { +class ProtocolServerMCP : public ProtocolServer { public: ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; @@ -39,26 +38,24 @@ class ProtocolServerMCP : public ProtocolServer, Socket *GetSocket() const override { return m_listener.get(); } +protected: + // This adds tools and resource providers that + // are specific to this server. Overridable by the unit tests. + virtual void Extend(lldb_protocol::mcp::Server &server) const; + private: void AcceptCallback(std::unique_ptr socket); - lldb_protocol::mcp::Capabilities GetCapabilities() override; - bool m_running = false; - MainLoop m_loop; + lldb_private::MainLoop m_loop; std::thread m_loop_thread; + std::mutex m_mutex; std::unique_ptr m_listener; - std::vector m_listen_handlers; - struct Client { - lldb::IOObjectSP io_sp; - MainLoopBase::ReadHandleUP read_handle_up; - std::string buffer; - }; - llvm::Error ReadCallback(Client &client); - std::vector> m_clients; + std::vector m_listen_handlers; + std::vector> m_instances; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index a9c1482e3e378..3713e8e46c5d6 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -12,8 +12,11 @@ using namespace lldb_protocol::mcp; using namespace llvm; -Server::Server(std::string name, std::string version) - : m_name(std::move(name)), m_version(std::move(version)) { +Server::Server(std::string name, std::string version, + std::unique_ptr transport_up, + lldb_private::MainLoop &loop) + : m_name(std::move(name)), m_version(std::move(version)), + m_transport_up(std::move(transport_up)), m_loop(loop) { AddRequestHandlers(); } @@ -232,3 +235,71 @@ llvm::Expected Server::ResourcesReadHandler(const Request &request) { llvm::formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } + +Capabilities Server::GetCapabilities() { + lldb_protocol::mcp::Capabilities capabilities; + capabilities.tools.listChanged = true; + // FIXME: Support sending notifications when a debugger/target are + // added/removed. + capabilities.resources.listChanged = false; + return capabilities; +} + +llvm::Error Server::Run() { + auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this); + if (!handle) + return handle.takeError(); + + lldb_private::Status status = m_loop.Run(); + if (status.Fail()) + return status.takeError(); + + return llvm::Error::success(); +} + +void Server::Received(const Request &request) { + auto SendResponse = [this](const Response &response) { + if (llvm::Error error = m_transport_up->Send(response)) + m_transport_up->Log(llvm::toString(std::move(error))); + }; + + llvm::Expected response = Handle(request); + if (response) + return SendResponse(*response); + + lldb_protocol::mcp::Error protocol_error; + llvm::handleAllErrors( + response.takeError(), + [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = MCPError::kInternalError; + protocol_error.message = err.message(); + }); + Response error_response; + error_response.id = request.id; + error_response.result = std::move(protocol_error); + SendResponse(error_response); +} + +void Server::Received(const Response &response) { + m_transport_up->Log("unexpected MCP message: response"); +} + +void Server::Received(const Notification ¬ification) { + Handle(notification); +} + +void Server::OnError(llvm::Error error) { + m_transport_up->Log(llvm::toString(std::move(error))); + TerminateLoop(); +} + +void Server::OnClosed() { + m_transport_up->Log("EOF"); + TerminateLoop(); +} + +void Server::TerminateLoop() { + m_loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); +} diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index 18112428950ce..83a42bfb6970c 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -39,12 +39,20 @@ using testing::_; namespace { class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { public: - using ProtocolServerMCP::AddNotificationHandler; - using ProtocolServerMCP::AddRequestHandler; - using ProtocolServerMCP::AddResourceProvider; - using ProtocolServerMCP::AddTool; using ProtocolServerMCP::GetSocket; using ProtocolServerMCP::ProtocolServerMCP; + + using ExtendCallback = + std::function; + + virtual void Extend(lldb_protocol::mcp::Server &server) const override { + if (m_extend_callback) + m_extend_callback(server); + }; + + void Extend(ExtendCallback callback) { m_extend_callback = callback; } + + ExtendCallback m_extend_callback; }; using Message = typename Transport::Message; @@ -183,8 +191,10 @@ class ProtocolServerMCPTest : public ::testing::Test { connection.protocol = Socket::SocketProtocol::ProtocolTcp; connection.name = llvm::formatv("{0}:0", k_localhost).str(); m_server_up = std::make_unique(); - m_server_up->AddTool(std::make_unique("test", "test tool")); - m_server_up->AddResourceProvider(std::make_unique()); + m_server_up->Extend([&](auto &server) { + server.AddTool(std::make_unique("test", "test tool")); + server.AddResourceProvider(std::make_unique()); + }); ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); // Connect to the server over a TCP socket. @@ -233,20 +243,10 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { test_tool.description = "test tool"; test_tool.inputSchema = json::Object{{"type", "object"}}; - ToolDefinition lldb_command_tool; - lldb_command_tool.description = "Run an lldb command."; - lldb_command_tool.name = "lldb_command"; - lldb_command_tool.inputSchema = json::Object{ - {"type", "object"}, - {"properties", - json::Object{{"arguments", json::Object{{"type", "string"}}}, - {"debugger_id", json::Object{{"type", "number"}}}}}, - {"required", json::Array{"debugger_id"}}}; Response response; response.id = "one"; response.result = json::Object{ - {"tools", - json::Array{std::move(test_tool), std::move(lldb_command_tool)}}, + {"tools", json::Array{std::move(test_tool)}}, }; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); @@ -281,7 +281,9 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { } TEST_F(ProtocolServerMCPTest, ToolsCallError) { - m_server_up->AddTool(std::make_unique("error", "error tool")); + m_server_up->Extend([&](auto &server) { + server.AddTool(std::make_unique("error", "error tool")); + }); llvm::StringLiteral request = R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; @@ -296,7 +298,9 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { - m_server_up->AddTool(std::make_unique("fail", "fail tool")); + m_server_up->Extend([&](auto &server) { + server.AddTool(std::make_unique("fail", "fail tool")); + }); llvm::StringLiteral request = R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; @@ -315,14 +319,16 @@ TEST_F(ProtocolServerMCPTest, NotificationInitialized) { std::condition_variable cv; std::mutex mutex; - m_server_up->AddNotificationHandler( - "notifications/initialized", [&](const Notification ¬ification) { - { - std::lock_guard lock(mutex); - handler_called = true; - } - cv.notify_all(); - }); + m_server_up->Extend([&](auto &server) { + server.AddNotificationHandler("notifications/initialized", + [&](const Notification ¬ification) { + { + std::lock_guard lock(mutex); + handler_called = true; + } + cv.notify_all(); + }); + }); llvm::StringLiteral request = R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json";