7
7
# pyre-strict
8
8
9
9
from dataclasses import dataclass
10
- from typing import Any , get_args , List , Union
10
+ from typing import Any , get_args , List , Optional , Sequence , Union
11
11
12
12
import torch
13
13
from torch .utils ._pytree import tree_flatten
16
16
17
17
"""
18
18
The data types currently supported for element to be bundled. It should be
19
- consistent with the types in bundled_program.schema.BundledValue .
19
+ consistent with the types in bundled_program.schema.Value .
20
20
"""
21
21
ConfigValue : TypeAlias = Union [
22
22
torch .Tensor ,
28
28
"""
29
29
The data type of the input for method single execution.
30
30
"""
31
- MethodInputType : TypeAlias = List [ConfigValue ]
31
+ MethodInputType : TypeAlias = Sequence [ConfigValue ]
32
32
33
33
"""
34
34
The data type of the output for method single execution.
35
35
"""
36
- MethodOutputType : TypeAlias = List [torch .Tensor ]
36
+ MethodOutputType : TypeAlias = Sequence [torch .Tensor ]
37
37
38
38
"""
39
- All supported types for input/expected output of test set .
39
+ All supported types for input/expected output of MethodTestCase .
40
40
41
41
Namedtuple is also supported and listed implicity since it is a subclass of tuple.
42
42
"""
45
45
DataContainer : TypeAlias = Union [list , tuple , dict ]
46
46
47
47
48
- @dataclass
49
- class ConfigIOSet :
50
- """Type of data BundledConfig stored for each validation set."""
51
-
52
- inputs : List [ConfigValue ]
53
- expected_outputs : List [ConfigValue ]
54
-
55
-
56
- @dataclass
57
- class ConfigExecutionPlanTest :
58
- """All info related to verify execution plan"""
59
-
60
- method_name : str
61
- test_sets : List [ConfigIOSet ]
62
-
63
-
64
- class BundledConfig :
65
- """All information needed to verify a model.
66
-
67
- Public Attributes:
68
- execution_plan_tests: inputs, expected outputs, and other info for each execution plan verification.
69
- attachments: Other info need to be attached.
70
- """
48
+ class MethodTestCase :
49
+ """Test case with inputs and expected outputs
50
+ The expected_outputs are optional and only required if the user wants to verify model outputs after execution."""
71
51
72
52
def __init__ (
73
53
self ,
74
- method_names : List [str ],
75
- inputs : List [List [MethodInputType ]],
76
- expected_outputs : List [List [MethodOutputType ]],
54
+ inputs : MethodInputType ,
55
+ expected_outputs : Optional [MethodOutputType ] = None ,
77
56
) -> None :
78
- """Contruct the config given inputs and expected outputs
57
+ """Single test case for verifying specific method
79
58
80
59
Args:
81
- method_names: All method names need to be verified in program.
82
- inputs: All sets of input need to be test on for all methods. Each list
83
- of `inputs` is all sets which will be run on the method in the
84
- program with corresponding method name. Each set of any `inputs` element should
85
- contain all inputs required by eager_model with the same inference function
86
- as corresponding execution plan for one-time execution.
60
+ input: All inputs required by eager_model with specific inference method for one-time execution.
87
61
88
62
It is worth mentioning that, although both bundled program and ET runtime apis support setting input
89
63
other than torch.tensor type, only the input in torch.tensor type will be actually updated in
90
64
the method, and the rest of the inputs will just do a sanity check if they match the default value in method.
91
65
92
- expected_outputs: Expected outputs for inputs sharing same index. The size of
93
- expected_outputs should be the same as the size of inputs and provided method_names.
66
+ expected_output: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling.
94
67
95
68
Returns:
96
69
self
97
70
"""
98
- BundledConfig ._check_io_type (inputs )
99
- BundledConfig ._check_io_type (expected_outputs )
100
-
101
- for m_name in method_names :
102
- assert isinstance (m_name , str )
103
-
104
- assert len (method_names ) == len (inputs ) == len (expected_outputs ), (
105
- "length of method_names, inputs and expected_outputs should match,"
106
- + " but got {}, {} and {}" .format (
107
- len (method_names ), len (inputs ), len (expected_outputs )
108
- )
109
- )
110
-
111
- self .execution_plan_tests : List [
112
- ConfigExecutionPlanTest
113
- ] = BundledConfig ._gen_execution_plan_tests (
114
- method_names , inputs , expected_outputs
115
- )
116
-
117
- @staticmethod
118
- # TODO(T138930448): Give pyre-ignore commands appropriate warning type and comments.
119
- # pyre-ignore
120
- def _tree_flatten (unflatten_data : Any ) -> List [ConfigValue ]:
71
+ # TODO(gasoonjia): Update type check logic.
72
+ # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
73
+ self .inputs : List [ConfigValue ] = self ._flatten_and_sanity_check (inputs )
74
+ self .expected_outputs : List [ConfigValue ] = []
75
+ if expected_outputs is not None :
76
+ # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
77
+ self .expected_outputs = self ._flatten_and_sanity_check (expected_outputs )
78
+
79
+ def _flatten_and_sanity_check (
80
+ self , unflatten_data : DataContainer
81
+ ) -> List [ConfigValue ]:
121
82
"""Flat the given data and check its legality
122
83
123
84
Args:
@@ -126,6 +87,7 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
126
87
Returns:
127
88
flatten_data: Flatten data with legal type.
128
89
"""
90
+
129
91
flatten_data , _ = tree_flatten (unflatten_data )
130
92
131
93
for data in flatten_data :
@@ -142,68 +104,15 @@ def _tree_flatten(unflatten_data: Any) -> List[ConfigValue]:
142
104
143
105
return flatten_data
144
106
145
- @staticmethod
146
- # pyre-ignore
147
- def _check_io_type (test_data_program : List [List [Any ]]) -> None :
148
- """Check the type of each set of inputs or exepcted_outputs
149
-
150
- Each test set of inputs or expected_outputs will be put into the config
151
- should be one of the sub-type in DataContainer.
152
-
153
- Args:
154
- test_data_program: inputs or expected_outputs to be put into the config
155
- to verify the whole program.
156
107
157
- """
158
- for test_data_execution_plan in test_data_program :
159
- for test_set in test_data_execution_plan :
160
- assert isinstance (test_set , get_args (DataContainer ))
161
-
162
- @staticmethod
163
- def _gen_execution_plan_tests (
164
- method_names : List [str ],
165
- inputs : List [List [MethodInputType ]],
166
- expected_outputs : List [List [MethodOutputType ]],
167
- ) -> List [ConfigExecutionPlanTest ]:
168
- """Generate execution plan test given inputs, expected outputs for verifying each execution plan"""
169
-
170
- execution_plan_tests : List [ConfigExecutionPlanTest ] = []
171
-
172
- for (
173
- m_name ,
174
- inputs_per_plan_test ,
175
- expect_outputs_per_plan_test ,
176
- ) in zip (method_names , inputs , expected_outputs ):
177
- test_sets : List [ConfigIOSet ] = []
178
-
179
- # transfer I/O sets into ConfigIOSet for each execution plan
180
- assert len (inputs_per_plan_test ) == len (expect_outputs_per_plan_test ), (
181
- "The number of input and expected output for identical execution plan should be the same,"
182
- + " but got {} and {}" .format (
183
- len (inputs_per_plan_test ), len (expect_outputs_per_plan_test )
184
- )
185
- )
186
- for unflatten_input , unflatten_expected_output in zip (
187
- inputs_per_plan_test , expect_outputs_per_plan_test
188
- ):
189
- flatten_inputs = BundledConfig ._tree_flatten (unflatten_input )
190
- flatten_expected_output = BundledConfig ._tree_flatten (
191
- unflatten_expected_output
192
- )
193
- test_sets .append (
194
- ConfigIOSet (
195
- inputs = flatten_inputs , expected_outputs = flatten_expected_output
196
- )
197
- )
198
-
199
- execution_plan_tests .append (
200
- ConfigExecutionPlanTest (
201
- method_name = m_name ,
202
- test_sets = test_sets ,
203
- )
204
- )
108
+ @dataclass
109
+ class MethodTestSuite :
110
+ """All test info related to verify method
205
111
206
- # sort the execution plan tests by method name to in line with core program emitter.
207
- execution_plan_tests .sort (key = lambda x : x .method_name )
112
+ Attributes:
113
+ method_name: Name of the method to be verified.
114
+ test_cases: All test cases for verifying the method.
115
+ """
208
116
209
- return execution_plan_tests
117
+ method_name : str
118
+ test_cases : Sequence [MethodTestCase ]
0 commit comments