@@ -84,8 +84,14 @@ std::vector<SegmentedBlock> injectNodesForNonTensorInputs(SegmentedBlock& seg_bl
84
84
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only
85
85
// one new block
86
86
if (seg_block.target () == SegmentedBlock::kTorch || isAllNodesSupported (dependency_nodes)) {
87
- dependency_nodes.insert (dependency_nodes.end (), seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
88
- new_seg_blocks.emplace_back (seg_block.target (), dependency_nodes);
87
+ // if current node is prim::If, just ensure that we have all required input in kTorch
88
+ if (seg_block.raw_nodes ()[0 ]->kind () == torch::jit::prim::If) {
89
+ new_seg_blocks.emplace_back (seg_block.target (), dependency_nodes);
90
+ new_seg_blocks.push_back (seg_block);
91
+ } else {
92
+ dependency_nodes.insert (dependency_nodes.end (), seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
93
+ new_seg_blocks.emplace_back (seg_block.target (), dependency_nodes);
94
+ }
89
95
} else {
90
96
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
91
97
std::unordered_set<torch::jit::Value*> nontensor_inputs_set (nontensor_inputs.begin (), nontensor_inputs.end ());
@@ -141,8 +147,9 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) {
141
147
if (segmented_blocks[use_info.produce_id ].target () == SegmentedBlock::kTensorRT && !use_info.torch_use_id .empty ()) {
142
148
int first_torch_id = use_info.torch_use_id .front ();
143
149
if (!updated_segments.count (first_torch_id)) {
144
- auto new_torch_block = injectNodesForNonTensorInputs (segmented_blocks[first_torch_id]).front ();
145
- segmented_blocks[first_torch_id] = new_torch_block;
150
+ auto to_inject_blocks = injectNodesForNonTensorInputs (segmented_blocks[first_torch_id]);
151
+ segmented_blocks.erase (segmented_blocks.begin () + first_torch_id);
152
+ segmented_blocks.insert (segmented_blocks.begin () + first_torch_id, to_inject_blocks.begin (), to_inject_blocks.end ());
146
153
updated_segments.insert (first_torch_id);
147
154
}
148
155
} else {
0 commit comments