Skip to content

Commit 74278c6

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
support multimethod by method name runtime (#394)
Summary: Pull Request resolved: #394 Update runtime apis to support multimethod based on method name instead of method index Since the following diffs will reformat the AOT apis, this diff doesn't pay much effort on documentation stuff. Reviewed By: tarun292 Differential Revision: D49335926 fbshipit-source-id: 9052e2cb27081f9e5196c904aa9c099a5e884811
1 parent 7108f37 commit 74278c6

File tree

4 files changed

+45
-21
lines changed

4 files changed

+45
-21
lines changed

bundled_program/tests/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import executorch.exir as exir
1313
import torch
1414
from executorch.bundled_program.config import BundledConfig
15-
from executorch.exir import CaptureConfig
1615
from executorch.exir.schema import Program
1716

1817
# @manual=fbsource//third-party/pypi/typing-extensions:typing-extensions

sdk/runners/executor_runner.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ DEFINE_string(
7575

7676
DEFINE_int32(num_threads, 1, "Number of threads to use.");
7777

78+
DEFINE_string(
79+
method_name,
80+
"forward",
81+
"Name of method to run. Only used by bundled program mode.");
82+
7883
DEFINE_int32(
7984
testset_idx,
8085
0,
@@ -347,7 +352,7 @@ int main(int argc, char** argv) {
347352
*method,
348353
program_data.bundled_program_data(),
349354
&bundled_input_allocator,
350-
method_index,
355+
method_name,
351356
FLAGS_testset_idx);
352357
ET_CHECK_MSG(
353358
status == Error::Ok,
@@ -376,7 +381,7 @@ int main(int argc, char** argv) {
376381
*method,
377382
program_data.bundled_program_data(),
378383
&bundled_input_allocator,
379-
method_index,
384+
method_name,
380385
FLAGS_testset_idx,
381386
FLAGS_rtol,
382387
FLAGS_atol);

util/bundled_program_verification.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,28 +166,45 @@ bool tensors_are_close(
166166
}
167167
}
168168

169+
Result<executorch_flatbuffer::BundledExecutionPlanTest*> get_method_test(
170+
const executorch_flatbuffer::BundledProgram* bundled_program,
171+
const char* method_name) {
172+
auto method_tests = bundled_program->execution_plan_tests();
173+
for (size_t i = 0; i < method_tests->size(); i++) {
174+
auto m_test = method_tests->GetMutableObject(i);
175+
if (std::strcmp(m_test->method_name()->c_str(), method_name) == 0) {
176+
return m_test;
177+
}
178+
}
179+
ET_LOG(Error, "No method named '%s' in given bundled program", method_name);
180+
return Error::InvalidArgument;
181+
}
182+
169183
} // namespace
170184

171185
// Load testset_idx-th bundled data into the Method
172186
__ET_NODISCARD Error LoadBundledInput(
173187
Method& method,
174188
serialized_bundled_program* bundled_program_ptr,
175189
MemoryAllocator* memory_allocator,
176-
size_t method_idx,
190+
const char* method_name,
177191
size_t testset_idx) {
178192
ET_CHECK_OR_RETURN_ERROR(
179193
executorch_flatbuffer::BundledProgramBufferHasIdentifier(
180194
bundled_program_ptr),
181195
NotSupported,
182196
"The input buffer should be a bundled program.");
183197

198+
auto method_test = get_method_test(
199+
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr),
200+
method_name);
201+
202+
if (!method_test.ok()) {
203+
return method_test.error();
204+
}
205+
184206
auto bundled_inputs =
185-
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr)
186-
->execution_plan_tests()
187-
->Get(method_idx)
188-
->test_sets()
189-
->Get(testset_idx)
190-
->inputs();
207+
method_test.get()->test_sets()->Get(testset_idx)->inputs();
191208

192209
for (size_t input_idx = 0; input_idx < method.inputs_size(); input_idx++) {
193210
auto bundled_input = bundled_inputs->GetMutableObject(input_idx);
@@ -263,7 +280,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
263280
Method& method,
264281
serialized_bundled_program* bundled_program_ptr,
265282
MemoryAllocator* memory_allocator,
266-
size_t method_idx,
283+
const char* method_name,
267284
size_t testset_idx,
268285
double rtol,
269286
double atol) {
@@ -273,13 +290,16 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
273290
NotSupported,
274291
"The input buffer should be a bundled program.");
275292

293+
auto method_test = get_method_test(
294+
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr),
295+
method_name);
296+
297+
if (!method_test.ok()) {
298+
return method_test.error();
299+
}
300+
276301
auto bundled_expected_outputs =
277-
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr)
278-
->execution_plan_tests()
279-
->Get(method_idx)
280-
->test_sets()
281-
->Get(testset_idx)
282-
->expected_outputs();
302+
method_test.get()->test_sets()->Get(testset_idx)->expected_outputs();
283303

284304
for (size_t output_idx = 0; output_idx < method.outputs_size();
285305
output_idx++) {

util/bundled_program_verification.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using serialized_bundled_program = const void;
2626
*
2727
* @param[in] method The Method to verify.
2828
* @param[in] bundled_program_ptr The bundled program contains expected output.
29-
* @param[in] method_idx The index of the Method being verified.
29+
* @param[in] method_name The name of the Method being verified.
3030
* @param[in] testset_idx The index of input needs to be set into given Method.
3131
*
3232
* @returns Return Error::Ok if load successfully, or the error happens during
@@ -36,7 +36,7 @@ __ET_NODISCARD Error LoadBundledInput(
3636
Method& method,
3737
serialized_bundled_program* bundled_program_ptr,
3838
MemoryAllocator* memory_allocator,
39-
size_t method_idx,
39+
const char* method_name,
4040
size_t testset_idx);
4141

4242
/**
@@ -45,7 +45,7 @@ __ET_NODISCARD Error LoadBundledInput(
4545
*
4646
* @param[in] method The Method to extract outputs from.
4747
* @param[in] bundled_program_ptr The bundled program contains expected output.
48-
* @param[in] method_idx The index of the Method being verified.
48+
* @param[in] method_name The name of the Method being verified.
4949
* @param[in] testset_idx The index of expected output needs to be compared.
5050
* @param[in] rtol Relative tolerance used for data comparsion.
5151
* @param[in] atol Absolute tolerance used for data comparsion.
@@ -57,7 +57,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
5757
Method& method,
5858
serialized_bundled_program* bundled_program_ptr,
5959
MemoryAllocator* memory_allocator,
60-
size_t method_idx,
60+
const char* method_name,
6161
size_t testset_idx,
6262
double rtol = 1e-5,
6363
double atol = 1e-8);

0 commit comments

Comments
 (0)