File tree Expand file tree Collapse file tree 1 file changed +15
-3
lines changed Expand file tree Collapse file tree 1 file changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -280,9 +280,21 @@ GraphAndMapping ConstructFallbackGraph(
280
280
}
281
281
}
282
282
283
- for (auto & output : block->outputs ()) {
284
- if (old_to_new_g.count (output)) {
285
- new_g->registerOutput (old_to_new_g[output]);
283
+ if (block->outputs ().size () > 1 ) {
284
+ std::vector<torch::jit::Value*> fallback_graph_vector;
285
+ for (auto & output : block->outputs ()) {
286
+ if (old_to_new_g.count (output)) {
287
+ fallback_graph_vector.push_back (old_to_new_g[output]);
288
+ }
289
+ }
290
+ torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs (fallback_graph_vector);
291
+ auto return_tuple_node = new_g->createTuple (fallback_graph_outputs);
292
+ new_g->block ()->appendNode (return_tuple_node);
293
+ // Set the output as the produced tuple
294
+ new_g->registerOutput (return_tuple_node->outputs ()[0 ]);
295
+ } else {
296
+ if (old_to_new_g.count (block->outputs ()[0 ])) {
297
+ new_g->registerOutput (old_to_new_g[block->outputs ()[0 ]]);
286
298
}
287
299
}
288
300
return {new_g, old_to_new_g};
You can’t perform that action at this time.
0 commit comments