Skip to content

Commit e7b32be

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
support multimethod by method name on AOT (#341)
Summary: Update AOT api and schema to support multimethod by method name instead of method idx. Since the following diffs will reformat the AOT apis, this diff doesn't pay much effort on documentation stuff. Reviewed By: tarun292 Differential Revision: D49246787
1 parent 96b83ce commit e7b32be

File tree

9 files changed

+190
-78
lines changed

9 files changed

+190
-78
lines changed

bundled_program/config.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class ConfigIOSet:
4747
class ConfigExecutionPlanTest:
4848
"""All info related to verify execution plan"""
4949

50+
method_name: str
5051
test_sets: List[ConfigIOSet]
5152

5253

@@ -60,6 +61,7 @@ class BundledConfig:
6061

6162
def __init__(
6263
self,
64+
method_names: List[str],
6365
# pyre-ignore
6466
inputs: List[List[Any]],
6567
# pyre-ignore
@@ -68,34 +70,34 @@ def __init__(
6870
"""Contruct the config given inputs and expected outputs
6971
7072
Args:
71-
inputs: All sets of input need to be test on for all execution plans. Each list
72-
of `inputs` is all sets which will be run on the execution plan in the
73-
program sharing same index. Each set of any `inputs` element should
73+
method_names: All method names need to be verified in program.
74+
inputs: All sets of input need to be test on for all methods. Each list
75+
of `inputs` is all sets which will be run on the method in the
76+
program with corresponding method name. Each set of any `inputs` element should
7477
contain all inputs required by eager_model with the same inference function
7578
as corresponding execution plan for one-time execution.
7679
77-
Please note that currently we do not have any consensus about the mapping rule
78-
between inference name in eager_model and execution plan id in executorch
79-
program. Hence, user should take care of the data order in `inputs`: each list
80-
of `inputs` is all sets which will be run on the execution plan with same index,
81-
not the inference function with same index in the result of get_inference_name.
82-
Same as the `expected_outputs` and `metadatas` below.
83-
84-
It shouldn't be a problem if there's only one inferenece function per model.
85-
8680
expected_outputs: Expected outputs for inputs sharing same index. The size of
87-
expected_outputs should be the same as the size of inputs.
81+
expected_outputs should be the same as the size of inputs and provided method_names.
8882
"""
8983
BundledConfig._check_io_type(inputs)
9084
BundledConfig._check_io_type(expected_outputs)
91-
assert len(inputs) == len(expected_outputs), (
92-
"length of inputs and expected_outputs should match,"
93-
+ " but got {} and {}".format(len(inputs), len(expected_outputs))
85+
86+
for m_name in method_names:
87+
assert isinstance(m_name, str)
88+
89+
assert len(method_names) == len(inputs) == len(expected_outputs), (
90+
"length of method_names, inputs and expected_outputs should match,"
91+
+ " but got {}, {} and {}".format(
92+
len(method_names), len(inputs), len(expected_outputs)
93+
)
9494
)
9595

9696
self.execution_plan_tests: List[
9797
ConfigExecutionPlanTest
98-
] = BundledConfig._gen_execution_plan_tests(inputs, expected_outputs)
98+
] = BundledConfig._gen_execution_plan_tests(
99+
method_names, inputs, expected_outputs
100+
)
99101

100102
@staticmethod
101103
# TODO(T138930448): Give pyre-ignore commands appropriate warning type and comments.
@@ -145,6 +147,7 @@ def _check_io_type(test_data_program: List[List[Any]]) -> None:
145147

146148
@staticmethod
147149
def _gen_execution_plan_tests(
150+
method_names: List[str],
148151
# pyre-ignore
149152
inputs: List[List[Any]],
150153
# pyre-ignore
@@ -155,9 +158,10 @@ def _gen_execution_plan_tests(
155158
execution_plan_tests: List[ConfigExecutionPlanTest] = []
156159

157160
for (
161+
m_name,
158162
inputs_per_plan_test,
159163
expect_outputs_per_plan_test,
160-
) in zip(inputs, expected_outputs):
164+
) in zip(method_names, inputs, expected_outputs):
161165
test_sets: List[ConfigIOSet] = []
162166

163167
# transfer I/O sets into ConfigIOSet for each execution plan
@@ -182,7 +186,12 @@ def _gen_execution_plan_tests(
182186

183187
execution_plan_tests.append(
184188
ConfigExecutionPlanTest(
189+
method_name=m_name,
185190
test_sets=test_sets,
186191
)
187192
)
193+
194+
# sort the execution plan tests by method name to in line with core program emitter.
195+
execution_plan_tests.sort(key=lambda x: x.method_name)
196+
188197
return execution_plan_tests

bundled_program/core.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -136,32 +136,56 @@ def assert_valid_bundle(
136136
137137
"""
138138

139-
# Check the number of execution plan tests
140-
assert len(bundled_config.execution_plan_tests) == len(
141-
program.execution_plan
142-
), "The length of execution_plan_tests in config should match the length of execution_plan in program, but get {} and {}.".format(
143-
len(bundled_config.execution_plan_tests), len(program.execution_plan)
144-
)
139+
program_plan_id = 0
140+
bp_plan_id = 0
141+
142+
method_name_of_program = {e.name for e in program.execution_plan}
143+
method_name_of_test_suites = {
144+
t.method_name for t in bundled_config.execution_plan_tests
145+
}
146+
147+
assert method_name_of_test_suites.issubset(
148+
method_name_of_program
149+
), f"All methods in method_test_suites should be found in program.execution_plan, \
150+
but {str(method_name_of_test_suites - method_name_of_program)} does not include."
151+
152+
# check if method_tesdt_suites has been sorted in ascending alphabetical order of method name.
153+
for bp_plan_id in range(1, len(bundled_config.execution_plan_tests)):
154+
assert (
155+
bundled_config.execution_plan_tests[bp_plan_id - 1].method_name
156+
<= bundled_config.execution_plan_tests[bp_plan_id].method_name
157+
), f"The method name of test suite should be sorted in ascending alphabetical \
158+
order of method name, but {bp_plan_id-1}-th and {bp_plan_id}-th method_test_suite aren't."
145159

146160
# Check if the inputs' type meet Program's requirement
147-
for plan_id in range(len(program.execution_plan)):
161+
while bp_plan_id < len(bundled_config.execution_plan_tests):
162+
148163
plan_test: ConfigExecutionPlanTest = bundled_config.execution_plan_tests[
149-
plan_id
164+
bp_plan_id
150165
]
166+
plan: ExecutionPlan = program.execution_plan[program_plan_id]
151167

152-
plan: ExecutionPlan = program.execution_plan[plan_id]
168+
# User does not provide testcases for current plan, skip it
169+
if plan_test.method_name > plan.name:
170+
program_plan_id += 1
171+
continue
172+
173+
# Check if the method name in user provided test matches the one in the original program
174+
assert (
175+
plan_test.method_name == plan.name
176+
), f"BundledConfig has testcases for method {plan_test.method_name}, but can not find it in the given program. All method names in the program are {', '.join([p.name for p in program.execution_plan])}."
153177

154178
# Check if the type of Program's input is supported
155179
for index in range(len(plan.inputs)):
156180
assert (
157-
type(get_program_input(program, plan_id, index))
181+
type(get_program_input(program, program_plan_id, index))
158182
in supported_program_type_table
159183
), "The type of program's input isn't supported."
160184

161185
# Check if the type of Program's output is supported
162186
for index in range(len(plan.outputs)):
163187
assert (
164-
type(get_program_output(program, plan_id, index)) == Tensor
188+
type(get_program_output(program, program_plan_id, index)) == Tensor
165189
), "Only supports program with output in Tensor type."
166190

167191
# Check if the I/O sets of each execution plan test match program's requirement.
@@ -181,14 +205,14 @@ def assert_valid_bundle(
181205
assert (
182206
type(cur_plan_test_inputs[j])
183207
== supported_program_type_table[
184-
type(get_program_input(program, plan_id, j))
208+
type(get_program_input(program, program_plan_id, j))
185209
]
186210
), "The type {}-th input in {}-th test set of {}-th execution plan does not meet Program's requirement: expected {} but get {}".format(
187211
j,
188212
i,
189-
plan_id,
213+
program_plan_id,
190214
supported_program_type_table[
191-
type(get_program_input(program, plan_id, j))
215+
type(get_program_input(program, program_plan_id, j))
192216
],
193217
type(cur_plan_test_inputs[j]),
194218
)
@@ -198,10 +222,10 @@ def assert_valid_bundle(
198222
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
199223
# has no attribute `dtype`.
200224
assert cur_plan_test_inputs[j].dtype == get_input_dtype(
201-
program, plan_id, j
225+
program, program_plan_id, j
202226
), "The input tensor {} dtype shall be {}, but now is {}".format(
203227
cur_plan_test_inputs[j],
204-
get_input_dtype(program, plan_id, j),
228+
get_input_dtype(program, program_plan_id, j),
205229
cur_plan_test_inputs[j].dtype,
206230
)
207231
elif type(cur_plan_test_inputs[j]) in (
@@ -210,9 +234,9 @@ def assert_valid_bundle(
210234
float,
211235
):
212236
assert type(cur_plan_test_inputs[j]) == get_input_type(
213-
program, plan_id, j
237+
program, program_plan_id, j
214238
), "The input primitive dtype shall be {}, but now is {}".format(
215-
get_input_type(program, plan_id, j),
239+
get_input_type(program, program_plan_id, j),
216240
type(cur_plan_test_inputs[j]),
217241
)
218242

@@ -221,13 +245,16 @@ def assert_valid_bundle(
221245
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
222246
# has no attribute `dtype`.
223247
assert cur_plan_test_expected_outputs[j].dtype == get_output_dtype(
224-
program, plan_id, j
248+
program, program_plan_id, j
225249
), "The label tensor {} dtype shall be {}, but now is {}".format(
226250
cur_plan_test_expected_outputs[j],
227-
get_output_dtype(program, plan_id, j),
251+
get_output_dtype(program, program_plan_id, j),
228252
cur_plan_test_expected_outputs[j].dtype,
229253
)
230254

255+
program_plan_id += 1
256+
bp_plan_id += 1
257+
231258

232259
def create_bundled_program(
233260
program: Program,
@@ -245,10 +272,7 @@ def create_bundled_program(
245272
execution_plan_tests: List[BundledExecutionPlanTest] = []
246273

247274
# Emit data and metadata of bundled tensor
248-
for plan_id in range(len(program.execution_plan)):
249-
plan_test: ConfigExecutionPlanTest = bundled_config.execution_plan_tests[
250-
plan_id
251-
]
275+
for plan_test in bundled_config.execution_plan_tests:
252276
test_sets: List[BundledIOSet] = []
253277

254278
# emit I/O sets for each execution plan test
@@ -283,7 +307,11 @@ def create_bundled_program(
283307
)
284308

285309
# emit the whole execution plan test
286-
execution_plan_tests.append(BundledExecutionPlanTest(test_sets=test_sets))
310+
execution_plan_tests.append(
311+
BundledExecutionPlanTest(
312+
method_name=plan_test.method_name, test_sets=test_sets
313+
)
314+
)
287315

288316
program_bytes: bytes = _serialize_pte_binary(program)
289317

bundled_program/schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ class BundledIOSet:
7373
class BundledExecutionPlanTest:
7474
"""Context for testing and verifying an exceution plan."""
7575

76+
# The name of the method to test; e.g., "forward" for the forward() method
77+
# of an nn.Module. This name match a method defined by the ExecuTorch
78+
# program.
79+
method_name: str
80+
7681
# Sets of input/outputs to test with.
7782
test_sets: List[BundledIOSet]
7883

0 commit comments

Comments
 (0)