@@ -198,7 +198,8 @@ void AddIfBlockToGraph(
198
198
199
199
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue (v, new_g, block_graph_to_new_g); };
200
200
new_if_block->cloneFrom (cur_block_graph->block (), env);
201
- if (cur_block_graph->inputs ()[0 ]->type ()->str ().find (" __torch__" ) != std::string::npos) {
201
+ if (cur_block_graph->inputs ().size () &&
202
+ cur_block_graph->inputs ()[0 ]->type ()->str ().find (" __torch__" ) != std::string::npos) {
202
203
if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
203
204
auto self = new_g->insertInput (0 , " self_1" );
204
205
self->setType (cur_block_graph->inputs ()[0 ]->type ());
@@ -223,13 +224,14 @@ GraphAndMapping ConstructFallbackGraph(
223
224
torch::jit::Block* block,
224
225
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
225
226
CompileSpec cfg,
226
- ir::StaticParams static_params) {
227
+ ir::StaticParams static_params,
228
+ std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
227
229
auto convert_cfg = cfg.convert_info ;
228
230
auto partition_info = cfg.partition_info ;
229
231
230
232
auto new_g = std::make_shared<torch::jit::Graph>();
231
233
232
- auto segmented_blocks = partitioning::Partition (block, example_tensor_map, partition_info);
234
+ auto segmented_blocks = partitioning::Partition (block, example_tensor_map, partition_info, fallback_nodes );
233
235
234
236
// the mapping from lowering graph => fallback global graph
235
237
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -270,7 +272,7 @@ GraphAndMapping ConstructFallbackGraph(
270
272
std::vector<GraphAndMapping> graph_and_mappings;
271
273
for (auto cur_block : if_node->blocks ()) {
272
274
graph_and_mappings.push_back (
273
- ConstructFallbackGraph (new_mod, cur_block, example_tensor_map, cfg, static_params));
275
+ ConstructFallbackGraph (new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes ));
274
276
}
275
277
AddIfBlockToGraph (new_g, if_node, graph_and_mappings, old_to_new_g);
276
278
@@ -293,7 +295,7 @@ GraphAndMapping ConstructFallbackGraph(
293
295
// Set the output as the produced tuple
294
296
new_g->registerOutput (return_tuple_node->outputs ()[0 ]);
295
297
} else {
296
- if (old_to_new_g.count (block->outputs ()[0 ])) {
298
+ if (block-> outputs (). size () && old_to_new_g.count (block->outputs ()[0 ])) {
297
299
new_g->registerOutput (old_to_new_g[block->outputs ()[0 ]]);
298
300
}
299
301
}
@@ -430,7 +432,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
430
432
!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
431
433
cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)) {
432
434
auto input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .inputs , first_use_types);
433
- auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, static_params);
435
+ std::unordered_map<torch::jit::Node*, int > fallback_nodes;
436
+ auto graph_and_mapping =
437
+ ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, static_params, fallback_nodes);
434
438
new_g = graph_and_mapping.first ;
435
439
LOG_INFO (" Segmented Graph: " << *new_g);
436
440
0 commit comments