Skip to content

Commit dab0d71

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
structural api and schema update (#1009)
Summary: Pull Request resolved: #1009 Update the AOT API for bundled program to make it more user friendly. We created two new classes for holding testcases in a more structural way: MethodTestCase: contains inputs and expected outputs for a single test for method. MethodTestSuite: Collection of all test cases for a program method. Detailed design can be found here: https://docs.google.com/document/d/170WJ81dPuNQsvsjqH_wQjafbu2kxtEG1sgX-z6OqCEI/edit#heading=h.69fgezruwl6n Reviewed By: tarun292 Differential Revision: D50422563 fbshipit-source-id: 6bef1d29f4f2d5d56eacbb6744283066d40724e8
1 parent b32f5a3 commit dab0d71

File tree

18 files changed

+561
-607
lines changed

18 files changed

+561
-607
lines changed

backends/apple/mps/test/test_mps.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@
2525
TestMPS,
2626
)
2727

28-
from executorch.bundled_program.config import BundledConfig
28+
from executorch.bundled_program.config import MethodTestCase, MethodTestSuite
2929
from executorch.bundled_program.core import create_bundled_program
3030
from executorch.bundled_program.serialize import (
3131
serialize_from_bundled_program_to_flatbuffer,
3232
)
33+
3334
from executorch.exir import ExirExportedProgram
3435
from executorch.exir.backend.backend_api import to_backend
3536
from executorch.exir.tests.models import (
@@ -129,21 +130,20 @@ def forward(self, *args):
129130
f" -> Number of execution plans: {len(executorch_program.program.execution_plan)}"
130131
)
131132

132-
bundled_inputs = [
133-
[m_inputs] for _ in range(len(executorch_program.program.execution_plan))
134-
]
135-
logging.info(" -> Bundled inputs generated successfully")
136-
137-
output = m(*m_inputs)
138-
expected_outputs = [
139-
[[output]] for _ in range(len(executorch_program.program.execution_plan))
133+
method_test_suites = [
134+
MethodTestSuite(
135+
method_name="forward",
136+
test_cases=[
137+
MethodTestCase(inputs=m_inputs, expected_outputs=model(*m_inputs))
138+
],
139+
)
140140
]
141-
logging.info(" -> Bundled outputs generated successfully")
142141

143-
bundled_config = BundledConfig(["forward"], bundled_inputs, expected_outputs)
144-
logging.info(" -> Bundled config generated successfully")
142+
logging.info(" -> Test suites generated successfully")
145143

146-
bundled_program = create_bundled_program(executorch_program.program, bundled_config)
144+
bundled_program = create_bundled_program(
145+
executorch_program.program, method_test_suites
146+
)
147147
logging.info(" -> Bundled program generated successfully")
148148

149149
bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(

backends/apple/mps/test/test_mps_utils.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414
from executorch.backends.apple.mps.mps_preprocess import MPSBackend
15-
from executorch.bundled_program.config import BundledConfig
15+
from executorch.bundled_program.config import MethodTestCase, MethodTestSuite
1616
from executorch.bundled_program.core import create_bundled_program
1717
from executorch.bundled_program.serialize import (
1818
serialize_from_bundled_program_to_flatbuffer,
@@ -189,28 +189,23 @@ def forward(self, *args):
189189
logging.info(
190190
" -> Number of execution plans: {len(executorch_program.program.execution_plan)}"
191191
)
192-
bundled_inputs = [
193-
[sample_inputs]
194-
for _ in range(len(executorch_program.program.execution_plan))
195-
]
196-
logging.info(" -> Bundled inputs generated successfully")
197192

198-
output = module(*sample_inputs)
199-
expected_outputs = [
200-
[[output]] for _ in range(len(executorch_program.program.execution_plan))
193+
method_test_suites = [
194+
MethodTestSuite(method_name="forward", test_cases=[
195+
MethodTestCase(input=sample_inputs, expected_outputs=module(*sample_inputs))
196+
])
201197
]
202-
logging.info(" -> Bundled outputs generated successfully")
203198

204-
method_names = ["forward"]
205-
bundled_config = BundledConfig(method_names, bundled_inputs, expected_outputs)
199+
logging.info(" -> Test suites generated successfully")
200+
206201
bundled_program = create_bundled_program(
207-
executorch_program.program, bundled_config
202+
executorch_program.program, method_test_suites
208203
)
209204
bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(
210205
bundled_program
211206
)
212207

213-
filename = f"{func_name}.pte"
208+
filename = f"{func_name}.bpte"
214209
logging.info(f"Step 5: Saving bundled program to {filename}...")
215210
with open(filename, "wb") as file:
216211
file.write(bundled_program_buffer)

backends/xnnpack/test/test_xnnpack_utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import unittest
88
from random import randint
9-
from typing import Any, Tuple
9+
from typing import Any, List, Tuple
1010

1111
import torch
1212
import torch.nn.functional as F
@@ -26,7 +26,7 @@
2626
# import the xnnpack backend implementation
2727
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
2828

29-
from executorch.bundled_program.config import BundledConfig
29+
from executorch.bundled_program.config import MethodTestCase, MethodTestSuite
3030
from executorch.bundled_program.core import create_bundled_program
3131
from executorch.bundled_program.serialize import (
3232
serialize_from_bundled_program_to_flatbuffer,
@@ -101,14 +101,22 @@ def save_bundled_program(representative_inputs, program, ref_output, output_path
101101
niter = 1
102102

103103
print("generating bundled program inputs / outputs")
104-
inputs_list = [list(representative_inputs) for _ in range(niter)]
105-
expected_outputs_list = [
106-
[[ref_output] for x in inputs_list],
104+
105+
method_test_cases: List[MethodTestCase] = []
106+
for _ in range(niter):
107+
method_test_cases.append(
108+
MethodTestCase(
109+
inputs=representative_inputs,
110+
expected_outputs=ref_output,
111+
)
112+
)
113+
114+
method_test_suites = [
115+
MethodTestSuite(method_name="forward", method_test_cases=method_test_cases)
107116
]
108-
bundled_config = BundledConfig([inputs_list], expected_outputs_list)
109117

110118
print("creating bundled program...")
111-
bundled_program = create_bundled_program(program, bundled_config)
119+
bundled_program = create_bundled_program(program, method_test_suites)
112120

113121
print("serializing bundled program...")
114122
bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(

bundled_program/config.py

Lines changed: 34 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from dataclasses import dataclass
10-
from typing import Any, get_args, List, Union
10+
from typing import Any, get_args, List, Optional, Sequence, Union
1111

1212
import torch
1313
from torch.utils._pytree import tree_flatten
@@ -16,7 +16,7 @@
1616

1717
"""
1818
The data types currently supported for element to be bundled. It should be
19-
consistent with the types in bundled_program.schema.BundledValue.
19+
consistent with the types in bundled_program.schema.Value.
2020
"""
2121
ConfigValue: TypeAlias = Union[
2222
torch.Tensor,
@@ -28,15 +28,15 @@
2828
"""
2929
The data type of the input for method single execution.
3030
"""
31-
MethodInputType: TypeAlias = List[ConfigValue]
31+
MethodInputType: TypeAlias = Sequence[ConfigValue]
3232

3333
"""
3434
The data type of the output for method single execution.
3535
"""
36-
MethodOutputType: TypeAlias = List[torch.Tensor]
36+
MethodOutputType: TypeAlias = Sequence[torch.Tensor]
3737

3838
"""
39-
All supported types for input/expected output of test set.
39+
All supported types for input/expected output of MethodTestCase.
4040
4141
Namedtuple is also supported and listed implicity since it is a subclass of tuple.
4242
"""
@@ -45,79 +45,40 @@
4545
DataContainer: TypeAlias = Union[list, tuple, dict]
4646

4747

48-
@dataclass
49-
class ConfigIOSet:
50-
"""Type of data BundledConfig stored for each validation set."""
51-
52-
inputs: List[ConfigValue]
53-
expected_outputs: List[ConfigValue]
54-
55-
56-
@dataclass
57-
class ConfigExecutionPlanTest:
58-
"""All info related to verify execution plan"""
59-
60-
method_name: str
61-
test_sets: List[ConfigIOSet]
62-
63-
64-
class BundledConfig:
65-
"""All information needed to verify a model.
66-
67-
Public Attributes:
68-
execution_plan_tests: inputs, expected outputs, and other info for each execution plan verification.
69-
attachments: Other info need to be attached.
70-
"""
48+
class MethodTestCase:
49+
"""Test case with inputs and expected outputs
50+
The expected_outputs are optional and only required if the user wants to verify model outputs after execution."""
7151

7252
def __init__(
7353
self,
74-
method_names: List[str],
75-
inputs: List[List[MethodInputType]],
76-
expected_outputs: List[List[MethodOutputType]],
54+
inputs: MethodInputType,
55+
expected_outputs: Optional[MethodOutputType] = None,
7756
) -> None:
78-
"""Contruct the config given inputs and expected outputs
57+
"""Single test case for verifying specific method
7958
8059
Args:
81-
method_names: All method names need to be verified in program.
82-
inputs: All sets of input need to be test on for all methods. Each list
83-
of `inputs` is all sets which will be run on the method in the
84-
program with corresponding method name. Each set of any `inputs` element should
85-
contain all inputs required by eager_model with the same inference function
86-
as corresponding execution plan for one-time execution.
60+
input: All inputs required by eager_model with specific inference method for one-time execution.
8761
8862
It is worth mentioning that, although both bundled program and ET runtime apis support setting input
8963
other than torch.tensor type, only the input in torch.tensor type will be actually updated in
9064
the method, and the rest of the inputs will just do a sanity check if they match the default value in method.
9165
92-
expected_outputs: Expected outputs for inputs sharing same index. The size of
93-
expected_outputs should be the same as the size of inputs and provided method_names.
66+
expected_output: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling.
9467
9568
Returns:
9669
self
9770
"""
98-
BundledConfig._check_io_type(inputs)
99-
BundledConfig._check_io_type(expected_outputs)
100-
101-
for m_name in method_names:
102-
assert isinstance(m_name, str)
103-
104-
assert len(method_names) == len(inputs) == len(expected_outputs), (
105-
"length of method_names, inputs and expected_outputs should match,"
106-
+ " but got {}, {} and {}".format(
107-
len(method_names), len(inputs), len(expected_outputs)
108-
)
109-
)
110-
111-
self.execution_plan_tests: List[
112-
ConfigExecutionPlanTest
113-
] = BundledConfig._gen_execution_plan_tests(
114-
method_names, inputs, expected_outputs
115-
)
116-
117-
@staticmethod
118-
# TODO(T138930448): Give pyre-ignore commands appropriate warning type and comments.
119-
# pyre-ignore
120-
def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
71+
# TODO(gasoonjia): Update type check logic.
72+
# pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
73+
self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs)
74+
self.expected_outputs: List[ConfigValue] = []
75+
if expected_outputs is not None:
76+
# pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
77+
self.expected_outputs = self._flatten_and_sanity_check(expected_outputs)
78+
79+
def _flatten_and_sanity_check(
80+
self, unflatten_data: DataContainer
81+
) -> List[ConfigValue]:
12182
"""Flat the given data and check its legality
12283
12384
Args:
@@ -126,6 +87,7 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
12687
Returns:
12788
flatten_data: Flatten data with legal type.
12889
"""
90+
12991
flatten_data, _ = tree_flatten(unflatten_data)
13092

13193
for data in flatten_data:
@@ -142,68 +104,15 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
142104

143105
return flatten_data
144106

145-
@staticmethod
146-
# pyre-ignore
147-
def _check_io_type(test_data_program: List[List[Any]]) -> None:
148-
"""Check the type of each set of inputs or exepcted_outputs
149-
150-
Each test set of inputs or expected_outputs will be put into the config
151-
should be one of the sub-type in DataContainer.
152-
153-
Args:
154-
test_data_program: inputs or expected_outputs to be put into the config
155-
to verify the whole program.
156107

157-
"""
158-
for test_data_execution_plan in test_data_program:
159-
for test_set in test_data_execution_plan:
160-
assert isinstance(test_set, get_args(DataContainer))
161-
162-
@staticmethod
163-
def _gen_execution_plan_tests(
164-
method_names: List[str],
165-
inputs: List[List[MethodInputType]],
166-
expected_outputs: List[List[MethodOutputType]],
167-
) -> List[ConfigExecutionPlanTest]:
168-
"""Generate execution plan test given inputs, expected outputs for verifying each execution plan"""
169-
170-
execution_plan_tests: List[ConfigExecutionPlanTest] = []
171-
172-
for (
173-
m_name,
174-
inputs_per_plan_test,
175-
expect_outputs_per_plan_test,
176-
) in zip(method_names, inputs, expected_outputs):
177-
test_sets: List[ConfigIOSet] = []
178-
179-
# transfer I/O sets into ConfigIOSet for each execution plan
180-
assert len(inputs_per_plan_test) == len(expect_outputs_per_plan_test), (
181-
"The number of input and expected output for identical execution plan should be the same,"
182-
+ " but got {} and {}".format(
183-
len(inputs_per_plan_test), len(expect_outputs_per_plan_test)
184-
)
185-
)
186-
for unflatten_input, unflatten_expected_output in zip(
187-
inputs_per_plan_test, expect_outputs_per_plan_test
188-
):
189-
flatten_inputs = BundledConfig._tree_flatten(unflatten_input)
190-
flatten_expected_output = BundledConfig._tree_flatten(
191-
unflatten_expected_output
192-
)
193-
test_sets.append(
194-
ConfigIOSet(
195-
inputs=flatten_inputs, expected_outputs=flatten_expected_output
196-
)
197-
)
198-
199-
execution_plan_tests.append(
200-
ConfigExecutionPlanTest(
201-
method_name=m_name,
202-
test_sets=test_sets,
203-
)
204-
)
108+
@dataclass
109+
class MethodTestSuite:
110+
"""All test info related to verify method
205111
206-
# sort the execution plan tests by method name to in line with core program emitter.
207-
execution_plan_tests.sort(key=lambda x: x.method_name)
112+
Attributes:
113+
method_name: Name of the method to be verified.
114+
test_cases: All test cases for verifying the method.
115+
"""
208116

209-
return execution_plan_tests
117+
method_name: str
118+
test_cases: Sequence[MethodTestCase]

0 commit comments

Comments
 (0)