@@ -182,11 +182,62 @@ void AddSegmentedBlockToGraph(
182
182
return ;
183
183
}
184
184
185
- typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>> GraphAndMapping;
185
+ typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
186
+ GraphAndMapping;
186
187
187
- GraphAndMapping ConstructFallbackGraph (torch::jit::script::Module& new_mod, torch::jit::Block* block,
188
- std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map,
189
- CompileSpec cfg, int & trt_engine_id, conversion::GraphParams named_params) {
188
+ void AddIfBlockToGraph (
189
+ std::shared_ptr<torch::jit::Graph>& new_g,
190
+ torch::jit::Node* if_node,
191
+ const std::vector<GraphAndMapping>& graph_and_mappings,
192
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
193
+ torch::jit::IfView if_view (if_node);
194
+
195
+ // create a new if node in new_g and add corresponding inputs
196
+ auto new_if = new_g->insertNode (new_g->create (torch::jit::prim::If, {}, 0 ));
197
+ new_if->addInput (util::getOrAddInputForValue (if_view.cond (), new_g, old_to_new_g));
198
+
199
+ for (auto graph_and_mapping : graph_and_mappings) {
200
+ auto new_if_block = new_if->addBlock ();
201
+ auto cur_block_graph = graph_and_mapping.first ;
202
+ auto cur_block_mapping = graph_and_mapping.second ;
203
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
204
+ for (auto & i : cur_block_mapping) {
205
+ // for every pair in then_mapping, old_value => then value, if old_value also appears in old_to_new_g, then it's
206
+ // then graph's input
207
+ if (old_to_new_g.count (i.first )) {
208
+ block_graph_to_new_g[i.second ] = old_to_new_g[i.first ];
209
+ }
210
+ }
211
+
212
+ auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue (v, new_g, block_graph_to_new_g); };
213
+ new_if_block->cloneFrom (cur_block_graph->block (), env);
214
+ if (cur_block_graph->inputs ()[0 ]->type ()->str ().find (" __torch__" ) != std::string::npos) {
215
+ if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
216
+ auto self = new_g->insertInput (0 , " self_1" );
217
+ self->setType (loop_graph->inputs ()[0 ]->type ());
218
+ }
219
+ block_graph_to_new_g[cur_block_graph->inputs ()[0 ]] = new_g->inputs ()[0 ];
220
+ }
221
+ for (int i = cur_block_graph->inputs ().size () - 1 ; i >= 0 ; --i) {
222
+ new_if_block->inputs ()[i]->replaceAllUsesWith (block_graph_to_new_g[cur_block_graph->inputs ()[i]]);
223
+ new_if_block->eraseInput (i);
224
+ }
225
+ }
226
+ for (auto ov : if_view.outputs ()) {
227
+ auto no = new_if->addOutput ();
228
+ old_to_new_g[ov] = no;
229
+ no->copyMetadata (ov);
230
+ }
231
+ return ;
232
+ }
233
+
234
+ GraphAndMapping ConstructFallbackGraph (
235
+ torch::jit::script::Module& new_mod,
236
+ torch::jit::Block* block,
237
+ std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map,
238
+ CompileSpec cfg,
239
+ int & trt_engine_id,
240
+ conversion::GraphParams named_params) {
190
241
auto convert_cfg = cfg.convert_info ;
191
242
auto partition_info = cfg.partition_info ;
192
243
@@ -218,51 +269,16 @@ GraphAndMapping ConstructFallbackGraph(torch::jit::script::Module& new_mod, torc
218
269
AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
219
270
} else {
220
271
if (seg_block.raw_nodes ()[0 ]->kind () == torch::jit::prim::If) {
221
- auto outer_node = seg_block.raw_nodes ()[0 ];
222
- torch::jit::IfView if_view (outer_node);
223
-
272
+ auto if_node = seg_block.raw_nodes ()[0 ];
224
273
225
274
// convert the 2 blocks in prim::if and get the converted graph with mappings
226
275
std::vector<GraphAndMapping> graph_and_mappings;
227
- for (auto cur_block : outer_node->blocks ()) {
228
- graph_and_mappings.push_back (ConstructFallbackGraph (new_mod, cur_block, input_ivalues_map, cfg, trt_engine_id, named_params));
276
+ for (auto cur_block : if_node->blocks ()) {
277
+ graph_and_mappings.push_back (
278
+ ConstructFallbackGraph (new_mod, cur_block, input_ivalues_map, cfg, trt_engine_id, named_params));
229
279
}
280
+ AddIfBlockToGraph (new_g, if_node, graph_and_mappings, old_to_new_g);
230
281
231
- // create a new if node in new_g and add corresponding inputs
232
- auto new_if =
233
- new_g->insertNode (new_g->create (torch::jit::prim::If, {}, 0 ));
234
- new_if->addInput (util::getOrAddInputForValue (if_view.cond (), new_g, old_to_new_g));
235
-
236
-
237
- for (auto graph_and_mapping : graph_and_mappings) {
238
- auto new_if_block = new_if->addBlock ();
239
- auto cur_block_graph = graph_and_mapping.first ;
240
- auto cur_block_mapping = graph_and_mapping.second ;
241
- std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
242
- for (auto & i : cur_block_mapping) {
243
- // for every pair in then_mapping, old_value => then value, if old_value also appears in old_to_new_g, then it's then graph's input
244
- if (old_to_new_g.count (i.first )) {
245
- block_graph_to_new_g[i.second ] = old_to_new_g[i.first ];
246
- }
247
- }
248
-
249
- auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue (v, new_g, block_graph_to_new_g); };
250
- new_if_block->cloneFrom (cur_block_graph->block (), env);
251
- if (cur_block_graph->inputs ()[0 ]->type ()->str ().find (" __torch__" ) != std::string::npos) {
252
- block_graph_to_new_g[cur_block_graph->inputs ()[0 ]] = new_g->inputs ()[0 ];
253
- }
254
- for (int i = cur_block_graph->inputs ().size () - 1 ; i >= 0 ; --i) {
255
- new_if_block->inputs ()[i]->replaceAllUsesWith (block_graph_to_new_g[cur_block_graph->inputs ()[i]]);
256
- new_if_block->eraseInput (i);
257
- }
258
- }
259
- for (auto ov : if_view.outputs ()) {
260
- auto no = new_if->addOutput ();
261
- old_to_new_g[ov] = no;
262
- no->copyMetadata (ov);
263
- }
264
-
265
- LOG_INFO (*new_g << " new g with if\n " );
266
282
} else {
267
283
AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
268
284
}
@@ -294,23 +310,24 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
294
310
auto named_params = conversion::get_named_params (g->inputs (), params);
295
311
LOG_INFO (*g << " (LoweringGraph)\n " );
296
312
297
- // segment the graph and convert segmented TensorRT block
298
- // auto segmented_blocks = partitioning::Partition(g->block(), convert_cfg.input_ranges, cfg.partition_info);
299
- // if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
300
- // LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
301
- // return mod;
302
- // }
303
-
304
313
int trt_engine_id = 0 ;
305
314
std::unordered_map<torch::jit::Value*, ir::InputRange> input_ranges;
306
315
for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
307
316
input_ranges.insert ({g->inputs ()[i], cfg.convert_info .input_ranges [i]});
308
317
}
309
318
auto input_ivalues_map = partitioning::generateRandomInputs (input_ranges);
310
- auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, trt_engine_id, named_params);
319
+ auto graph_and_mapping =
320
+ ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, trt_engine_id, named_params);
311
321
new_g = graph_and_mapping.first ;
312
322
LOG_INFO (*new_g << " (FallbackGraph)\n " );
313
323
324
+ // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
325
+ // module
326
+ if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
327
+ LOG_WARNING (" Didn't generate any TensorRT engines, the compiler did nothing\n " );
328
+ return mod;
329
+ }
330
+
314
331
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
315
332
auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
316
333
new_mod.type ()->addMethod (new_method);
0 commit comments