Skip to content

Commit 983eabd

Browse files
authored
Merge pull request #1140 from pytorch/pr1067
fix(tests/core/partitioning): Fix tests of refactoring segmentation in partitioning
2 parents a3432e2 + 85306d8 commit 983eabd

File tree

10 files changed

+225
-352
lines changed

10 files changed

+225
-352
lines changed

core/compiler.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ void AddIfBlockToGraph(
198198

199199
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
200200
new_if_block->cloneFrom(cur_block_graph->block(), env);
201-
if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
201+
if (cur_block_graph->inputs().size() &&
202+
cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
202203
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
203204
auto self = new_g->insertInput(0, "self_1");
204205
self->setType(cur_block_graph->inputs()[0]->type());
@@ -223,13 +224,14 @@ GraphAndMapping ConstructFallbackGraph(
223224
torch::jit::Block* block,
224225
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
225226
CompileSpec cfg,
226-
ir::StaticParams static_params) {
227+
ir::StaticParams static_params,
228+
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
227229
auto convert_cfg = cfg.convert_info;
228230
auto partition_info = cfg.partition_info;
229231

230232
auto new_g = std::make_shared<torch::jit::Graph>();
231233

232-
auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info);
234+
auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes);
233235

234236
// the mapping from lowering graph => fallback global graph
235237
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -270,7 +272,7 @@ GraphAndMapping ConstructFallbackGraph(
270272
std::vector<GraphAndMapping> graph_and_mappings;
271273
for (auto cur_block : if_node->blocks()) {
272274
graph_and_mappings.push_back(
273-
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params));
275+
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
274276
}
275277
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
276278

@@ -293,7 +295,7 @@ GraphAndMapping ConstructFallbackGraph(
293295
// Set the output as the produced tuple
294296
new_g->registerOutput(return_tuple_node->outputs()[0]);
295297
} else {
296-
if (old_to_new_g.count(block->outputs()[0])) {
298+
if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) {
297299
new_g->registerOutput(old_to_new_g[block->outputs()[0]]);
298300
}
299301
}
@@ -430,7 +432,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
430432
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
431433
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
432434
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
433-
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
435+
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
436+
auto graph_and_mapping =
437+
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
434438
new_g = graph_and_mapping.first;
435439
LOG_INFO("Segmented Graph: " << *new_g);
436440

0 commit comments

Comments
 (0)