12
12
#include " torch/csrc/jit/frontend/function_schema_parser.h"
13
13
#include " torch/csrc/jit/ir/ir.h"
14
14
#include " torch/csrc/jit/passes/graph_fuser.h"
15
+ #include " torch/csrc/jit/passes/loop_unrolling.h"
15
16
#include " torch/csrc/jit/passes/lower_graph.h"
16
17
#include " torch/csrc/jit/passes/pass_manager.h"
17
18
#include " torch/custom_class.h"
18
19
19
20
#include " core/compiler.h"
20
- #include " core/util/prelude.h"
21
21
22
22
#include " core/conversion/conversion.h"
23
23
#include " core/lowering/lowering.h"
24
+ #include " core/partitioning/partitioning.h"
24
25
#include " core/runtime/runtime.h"
25
26
26
27
namespace trtorch {
27
28
namespace core {
28
29
29
- c10::FunctionSchema GenerateGraphSchema (
30
- torch::jit::script::Module mod,
31
- std::string method_name,
32
- std::shared_ptr<torch::jit::Graph>& g) {
33
- std::vector<c10::Argument> args;
34
- for (auto in : g->inputs ()) {
35
- args.push_back (c10::Argument (in->debugName (), in->type ()));
36
- }
37
-
38
- std::vector<c10::Argument> returns;
39
- for (auto out : g->outputs ()) {
40
- returns.push_back (c10::Argument (out->debugName (), out->type ()));
41
- }
42
-
43
- return c10::FunctionSchema (method_name, method_name, args, returns);
44
- }
45
-
46
30
void AddEngineToGraph (
47
31
torch::jit::script::Module mod,
48
32
std::shared_ptr<torch::jit::Graph>& g,
49
- const std::string& serialized_engine) {
50
- auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue ()->name (), serialized_engine);
33
+ const std::string& serialized_engine,
34
+ std::string engine_id = " " ,
35
+ bool fallback = false ) {
36
+ auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue ()->name () + engine_id, serialized_engine);
51
37
// Get required metadata about the engine out
52
38
auto num_io = engine_ptr->num_io ;
53
39
auto name = engine_ptr->name ;
54
40
41
+ // ..
55
42
// Add the engine as an attribute of the module, this will let the engine be
56
43
// serialized and deserialized
57
44
mod.register_attribute (
@@ -108,17 +95,19 @@ void AddEngineToGraph(
108
95
g->block ()->appendNode (unpack_node);
109
96
110
97
// If there are multiple output tensors from TensorRT we wrap them in a tuple
111
- // to return
112
- if (unpack_node->outputs ().size () > 1 ) {
98
+ // to return, convert to tuple only when we only have 1 segmented graph
99
+ if (!fallback && unpack_node->outputs ().size () > 1 ) {
113
100
// Creates prim::TupleConstruct(<output tensors>) using outputs of the
114
101
// unpack node
115
102
auto return_tuple_node = g->createTuple (unpack_node->outputs ());
116
103
g->block ()->appendNode (return_tuple_node);
117
104
// Set the output as the produced tuple
118
105
g->registerOutput (return_tuple_node->outputs ()[0 ]);
119
106
} else {
120
- // Set the output as the sole output tensor
121
- g->registerOutput (unpack_node->outputs ()[0 ]);
107
+ // if fallback is enabled, multiple outputs will be registered
108
+ for (size_t i = 0 ; i < unpack_node->outputs ().size (); ++i) {
109
+ g->registerOutput (unpack_node->outputs ()[i]);
110
+ }
122
111
}
123
112
124
113
LOG_DEBUG (*g << " (AddEngineToGraph)\n " );
@@ -142,6 +131,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
142
131
143
132
auto convert_cfg = std::move (cfg.convert_info );
144
133
auto g = graph_and_parameters.first ;
134
+
145
135
auto params = graph_and_parameters.second ;
146
136
auto named_params = conversion::get_named_params (g->inputs (), params);
147
137
@@ -151,7 +141,115 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
151
141
return std::move (engine);
152
142
}
153
143
144
+ void AddSegmentedBlockToGraph (
145
+ std::shared_ptr<torch::jit::Graph>& g,
146
+ partitioning::SegmentedBlock& seg,
147
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
148
+ // old_to_new_g contains: original global graph value => new global graph value,
149
+ // mini_to_new_g: mini graph value -> new graph value
150
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
151
+ size_t input_idx = 0 ;
152
+ if (seg.target () == partitioning::SegmentedBlock::kTensorRT && g->inputs ().size () > 0 ) {
153
+ if (g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
154
+ auto self = g->insertInput (0 , " self_1" );
155
+ self->setType (seg.inputs ()[0 ]->type ());
156
+ }
157
+ mini_to_new_g[seg.inputs ()[input_idx++]] = g->inputs ()[0 ];
158
+ }
159
+
160
+ for (auto & raw_input : seg.raw_inputs ()) {
161
+ if (old_to_new_g.count (raw_input)) {
162
+ mini_to_new_g[seg.inputs ()[input_idx++]] = old_to_new_g[raw_input];
163
+ }
164
+ }
165
+
166
+ for (const auto n : seg.nodes ()) {
167
+ util::cloneNode (n, g, mini_to_new_g);
168
+ }
169
+
170
+ // original graph value => new global graph value
171
+ for (size_t i = 0 ; i < seg.raw_outputs ().size (); ++i) {
172
+ old_to_new_g[seg.raw_outputs ()[i]] = mini_to_new_g[seg.outputs ()[i]];
173
+ }
174
+
175
+ return ;
176
+ }
177
+
178
+ torch::jit::script::Module CompileGraphWithFallback (const torch::jit::script::Module& mod, CompileSpec cfg) {
179
+ // TODO: Should be doing a functional transform but need PR #31978
180
+ // [jit] More robust mangling
181
+ // torch::jit::script::Module new_mod = mod.clone();
182
+ torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
183
+ std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
184
+ for (const torch::jit::script::Method& method : mod.get_methods ()) {
185
+ // Don't convert hidden methods
186
+ if (method.name ().rfind (" _" , 0 )) {
187
+ auto new_g = std::make_shared<torch::jit::Graph>();
188
+ auto graph_and_parameters = lowering::Lower (mod, method.name ());
189
+
190
+ auto g = graph_and_parameters.first ;
191
+ auto params = graph_and_parameters.second ;
192
+ auto named_params = conversion::get_named_params (g->inputs (), params);
193
+ auto convert_cfg = std::move (cfg.convert_info );
194
+ LOG_INFO (*g << " (LoweringGraph)\n " );
195
+
196
+ // segment the graph and convert segmented TensorRT block
197
+ auto segmented_blocks = partitioning::Partition (g, convert_cfg.input_ranges , cfg.partition_info );
198
+ if (segmented_blocks.size () == 1 && segmented_blocks[0 ].target () == partitioning::SegmentedBlock::kTorch ) {
199
+ LOG_WARNING (" Didn't generate any TensorRT engines, the compiler did nothing\n " );
200
+ return mod;
201
+ }
202
+
203
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
204
+ // add global graph's input to old_to_new_g mapping
205
+ for (auto input : g->inputs ()) {
206
+ util::getOrAddInputForValue (input, new_g, old_to_new_g);
207
+ }
208
+ for (auto & seg_block : segmented_blocks) {
209
+ std::string cur_block_target =
210
+ seg_block.target () == partitioning::SegmentedBlock::kTensorRT ? " TensorRT" : " Torch" ;
211
+ LOG_INFO (*g << " (MiniGraphIn" << cur_block_target << " Block\n " );
212
+ std::ostringstream trt_engine_id;
213
+ trt_engine_id << reinterpret_cast <const int *>(&seg_block);
214
+ if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
215
+ std::vector<ir::InputRange> input_ranges;
216
+ for (auto & shape : seg_block.in_shape ()) {
217
+ input_ranges.push_back (ir::InputRange (shape));
218
+ }
219
+ // update the input ranges for each segments
220
+ convert_cfg.input_ranges = input_ranges;
221
+ auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, named_params);
222
+ auto temp_g = std::make_shared<torch::jit::Graph>();
223
+ AddEngineToGraph (new_mod, temp_g, engine, trt_engine_id.str (), true );
224
+
225
+ seg_block.update_graph (temp_g);
226
+ AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
227
+ } else {
228
+ AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
229
+ }
230
+ }
231
+
232
+ for (auto & output : g->outputs ()) {
233
+ new_g->registerOutput (old_to_new_g[output]);
234
+ }
235
+
236
+ LOG_INFO (*new_g << " (FallbackGraph)\n " );
237
+
238
+ auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
239
+ auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
240
+ new_mod.type ()->addMethod (new_method);
241
+ new_method->setSchema (schema);
242
+ }
243
+ }
244
+
245
+ return new_mod;
246
+ }
247
+
154
248
torch::jit::script::Module CompileGraph (const torch::jit::script::Module& mod, CompileSpec cfg) {
249
+ // TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
250
+ if (cfg.partition_info .enabled ) {
251
+ return CompileGraphWithFallback (mod, cfg);
252
+ }
155
253
// TODO: Should be doing a functional transform but need PR #31978
156
254
// [jit] More robust mangling
157
255
// torch::jit::script::Module new_mod = mod.clone();
@@ -164,7 +262,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
164
262
auto new_g = std::make_shared<torch::jit::Graph>();
165
263
AddEngineToGraph (new_mod, new_g, engine);
166
264
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
167
- auto schema = GenerateGraphSchema (new_mod, new_method->name (), new_g);
265
+ auto schema = util:: GenerateGraphSchema (new_method->name (), new_g);
168
266
new_mod.type ()->addMethod (new_method);
169
267
new_method->setSchema (schema);
170
268
}
@@ -180,7 +278,7 @@ torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
180
278
auto new_g = std::make_shared<torch::jit::Graph>();
181
279
AddEngineToGraph (new_mod, new_g, engine);
182
280
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (" forward" , new_g);
183
- auto schema = GenerateGraphSchema (new_mod, new_method->name (), new_g);
281
+ auto schema = util:: GenerateGraphSchema (new_method->name (), new_g);
184
282
new_mod.type ()->addMethod (new_method);
185
283
new_method->setSchema (schema);
186
284
0 commit comments