Skip to content

Commit 9effe70

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
support multimethod by method name runtime
Differential Revision: https://internalfb.com/D49335926 fbshipit-source-id: ff6b4e7688e3f605b7536d923c031ca3300254b0
1 parent f1bcd66 commit 9effe70

File tree

2 files changed

+38
-18
lines changed

2 files changed

+38
-18
lines changed

util/bundled_program_verification.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,28 +166,45 @@ bool tensors_are_close(
166166
}
167167
}
168168

169+
Result<executorch_flatbuffer::BundledExecutionPlanTest*> get_method_test(
170+
const executorch_flatbuffer::BundledProgram* bundled_program,
171+
const char* method_name) {
172+
auto method_tests = bundled_program->execution_plan_tests();
173+
for (size_t i = 0; i < method_tests->size(); i++) {
174+
auto m_test = method_tests->GetMutableObject(i);
175+
if (std::strcmp(m_test->method_name()->c_str(), method_name) == 0) {
176+
return m_test;
177+
}
178+
}
179+
ET_LOG(Error, "No method named '%s' in given bundled program", method_name);
180+
return Error::InvalidArgument;
181+
}
182+
169183
} // namespace
170184

171185
// Load testset_idx-th bundled data into the Method
172186
__ET_NODISCARD Error LoadBundledInput(
173187
Method& method,
174188
serialized_bundled_program* bundled_program_ptr,
175189
MemoryAllocator* memory_allocator,
176-
size_t method_idx,
190+
const char* method_name,
177191
size_t testset_idx) {
178192
ET_CHECK_OR_RETURN_ERROR(
179193
executorch_flatbuffer::BundledProgramBufferHasIdentifier(
180194
bundled_program_ptr),
181195
NotSupported,
182196
"The input buffer should be a bundled program.");
183197

198+
auto method_test = get_method_test(
199+
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr),
200+
method_name);
201+
202+
if (!method_test.ok()) {
203+
return method_test.error();
204+
}
205+
184206
auto bundled_inputs =
185-
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr)
186-
->execution_plan_tests()
187-
->Get(method_idx)
188-
->test_sets()
189-
->Get(testset_idx)
190-
->inputs();
207+
method_test.get()->test_sets()->Get(testset_idx)->inputs();
191208

192209
for (size_t input_idx = 0; input_idx < method.inputs_size(); input_idx++) {
193210
auto bundled_input = bundled_inputs->GetMutableObject(input_idx);
@@ -263,7 +280,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
263280
Method& method,
264281
serialized_bundled_program* bundled_program_ptr,
265282
MemoryAllocator* memory_allocator,
266-
size_t method_idx,
283+
const char* method_name,
267284
size_t testset_idx,
268285
double rtol,
269286
double atol) {
@@ -273,13 +290,16 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
273290
NotSupported,
274291
"The input buffer should be a bundled program.");
275292

293+
auto method_test = get_method_test(
294+
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr),
295+
method_name);
296+
297+
if (!method_test.ok()) {
298+
return method_test.error();
299+
}
300+
276301
auto bundled_expected_outputs =
277-
executorch_flatbuffer::GetBundledProgram(bundled_program_ptr)
278-
->execution_plan_tests()
279-
->Get(method_idx)
280-
->test_sets()
281-
->Get(testset_idx)
282-
->expected_outputs();
302+
method_test.get()->test_sets()->Get(testset_idx)->expected_outputs();
283303

284304
for (size_t output_idx = 0; output_idx < method.outputs_size();
285305
output_idx++) {

util/bundled_program_verification.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using serialized_bundled_program = const void;
2626
*
2727
* @param[in] method The Method to verify.
2828
* @param[in] bundled_program_ptr The bundled program contains expected output.
29-
* @param[in] method_idx The index of the Method being verified.
29+
* @param[in] method_name The name of the Method being verified.
3030
* @param[in] testset_idx The index of input needs to be set into given Method.
3131
*
3232
* @returns Return Error::Ok if load successfully, or the error happens during
@@ -36,7 +36,7 @@ __ET_NODISCARD Error LoadBundledInput(
3636
Method& method,
3737
serialized_bundled_program* bundled_program_ptr,
3838
MemoryAllocator* memory_allocator,
39-
size_t method_idx,
39+
const char* method_name,
4040
size_t testset_idx);
4141

4242
/**
@@ -45,7 +45,7 @@ __ET_NODISCARD Error LoadBundledInput(
4545
*
4646
* @param[in] method The Method to extract outputs from.
4747
* @param[in] bundled_program_ptr The bundled program contains expected output.
48-
* @param[in] method_idx The index of the Method being verified.
48+
* @param[in] method_name The name of the Method being verified.
4949
* @param[in] testset_idx The index of expected output needs to be compared.
5050
* @param[in] rtol Relative tolerance used for data comparsion.
5151
* @param[in] atol Absolute tolerance used for data comparsion.
@@ -57,7 +57,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
5757
Method& method,
5858
serialized_bundled_program* bundled_program_ptr,
5959
MemoryAllocator* memory_allocator,
60-
size_t method_idx,
60+
const char* method_name,
6161
size_t testset_idx,
6262
double rtol = 1e-5,
6363
double atol = 1e-8);

0 commit comments

Comments
 (0)