Skip to content

Commit 726b031

Browse files
authored
Merge pull request #817 from cyfwry/811
Fix the bug that fallback does not support more than one output
2 parents 68dd005 + a874e35 commit 726b031

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

core/compiler.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,21 @@ GraphAndMapping ConstructFallbackGraph(
280280
}
281281
}
282282

283-
for (auto& output : block->outputs()) {
284-
if (old_to_new_g.count(output)) {
285-
new_g->registerOutput(old_to_new_g[output]);
283+
if (block->outputs().size() > 1) {
284+
std::vector<torch::jit::Value*> fallback_graph_vector;
285+
for (auto& output : block->outputs()) {
286+
if (old_to_new_g.count(output)) {
287+
fallback_graph_vector.push_back(old_to_new_g[output]);
288+
}
289+
}
290+
torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs(fallback_graph_vector);
291+
auto return_tuple_node = new_g->createTuple(fallback_graph_outputs);
292+
new_g->block()->appendNode(return_tuple_node);
293+
// Set the output as the produced tuple
294+
new_g->registerOutput(return_tuple_node->outputs()[0]);
295+
} else {
296+
if (old_to_new_g.count(block->outputs()[0])) {
297+
new_g->registerOutput(old_to_new_g[block->outputs()[0]]);
286298
}
287299
}
288300
return {new_g, old_to_new_g};

0 commit comments

Comments
 (0)