Skip to content

Commit afaecbe

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
support multimethod by method name on AOT (#341)
Summary: Update AOT api and schema to support multimethod by method name instead of method idx. Since the following diffs will reformat the AOT apis, this diff doesn't pay much effort on documentation stuff. Differential Revision: D49246787
1 parent d1a009e commit afaecbe

File tree

9 files changed

+172
-78
lines changed

9 files changed

+172
-78
lines changed

bundled_program/config.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class ConfigIOSet:
4747
class ConfigExecutionPlanTest:
4848
"""All info related to verify execution plan"""
4949

50+
method_name: str
5051
test_sets: List[ConfigIOSet]
5152

5253

@@ -60,6 +61,7 @@ class BundledConfig:
6061

6162
def __init__(
6263
self,
64+
method_names: List[str],
6365
# pyre-ignore
6466
inputs: List[List[Any]],
6567
# pyre-ignore
@@ -68,34 +70,34 @@ def __init__(
6870
"""Contruct the config given inputs and expected outputs
6971
7072
Args:
71-
inputs: All sets of input need to be test on for all execution plans. Each list
72-
of `inputs` is all sets which will be run on the execution plan in the
73-
program sharing same index. Each set of any `inputs` element should
73+
method_names: All method names need to be verified in program. Each method name
74+
inputs: All sets of input need to be test on for all methods. Each list
75+
of `inputs` is all sets which will be run on the method in the
76+
program with corresponding method name. Each set of any `inputs` element should
7477
contain all inputs required by eager_model with the same inference function
7578
as corresponding execution plan for one-time execution.
7679
77-
Please note that currently we do not have any consensus about the mapping rule
78-
between inference name in eager_model and execution plan id in executorch
79-
program. Hence, user should take care of the data order in `inputs`: each list
80-
of `inputs` is all sets which will be run on the execution plan with same index,
81-
not the inference function with same index in the result of get_inference_name.
82-
Same as the `expected_outputs` and `metadatas` below.
83-
84-
It shouldn't be a problem if there's only one inferenece function per model.
85-
8680
expected_outputs: Expected outputs for inputs sharing same index. The size of
87-
expected_outputs should be the same as the size of inputs.
81+
expected_outputs should be the same as the size of inputs and provided method_names.
8882
"""
8983
BundledConfig._check_io_type(inputs)
9084
BundledConfig._check_io_type(expected_outputs)
91-
assert len(inputs) == len(expected_outputs), (
92-
"length of inputs and expected_outputs should match,"
93-
+ " but got {} and {}".format(len(inputs), len(expected_outputs))
85+
86+
for m_name in method_names:
87+
assert isinstance(m_name, str)
88+
89+
assert len(method_names) == len(inputs) == len(expected_outputs), (
90+
"length of method_names, inputs and expected_outputs should match,"
91+
+ " but got {}, {} and {}".format(
92+
len(method_names), len(inputs), len(expected_outputs)
93+
)
9494
)
9595

9696
self.execution_plan_tests: List[
9797
ConfigExecutionPlanTest
98-
] = BundledConfig._gen_execution_plan_tests(inputs, expected_outputs)
98+
] = BundledConfig._gen_execution_plan_tests(
99+
method_names, inputs, expected_outputs
100+
)
99101

100102
@staticmethod
101103
# TODO(T138930448): Give pyre-ignore commands appropriate warning type and comments.
@@ -145,6 +147,7 @@ def _check_io_type(test_data_program: List[List[Any]]) -> None:
145147

146148
@staticmethod
147149
def _gen_execution_plan_tests(
150+
method_names: List[str],
148151
# pyre-ignore
149152
inputs: List[List[Any]],
150153
# pyre-ignore
@@ -155,9 +158,10 @@ def _gen_execution_plan_tests(
155158
execution_plan_tests: List[ConfigExecutionPlanTest] = []
156159

157160
for (
161+
m_name,
158162
inputs_per_plan_test,
159163
expect_outputs_per_plan_test,
160-
) in zip(inputs, expected_outputs):
164+
) in zip(method_names, inputs, expected_outputs):
161165
test_sets: List[ConfigIOSet] = []
162166

163167
# transfer I/O sets into ConfigIOSet for each execution plan
@@ -182,7 +186,12 @@ def _gen_execution_plan_tests(
182186

183187
execution_plan_tests.append(
184188
ConfigExecutionPlanTest(
189+
method_name=m_name,
185190
test_sets=test_sets,
186191
)
187192
)
193+
194+
# sort the execution plan tests by method name to in line with core program emitter.
195+
execution_plan_tests.sort(key=lambda x: x.method_name)
196+
188197
return execution_plan_tests

bundled_program/core.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -136,32 +136,38 @@ def assert_valid_bundle(
136136
137137
"""
138138

139-
# Check the number of execution plan tests
140-
assert len(bundled_config.execution_plan_tests) == len(
141-
program.execution_plan
142-
), "The length of execution_plan_tests in config should match the length of execution_plan in program, but get {} and {}.".format(
143-
len(bundled_config.execution_plan_tests), len(program.execution_plan)
144-
)
139+
program_plan_id = 0
140+
bp_plan_id = 0
145141

146142
# Check if the inputs' type meet Program's requirement
147-
for plan_id in range(len(program.execution_plan)):
143+
while bp_plan_id < len(bundled_config.execution_plan_tests):
144+
148145
plan_test: ConfigExecutionPlanTest = bundled_config.execution_plan_tests[
149-
plan_id
146+
bp_plan_id
150147
]
148+
plan: ExecutionPlan = program.execution_plan[program_plan_id]
151149

152-
plan: ExecutionPlan = program.execution_plan[plan_id]
150+
# User does not provide testcases for current plan, skip it
151+
if plan_test.method_name < plan.name:
152+
program_plan_id += 1
153+
continue
154+
155+
# Check if the method name in user provided test matches the one in the original program
156+
assert (
157+
plan_test.method_name == plan.name
158+
), f"BundledConfig has testcases for method {plan_test.method_name}, but can not find it in the given program. All method names in the program are {', '.join([p.name for p in program.execution_plan])}."
153159

154160
# Check if the type of Program's input is supported
155161
for index in range(len(plan.inputs)):
156162
assert (
157-
type(get_program_input(program, plan_id, index))
163+
type(get_program_input(program, program_plan_id, index))
158164
in supported_program_type_table
159165
), "The type of program's input isn't supported."
160166

161167
# Check if the type of Program's output is supported
162168
for index in range(len(plan.outputs)):
163169
assert (
164-
type(get_program_output(program, plan_id, index)) == Tensor
170+
type(get_program_output(program, program_plan_id, index)) == Tensor
165171
), "Only supports program with output in Tensor type."
166172

167173
# Check if the I/O sets of each execution plan test match program's requirement.
@@ -181,14 +187,14 @@ def assert_valid_bundle(
181187
assert (
182188
type(cur_plan_test_inputs[j])
183189
== supported_program_type_table[
184-
type(get_program_input(program, plan_id, j))
190+
type(get_program_input(program, program_plan_id, j))
185191
]
186192
), "The type {}-th input in {}-th test set of {}-th execution plan does not meet Program's requirement: expected {} but get {}".format(
187193
j,
188194
i,
189-
plan_id,
195+
program_plan_id,
190196
supported_program_type_table[
191-
type(get_program_input(program, plan_id, j))
197+
type(get_program_input(program, program_plan_id, j))
192198
],
193199
type(cur_plan_test_inputs[j]),
194200
)
@@ -198,10 +204,10 @@ def assert_valid_bundle(
198204
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
199205
# has no attribute `dtype`.
200206
assert cur_plan_test_inputs[j].dtype == get_input_dtype(
201-
program, plan_id, j
207+
program, program_plan_id, j
202208
), "The input tensor {} dtype shall be {}, but now is {}".format(
203209
cur_plan_test_inputs[j],
204-
get_input_dtype(program, plan_id, j),
210+
get_input_dtype(program, program_plan_id, j),
205211
cur_plan_test_inputs[j].dtype,
206212
)
207213
elif type(cur_plan_test_inputs[j]) in (
@@ -210,9 +216,9 @@ def assert_valid_bundle(
210216
float,
211217
):
212218
assert type(cur_plan_test_inputs[j]) == get_input_type(
213-
program, plan_id, j
219+
program, program_plan_id, j
214220
), "The input primitive dtype shall be {}, but now is {}".format(
215-
get_input_type(program, plan_id, j),
221+
get_input_type(program, program_plan_id, j),
216222
type(cur_plan_test_inputs[j]),
217223
)
218224

@@ -221,13 +227,16 @@ def assert_valid_bundle(
221227
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
222228
# has no attribute `dtype`.
223229
assert cur_plan_test_expected_outputs[j].dtype == get_output_dtype(
224-
program, plan_id, j
230+
program, program_plan_id, j
225231
), "The label tensor {} dtype shall be {}, but now is {}".format(
226232
cur_plan_test_expected_outputs[j],
227-
get_output_dtype(program, plan_id, j),
233+
get_output_dtype(program, program_plan_id, j),
228234
cur_plan_test_expected_outputs[j].dtype,
229235
)
230236

237+
program_plan_id += 1
238+
bp_plan_id += 1
239+
231240

232241
def create_bundled_program(
233242
program: Program,
@@ -245,10 +254,7 @@ def create_bundled_program(
245254
execution_plan_tests: List[BundledExecutionPlanTest] = []
246255

247256
# Emit data and metadata of bundled tensor
248-
for plan_id in range(len(program.execution_plan)):
249-
plan_test: ConfigExecutionPlanTest = bundled_config.execution_plan_tests[
250-
plan_id
251-
]
257+
for plan_test in bundled_config.execution_plan_tests:
252258
test_sets: List[BundledIOSet] = []
253259

254260
# emit I/O sets for each execution plan test
@@ -283,7 +289,11 @@ def create_bundled_program(
283289
)
284290

285291
# emit the whole execution plan test
286-
execution_plan_tests.append(BundledExecutionPlanTest(test_sets=test_sets))
292+
execution_plan_tests.append(
293+
BundledExecutionPlanTest(
294+
method_name=plan_test.method_name, test_sets=test_sets
295+
)
296+
)
287297

288298
program_bytes: bytes = _serialize_pte_binary(program)
289299

bundled_program/schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ class BundledIOSet:
7373
class BundledExecutionPlanTest:
7474
"""Context for testing and verifying an exceution plan."""
7575

76+
# The name of the method to test; e.g., "forward" for the forward() method
77+
# of an nn.Module. This name match a method defined by the ExecuTorch
78+
# program.
79+
method_name: str
80+
7681
# Sets of input/outputs to test with.
7782
test_sets: List[BundledIOSet]
7883

0 commit comments

Comments
 (0)