@@ -293,33 +293,68 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
293
293
false_const_val->setType (c10::BoolType::get ());
294
294
torch::jit::IValue neg_one (-1 );
295
295
auto neg_one_const_val = g->insertConstant (neg_one);
296
- auto dict_node = g->createDict (ins_key_val->type (), x->type (), torch::jit::ArrayRef<torch::jit::Value*>(), torch::jit::ArrayRef<torch::jit::Value*>());
296
+ auto dict_node = g->createDict (
297
+ ins_key_val->type (),
298
+ x->type (),
299
+ torch::jit::ArrayRef<torch::jit::Value*>(),
300
+ torch::jit::ArrayRef<torch::jit::Value*>());
297
301
g->insertNode (dict_node);
298
- auto set_node = g->create (torch::jit::Symbol::fromQualString (" aten::_set_item" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val, x}, 0 );
302
+ auto set_node = g->create (
303
+ torch::jit::Symbol::fromQualString (" aten::_set_item" ),
304
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val, x},
305
+ 0 );
299
306
g->insertNode (set_node);
300
- auto get_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val}, 1 );
307
+ auto get_node = g->create (
308
+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
309
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
310
+ 1 );
301
311
g->insertNode (get_node);
302
- auto lt_node = g->create (torch::jit::Symbol::fromQualString (" aten::lt" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), y}, 1 );
312
+ auto lt_node = g->create (
313
+ torch::jit::Symbol::fromQualString (" aten::lt" ),
314
+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), y},
315
+ 1 );
303
316
g->insertNode (lt_node);
304
- auto list_node = g->createList (at::OptionalType::create (lt_node->output ()->type ()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output ()});
317
+ auto list_node = g->createList (
318
+ at::OptionalType::create (lt_node->output ()->type ()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output ()});
305
319
g->insertNode (list_node);
306
- auto dtype_node = g->create (torch::jit::Symbol::fromQualString (" prim::dtype" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()}, 1 );
320
+ auto dtype_node = g->create (
321
+ torch::jit::Symbol::fromQualString (" prim::dtype" ),
322
+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
323
+ 1 );
307
324
dtype_node->output ()->setType (neg_one_const_val->type ());
308
325
g->insertNode (dtype_node);
309
- auto device_node = g->create (torch::jit::Symbol::fromQualString (" prim::device" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()}, 1 );
326
+ auto device_node = g->create (
327
+ torch::jit::Symbol::fromQualString (" prim::device" ),
328
+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
329
+ 1 );
310
330
device_node->output ()->setType (c10::DeviceObjType::get ());
311
331
g->insertNode (device_node);
312
- auto tensor_node = g->create (torch::jit::Symbol::fromQualString (" aten::tensor" ), torch::jit::ArrayRef<torch::jit::Value*>{neg_one_const_val, dtype_node->output (), device_node->output (), false_const_val}, 1 );
332
+ auto tensor_node = g->create (
333
+ torch::jit::Symbol::fromQualString (" aten::tensor" ),
334
+ torch::jit::ArrayRef<torch::jit::Value*>{
335
+ neg_one_const_val, dtype_node->output (), device_node->output (), false_const_val},
336
+ 1 );
313
337
g->insertNode (tensor_node);
314
- auto index_put_node = g->create (torch::jit::Symbol::fromQualString (" aten::index_put_" ),
315
- torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), list_node->output (), tensor_node->output (), false_const_val}, 1 );
338
+ auto index_put_node = g->create (
339
+ torch::jit::Symbol::fromQualString (" aten::index_put_" ),
340
+ torch::jit::ArrayRef<torch::jit::Value*>{
341
+ get_node->output (), list_node->output (), tensor_node->output (), false_const_val},
342
+ 1 );
316
343
g->insertNode (index_put_node);
317
- auto out_set_node = g->create (torch::jit::Symbol::fromQualString (" aten::_set_item" ),
318
- torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val, get_node->output ()}, 0 );
344
+ auto out_set_node = g->create (
345
+ torch::jit::Symbol::fromQualString (" aten::_set_item" ),
346
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val, get_node->output ()},
347
+ 0 );
319
348
g->insertNode (out_set_node);
320
- auto get_ins_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val}, 1 );
349
+ auto get_ins_node = g->create (
350
+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
351
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
352
+ 1 );
321
353
g->insertNode (get_ins_node);
322
- auto get_outs_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val}, 1 );
354
+ auto get_outs_node = g->create (
355
+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
356
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val},
357
+ 1 );
323
358
g->insertNode (get_outs_node);
324
359
g->registerOutput (get_ins_node->output ());
325
360
g->registerOutput (get_outs_node->output ());
@@ -337,10 +372,9 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
337
372
input_types.insert ({g->inputs ()[i], {at::kFloat }});
338
373
}
339
374
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs (inputs_map, input_types);
340
- auto segmented_blocks =
341
- torch_tensorrt::core::partitioning::Partition (g->block (), input_ivalues_map, partition_info);
375
+ auto segmented_blocks = torch_tensorrt::core::partitioning::Partition (g->block (), input_ivalues_map, partition_info);
342
376
343
- int torch_block_cnt = 0 , trt_block_cnt = 0 ;
377
+ int torch_block_cnt = 0 , trt_block_cnt = 0 ;
344
378
for (const auto & segmented_block : segmented_blocks) {
345
379
if (segmented_block.target () == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT ) {
346
380
++trt_block_cnt;
@@ -353,12 +387,12 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
353
387
bool input_dict = false ;
354
388
auto dict_type = dict_node->output ()->type ();
355
389
for (auto in : segmented_block.raw_inputs ()) {
356
- if (in->type ()->isSubtypeOf (dict_type)){
390
+ if (in->type ()->isSubtypeOf (dict_type)) {
357
391
input_dict = true ;
358
392
}
359
393
}
360
394
for (auto out : segmented_block.raw_outputs ()) {
361
- if (out->type ()->isSubtypeOf (dict_type)){
395
+ if (out->type ()->isSubtypeOf (dict_type)) {
362
396
output_dict = true ;
363
397
}
364
398
}
0 commit comments