Skip to content

[mlir-lsp] Parse outgoing request callback JSON #90693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions mlir/include/mlir/Tools/lsp-server-support/Transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ using OutgoingRequest =

/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
/// client receives a response in turn. It is passed the original request's ID,
/// as well as the result JSON.
/// as well as the response result.
template <typename T>
using OutgoingRequestCallback =
std::function<void(llvm::json::Value, llvm::Expected<llvm::json::Value>)>;
std::function<void(llvm::json::Value, llvm::Expected<T>)>;

/// A handler used to process the incoming transport messages.
class MessageHandler {
Expand Down Expand Up @@ -185,21 +186,37 @@ class MessageHandler {

/// Create an OutgoingRequest function that, when called, sends a request with
/// the given method via the transport. Should the outgoing request be
/// met with a response, the response callback is invoked to handle that
/// response.
template <typename T>
OutgoingRequest<T> outgoingRequest(llvm::StringLiteral method,
OutgoingRequestCallback callback) {
return [&, method, callback](const T &params, llvm::json::Value id) {
/// met with a response, the result JSON is parsed and the response callback
/// is invoked.
template <typename Param, typename Result>
OutgoingRequest<Param>
outgoingRequest(llvm::StringLiteral method,
OutgoingRequestCallback<Result> callback) {
return [&, method, callback](const Param &param, llvm::json::Value id) {
auto callbackWrapper = [method, callback = std::move(callback)](
llvm::json::Value id,
llvm::Expected<llvm::json::Value> value) {
if (!value)
return callback(std::move(id), value.takeError());

std::string responseName = llvm::formatv("reply:{0}({1})", method, id);
llvm::Expected<Result> result =
parse<Result>(*value, responseName, "response");
if (!result)
return callback(std::move(id), result.takeError());

return callback(std::move(id), *result);
};

{
std::lock_guard<std::mutex> lock(responseHandlersMutex);
responseHandlers.insert(
{debugString(id), std::make_pair(method.str(), callback)});
{debugString(id), std::make_pair(method.str(), callbackWrapper)});
}

std::lock_guard<std::mutex> transportLock(transportOutputMutex);
Logger::info("--> {0}({1})", method, id);
transport.call(method, llvm::json::Value(params), id);
transport.call(method, llvm::json::Value(param), id);
};
}

Expand All @@ -213,7 +230,8 @@ class MessageHandler {

/// A pair of (1) the original request's method name, and (2) the callback
/// function to be invoked for responses.
using ResponseHandlerTy = std::pair<std::string, OutgoingRequestCallback>;
using ResponseHandlerTy =
std::pair<std::string, OutgoingRequestCallback<llvm::json::Value>>;
/// A mapping from request/response ID to response handler.
llvm::StringMap<ResponseHandlerTy> responseHandlers;
/// Mutex to guard insertion into the response handler map.
Expand Down
58 changes: 45 additions & 13 deletions mlir/unittests/Tools/lsp-server-support/Transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,17 @@ TEST_F(TransportInputTest, ResponseHandlerNotFound) {
TEST_F(TransportInputTest, OutgoingRequest) {
// Make some outgoing requests.
int responseCallbackInvoked = 0;
auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
"outgoing-request",
[&responseCallbackInvoked](llvm::json::Value id,
llvm::Expected<llvm::json::Value> value) {
// Make expectations on the expected response.
EXPECT_EQ(id, 83);
ASSERT_TRUE((bool)value);
EXPECT_EQ(debugString(*value), "{\"foo\":6}");
responseCallbackInvoked += 1;
llvm::outs() << "here!!!\n";
});
auto callFn =
getMessageHandler().outgoingRequest<CompletionList, CompletionContext>(
"outgoing-request",
[&responseCallbackInvoked](llvm::json::Value id,
llvm::Expected<CompletionContext> result) {
// Make expectations on the expected response.
EXPECT_EQ(id, 83);
ASSERT_TRUE((bool)result);
EXPECT_EQ(result->triggerKind, CompletionTriggerKind::Invoked);
responseCallbackInvoked += 1;
});
callFn({}, 82);
callFn({}, 83);
callFn({}, 84);
Expand All @@ -164,9 +164,41 @@ TEST_F(TransportInputTest, OutgoingRequest) {
// One of the requests receives a response. The message handler handles this
// response by invoking the callback from above. Subsequent responses with the
// same ID are ignored.
writeInput("{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"foo\":6}}\n"
writeInput(
"{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":1}}\n"
"// -----\n"
"{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":3}}\n");
runTransport();
EXPECT_EQ(responseCallbackInvoked, 1);
}

TEST_F(TransportInputTest, OutgoingRequestJSONParseFailure) {
// Make an outgoing request that expects a failure response.
bool responseCallbackInvoked = 0;
auto callFn = getMessageHandler().outgoingRequest<CompletionList, Position>(
"outgoing-request-json-parse-failure",
[&responseCallbackInvoked](llvm::json::Value id,
llvm::Expected<Position> result) {
llvm::Error err = result.takeError();
EXPECT_EQ(id, 109);
ASSERT_TRUE((bool)err);
EXPECT_THAT(debugString(err),
HasSubstr("failed to decode "
"reply:outgoing-request-json-parse-failure(109) "
"response: missing value at (root).character"));
llvm::consumeError(std::move(err));
responseCallbackInvoked += 1;
});
callFn({}, 109);
EXPECT_EQ(responseCallbackInvoked, 0);

// The request receives multiple responses, but only the first one triggers
// the response callback. The first response has erroneous JSON that causes a
// parse failure.
writeInput("{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":7}}\n"
"// -----\n"
"{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"bar\":8}}\n");
"{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":3,"
"\"character\":2}}\n");
runTransport();
EXPECT_EQ(responseCallbackInvoked, 1);
}
Expand Down