Skip to content

Commit 4aeddf1

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
structural api and schema update (#410)
Summary: Pull Request resolved: #410 Update the AOT API for bundled program to make it more user friendly. We created two new classes for holding testcases in a more structural way: MethodTestCase: contains inputs and expected outputs for a single test for method. MethodTestSuite: Collection of all test cases for a program method. Detailed design can be found here: https://docs.google.com/document/d/170WJ81dPuNQsvsjqH_wQjafbu2kxtEG1sgX-z6OqCEI/edit#heading=h.69fgezruwl6n Differential Revision: D49406613 fbshipit-source-id: 3741c769ea84105b1e64c0f1e6503b6903b91057
1 parent 9dfbd56 commit 4aeddf1

File tree

14 files changed

+290
-357
lines changed

14 files changed

+290
-357
lines changed

backends/xnnpack/test/TARGETS

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,7 @@ python_unittest(
127127
]),
128128
deps = [
129129
"//caffe2:torch",
130-
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
131130
"//executorch/backends/xnnpack/test/tester:tester",
132-
"//executorch/exir:lib",
133131
"//pytorch/vision:torchvision",
134132
],
135133
)

backends/xnnpack/test/test_xnnpack_utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import unittest
88
from random import randint
9-
from typing import Any, Tuple
9+
from typing import Any, List, Tuple
1010

1111
import torch
1212
import torch.nn.functional as F
@@ -26,7 +26,7 @@
2626
# import the xnnpack backend implementation
2727
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
2828

29-
from executorch.bundled_program.config import BundledConfig
29+
from executorch.bundled_program.config import MethodTestCase, MethodTestSuite
3030
from executorch.bundled_program.core import create_bundled_program
3131
from executorch.bundled_program.serialize import (
3232
serialize_from_bundled_program_to_flatbuffer,
@@ -101,14 +101,22 @@ def save_bundled_program(representative_inputs, program, ref_output, output_path
101101
niter = 1
102102

103103
print("generating bundled program inputs / outputs")
104-
inputs_list = [list(representative_inputs) for _ in range(niter)]
105-
expected_outputs_list = [
106-
[[ref_output] for x in inputs_list],
104+
105+
method_test_cases: List[MethodTestCase] = []
106+
for _ in range(niter):
107+
method_test_cases.append(
108+
MethodTestCase(
109+
inputs=list(representative_inputs),
110+
expected_outputs=[ref_output],
111+
)
112+
)
113+
114+
method_test_suites = [
115+
MethodTestSuite(method_name="forward", method_test_cases=method_test_cases)
107116
]
108-
bundled_config = BundledConfig([inputs_list], expected_outputs_list)
109117

110118
print("creating bundled program...")
111-
bundled_program = create_bundled_program(program, bundled_config)
119+
bundled_program = create_bundled_program(program, method_test_suites)
112120

113121
print("serializing bundled program...")
114122
bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(

bundled_program/config.py

Lines changed: 25 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from dataclasses import dataclass
10-
from typing import Any, get_args, List, Union
10+
from typing import get_args, List, Optional, Sequence, Union
1111

1212
import torch
1313
from executorch.extension.pytree import tree_flatten
@@ -16,7 +16,7 @@
1616

1717
"""
1818
The data types currently supported for element to be bundled. It should be
19-
consistent with the types in bundled_program.schema.BundledValue.
19+
consistent with the types in bundled_program.schema.Value.
2020
"""
2121
ConfigValue: TypeAlias = Union[
2222
torch.Tensor,
@@ -26,7 +26,7 @@
2626
]
2727

2828
"""
29-
All supported types for input/expected output of test set.
29+
All supported types for input/expected output of MethodTestCase.
3030
3131
Namedtuple is also supported and listed implicity since it is a subclass of tuple.
3232
"""
@@ -35,74 +35,21 @@
3535
DataContainer: TypeAlias = Union[list, tuple, dict]
3636

3737

38-
@dataclass
39-
class ConfigIOSet:
40-
"""Type of data BundledConfig stored for each validation set."""
41-
42-
inputs: List[ConfigValue]
43-
expected_outputs: List[ConfigValue]
44-
45-
46-
@dataclass
47-
class ConfigExecutionPlanTest:
48-
"""All info related to verify execution plan"""
49-
50-
method_name: str
51-
test_sets: List[ConfigIOSet]
52-
53-
54-
class BundledConfig:
55-
"""All information needed to verify a model.
56-
57-
Public Attributes:
58-
execution_plan_tests: inputs, expected outputs, and other info for each execution plan verification.
59-
attachments: Other info need to be attached.
60-
"""
38+
class MethodTestCase:
39+
"""Test case with inputs and expected outputs
40+
The expected_outputs could be None if user only want to user the test case for profiling."""
6141

6242
def __init__(
63-
self,
64-
method_names: List[str],
65-
# pyre-ignore
66-
inputs: List[List[Any]],
67-
# pyre-ignore
68-
expected_outputs: List[List[Any]],
43+
self, inputs: DataContainer, expected_outputs: Optional[DataContainer] = None
6944
) -> None:
70-
"""Contruct the config given inputs and expected outputs
71-
72-
Args:
73-
method_names: All method names need to be verified in program.
74-
inputs: All sets of input need to be test on for all methods. Each list
75-
of `inputs` is all sets which will be run on the method in the
76-
program with corresponding method name. Each set of any `inputs` element should
77-
contain all inputs required by eager_model with the same inference function
78-
as corresponding execution plan for one-time execution.
79-
80-
expected_outputs: Expected outputs for inputs sharing same index. The size of
81-
expected_outputs should be the same as the size of inputs and provided method_names.
82-
"""
83-
BundledConfig._check_io_type(inputs)
84-
BundledConfig._check_io_type(expected_outputs)
85-
86-
for m_name in method_names:
87-
assert isinstance(m_name, str)
88-
89-
assert len(method_names) == len(inputs) == len(expected_outputs), (
90-
"length of method_names, inputs and expected_outputs should match,"
91-
+ " but got {}, {} and {}".format(
92-
len(method_names), len(inputs), len(expected_outputs)
93-
)
94-
)
95-
96-
self.execution_plan_tests: List[
97-
ConfigExecutionPlanTest
98-
] = BundledConfig._gen_execution_plan_tests(
99-
method_names, inputs, expected_outputs
100-
)
101-
102-
@staticmethod
103-
# TODO(T138930448): Give pyre-ignore commands appropriate warning type and comments.
104-
# pyre-ignore
105-
def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
45+
self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs)
46+
self.expected_outputs: List[ConfigValue] = []
47+
if expected_outputs:
48+
self.expected_outputs = self._flatten_and_sanity_check(expected_outputs)
49+
50+
def _flatten_and_sanity_check(
51+
self, unflatten_data: DataContainer
52+
) -> List[ConfigValue]:
10653
"""Flat the given data and check its legality
10754
10855
Args:
@@ -111,6 +58,11 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
11158
Returns:
11259
flatten_data: Flatten data with legal type.
11360
"""
61+
62+
assert isinstance(
63+
unflatten_data, get_args(DataContainer)
64+
), f"The input or expected output of MethodTestCase should be in list, tuple or dict, but got {type(unflatten_data)}."
65+
11466
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
11567
flatten_data, _ = tree_flatten(unflatten_data)
11668

@@ -128,70 +80,10 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
12880

12981
return flatten_data
13082

131-
@staticmethod
132-
# pyre-ignore
133-
def _check_io_type(test_data_program: List[List[Any]]) -> None:
134-
"""Check the type of each set of inputs or exepcted_outputs
135-
136-
Each test set of inputs or expected_outputs will be put into the config
137-
should be one of the sub-type in DataContainer.
138-
139-
Args:
140-
test_data_program: inputs or expected_outputs to be put into the config
141-
to verify the whole program.
14283

143-
"""
144-
for test_data_execution_plan in test_data_program:
145-
for test_set in test_data_execution_plan:
146-
assert isinstance(test_set, get_args(DataContainer))
147-
148-
@staticmethod
149-
def _gen_execution_plan_tests(
150-
method_names: List[str],
151-
# pyre-ignore
152-
inputs: List[List[Any]],
153-
# pyre-ignore
154-
expected_outputs: List[List[Any]],
155-
) -> List[ConfigExecutionPlanTest]:
156-
"""Generate execution plan test given inputs, expected outputs for verifying each execution plan"""
157-
158-
execution_plan_tests: List[ConfigExecutionPlanTest] = []
159-
160-
for (
161-
m_name,
162-
inputs_per_plan_test,
163-
expect_outputs_per_plan_test,
164-
) in zip(method_names, inputs, expected_outputs):
165-
test_sets: List[ConfigIOSet] = []
166-
167-
# transfer I/O sets into ConfigIOSet for each execution plan
168-
assert len(inputs_per_plan_test) == len(expect_outputs_per_plan_test), (
169-
"The number of input and expected output for identical execution plan should be the same,"
170-
+ " but got {} and {}".format(
171-
len(inputs_per_plan_test), len(expect_outputs_per_plan_test)
172-
)
173-
)
174-
for unflatten_input, unflatten_expected_output in zip(
175-
inputs_per_plan_test, expect_outputs_per_plan_test
176-
):
177-
flatten_inputs = BundledConfig._tree_flatten(unflatten_input)
178-
flatten_expected_output = BundledConfig._tree_flatten(
179-
unflatten_expected_output
180-
)
181-
test_sets.append(
182-
ConfigIOSet(
183-
inputs=flatten_inputs, expected_outputs=flatten_expected_output
184-
)
185-
)
186-
187-
execution_plan_tests.append(
188-
ConfigExecutionPlanTest(
189-
method_name=m_name,
190-
test_sets=test_sets,
191-
)
192-
)
193-
194-
# sort the execution plan tests by method name to in line with core program emitter.
195-
execution_plan_tests.sort(key=lambda x: x.method_name)
84+
@dataclass
85+
class MethodTestSuite:
86+
"""All info related to verify method"""
19687

197-
return execution_plan_tests
88+
method_name: str
89+
test_cases: Sequence[MethodTestCase]

0 commit comments

Comments
 (0)