Skip to content

Commit acb3bab

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
support multimethod by method name runtime
Differential Revision: https://internalfb.com/D49335926 fbshipit-source-id: fb391752c992aa00fb8966748aea899ee8c14ecc
1 parent f984c1d commit acb3bab

File tree

4 files changed

+45
-21
lines changed

4 files changed

+45
-21
lines changed

bundled_program/tests/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import executorch.exir as exir
1313
import torch
1414
from executorch.bundled_program.config import BundledConfig
15-
from executorch.exir import CaptureConfig
1615
from executorch.exir.schema import Program
1716

1817
# @manual=fbsource//third-party/pypi/typing-extensions:typing-extensions

sdk/runners/executor_runner.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ DEFINE_string(
7676

7777
DEFINE_int32(num_threads, 1, "Number of threads to use.");
7878

79+
DEFINE_string(
80+
method_name,
81+
"forward",
82+
"Name of method to run. Only used by bundled program mode.");
83+
7984
DEFINE_int32(
8085
testset_idx,
8186
0,
@@ -362,7 +367,7 @@ int main(int argc, char** argv) {
362367
*method,
363368
program_data.bundled_program_data(),
364369
&bundled_input_allocator,
365-
method_index,
370+
method_name,
366371
FLAGS_testset_idx);
367372
ET_CHECK_MSG(
368373
status == Error::Ok,
@@ -391,7 +396,7 @@ int main(int argc, char** argv) {
391396
*method,
392397
program_data.bundled_program_data(),
393398
&bundled_input_allocator,
394-
method_index,
399+
method_name,
395400
FLAGS_testset_idx,
396401
FLAGS_rtol,
397402
FLAGS_atol);

util/bundled_program_verification.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,28 +166,45 @@ bool tensors_are_close(
166166
}
167167
}
168168

169+
Result<executorch_flatbuffer::BundledExecutionPlanTest*> get_method_test(
170+
const executorch_flatbuffer::BundledProgram* bundled_program,
171+
const char* method_name) {
172+
auto method_tests = bundled_program->execution_plan_tests();
173+
for (size_t i = 0; i < method_tests->size(); i++) {
174+
auto m_test = method_tests->GetMutableObject(i);
175+
if (std::strcmp(m_test->method_name()->c_str(), method_name) == 0) {
176+
return m_test;
177+
}
178+
}
179+
ET_LOG(Error, "No method named '%s' in given bundled program", method_name);
180+
return Error::InvalidArgument;
181+
}
182+
169183
} // namespace
170184

171185
// Load testset_idx-th bundled data into the Method
172186
__ET_NODISCARD Error LoadBundledInput(
173187
Method& method,
174188
serialized_bundled_program* bundled_program_ptr,
175189
MemoryAllocator* memory_allocator,
176-
size_t method_idx,
190+
const char* method_name,
177191
size_t testset_idx) {
178192
ET_CHECK_OR_RETURN_ERROR(
179193
executorch_flatbuffer::BundledProgramBufferHasIdentifier(
180194
bundled_program_ptr),
181195
NotSupported,
182196
"The input buffer should be a bundled program.");
183197

198+
auto method_test = get_method_test(
199+
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr),
200+
method_name);
201+
202+
if (!method_test.ok()) {
203+
return method_test.error();
204+
}
205+
184206
auto bundled_inputs =
185-
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr)
186-
->execution_plan_tests()
187-
->Get(method_idx)
188-
->test_sets()
189-
->Get(testset_idx)
190-
->inputs();
207+
method_test.get()->test_sets()->Get(testset_idx)->inputs();
191208

192209
for (size_t input_idx = 0; input_idx < method.inputs_size(); input_idx++) {
193210
auto bundled_input = bundled_inputs->GetMutableObject(input_idx);
@@ -263,7 +280,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
263280
Method& method,
264281
serialized_bundled_program* bundled_program_ptr,
265282
MemoryAllocator* memory_allocator,
266-
size_t method_idx,
283+
const char* method_name,
267284
size_t testset_idx,
268285
double rtol,
269286
double atol) {
@@ -273,13 +290,16 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
273290
NotSupported,
274291
"The input buffer should be a bundled program.");
275292

293+
auto method_test = get_method_test(
294+
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr),
295+
method_name);
296+
297+
if (!method_test.ok()) {
298+
return method_test.error();
299+
}
300+
276301
auto bundled_expected_outputs =
277-
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr)
278-
->execution_plan_tests()
279-
->Get(method_idx)
280-
->test_sets()
281-
->Get(testset_idx)
282-
->expected_outputs();
302+
method_test.get()->test_sets()->Get(testset_idx)->expected_outputs();
283303

284304
for (size_t output_idx = 0; output_idx < method.outputs_size();
285305
output_idx++) {

util/bundled_program_verification.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using serialized_bundled_program = const void;
2626
*
2727
* @param[in] method The Method to verify.
2828
* @param[in] bundled_program_ptr The bundled program contains expected output.
29-
* @param[in] method_idx The index of the Method being verified.
29+
* @param[in] method_name The name of the Method being verified.
3030
* @param[in] testset_idx The index of input needs to be set into given Method.
3131
*
3232
* @returns Return Error::Ok if load successfully, or the error happens during
@@ -36,7 +36,7 @@ __ET_NODISCARD Error LoadBundledInput(
3636
Method& method,
3737
serialized_bundled_program* bundled_program_ptr,
3838
MemoryAllocator* memory_allocator,
39-
size_t method_idx,
39+
const char* method_name,
4040
size_t testset_idx);
4141

4242
/**
@@ -45,7 +45,7 @@ __ET_NODISCARD Error LoadBundledInput(
4545
*
4646
* @param[in] method The Method to extract outputs from.
4747
* @param[in] bundled_program_ptr The bundled program contains expected output.
48-
* @param[in] method_idx The index of the Method being verified.
48+
* @param[in] method_name The name of the Method being verified.
4949
* @param[in] testset_idx The index of expected output needs to be compared.
5050
* @param[in] rtol Relative tolerance used for data comparsion.
5151
* @param[in] atol Absolute tolerance used for data comparsion.
@@ -57,7 +57,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
5757
Method& method,
5858
serialized_bundled_program* bundled_program_ptr,
5959
MemoryAllocator* memory_allocator,
60-
size_t method_idx,
60+
const char* method_name,
6161
size_t testset_idx,
6262
double rtol = 1e-5,
6363
double atol = 1e-8);

0 commit comments

Comments
 (0)