Skip to content

Commit 9823fff

Browse files
committed
chore: improve the implementation of prim::if
Signed-off-by: Bo Wang <[email protected]>
1 parent a64c501 commit 9823fff

File tree

3 files changed

+72
-54
lines changed

3 files changed

+72
-54
lines changed

core/compiler.cpp

Lines changed: 69 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,62 @@ void AddSegmentedBlockToGraph(
182182
return;
183183
}
184184

185-
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>> GraphAndMapping;
185+
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
186+
GraphAndMapping;
186187

187-
GraphAndMapping ConstructFallbackGraph(torch::jit::script::Module& new_mod, torch::jit::Block* block,
188-
std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map,
189-
CompileSpec cfg, int& trt_engine_id, conversion::GraphParams named_params) {
188+
void AddIfBlockToGraph(
189+
std::shared_ptr<torch::jit::Graph>& new_g,
190+
torch::jit::Node* if_node,
191+
const std::vector<GraphAndMapping>& graph_and_mappings,
192+
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
193+
torch::jit::IfView if_view(if_node);
194+
195+
// create a new if node in new_g and add corresponding inputs
196+
auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
197+
new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));
198+
199+
for (auto graph_and_mapping : graph_and_mappings) {
200+
auto new_if_block = new_if->addBlock();
201+
auto cur_block_graph = graph_and_mapping.first;
202+
auto cur_block_mapping = graph_and_mapping.second;
203+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
204+
for (auto& i : cur_block_mapping) {
205+
// for every pair in then_mapping, old_value => then value, if old_value also appears in old_to_new_g, then it's
206+
// then graph's input
207+
if (old_to_new_g.count(i.first)) {
208+
block_graph_to_new_g[i.second] = old_to_new_g[i.first];
209+
}
210+
}
211+
212+
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
213+
new_if_block->cloneFrom(cur_block_graph->block(), env);
214+
if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
215+
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
216+
auto self = new_g->insertInput(0, "self_1");
217+
self->setType(loop_graph->inputs()[0]->type());
218+
}
219+
block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
220+
}
221+
for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) {
222+
new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]);
223+
new_if_block->eraseInput(i);
224+
}
225+
}
226+
for (auto ov : if_view.outputs()) {
227+
auto no = new_if->addOutput();
228+
old_to_new_g[ov] = no;
229+
no->copyMetadata(ov);
230+
}
231+
return;
232+
}
233+
234+
GraphAndMapping ConstructFallbackGraph(
235+
torch::jit::script::Module& new_mod,
236+
torch::jit::Block* block,
237+
std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map,
238+
CompileSpec cfg,
239+
int& trt_engine_id,
240+
conversion::GraphParams named_params) {
190241
auto convert_cfg = cfg.convert_info;
191242
auto partition_info = cfg.partition_info;
192243

@@ -218,51 +269,16 @@ GraphAndMapping ConstructFallbackGraph(torch::jit::script::Module& new_mod, torc
218269
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
219270
} else {
220271
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
221-
auto outer_node = seg_block.raw_nodes()[0];
222-
torch::jit::IfView if_view(outer_node);
223-
272+
auto if_node = seg_block.raw_nodes()[0];
224273

225274
// convert the 2 blocks in prim::if and get the converted graph with mappings
226275
std::vector<GraphAndMapping> graph_and_mappings;
227-
for (auto cur_block : outer_node->blocks()) {
228-
graph_and_mappings.push_back(ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, trt_engine_id, named_params));
276+
for (auto cur_block : if_node->blocks()) {
277+
graph_and_mappings.push_back(
278+
ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, trt_engine_id, named_params));
229279
}
280+
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
230281

231-
// create a new if node in new_g and add corresponding inputs
232-
auto new_if =
233-
new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
234-
new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));
235-
236-
237-
for (auto graph_and_mapping : graph_and_mappings) {
238-
auto new_if_block = new_if->addBlock();
239-
auto cur_block_graph = graph_and_mapping.first;
240-
auto cur_block_mapping = graph_and_mapping.second;
241-
std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
242-
for (auto& i : cur_block_mapping) {
243-
// for every pair in then_mapping, old_value => then value, if old_value also appears in old_to_new_g, then it's then graph's input
244-
if (old_to_new_g.count(i.first)) {
245-
block_graph_to_new_g[i.second] = old_to_new_g[i.first];
246-
}
247-
}
248-
249-
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
250-
new_if_block->cloneFrom(cur_block_graph->block(), env);
251-
if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
252-
block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
253-
}
254-
for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) {
255-
new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]);
256-
new_if_block->eraseInput(i);
257-
}
258-
}
259-
for (auto ov : if_view.outputs()) {
260-
auto no = new_if->addOutput();
261-
old_to_new_g[ov] = no;
262-
no->copyMetadata(ov);
263-
}
264-
265-
LOG_INFO(*new_g << "new g with if\n");
266282
} else {
267283
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
268284
}
@@ -294,23 +310,24 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
294310
auto named_params = conversion::get_named_params(g->inputs(), params);
295311
LOG_INFO(*g << "(LoweringGraph)\n");
296312

297-
// segment the graph and convert segmented TensorRT block
298-
// auto segmented_blocks = partitioning::Partition(g->block(), convert_cfg.input_ranges, cfg.partition_info);
299-
// if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
300-
// LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
301-
// return mod;
302-
// }
303-
304313
int trt_engine_id = 0;
305314
std::unordered_map<torch::jit::Value*, ir::InputRange> input_ranges;
306315
for (size_t i = 0; i < g->inputs().size(); ++i) {
307316
input_ranges.insert({g->inputs()[i], cfg.convert_info.input_ranges[i]});
308317
}
309318
auto input_ivalues_map = partitioning::generateRandomInputs(input_ranges);
310-
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, trt_engine_id, named_params);
319+
auto graph_and_mapping =
320+
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, trt_engine_id, named_params);
311321
new_g = graph_and_mapping.first;
312322
LOG_INFO(*new_g << "(FallbackGraph)\n");
313323

324+
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
325+
// module
326+
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
327+
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
328+
return mod;
329+
}
330+
314331
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
315332
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
316333
new_mod.type()->addMethod(new_method);

core/partitioning/partitioning.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
223223
std::unordered_set<std::string> forced_fallback_operators(
224224
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());
225225

226-
227226
auto nodes = block->nodes();
228227
std::vector<SegmentedBlock> segmented_blocks;
229228

core/partitioning/shape_analysis.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ void getSegmentsOutputByRunning(
5252

5353
// set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments
5454
for (auto& input : seg_block.raw_inputs()) {
55-
TRTORCH_CHECK(ivalues_maps.count(input), "Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n");
55+
TRTORCH_CHECK(
56+
ivalues_maps.count(input),
57+
"Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n");
5658
if (input->node()->kind() == torch::jit::prim::Param) {
5759
jit_inputs_ivalues.push_back(ivalues_maps[input]);
5860
} else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {

0 commit comments

Comments
 (0)