@@ -136,32 +136,38 @@ def assert_valid_bundle(
136
136
137
137
"""
138
138
139
- # Check the number of execution plan tests
140
- assert len (bundled_config .execution_plan_tests ) == len (
141
- program .execution_plan
142
- ), "The length of execution_plan_tests in config should match the length of execution_plan in program, but get {} and {}." .format (
143
- len (bundled_config .execution_plan_tests ), len (program .execution_plan )
144
- )
139
+ program_plan_id = 0
140
+ bp_plan_id = 0
145
141
146
142
# Check if the inputs' type meet Program's requirement
147
- for plan_id in range (len (program .execution_plan )):
143
+ while bp_plan_id < len (bundled_config .execution_plan_tests ):
144
+
148
145
plan_test : ConfigExecutionPlanTest = bundled_config .execution_plan_tests [
149
- plan_id
146
+ bp_plan_id
150
147
]
148
+ plan : ExecutionPlan = program .execution_plan [program_plan_id ]
151
149
152
- plan : ExecutionPlan = program .execution_plan [plan_id ]
150
+ # User does not provide testcases for current plan, skip it
151
+ if plan_test .method_name < plan .name :
152
+ program_plan_id += 1
153
+ continue
154
+
155
+ # Check if the method name in user provided test matches the one in the original program
156
+ assert (
157
+ plan_test .method_name == plan .name
158
+ ), f"BundledConfig has testcases for method { plan_test .method_name } , but can not find it in the given program. All method names in the program are { ', ' .join ([p .name for p in program .execution_plan ])} ."
153
159
154
160
# Check if the type of Program's input is supported
155
161
for index in range (len (plan .inputs )):
156
162
assert (
157
- type (get_program_input (program , plan_id , index ))
163
+ type (get_program_input (program , program_plan_id , index ))
158
164
in supported_program_type_table
159
165
), "The type of program's input isn't supported."
160
166
161
167
# Check if the type of Program's output is supported
162
168
for index in range (len (plan .outputs )):
163
169
assert (
164
- type (get_program_output (program , plan_id , index )) == Tensor
170
+ type (get_program_output (program , program_plan_id , index )) == Tensor
165
171
), "Only supports program with output in Tensor type."
166
172
167
173
# Check if the I/O sets of each execution plan test match program's requirement.
@@ -181,14 +187,14 @@ def assert_valid_bundle(
181
187
assert (
182
188
type (cur_plan_test_inputs [j ])
183
189
== supported_program_type_table [
184
- type (get_program_input (program , plan_id , j ))
190
+ type (get_program_input (program , program_plan_id , j ))
185
191
]
186
192
), "The type {}-th input in {}-th test set of {}-th execution plan does not meet Program's requirement: expected {} but get {}" .format (
187
193
j ,
188
194
i ,
189
- plan_id ,
195
+ program_plan_id ,
190
196
supported_program_type_table [
191
- type (get_program_input (program , plan_id , j ))
197
+ type (get_program_input (program , program_plan_id , j ))
192
198
],
193
199
type (cur_plan_test_inputs [j ]),
194
200
)
@@ -198,10 +204,10 @@ def assert_valid_bundle(
198
204
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
199
205
# has no attribute `dtype`.
200
206
assert cur_plan_test_inputs [j ].dtype == get_input_dtype (
201
- program , plan_id , j
207
+ program , program_plan_id , j
202
208
), "The input tensor {} dtype shall be {}, but now is {}" .format (
203
209
cur_plan_test_inputs [j ],
204
- get_input_dtype (program , plan_id , j ),
210
+ get_input_dtype (program , program_plan_id , j ),
205
211
cur_plan_test_inputs [j ].dtype ,
206
212
)
207
213
elif type (cur_plan_test_inputs [j ]) in (
@@ -210,9 +216,9 @@ def assert_valid_bundle(
210
216
float ,
211
217
):
212
218
assert type (cur_plan_test_inputs [j ]) == get_input_type (
213
- program , plan_id , j
219
+ program , program_plan_id , j
214
220
), "The input primitive dtype shall be {}, but now is {}" .format (
215
- get_input_type (program , plan_id , j ),
221
+ get_input_type (program , program_plan_id , j ),
216
222
type (cur_plan_test_inputs [j ]),
217
223
)
218
224
@@ -221,13 +227,16 @@ def assert_valid_bundle(
221
227
# pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
222
228
# has no attribute `dtype`.
223
229
assert cur_plan_test_expected_outputs [j ].dtype == get_output_dtype (
224
- program , plan_id , j
230
+ program , program_plan_id , j
225
231
), "The label tensor {} dtype shall be {}, but now is {}" .format (
226
232
cur_plan_test_expected_outputs [j ],
227
- get_output_dtype (program , plan_id , j ),
233
+ get_output_dtype (program , program_plan_id , j ),
228
234
cur_plan_test_expected_outputs [j ].dtype ,
229
235
)
230
236
237
+ program_plan_id += 1
238
+ bp_plan_id += 1
239
+
231
240
232
241
def create_bundled_program (
233
242
program : Program ,
@@ -245,10 +254,7 @@ def create_bundled_program(
245
254
execution_plan_tests : List [BundledExecutionPlanTest ] = []
246
255
247
256
# Emit data and metadata of bundled tensor
248
- for plan_id in range (len (program .execution_plan )):
249
- plan_test : ConfigExecutionPlanTest = bundled_config .execution_plan_tests [
250
- plan_id
251
- ]
257
+ for plan_test in bundled_config .execution_plan_tests :
252
258
test_sets : List [BundledIOSet ] = []
253
259
254
260
# emit I/O sets for each execution plan test
@@ -283,7 +289,11 @@ def create_bundled_program(
283
289
)
284
290
285
291
# emit the whole execution plan test
286
- execution_plan_tests .append (BundledExecutionPlanTest (test_sets = test_sets ))
292
+ execution_plan_tests .append (
293
+ BundledExecutionPlanTest (
294
+ method_name = plan_test .method_name , test_sets = test_sets
295
+ )
296
+ )
287
297
288
298
program_bytes : bytes = _serialize_pte_binary (program )
289
299
0 commit comments