@@ -90,7 +90,6 @@ std::vector<SegmentedBlock> injectNodesForNonTensorInputs(SegmentedBlock& seg_bl
90
90
} else {
91
91
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
92
92
std::unordered_set<torch::jit::Value*> nontensor_inputs_set (nontensor_inputs.begin (), nontensor_inputs.end ());
93
- new_seg_blocks.emplace_back (SegmentedBlock::kTorch , dependency_nodes);
94
93
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
95
94
bool prev_non_tensor_outputs = false ;
96
95
for (auto n : seg_block.raw_nodes ()) {
@@ -204,17 +203,16 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr
204
203
}
205
204
}
206
205
}
207
- std::for_each (
208
- segmented_blocks.begin (),
209
- segmented_blocks.end (),
210
- [](SegmentedBlock& seg_block) { torch::jit::EliminateDeadCode (seg_block.g ()); });
211
- // erase segments which still have no output
212
- segmented_blocks.erase (
213
- std::remove_if (
214
- segmented_blocks.begin (),
215
- segmented_blocks.end (),
216
- [](SegmentedBlock& seg_block) { return seg_block.raw_outputs ().empty (); }),
217
- segmented_blocks.end ());
206
+ std::for_each (segmented_blocks.begin (), segmented_blocks.end (), [](SegmentedBlock& seg_block) {
207
+ torch::jit::EliminateDeadCode (seg_block.g ());
208
+ });
209
+ // erase segments which still have no output
210
+ segmented_blocks.erase (
211
+ std::remove_if (
212
+ segmented_blocks.begin (),
213
+ segmented_blocks.end (),
214
+ [](SegmentedBlock& seg_block) { return seg_block.raw_outputs ().empty (); }),
215
+ segmented_blocks.end ());
218
216
219
217
return ;
220
218
}
0 commit comments