Skip to content

Commit beb27bf

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
support multimethod by method name runtime
Differential Revision: https://internalfb.com/D49335926 fbshipit-source-id: 2eb55cd9ee40d82a4ac33aaa10950c9044192c80
1 parent 4e5601a commit beb27bf

File tree

4 files changed

+45
-21
lines changed

4 files changed

+45
-21
lines changed

bundled_program/tests/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import executorch.exir as exir
1313
import torch
1414
from executorch.bundled_program.config import BundledConfig
15-
from executorch.exir import CaptureConfig
1615
from executorch.exir.schema import Program
1716

1817
# @manual=fbsource//third-party/pypi/typing-extensions:typing-extensions

sdk/runners/executor_runner.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ DEFINE_string(
7575

7676
DEFINE_int32(num_threads, 1, "Number of threads to use.");
7777

78+
DEFINE_string(
79+
method_name,
80+
"forward",
81+
"Name of method to run. Only used by bundled program mode.");
82+
7883
DEFINE_int32(
7984
testset_idx,
8085
0,
@@ -347,7 +352,7 @@ int main(int argc, char** argv) {
347352
*method,
348353
program_data.bundled_program_data(),
349354
&bundled_input_allocator,
350-
method_index,
355+
method_name,
351356
FLAGS_testset_idx);
352357
ET_CHECK_MSG(
353358
status == Error::Ok,
@@ -376,7 +381,7 @@ int main(int argc, char** argv) {
376381
*method,
377382
program_data.bundled_program_data(),
378383
&bundled_input_allocator,
379-
method_index,
384+
method_name,
380385
FLAGS_testset_idx,
381386
FLAGS_rtol,
382387
FLAGS_atol);

util/bundled_program_verification.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,28 +166,45 @@ bool tensors_are_close(
166166
}
167167
}
168168

169+
Result<executorch_flatbuffer::BundledExecutionPlanTest*> get_method_test(
170+
const executorch_flatbuffer::BundledProgram* bundled_program,
171+
const char* method_name) {
172+
auto method_tests = bundled_program->execution_plan_tests();
173+
for (size_t i = 0; i < method_tests->size(); i++) {
174+
auto m_test = method_tests->GetMutableObject(i);
175+
if (std::strcmp(m_test->method_name()->c_str(), method_name) == 0) {
176+
return m_test;
177+
}
178+
}
179+
ET_LOG(Error, "No method named '%s' in given bundled program", method_name);
180+
return Error::InvalidArgument;
181+
}
182+
169183
} // namespace
170184

171185
// Load testset_idx-th bundled data into the Method
172186
__ET_NODISCARD Error LoadBundledInput(
173187
Method& method,
174188
serialized_bundled_program* bundled_program_ptr,
175189
MemoryAllocator* memory_allocator,
176-
size_t method_idx,
190+
const char* method_name,
177191
size_t testset_idx) {
178192
ET_CHECK_OR_RETURN_ERROR(
179193
executorch_flatbuffer::BundledProgramBufferHasIdentifier(
180194
bundled_program_ptr),
181195
NotSupported,
182196
"The input buffer should be a bundled program.");
183197

198+
auto method_test = get_method_test(
199+
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr),
200+
method_name);
201+
202+
if (!method_test.ok()) {
203+
return method_test.error();
204+
}
205+
184206
auto bundled_inputs =
185-
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr)
186-
->execution_plan_tests()
187-
->Get(method_idx)
188-
->test_sets()
189-
->Get(testset_idx)
190-
->inputs();
207+
method_test.get()->test_sets()->Get(testset_idx)->inputs();
191208

192209
for (size_t input_idx = 0; input_idx < method.inputs_size(); input_idx++) {
193210
auto bundled_input = bundled_inputs->GetMutableObject(input_idx);
@@ -263,7 +280,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
263280
Method& method,
264281
serialized_bundled_program* bundled_program_ptr,
265282
MemoryAllocator* memory_allocator,
266-
size_t method_idx,
283+
const char* method_name,
267284
size_t testset_idx,
268285
double rtol,
269286
double atol) {
@@ -273,13 +290,16 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
273290
NotSupported,
274291
"The input buffer should be a bundled program.");
275292

293+
auto method_test = get_method_test(
294+
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr),
295+
method_name);
296+
297+
if (!method_test.ok()) {
298+
return method_test.error();
299+
}
300+
276301
auto bundled_expected_outputs =
277-
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr)
278-
->execution_plan_tests()
279-
->Get(method_idx)
280-
->test_sets()
281-
->Get(testset_idx)
282-
->expected_outputs();
302+
method_test.get()->test_sets()->Get(testset_idx)->expected_outputs();
283303

284304
for (size_t output_idx = 0; output_idx < method.outputs_size();
285305
output_idx++) {

util/bundled_program_verification.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using serialized_bundled_program = const void;
2626
*
2727
* @param[in] method The Method to verify.
2828
* @param[in] bundled_program_ptr The bundled program contains expected output.
29-
* @param[in] method_idx The index of the Method being verified.
29+
* @param[in] method_name The name of the Method being verified.
3030
* @param[in] testset_idx The index of input needs to be set into given Method.
3131
*
3232
* @returns Return Error::Ok if load successfully, or the error happens during
@@ -36,7 +36,7 @@ __ET_NODISCARD Error LoadBundledInput(
3636
Method& method,
3737
serialized_bundled_program* bundled_program_ptr,
3838
MemoryAllocator* memory_allocator,
39-
size_t method_idx,
39+
const char* method_name,
4040
size_t testset_idx);
4141

4242
/**
@@ -45,7 +45,7 @@ __ET_NODISCARD Error LoadBundledInput(
4545
*
4646
* @param[in] method The Method to extract outputs from.
4747
* @param[in] bundled_program_ptr The bundled program contains expected output.
48-
* @param[in] method_idx The index of the Method being verified.
48+
* @param[in] method_name The name of the Method being verified.
4949
* @param[in] testset_idx The index of expected output needs to be compared.
5050
* @param[in] rtol Relative tolerance used for data comparsion.
5151
* @param[in] atol Absolute tolerance used for data comparsion.
@@ -57,7 +57,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
5757
Method& method,
5858
serialized_bundled_program* bundled_program_ptr,
5959
MemoryAllocator* memory_allocator,
60-
size_t method_idx,
60+
const char* method_name,
6161
size_t testset_idx,
6262
double rtol = 1e-5,
6363
double atol = 1e-8);

0 commit comments

Comments
 (0)