Skip to content

Commit b3d9a21

Browse files
committed
Linting
Signed-off-by: Michael Feliz <[email protected]>
1 parent 58df59a commit b3d9a21

File tree

2 files changed

+63
-25
lines changed

2 files changed

+63
-25
lines changed

core/partitioning/partitioning.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock
137137
return std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock>(append_blocks, trt_block);
138138
}
139139

140-
PartitionedGraph segmentBlocksWithSpecifiedInputs(SegmentedBlock& seg_block, const std::vector<torch::jit::Value*> &inputs_to_resolve){
140+
PartitionedGraph segmentBlocksWithSpecifiedInputs(
141+
SegmentedBlock& seg_block,
142+
const std::vector<torch::jit::Value*>& inputs_to_resolve) {
141143
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve);
142144
PartitionedGraph new_seg_blocks;
143145
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
@@ -251,8 +253,9 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
251253
segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); });
252254
auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list);
253255

254-
std::map<int, std::vector<torch::jit::Value*>> torch_values_to_fix; //Only need to resolve values generated by tensorrt
255-
std::set<int> tensorrt_blocks_to_fix; //Need to resolve ALL non-tensor inputs
256+
std::map<int, std::vector<torch::jit::Value*>>
257+
torch_values_to_fix; // Only need to resolve values generated by tensorrt
258+
std::set<int> tensorrt_blocks_to_fix; // Need to resolve ALL non-tensor inputs
256259

257260
// update blocks_list
258261
std::unordered_set<int> updated_segments;
@@ -269,13 +272,14 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
269272
tensorrt_blocks_to_fix.insert(i);
270273
}
271274
}
272-
for(auto torch_block_pair : torch_values_to_fix){
273-
auto to_inject_blocks = segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second);
275+
for (auto torch_block_pair : torch_values_to_fix) {
276+
auto to_inject_blocks =
277+
segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second);
274278
auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]);
275279
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
276280
}
277281

278-
for(auto i : tensorrt_blocks_to_fix){
282+
for (auto i : tensorrt_blocks_to_fix) {
279283
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
280284
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
281285
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());

tests/core/partitioning/test_resolve_nontensor_inputs.cpp

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -293,33 +293,68 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
293293
false_const_val->setType(c10::BoolType::get());
294294
torch::jit::IValue neg_one(-1);
295295
auto neg_one_const_val = g->insertConstant(neg_one);
296-
auto dict_node = g->createDict(ins_key_val->type(), x->type(), torch::jit::ArrayRef<torch::jit::Value*>(), torch::jit::ArrayRef<torch::jit::Value*>());
296+
auto dict_node = g->createDict(
297+
ins_key_val->type(),
298+
x->type(),
299+
torch::jit::ArrayRef<torch::jit::Value*>(),
300+
torch::jit::ArrayRef<torch::jit::Value*>());
297301
g->insertNode(dict_node);
298-
auto set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val, x}, 0);
302+
auto set_node = g->create(
303+
torch::jit::Symbol::fromQualString("aten::_set_item"),
304+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val, x},
305+
0);
299306
g->insertNode(set_node);
300-
auto get_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val}, 1);
307+
auto get_node = g->create(
308+
torch::jit::Symbol::fromQualString("aten::__getitem__"),
309+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val},
310+
1);
301311
g->insertNode(get_node);
302-
auto lt_node = g->create(torch::jit::Symbol::fromQualString("aten::lt"), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), y}, 1);
312+
auto lt_node = g->create(
313+
torch::jit::Symbol::fromQualString("aten::lt"),
314+
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), y},
315+
1);
303316
g->insertNode(lt_node);
304-
auto list_node = g->createList(at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output()});
317+
auto list_node = g->createList(
318+
at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output()});
305319
g->insertNode(list_node);
306-
auto dtype_node = g->create(torch::jit::Symbol::fromQualString("prim::dtype"), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()}, 1);
320+
auto dtype_node = g->create(
321+
torch::jit::Symbol::fromQualString("prim::dtype"),
322+
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()},
323+
1);
307324
dtype_node->output()->setType(neg_one_const_val->type());
308325
g->insertNode(dtype_node);
309-
auto device_node = g->create(torch::jit::Symbol::fromQualString("prim::device"), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()}, 1);
326+
auto device_node = g->create(
327+
torch::jit::Symbol::fromQualString("prim::device"),
328+
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()},
329+
1);
310330
device_node->output()->setType(c10::DeviceObjType::get());
311331
g->insertNode(device_node);
312-
auto tensor_node = g->create(torch::jit::Symbol::fromQualString("aten::tensor"), torch::jit::ArrayRef<torch::jit::Value*>{neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val}, 1);
332+
auto tensor_node = g->create(
333+
torch::jit::Symbol::fromQualString("aten::tensor"),
334+
torch::jit::ArrayRef<torch::jit::Value*>{
335+
neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val},
336+
1);
313337
g->insertNode(tensor_node);
314-
auto index_put_node = g->create(torch::jit::Symbol::fromQualString("aten::index_put_"),
315-
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), list_node->output(), tensor_node->output(), false_const_val}, 1);
338+
auto index_put_node = g->create(
339+
torch::jit::Symbol::fromQualString("aten::index_put_"),
340+
torch::jit::ArrayRef<torch::jit::Value*>{
341+
get_node->output(), list_node->output(), tensor_node->output(), false_const_val},
342+
1);
316343
g->insertNode(index_put_node);
317-
auto out_set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"),
318-
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val, get_node->output()}, 0);
344+
auto out_set_node = g->create(
345+
torch::jit::Symbol::fromQualString("aten::_set_item"),
346+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val, get_node->output()},
347+
0);
319348
g->insertNode(out_set_node);
320-
auto get_ins_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val}, 1);
349+
auto get_ins_node = g->create(
350+
torch::jit::Symbol::fromQualString("aten::__getitem__"),
351+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val},
352+
1);
321353
g->insertNode(get_ins_node);
322-
auto get_outs_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val}, 1);
354+
auto get_outs_node = g->create(
355+
torch::jit::Symbol::fromQualString("aten::__getitem__"),
356+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val},
357+
1);
323358
g->insertNode(get_outs_node);
324359
g->registerOutput(get_ins_node->output());
325360
g->registerOutput(get_outs_node->output());
@@ -337,10 +372,9 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
337372
input_types.insert({g->inputs()[i], {at::kFloat}});
338373
}
339374
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
340-
auto segmented_blocks =
341-
torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info);
375+
auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info);
342376

343-
int torch_block_cnt = 0, trt_block_cnt = 0;
377+
int torch_block_cnt = 0, trt_block_cnt = 0;
344378
for (const auto& segmented_block : segmented_blocks) {
345379
if (segmented_block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT) {
346380
++trt_block_cnt;
@@ -353,12 +387,12 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
353387
bool input_dict = false;
354388
auto dict_type = dict_node->output()->type();
355389
for (auto in : segmented_block.raw_inputs()) {
356-
if(in->type()->isSubtypeOf(dict_type)){
390+
if (in->type()->isSubtypeOf(dict_type)) {
357391
input_dict = true;
358392
}
359393
}
360394
for (auto out : segmented_block.raw_outputs()) {
361-
if(out->type()->isSubtypeOf(dict_type)){
395+
if (out->type()->isSubtypeOf(dict_type)) {
362396
output_dict = true;
363397
}
364398
}

0 commit comments

Comments
 (0)