@@ -136,32 +136,56 @@ def assert_valid_bundle(
136
136
137
137
"""
138
138
139
- # Check the number of execution plan tests
140
- assert len (bundled_config .execution_plan_tests ) == len (
141
- program .execution_plan
142
- ), "The length of execution_plan_tests in config should match the length of execution_plan in program, but get {} and {}." .format (
143
- len (bundled_config .execution_plan_tests ), len (program .execution_plan )
144
- )
139
+ program_plan_id = 0
140
+ bp_plan_id = 0
141
+
142
+ method_name_of_program = {e .name for e in program .execution_plan }
143
+ method_name_of_test_suites = {
144
+ t .method_name for t in bundled_config .execution_plan_tests
145
+ }
146
+
147
+ assert method_name_of_test_suites .issubset (
148
+ method_name_of_program
149
+ ), f"All methods in method_test_suites should be found in program.execution_plan, \
150
+ but { str (method_name_of_test_suites - method_name_of_program )} does not include."
151
+
152
+ # check if method_tesdt_suites has been sorted in ascending alphabetical order of method name.
153
+ for bp_plan_id in range (1 , len (bundled_config .execution_plan_tests )):
154
+ assert (
155
+ bundled_config .execution_plan_tests [bp_plan_id - 1 ].method_name
156
+ <= bundled_config .execution_plan_tests [bp_plan_id ].method_name
157
+ ), f"The method name of test suite should be sorted in ascending alphabetical \
158
+ order of method name, but { bp_plan_id - 1 } -th and { bp_plan_id } -th method_test_suite aren't."
145
159
146
160
# Check if the inputs' type meet Program's requirement
147
- for plan_id in range (len (program .execution_plan )):
161
+ while bp_plan_id < len (bundled_config .execution_plan_tests ):
162
+
148
163
plan_test : ConfigExecutionPlanTest = bundled_config .execution_plan_tests [
149
- plan_id
164
+ bp_plan_id
150
165
]
166
+ plan : ExecutionPlan = program .execution_plan [program_plan_id ]
151
167
152
- plan : ExecutionPlan = program .execution_plan [plan_id ]
168
+ # User does not provide testcases for current plan, skip it
169
+ if plan_test .method_name > plan .name :
170
+ program_plan_id += 1
171
+ continue
172
+
173
+ # Check if the method name in user provided test matches the one in the original program
174
+ assert (
175
+ plan_test .method_name == plan .name
176
+ ), f"BundledConfig has testcases for method { plan_test .method_name } , but can not find it in the given program. All method names in the program are { ', ' .join ([p .name for p in program .execution_plan ])} ."
153
177
154
178
# Check if the type of Program's input is supported
155
179
for index in range (len (plan .inputs )):
156
180
assert (
157
- type (get_program_input (program , plan_id , index ))
181
+ type (get_program_input (program , program_plan_id , index ))
158
182
in supported_program_type_table
159
183
), "The type of program's input isn't supported."
160
184
161
185
# Check if the type of Program's output is supported
162
186
for index in range (len (plan .outputs )):
163
187
assert (
164
- type (get_program_output (program , plan_id , index )) == Tensor
188
+ type (get_program_output (program , program_plan_id , index )) == Tensor
165
189
), "Only supports program with output in Tensor type."
166
190
167
191
# Check if the I/O sets of each execution plan test match program's requirement.
@@ -181,14 +205,14 @@ def assert_valid_bundle(
181
205
assert (
182
206
type (cur_plan_test_inputs [j ])
183
207
== supported_program_type_table [
184
- type (get_program_input (program , plan_id , j ))
208
+ type (get_program_input (program , program_plan_id , j ))
185
209
]
186
210
), "The type {}-th input in {}-th test set of {}-th execution plan does not meet Program's requirement: expected {} but get {}" .format (
187
211
j ,
188
212
i ,
189
- plan_id ,
213
+ program_plan_id ,
190
214
supported_program_type_table [
191
- type (get_program_input (program , plan_id , j ))
215
+ type (get_program_input (program , program_plan_id , j ))
192
216
],
193
217
type (cur_plan_test_inputs [j ]),
194
218
)
@@ -198,10 +222,10 @@ def assert_valid_bundle(
198
222
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
199
223
# has no attribute `dtype`.
200
224
assert cur_plan_test_inputs [j ].dtype == get_input_dtype (
201
- program , plan_id , j
225
+ program , program_plan_id , j
202
226
), "The input tensor {} dtype shall be {}, but now is {}" .format (
203
227
cur_plan_test_inputs [j ],
204
- get_input_dtype (program , plan_id , j ),
228
+ get_input_dtype (program , program_plan_id , j ),
205
229
cur_plan_test_inputs [j ].dtype ,
206
230
)
207
231
elif type (cur_plan_test_inputs [j ]) in (
@@ -210,9 +234,9 @@ def assert_valid_bundle(
210
234
float ,
211
235
):
212
236
assert type (cur_plan_test_inputs [j ]) == get_input_type (
213
- program , plan_id , j
237
+ program , program_plan_id , j
214
238
), "The input primitive dtype shall be {}, but now is {}" .format (
215
- get_input_type (program , plan_id , j ),
239
+ get_input_type (program , program_plan_id , j ),
216
240
type (cur_plan_test_inputs [j ]),
217
241
)
218
242
@@ -221,13 +245,16 @@ def assert_valid_bundle(
221
245
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
222
246
# has no attribute `dtype`.
223
247
assert cur_plan_test_expected_outputs [j ].dtype == get_output_dtype (
224
- program , plan_id , j
248
+ program , program_plan_id , j
225
249
), "The label tensor {} dtype shall be {}, but now is {}" .format (
226
250
cur_plan_test_expected_outputs [j ],
227
- get_output_dtype (program , plan_id , j ),
251
+ get_output_dtype (program , program_plan_id , j ),
228
252
cur_plan_test_expected_outputs [j ].dtype ,
229
253
)
230
254
255
+ program_plan_id += 1
256
+ bp_plan_id += 1
257
+
231
258
232
259
def create_bundled_program (
233
260
program : Program ,
@@ -245,10 +272,7 @@ def create_bundled_program(
245
272
execution_plan_tests : List [BundledExecutionPlanTest ] = []
246
273
247
274
# Emit data and metadata of bundled tensor
248
- for plan_id in range (len (program .execution_plan )):
249
- plan_test : ConfigExecutionPlanTest = bundled_config .execution_plan_tests [
250
- plan_id
251
- ]
275
+ for plan_test in bundled_config .execution_plan_tests :
252
276
test_sets : List [BundledIOSet ] = []
253
277
254
278
# emit I/O sets for each execution plan test
@@ -283,7 +307,11 @@ def create_bundled_program(
283
307
)
284
308
285
309
# emit the whole execution plan test
286
- execution_plan_tests .append (BundledExecutionPlanTest (test_sets = test_sets ))
310
+ execution_plan_tests .append (
311
+ BundledExecutionPlanTest (
312
+ method_name = plan_test .method_name , test_sets = test_sets
313
+ )
314
+ )
287
315
288
316
program_bytes : bytes = _serialize_pte_binary (program )
289
317
0 commit comments