Skip to content

[mlir-lsp] Parse outgoing request callback JSON #90693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2024
Merged

Conversation

modocache
Copy link
Member

Rather than force callbacks for outgoing requests to parse the result JSON themselves (of type llvm::Expected<llvm::json::Value>), allow users to specify the result type, which
MessageHandler::outgoingRequest will parse for them. This eliminates boilerplate for users sending outgoing requests.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels May 1, 2024
@llvmbot
Copy link
Member

llvmbot commented May 1, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Brian Gesiak (modocache)

Changes

Rather than force callbacks for outgoing requests to parse the result JSON themselves (of type llvm::Expected&lt;llvm::json::Value&gt;), allow users to specify the result type, which
MessageHandler::outgoingRequest will parse for them. This eliminates boilerplate for users sending outgoing requests.


Full diff: https://github.com/llvm/llvm-project/pull/90693.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Tools/lsp-server-support/Transport.h (+31-11)
  • (modified) mlir/unittests/Tools/lsp-server-support/Transport.cpp (+45-13)
diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
index 047d174234df8d..b2979be60eacc8 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
@@ -109,9 +109,10 @@ using OutgoingRequest =
 
 /// An `OutgoingRequestCallback` is invoked when an outgoing request to the
 /// client receives a response in turn. It is passed the original request's ID,
-/// as well as the result JSON.
+/// as well as the response result.
+template <typename T>
 using OutgoingRequestCallback =
-    std::function<void(llvm::json::Value, llvm::Expected<llvm::json::Value>)>;
+    std::function<void(llvm::json::Value, llvm::Expected<T>)>;
 
 /// A handler used to process the incoming transport messages.
 class MessageHandler {
@@ -185,21 +186,37 @@ class MessageHandler {
 
   /// Create an OutgoingRequest function that, when called, sends a request with
   /// the given method via the transport. Should the outgoing request be
-  /// met with a response, the response callback is invoked to handle that
-  /// response.
-  template <typename T>
-  OutgoingRequest<T> outgoingRequest(llvm::StringLiteral method,
-                                     OutgoingRequestCallback callback) {
-    return [&, method, callback](const T &params, llvm::json::Value id) {
+  /// met with a response, the result JSON is parsed and the response callback
+  /// is invoked.
+  template <typename Param, typename Result>
+  OutgoingRequest<Param>
+  outgoingRequest(llvm::StringLiteral method,
+                  OutgoingRequestCallback<Result> callback) {
+    return [&, method, callback](const Param &param, llvm::json::Value id) {
+      auto callbackWrapper = [method, callback = std::move(callback)](
+                                 llvm::json::Value id,
+                                 llvm::Expected<llvm::json::Value> value) {
+        if (!value)
+          return callback(std::move(id), value.takeError());
+
+        std::string responseName = llvm::formatv("reply:{0}({1})", method, id);
+        llvm::Expected<Result> result =
+            parse<Result>(*value, responseName, "response");
+        if (!result)
+          return callback(std::move(id), result.takeError());
+
+        return callback(std::move(id), *result);
+      };
+
       {
         std::lock_guard<std::mutex> lock(responseHandlersMutex);
         responseHandlers.insert(
-            {debugString(id), std::make_pair(method.str(), callback)});
+            {debugString(id), std::make_pair(method.str(), callbackWrapper)});
       }
 
       std::lock_guard<std::mutex> transportLock(transportOutputMutex);
       Logger::info("--> {0}({1})", method, id);
-      transport.call(method, llvm::json::Value(params), id);
+      transport.call(method, llvm::json::Value(param), id);
     };
   }
 
@@ -213,7 +230,10 @@ class MessageHandler {
 
   /// A pair of (1) the original request's method name, and (2) the callback
   /// function to be invoked for responses.
-  using ResponseHandlerTy = std::pair<std::string, OutgoingRequestCallback>;
+  using ResponseHandlerTy =
+      std::pair<std::string,
+                std::function<void(llvm::json::Value,
+                                   llvm::Expected<llvm::json::Value>)>>;
   /// A mapping from request/response ID to response handler.
   llvm::StringMap<ResponseHandlerTy> responseHandlers;
   /// Mutex to guard insertion into the response handler map.
diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
index fee21840595232..0303c1cba8bc87 100644
--- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp
+++ b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
@@ -144,17 +144,17 @@ TEST_F(TransportInputTest, ResponseHandlerNotFound) {
 TEST_F(TransportInputTest, OutgoingRequest) {
   // Make some outgoing requests.
   int responseCallbackInvoked = 0;
-  auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
-      "outgoing-request",
-      [&responseCallbackInvoked](llvm::json::Value id,
-                                 llvm::Expected<llvm::json::Value> value) {
-        // Make expectations on the expected response.
-        EXPECT_EQ(id, 83);
-        ASSERT_TRUE((bool)value);
-        EXPECT_EQ(debugString(*value), "{\"foo\":6}");
-        responseCallbackInvoked += 1;
-        llvm::outs() << "here!!!\n";
-      });
+  auto callFn =
+      getMessageHandler().outgoingRequest<CompletionList, CompletionContext>(
+          "outgoing-request",
+          [&responseCallbackInvoked](llvm::json::Value id,
+                                     llvm::Expected<CompletionContext> result) {
+            // Make expectations on the expected response.
+            EXPECT_EQ(id, 83);
+            ASSERT_TRUE((bool)result);
+            EXPECT_EQ(result->triggerKind, CompletionTriggerKind::Invoked);
+            responseCallbackInvoked += 1;
+          });
   callFn({}, 82);
   callFn({}, 83);
   callFn({}, 84);
@@ -164,9 +164,41 @@ TEST_F(TransportInputTest, OutgoingRequest) {
   // One of the requests receives a response. The message handler handles this
   // response by invoking the callback from above. Subsequent responses with the
   // same ID are ignored.
-  writeInput("{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"foo\":6}}\n"
+  writeInput(
+      "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":1}}\n"
+      "// -----\n"
+      "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":3}}\n");
+  runTransport();
+  EXPECT_EQ(responseCallbackInvoked, 1);
+}
+
+TEST_F(TransportInputTest, OutgoingRequestJSONParseFailure) {
+  // Make an outgoing request that expects a failure response.
+  bool responseCallbackInvoked = 0;
+  auto callFn = getMessageHandler().outgoingRequest<CompletionList, Position>(
+      "outgoing-request-json-parse-failure",
+      [&responseCallbackInvoked](llvm::json::Value id,
+                                 llvm::Expected<Position> result) {
+        llvm::Error err = result.takeError();
+        EXPECT_EQ(id, 109);
+        ASSERT_TRUE((bool)err);
+        EXPECT_THAT(debugString(err),
+                    HasSubstr("failed to decode "
+                              "reply:outgoing-request-json-parse-failure(109) "
+                              "response: missing value at (root).character"));
+        llvm::consumeError(std::move(err));
+        responseCallbackInvoked += 1;
+      });
+  callFn({}, 109);
+  EXPECT_EQ(responseCallbackInvoked, 0);
+
+  // The request receives multiple responses, but only the first one triggers
+  // the response callback. The first response has erroneous JSON that causes a
+  // parse failure.
+  writeInput("{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":7}}\n"
              "// -----\n"
-             "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"bar\":8}}\n");
+             "{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":3,"
+             "\"character\":2}}\n");
   runTransport();
   EXPECT_EQ(responseCallbackInvoked, 1);
 }

Rather than force callbacks for outgoing requests to parse the result
JSON themselves (of type `llvm::Expected<llvm::json::Value>`), allow
users to specify the result type, which
`MessageHandler::outgoingRequest` will parse for them. This eliminates
boilerplate for users sending outgoing requests.
@modocache modocache merged commit 11bda17 into llvm:main May 2, 2024
@modocache modocache deleted the lsp-1 branch May 2, 2024 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants