@@ -166,28 +166,45 @@ bool tensors_are_close(
166
166
}
167
167
}
168
168
169
+ Result<executorch_flatbuffer::BundledExecutionPlanTest*> get_method_test (
170
+ const executorch_flatbuffer::BundledProgram* bundled_program,
171
+ const char * method_name) {
172
+ auto method_tests = bundled_program->execution_plan_tests ();
173
+ for (size_t i = 0 ; i < method_tests->size (); i++) {
174
+ auto m_test = method_tests->GetMutableObject (i);
175
+ if (std::strcmp (m_test->method_name ()->c_str (), method_name) == 0 ) {
176
+ return m_test;
177
+ }
178
+ }
179
+ ET_LOG (Error, " No method named '%s' in given bundled program" , method_name);
180
+ return Error::InvalidArgument;
181
+ }
182
+
169
183
} // namespace
170
184
171
185
// Load testset_idx-th bundled data into the Method
172
186
__ET_NODISCARD Error LoadBundledInput (
173
187
Method& method,
174
188
serialized_bundled_program* bundled_program_ptr,
175
189
MemoryAllocator* memory_allocator,
176
- size_t method_idx ,
190
+ const char * method_name ,
177
191
size_t testset_idx) {
178
192
ET_CHECK_OR_RETURN_ERROR (
179
193
executorch_flatbuffer::BundledProgramBufferHasIdentifier (
180
194
bundled_program_ptr),
181
195
NotSupported,
182
196
" The input buffer should be a bundled program." );
183
197
198
+ auto method_test = get_method_test (
199
+ executorch_flatbuffer::GetBundledProgram (bundled_program_ptr),
200
+ method_name);
201
+
202
+ if (!method_test.ok ()) {
203
+ return method_test.error ();
204
+ }
205
+
184
206
auto bundled_inputs =
185
- executorch_flatbuffer::GetBundledProgram (bundled_program_ptr)
186
- ->execution_plan_tests ()
187
- ->Get (method_idx)
188
- ->test_sets ()
189
- ->Get (testset_idx)
190
- ->inputs ();
207
+ method_test.get ()->test_sets ()->Get (testset_idx)->inputs ();
191
208
192
209
for (size_t input_idx = 0 ; input_idx < method.inputs_size (); input_idx++) {
193
210
auto bundled_input = bundled_inputs->GetMutableObject (input_idx);
@@ -263,7 +280,7 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
263
280
Method& method,
264
281
serialized_bundled_program* bundled_program_ptr,
265
282
MemoryAllocator* memory_allocator,
266
- size_t method_idx ,
283
+ const char * method_name ,
267
284
size_t testset_idx,
268
285
double rtol,
269
286
double atol) {
@@ -273,13 +290,16 @@ __ET_NODISCARD Error VerifyResultWithBundledExpectedOutput(
273
290
NotSupported,
274
291
" The input buffer should be a bundled program." );
275
292
293
+ auto method_test = get_method_test (
294
+ executorch_flatbuffer::GetBundledProgram (bundled_program_ptr),
295
+ method_name);
296
+
297
+ if (!method_test.ok ()) {
298
+ return method_test.error ();
299
+ }
300
+
276
301
auto bundled_expected_outputs =
277
- executorch_flatbuffer::GetBundledProgram (bundled_program_ptr)
278
- ->execution_plan_tests ()
279
- ->Get (method_idx)
280
- ->test_sets ()
281
- ->Get (testset_idx)
282
- ->expected_outputs ();
302
+ method_test.get ()->test_sets ()->Get (testset_idx)->expected_outputs ();
283
303
284
304
for (size_t output_idx = 0 ; output_idx < method.outputs_size ();
285
305
output_idx++) {
0 commit comments