Skip to content

Commit 11bda17

Browse files
authored
[mlir-lsp] Parse outgoing request callback JSON (#90693)
Rather than force callbacks for outgoing requests to parse the result JSON themselves (of type `llvm::Expected<llvm::json::Value>`), allow users to specify the result type, which `MessageHandler::outgoingRequest` will parse for them. This eliminates boilerplate for users sending outgoing requests.
1 parent a2e1f54 commit 11bda17

File tree

2 files changed

+74
-24
lines changed

2 files changed

+74
-24
lines changed

mlir/include/mlir/Tools/lsp-server-support/Transport.h

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,10 @@ using OutgoingRequest =
109109

110110
/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
111111
/// client receives a response in turn. It is passed the original request's ID,
112-
/// as well as the result JSON.
112+
/// as well as the response result.
113+
template <typename T>
113114
using OutgoingRequestCallback =
114-
std::function<void(llvm::json::Value, llvm::Expected<llvm::json::Value>)>;
115+
std::function<void(llvm::json::Value, llvm::Expected<T>)>;
115116

116117
/// A handler used to process the incoming transport messages.
117118
class MessageHandler {
@@ -185,21 +186,37 @@ class MessageHandler {
185186

186187
/// Create an OutgoingRequest function that, when called, sends a request with
187188
/// the given method via the transport. Should the outgoing request be
188-
/// met with a response, the response callback is invoked to handle that
189-
/// response.
190-
template <typename T>
191-
OutgoingRequest<T> outgoingRequest(llvm::StringLiteral method,
192-
OutgoingRequestCallback callback) {
193-
return [&, method, callback](const T &params, llvm::json::Value id) {
189+
/// met with a response, the result JSON is parsed and the response callback
190+
/// is invoked.
191+
template <typename Param, typename Result>
192+
OutgoingRequest<Param>
193+
outgoingRequest(llvm::StringLiteral method,
194+
OutgoingRequestCallback<Result> callback) {
195+
return [&, method, callback](const Param &param, llvm::json::Value id) {
196+
auto callbackWrapper = [method, callback = std::move(callback)](
197+
llvm::json::Value id,
198+
llvm::Expected<llvm::json::Value> value) {
199+
if (!value)
200+
return callback(std::move(id), value.takeError());
201+
202+
std::string responseName = llvm::formatv("reply:{0}({1})", method, id);
203+
llvm::Expected<Result> result =
204+
parse<Result>(*value, responseName, "response");
205+
if (!result)
206+
return callback(std::move(id), result.takeError());
207+
208+
return callback(std::move(id), *result);
209+
};
210+
194211
{
195212
std::lock_guard<std::mutex> lock(responseHandlersMutex);
196213
responseHandlers.insert(
197-
{debugString(id), std::make_pair(method.str(), callback)});
214+
{debugString(id), std::make_pair(method.str(), callbackWrapper)});
198215
}
199216

200217
std::lock_guard<std::mutex> transportLock(transportOutputMutex);
201218
Logger::info("--> {0}({1})", method, id);
202-
transport.call(method, llvm::json::Value(params), id);
219+
transport.call(method, llvm::json::Value(param), id);
203220
};
204221
}
205222

@@ -213,7 +230,8 @@ class MessageHandler {
213230

214231
/// A pair of (1) the original request's method name, and (2) the callback
215232
/// function to be invoked for responses.
216-
using ResponseHandlerTy = std::pair<std::string, OutgoingRequestCallback>;
233+
using ResponseHandlerTy =
234+
std::pair<std::string, OutgoingRequestCallback<llvm::json::Value>>;
217235
/// A mapping from request/response ID to response handler.
218236
llvm::StringMap<ResponseHandlerTy> responseHandlers;
219237
/// Mutex to guard insertion into the response handler map.

mlir/unittests/Tools/lsp-server-support/Transport.cpp

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,17 @@ TEST_F(TransportInputTest, ResponseHandlerNotFound) {
144144
TEST_F(TransportInputTest, OutgoingRequest) {
145145
// Make some outgoing requests.
146146
int responseCallbackInvoked = 0;
147-
auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
148-
"outgoing-request",
149-
[&responseCallbackInvoked](llvm::json::Value id,
150-
llvm::Expected<llvm::json::Value> value) {
151-
// Make expectations on the expected response.
152-
EXPECT_EQ(id, 83);
153-
ASSERT_TRUE((bool)value);
154-
EXPECT_EQ(debugString(*value), "{\"foo\":6}");
155-
responseCallbackInvoked += 1;
156-
llvm::outs() << "here!!!\n";
157-
});
147+
auto callFn =
148+
getMessageHandler().outgoingRequest<CompletionList, CompletionContext>(
149+
"outgoing-request",
150+
[&responseCallbackInvoked](llvm::json::Value id,
151+
llvm::Expected<CompletionContext> result) {
152+
// Make expectations on the expected response.
153+
EXPECT_EQ(id, 83);
154+
ASSERT_TRUE((bool)result);
155+
EXPECT_EQ(result->triggerKind, CompletionTriggerKind::Invoked);
156+
responseCallbackInvoked += 1;
157+
});
158158
callFn({}, 82);
159159
callFn({}, 83);
160160
callFn({}, 84);
@@ -164,9 +164,41 @@ TEST_F(TransportInputTest, OutgoingRequest) {
164164
// One of the requests receives a response. The message handler handles this
165165
// response by invoking the callback from above. Subsequent responses with the
166166
// same ID are ignored.
167-
writeInput("{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"foo\":6}}\n"
167+
writeInput(
168+
"{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":1}}\n"
169+
"// -----\n"
170+
"{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":3}}\n");
171+
runTransport();
172+
EXPECT_EQ(responseCallbackInvoked, 1);
173+
}
174+
175+
TEST_F(TransportInputTest, OutgoingRequestJSONParseFailure) {
176+
// Make an outgoing request that expects a failure response.
177+
bool responseCallbackInvoked = 0;
178+
auto callFn = getMessageHandler().outgoingRequest<CompletionList, Position>(
179+
"outgoing-request-json-parse-failure",
180+
[&responseCallbackInvoked](llvm::json::Value id,
181+
llvm::Expected<Position> result) {
182+
llvm::Error err = result.takeError();
183+
EXPECT_EQ(id, 109);
184+
ASSERT_TRUE((bool)err);
185+
EXPECT_THAT(debugString(err),
186+
HasSubstr("failed to decode "
187+
"reply:outgoing-request-json-parse-failure(109) "
188+
"response: missing value at (root).character"));
189+
llvm::consumeError(std::move(err));
190+
responseCallbackInvoked += 1;
191+
});
192+
callFn({}, 109);
193+
EXPECT_EQ(responseCallbackInvoked, 0);
194+
195+
// The request receives multiple responses, but only the first one triggers
196+
// the response callback. The first response has erroneous JSON that causes a
197+
// parse failure.
198+
writeInput("{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":7}}\n"
168199
"// -----\n"
169-
"{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"bar\":8}}\n");
200+
"{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":3,"
201+
"\"character\":2}}\n");
170202
runTransport();
171203
EXPECT_EQ(responseCallbackInvoked, 1);
172204
}

0 commit comments

Comments
 (0)