Skip to content

Commit 286ecdd

Browse files
Songhao Jiafacebook-github-bot
authored andcommitted
support multimethod by method name on AOT
Summary: Update AOT api and schema to support multimethod by method name instead of method idx. Differential Revision: D49246787
1 parent 3482830 commit 286ecdd

File tree

9 files changed

+167
-65
lines changed

9 files changed

+167
-65
lines changed

bundled_program/config.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class ConfigIOSet:
4747
class ConfigExecutionPlanTest:
4848
"""All info related to verify execution plan"""
4949

50+
method_name: str
5051
test_sets: List[ConfigIOSet]
5152

5253

@@ -60,6 +61,7 @@ class BundledConfig:
6061

6162
def __init__(
6263
self,
64+
method_names: List[str],
6365
# pyre-ignore
6466
inputs: List[List[Any]],
6567
# pyre-ignore
@@ -88,14 +90,22 @@ def __init__(
8890
"""
8991
BundledConfig._check_io_type(inputs)
9092
BundledConfig._check_io_type(expected_outputs)
91-
assert len(inputs) == len(expected_outputs), (
92-
"length of inputs and expected_outputs should match,"
93-
+ " but got {} and {}".format(len(inputs), len(expected_outputs))
93+
94+
for m_name in method_names:
95+
assert isinstance(m_name, str)
96+
97+
assert len(method_names) == len(inputs) == len(expected_outputs), (
98+
"length of method_names, inputs and expected_outputs should match,"
99+
+ " but got {}, {} and {}".format(
100+
len(method_names), len(inputs), len(expected_outputs)
101+
)
94102
)
95103

96104
self.execution_plan_tests: List[
97105
ConfigExecutionPlanTest
98-
] = BundledConfig._gen_execution_plan_tests(inputs, expected_outputs)
106+
] = BundledConfig._gen_execution_plan_tests(
107+
method_names, inputs, expected_outputs
108+
)
99109

100110
@staticmethod
101111
# TODO(T138930448): Give pyre-ignore commands appropriate warning type and comments.
@@ -145,6 +155,7 @@ def _check_io_type(test_data_program: List[List[Any]]) -> None:
145155

146156
@staticmethod
147157
def _gen_execution_plan_tests(
158+
method_names: List[str],
148159
# pyre-ignore
149160
inputs: List[List[Any]],
150161
# pyre-ignore
@@ -155,9 +166,10 @@ def _gen_execution_plan_tests(
155166
execution_plan_tests: List[ConfigExecutionPlanTest] = []
156167

157168
for (
169+
m_name,
158170
inputs_per_plan_test,
159171
expect_outputs_per_plan_test,
160-
) in zip(inputs, expected_outputs):
172+
) in zip(method_names, inputs, expected_outputs):
161173
test_sets: List[ConfigIOSet] = []
162174

163175
# transfer I/O sets into ConfigIOSet for each execution plan
@@ -182,7 +194,12 @@ def _gen_execution_plan_tests(
182194

183195
execution_plan_tests.append(
184196
ConfigExecutionPlanTest(
197+
method_name=m_name,
185198
test_sets=test_sets,
186199
)
187200
)
201+
202+
# sort the execution plan tests by method name to in line with core program emitter.
203+
execution_plan_tests.sort(key=lambda x: x.method_name)
204+
188205
return execution_plan_tests

bundled_program/core.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -136,32 +136,38 @@ def assert_valid_bundle(
136136
137137
"""
138138

139-
# Check the number of execution plan tests
140-
assert len(bundled_config.execution_plan_tests) == len(
141-
program.execution_plan
142-
), "The length of execution_plan_tests in config should match the length of execution_plan in program, but get {} and {}.".format(
143-
len(bundled_config.execution_plan_tests), len(program.execution_plan)
144-
)
139+
program_plan_id = 0
140+
bp_plan_id = 0
145141

146142
# Check if the inputs' type meet Program's requirement
147-
for plan_id in range(len(program.execution_plan)):
143+
while bp_plan_id < len(bundled_config.execution_plan_tests):
144+
148145
plan_test: ConfigExecutionPlanTest = bundled_config.execution_plan_tests[
149-
plan_id
146+
bp_plan_id
150147
]
148+
plan: ExecutionPlan = program.execution_plan[program_plan_id]
151149

152-
plan: ExecutionPlan = program.execution_plan[plan_id]
150+
# User does not provide testcases for current plan, skip it
151+
if plan_test.method_name < plan.name:
152+
program_plan_id += 1
153+
continue
154+
155+
# Check if the method name in user provided test matches the one in the original program
156+
assert (
157+
plan_test.method_name == plan.name
158+
), f"BundledConfig has testcases for method {plan_test.method_name}, but can not find it in the given program. All method names in the program are {', '.join([p.name for p in program.execution_plan])}."
153159

154160
# Check if the type of Program's input is supported
155161
for index in range(len(plan.inputs)):
156162
assert (
157-
type(get_program_input(program, plan_id, index))
163+
type(get_program_input(program, program_plan_id, index))
158164
in supported_program_type_table
159165
), "The type of program's input isn't supported."
160166

161167
# Check if the type of Program's output is supported
162168
for index in range(len(plan.outputs)):
163169
assert (
164-
type(get_program_output(program, plan_id, index)) == Tensor
170+
type(get_program_output(program, program_plan_id, index)) == Tensor
165171
), "Only supports program with output in Tensor type."
166172

167173
# Check if the I/O sets of each execution plan test match program's requirement.
@@ -181,14 +187,14 @@ def assert_valid_bundle(
181187
assert (
182188
type(cur_plan_test_inputs[j])
183189
== supported_program_type_table[
184-
type(get_program_input(program, plan_id, j))
190+
type(get_program_input(program, program_plan_id, j))
185191
]
186192
), "The type {}-th input in {}-th test set of {}-th execution plan does not meet Program's requirement: expected {} but get {}".format(
187193
j,
188194
i,
189-
plan_id,
195+
program_plan_id,
190196
supported_program_type_table[
191-
type(get_program_input(program, plan_id, j))
197+
type(get_program_input(program, program_plan_id, j))
192198
],
193199
type(cur_plan_test_inputs[j]),
194200
)
@@ -198,10 +204,10 @@ def assert_valid_bundle(
198204
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
199205
# has no attribute `dtype`.
200206
assert cur_plan_test_inputs[j].dtype == get_input_dtype(
201-
program, plan_id, j
207+
program, program_plan_id, j
202208
), "The input tensor {} dtype shall be {}, but now is {}".format(
203209
cur_plan_test_inputs[j],
204-
get_input_dtype(program, plan_id, j),
210+
get_input_dtype(program, program_plan_id, j),
205211
cur_plan_test_inputs[j].dtype,
206212
)
207213
elif type(cur_plan_test_inputs[j]) in (
@@ -210,9 +216,9 @@ def assert_valid_bundle(
210216
float,
211217
):
212218
assert type(cur_plan_test_inputs[j]) == get_input_type(
213-
program, plan_id, j
219+
program, program_plan_id, j
214220
), "The input primitive dtype shall be {}, but now is {}".format(
215-
get_input_type(program, plan_id, j),
221+
get_input_type(program, program_plan_id, j),
216222
type(cur_plan_test_inputs[j]),
217223
)
218224

@@ -221,13 +227,16 @@ def assert_valid_bundle(
221227
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
222228
# has no attribute `dtype`.
223229
assert cur_plan_test_expected_outputs[j].dtype == get_output_dtype(
224-
program, plan_id, j
230+
program, program_plan_id, j
225231
), "The label tensor {} dtype shall be {}, but now is {}".format(
226232
cur_plan_test_expected_outputs[j],
227-
get_output_dtype(program, plan_id, j),
233+
get_output_dtype(program, program_plan_id, j),
228234
cur_plan_test_expected_outputs[j].dtype,
229235
)
230236

237+
program_plan_id += 1
238+
bp_plan_id += 1
239+
231240

232241
def create_bundled_program(
233242
program: Program,
@@ -245,10 +254,7 @@ def create_bundled_program(
245254
execution_plan_tests: List[BundledExecutionPlanTest] = []
246255

247256
# Emit data and metadata of bundled tensor
248-
for plan_id in range(len(program.execution_plan)):
249-
plan_test: ConfigExecutionPlanTest = bundled_config.execution_plan_tests[
250-
plan_id
251-
]
257+
for plan_test in bundled_config.execution_plan_tests:
252258
test_sets: List[BundledIOSet] = []
253259

254260
# emit I/O sets for each execution plan test
@@ -283,7 +289,11 @@ def create_bundled_program(
283289
)
284290

285291
# emit the whole execution plan test
286-
execution_plan_tests.append(BundledExecutionPlanTest(test_sets=test_sets))
292+
execution_plan_tests.append(
293+
BundledExecutionPlanTest(
294+
method_name=plan_test.method_name, test_sets=test_sets
295+
)
296+
)
287297

288298
program_bytes: bytes = _serialize_pte_binary(program)
289299

bundled_program/schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ class BundledIOSet:
7373
class BundledExecutionPlanTest:
7474
"""Context for testing and verifying an exceution plan."""
7575

76+
# The name of the method to test; e.g., "forward" for the forward() method
77+
# of an nn.Module. This name match a method defined by the ExecuTorch
78+
# program.
79+
method_name: str
80+
7681
# Sets of input/outputs to test with.
7782
test_sets: List[BundledIOSet]
7883

bundled_program/tests/common.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
import random
9+
import string
810
from typing import List, Tuple, Union
911

1012
import executorch.exir as exir
@@ -43,15 +45,16 @@
4345
OutputValues: TypeAlias = List[torch.Tensor]
4446

4547

46-
class MISOModel(torch.nn.Module):
47-
"""An example model with Multi-Input Single-Output"""
48+
class SampleModel(torch.nn.Module):
49+
"""An example model with multi-methods. Each method has multiple input and single output"""
4850

4951
def __init__(self) -> None:
5052
super().__init__()
5153
self.a: torch.Tensor = 3 * torch.ones(2, 2, dtype=torch.int32)
5254
self.b: torch.Tensor = 2 * torch.ones(2, 2, dtype=torch.int32)
55+
self.method_names = ["encode", "decode"]
5356

54-
def forward(
57+
def encode(
5558
self, x: torch.Tensor, q: torch.Tensor, a: int = DEFAULT_INT_INPUT
5659
) -> torch.Tensor:
5760
z = x.clone()
@@ -61,6 +64,13 @@ def forward(
6164
torch.add(y, q, out=y)
6265
return y
6366

67+
def decode(
68+
self, x: torch.Tensor, q: torch.Tensor, a: int = DEFAULT_INT_INPUT
69+
) -> torch.Tensor:
70+
y = x * q
71+
torch.add(y, self.b, alpha=a, out=y)
72+
return y
73+
6474

6575
def get_rand_input_values(
6676
n_tensors: int,
@@ -96,6 +106,15 @@ def get_rand_output_values(
96106
]
97107

98108

109+
def get_rand_method_names(n_execution_plan_tests: int) -> List[str]:
110+
unique_strings = set()
111+
while len(unique_strings) < n_execution_plan_tests:
112+
rand_str = "".join(random.choices(string.ascii_letters, k=5))
113+
if rand_str not in unique_strings:
114+
unique_strings.add(rand_str)
115+
return list(unique_strings)
116+
117+
99118
# TODO(T143955558): make n_int and metadatas as its input;
100119
def get_random_config(
101120
n_model_inputs: int,
@@ -105,7 +124,12 @@ def get_random_config(
105124
dtype: torch.dtype,
106125
n_sets_per_plan_test: int,
107126
n_execution_plan_tests: int,
108-
) -> Tuple[List[List[InputValues]], List[List[OutputValues]], BundledConfig,]:
127+
) -> Tuple[
128+
List[str],
129+
List[List[InputValues]],
130+
List[List[OutputValues]],
131+
BundledConfig,
132+
]:
109133
"""Helper function to generate config filled with random inputs and expected outputs.
110134
111135
The return type of rand inputs is a List[List[InputValues]]. The inner list of
@@ -116,6 +140,8 @@ def get_random_config(
116140
117141
"""
118142

143+
rand_method_names = get_rand_method_names(n_execution_plan_tests)
144+
119145
rand_inputs = get_rand_input_values(
120146
n_tensors=n_model_inputs,
121147
sizes=model_input_sizes,
@@ -134,67 +160,67 @@ def get_random_config(
134160
)
135161

136162
return (
163+
rand_method_names,
137164
rand_inputs,
138165
rand_expected_outputs,
139-
BundledConfig(rand_inputs, rand_expected_outputs),
166+
BundledConfig(rand_method_names, rand_inputs, rand_expected_outputs),
140167
)
141168

142169

143170
def get_random_config_with_eager_model(
144171
eager_model: torch.nn.Module,
172+
method_names: List[str],
145173
n_model_inputs: int,
146174
model_input_sizes: List[List[int]],
147175
dtype: torch.dtype,
148176
n_sets_per_plan_test: int,
149-
n_execution_plan_tests: int,
150177
) -> Tuple[List[List[InputValues]], BundledConfig]:
151178
"""Generate config filled with random inputs for each inference method given eager model
152179
153180
The details of return type is the same as get_random_config_with_rand_io_lists.
154-
155-
NOTE: Right now we do not support multiple inference methods per eager model. To simulate
156-
generating exepected output for different inference functions, we infer the same method
157-
multiple times.
158-
159-
TODO(T143752810): Update the hacky method after we support multiple inference methods.
160181
"""
161182
inputs = get_rand_input_values(
162183
n_tensors=n_model_inputs,
163184
sizes=model_input_sizes,
164185
n_int=1,
165186
dtype=dtype,
166187
n_sets_per_plan_test=n_sets_per_plan_test,
167-
n_execution_plan_tests=n_execution_plan_tests,
188+
n_execution_plan_tests=len(method_names),
168189
)
169190

170191
expected_outputs = [
171-
[[eager_model(*x)] for x in inputs[i]] for i in range(n_execution_plan_tests)
192+
[[getattr(eager_model, m_name)(*x)] for x in inputs[i]]
193+
for i, m_name in enumerate(method_names)
172194
]
173195

174-
return inputs, BundledConfig(inputs, expected_outputs)
196+
return inputs, BundledConfig(method_names, inputs, expected_outputs)
175197

176198

177199
def get_common_program() -> Tuple[Program, BundledConfig]:
178200
"""Helper function to generate a sample BundledProgram with its config."""
179-
eager_model = MISOModel()
201+
eager_model = SampleModel()
180202
# Trace to FX Graph.
181-
capture_input = (
182-
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
183-
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
184-
DEFAULT_INT_INPUT,
185-
)
203+
capture_inputs = {
204+
m_name: (
205+
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
206+
(torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
207+
DEFAULT_INT_INPUT,
208+
)
209+
for m_name in eager_model.method_names
210+
}
211+
186212
program = (
187-
exir.capture(eager_model, capture_input, CaptureConfig())
213+
exir.capture_multiple(eager_model, capture_inputs)
188214
.to_edge()
189215
.to_executorch()
190216
.program
191217
)
192218
_, bundled_config = get_random_config_with_eager_model(
193219
eager_model=eager_model,
220+
method_names=eager_model.method_names,
194221
n_model_inputs=2,
195222
model_input_sizes=[[2, 2], [2, 2]],
196223
dtype=torch.int32,
197224
n_sets_per_plan_test=10,
198-
n_execution_plan_tests=len(program.execution_plan),
199225
)
200226
return program, bundled_config

0 commit comments

Comments
 (0)