Skip to content

Commit d00627f

Browse files
committed
fix: make sure that prim::if is in raw_nodes()[0] in dependency analysis
Signed-off-by: Bo Wang <[email protected]>
1 parent 9823fff commit d00627f

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

core/compiler.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,15 @@ void AddIfBlockToGraph(
196196
auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
197197
new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));
198198

199+
// iterate over all blocks and add them to new created prim::If
199200
for (auto graph_and_mapping : graph_and_mappings) {
200201
auto new_if_block = new_if->addBlock();
201202
auto cur_block_graph = graph_and_mapping.first;
202203
auto cur_block_mapping = graph_and_mapping.second;
203204
std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
204205
for (auto& i : cur_block_mapping) {
205-
// for every pair in then_mapping, old_value => then value, if old_value also appears in old_to_new_g, then it's
206-
// then graph's input
206+
// for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then it's
207+
// mini graph's input
207208
if (old_to_new_g.count(i.first)) {
208209
block_graph_to_new_g[i.second] = old_to_new_g[i.first];
209210
}
@@ -214,7 +215,7 @@ void AddIfBlockToGraph(
214215
if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
215216
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
216217
auto self = new_g->insertInput(0, "self_1");
217-
self->setType(loop_graph->inputs()[0]->type());
218+
self->setType(cur_block_graph->inputs()[0]->type());
218219
}
219220
block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
220221
}

core/partitioning/partitioning.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,14 @@ std::vector<SegmentedBlock> injectNodesForNonTensorInputs(SegmentedBlock& seg_bl
8484
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only
8585
// one new block
8686
if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) {
87-
dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
88-
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
87+
// if current node is prim::If, just ensure that we have all required input in kTorch
88+
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
89+
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
90+
new_seg_blocks.push_back(seg_block);
91+
} else {
92+
dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
93+
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
94+
}
8995
} else {
9096
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
9197
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
@@ -141,8 +147,9 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) {
141147
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
142148
int first_torch_id = use_info.torch_use_id.front();
143149
if (!updated_segments.count(first_torch_id)) {
144-
auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]).front();
145-
segmented_blocks[first_torch_id] = new_torch_block;
150+
auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]);
151+
segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
152+
segmented_blocks.insert(segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
146153
updated_segments.insert(first_torch_id);
147154
}
148155
} else {

0 commit comments

Comments
 (0)