7
7
# pyre-strict
8
8
9
9
from dataclasses import dataclass
10
- from typing import Any , get_args , List , Union
10
+ from typing import get_args , List , Optional , Sequence , Union
11
11
12
12
import torch
13
13
from executorch .extension .pytree import tree_flatten
16
16
17
17
"""
18
18
The data types currently supported for element to be bundled. It should be
19
- consistent with the types in bundled_program.schema.BundledValue .
19
+ consistent with the types in bundled_program.schema.Value .
20
20
"""
21
21
ConfigValue : TypeAlias = Union [
22
22
torch .Tensor ,
26
26
]
27
27
28
28
"""
29
- All supported types for input/expected output of test set .
29
+ All supported types for input/expected output of MethodTestCase .
30
30
31
31
Namedtuple is also supported and listed implicity since it is a subclass of tuple.
32
32
"""
35
35
DataContainer : TypeAlias = Union [list , tuple , dict ]
36
36
37
37
38
- @dataclass
39
- class ConfigIOSet :
40
- """Type of data BundledConfig stored for each validation set."""
41
-
42
- inputs : List [ConfigValue ]
43
- expected_outputs : List [ConfigValue ]
44
-
45
-
46
- @dataclass
47
- class ConfigExecutionPlanTest :
48
- """All info related to verify execution plan"""
49
-
50
- method_name : str
51
- test_sets : List [ConfigIOSet ]
52
-
53
-
54
- class BundledConfig :
55
- """All information needed to verify a model.
56
-
57
- Public Attributes:
58
- execution_plan_tests: inputs, expected outputs, and other info for each execution plan verification.
59
- attachments: Other info need to be attached.
60
- """
38
+ class MethodTestCase :
39
+ """Test case with inputs and expected outputs
40
+ The expected_outputs could be None if user only want to user the test case for profiling."""
61
41
62
42
def __init__ (
63
- self ,
64
- method_names : List [str ],
65
- # pyre-ignore
66
- inputs : List [List [Any ]],
67
- # pyre-ignore
68
- expected_outputs : List [List [Any ]],
43
+ self , inputs : DataContainer , expected_outputs : Optional [DataContainer ] = None
69
44
) -> None :
70
- """Contruct the config given inputs and expected outputs
71
-
72
- Args:
73
- method_names: All method names need to be verified in program.
74
- inputs: All sets of input need to be test on for all methods. Each list
75
- of `inputs` is all sets which will be run on the method in the
76
- program with corresponding method name. Each set of any `inputs` element should
77
- contain all inputs required by eager_model with the same inference function
78
- as corresponding execution plan for one-time execution.
79
-
80
- expected_outputs: Expected outputs for inputs sharing same index. The size of
81
- expected_outputs should be the same as the size of inputs and provided method_names.
82
- """
83
- BundledConfig ._check_io_type (inputs )
84
- BundledConfig ._check_io_type (expected_outputs )
85
-
86
- for m_name in method_names :
87
- assert isinstance (m_name , str )
88
-
89
- assert len (method_names ) == len (inputs ) == len (expected_outputs ), (
90
- "length of method_names, inputs and expected_outputs should match,"
91
- + " but got {}, {} and {}" .format (
92
- len (method_names ), len (inputs ), len (expected_outputs )
93
- )
94
- )
95
-
96
- self .execution_plan_tests : List [
97
- ConfigExecutionPlanTest
98
- ] = BundledConfig ._gen_execution_plan_tests (
99
- method_names , inputs , expected_outputs
100
- )
101
-
102
- @staticmethod
103
- # TODO(T138930448): Give pyre-ignore commands appropriate warning type and comments.
104
- # pyre-ignore
105
- def _tree_flatten (unflatten_data : Any ) -> List [ConfigValue ]:
45
+ self .inputs : List [ConfigValue ] = self ._flatten_and_sanity_check (inputs )
46
+ self .expected_outputs : List [ConfigValue ] = []
47
+ if expected_outputs :
48
+ self .expected_outputs = self ._flatten_and_sanity_check (expected_outputs )
49
+
50
+ def _flatten_and_sanity_check (
51
+ self , unflatten_data : DataContainer
52
+ ) -> List [ConfigValue ]:
106
53
"""Flat the given data and check its legality
107
54
108
55
Args:
@@ -111,6 +58,11 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
111
58
Returns:
112
59
flatten_data: Flatten data with legal type.
113
60
"""
61
+
62
+ assert isinstance (
63
+ unflatten_data , get_args (DataContainer )
64
+ ), f"The input or expected output of MethodTestCase should be in list, tuple or dict, but got { type (unflatten_data )} ."
65
+
114
66
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
115
67
flatten_data , _ = tree_flatten (unflatten_data )
116
68
@@ -128,70 +80,10 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
128
80
129
81
return flatten_data
130
82
131
- @staticmethod
132
- # pyre-ignore
133
- def _check_io_type (test_data_program : List [List [Any ]]) -> None :
134
- """Check the type of each set of inputs or exepcted_outputs
135
-
136
- Each test set of inputs or expected_outputs will be put into the config
137
- should be one of the sub-type in DataContainer.
138
-
139
- Args:
140
- test_data_program: inputs or expected_outputs to be put into the config
141
- to verify the whole program.
142
83
143
- """
144
- for test_data_execution_plan in test_data_program :
145
- for test_set in test_data_execution_plan :
146
- assert isinstance (test_set , get_args (DataContainer ))
147
-
148
- @staticmethod
149
- def _gen_execution_plan_tests (
150
- method_names : List [str ],
151
- # pyre-ignore
152
- inputs : List [List [Any ]],
153
- # pyre-ignore
154
- expected_outputs : List [List [Any ]],
155
- ) -> List [ConfigExecutionPlanTest ]:
156
- """Generate execution plan test given inputs, expected outputs for verifying each execution plan"""
157
-
158
- execution_plan_tests : List [ConfigExecutionPlanTest ] = []
159
-
160
- for (
161
- m_name ,
162
- inputs_per_plan_test ,
163
- expect_outputs_per_plan_test ,
164
- ) in zip (method_names , inputs , expected_outputs ):
165
- test_sets : List [ConfigIOSet ] = []
166
-
167
- # transfer I/O sets into ConfigIOSet for each execution plan
168
- assert len (inputs_per_plan_test ) == len (expect_outputs_per_plan_test ), (
169
- "The number of input and expected output for identical execution plan should be the same,"
170
- + " but got {} and {}" .format (
171
- len (inputs_per_plan_test ), len (expect_outputs_per_plan_test )
172
- )
173
- )
174
- for unflatten_input , unflatten_expected_output in zip (
175
- inputs_per_plan_test , expect_outputs_per_plan_test
176
- ):
177
- flatten_inputs = BundledConfig ._tree_flatten (unflatten_input )
178
- flatten_expected_output = BundledConfig ._tree_flatten (
179
- unflatten_expected_output
180
- )
181
- test_sets .append (
182
- ConfigIOSet (
183
- inputs = flatten_inputs , expected_outputs = flatten_expected_output
184
- )
185
- )
186
-
187
- execution_plan_tests .append (
188
- ConfigExecutionPlanTest (
189
- method_name = m_name ,
190
- test_sets = test_sets ,
191
- )
192
- )
193
-
194
- # sort the execution plan tests by method name to in line with core program emitter.
195
- execution_plan_tests .sort (key = lambda x : x .method_name )
84
+ @dataclass
85
+ class MethodTestSuite :
86
+ """All info related to verify method"""
196
87
197
- return execution_plan_tests
88
+ method_name : str
89
+ test_cases : Sequence [MethodTestCase ]
0 commit comments