Skip to content

Commit 5a53c13

Browse files
committed
Merge remote-tracking branch 'origin/master' into plugins
2 parents b5d4055 + e81367b commit 5a53c13

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1900
-106
lines changed

core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ cc_library(
2626
"//core/conversion",
2727
"//core/runtime",
2828
"//core/lowering",
29+
"//core/partitioning",
2930
"//core/util/logging",
3031
"@tensorrt//:nvinfer",
3132
] + select({

core/compiler.cpp

Lines changed: 124 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,46 +12,33 @@
1212
#include "torch/csrc/jit/frontend/function_schema_parser.h"
1313
#include "torch/csrc/jit/ir/ir.h"
1414
#include "torch/csrc/jit/passes/graph_fuser.h"
15+
#include "torch/csrc/jit/passes/loop_unrolling.h"
1516
#include "torch/csrc/jit/passes/lower_graph.h"
1617
#include "torch/csrc/jit/passes/pass_manager.h"
1718
#include "torch/custom_class.h"
1819

1920
#include "core/compiler.h"
20-
#include "core/util/prelude.h"
2121

2222
#include "core/conversion/conversion.h"
2323
#include "core/lowering/lowering.h"
24+
#include "core/partitioning/partitioning.h"
2425
#include "core/runtime/runtime.h"
2526

2627
namespace trtorch {
2728
namespace core {
2829

29-
c10::FunctionSchema GenerateGraphSchema(
30-
torch::jit::script::Module mod,
31-
std::string method_name,
32-
std::shared_ptr<torch::jit::Graph>& g) {
33-
std::vector<c10::Argument> args;
34-
for (auto in : g->inputs()) {
35-
args.push_back(c10::Argument(in->debugName(), in->type()));
36-
}
37-
38-
std::vector<c10::Argument> returns;
39-
for (auto out : g->outputs()) {
40-
returns.push_back(c10::Argument(out->debugName(), out->type()));
41-
}
42-
43-
return c10::FunctionSchema(method_name, method_name, args, returns);
44-
}
45-
4630
void AddEngineToGraph(
4731
torch::jit::script::Module mod,
4832
std::shared_ptr<torch::jit::Graph>& g,
49-
const std::string& serialized_engine) {
50-
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
33+
const std::string& serialized_engine,
34+
std::string engine_id = "",
35+
bool fallback = false) {
36+
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine);
5137
// Get required metadata about the engine out
5238
auto num_io = engine_ptr->num_io;
5339
auto name = engine_ptr->name;
5440

41+
//..
5542
// Add the engine as an attribute of the module, this will let the engine be
5643
// serialized and deserialized
5744
mod.register_attribute(
@@ -108,17 +95,19 @@ void AddEngineToGraph(
10895
g->block()->appendNode(unpack_node);
10996

11097
// If there are multiple output tensors from TensorRT we wrap them in a tuple
111-
// to return
112-
if (unpack_node->outputs().size() > 1) {
98+
// to return, convert to tuple only when we only have 1 segmented graph
99+
if (!fallback && unpack_node->outputs().size() > 1) {
113100
// Creates prim::TupleConstruct(<output tensors>) using outputs of the
114101
// unpack node
115102
auto return_tuple_node = g->createTuple(unpack_node->outputs());
116103
g->block()->appendNode(return_tuple_node);
117104
// Set the output as the produced tuple
118105
g->registerOutput(return_tuple_node->outputs()[0]);
119106
} else {
120-
// Set the output as the sole output tensor
121-
g->registerOutput(unpack_node->outputs()[0]);
107+
// if fallback is enabled, multiple outputs will be registered
108+
for (size_t i = 0; i < unpack_node->outputs().size(); ++i) {
109+
g->registerOutput(unpack_node->outputs()[i]);
110+
}
122111
}
123112

124113
LOG_DEBUG(*g << "(AddEngineToGraph)\n");
@@ -142,6 +131,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
142131

143132
auto convert_cfg = std::move(cfg.convert_info);
144133
auto g = graph_and_parameters.first;
134+
145135
auto params = graph_and_parameters.second;
146136
auto named_params = conversion::get_named_params(g->inputs(), params);
147137

@@ -151,7 +141,115 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
151141
return std::move(engine);
152142
}
153143

144+
void AddSegmentedBlockToGraph(
145+
std::shared_ptr<torch::jit::Graph>& g,
146+
partitioning::SegmentedBlock& seg,
147+
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
148+
// old_to_new_g contains: original global graph value => new global graph value,
149+
// mini_to_new_g: mini graph value -> new graph value
150+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
151+
size_t input_idx = 0;
152+
if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
153+
if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
154+
auto self = g->insertInput(0, "self_1");
155+
self->setType(seg.inputs()[0]->type());
156+
}
157+
mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
158+
}
159+
160+
for (auto& raw_input : seg.raw_inputs()) {
161+
if (old_to_new_g.count(raw_input)) {
162+
mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
163+
}
164+
}
165+
166+
for (const auto n : seg.nodes()) {
167+
util::cloneNode(n, g, mini_to_new_g);
168+
}
169+
170+
// original graph value => new global graph value
171+
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
172+
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
173+
}
174+
175+
return;
176+
}
177+
178+
torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) {
179+
// TODO: Should be doing a functional transform but need PR #31978
180+
// [jit] More robust mangling
181+
// torch::jit::script::Module new_mod = mod.clone();
182+
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
183+
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
184+
for (const torch::jit::script::Method& method : mod.get_methods()) {
185+
// Don't convert hidden methods
186+
if (method.name().rfind("_", 0)) {
187+
auto new_g = std::make_shared<torch::jit::Graph>();
188+
auto graph_and_parameters = lowering::Lower(mod, method.name());
189+
190+
auto g = graph_and_parameters.first;
191+
auto params = graph_and_parameters.second;
192+
auto named_params = conversion::get_named_params(g->inputs(), params);
193+
auto convert_cfg = std::move(cfg.convert_info);
194+
LOG_INFO(*g << "(LoweringGraph)\n");
195+
196+
// segment the graph and convert segmented TensorRT block
197+
auto segmented_blocks = partitioning::Partition(g, convert_cfg.input_ranges, cfg.partition_info);
198+
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
199+
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
200+
return mod;
201+
}
202+
203+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
204+
// add global graph's input to old_to_new_g mapping
205+
for (auto input : g->inputs()) {
206+
util::getOrAddInputForValue(input, new_g, old_to_new_g);
207+
}
208+
for (auto& seg_block : segmented_blocks) {
209+
std::string cur_block_target =
210+
seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
211+
LOG_INFO(*g << "(MiniGraphIn" << cur_block_target << "Block\n");
212+
std::ostringstream trt_engine_id;
213+
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
214+
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
215+
std::vector<ir::InputRange> input_ranges;
216+
for (auto& shape : seg_block.in_shape()) {
217+
input_ranges.push_back(ir::InputRange(shape));
218+
}
219+
// update the input ranges for each segments
220+
convert_cfg.input_ranges = input_ranges;
221+
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
222+
auto temp_g = std::make_shared<torch::jit::Graph>();
223+
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id.str(), true);
224+
225+
seg_block.update_graph(temp_g);
226+
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
227+
} else {
228+
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
229+
}
230+
}
231+
232+
for (auto& output : g->outputs()) {
233+
new_g->registerOutput(old_to_new_g[output]);
234+
}
235+
236+
LOG_INFO(*new_g << "(FallbackGraph)\n");
237+
238+
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
239+
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
240+
new_mod.type()->addMethod(new_method);
241+
new_method->setSchema(schema);
242+
}
243+
}
244+
245+
return new_mod;
246+
}
247+
154248
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
249+
// TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
250+
if (cfg.partition_info.enabled) {
251+
return CompileGraphWithFallback(mod, cfg);
252+
}
155253
// TODO: Should be doing a functional transform but need PR #31978
156254
// [jit] More robust mangling
157255
// torch::jit::script::Module new_mod = mod.clone();
@@ -164,7 +262,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
164262
auto new_g = std::make_shared<torch::jit::Graph>();
165263
AddEngineToGraph(new_mod, new_g, engine);
166264
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
167-
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
265+
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
168266
new_mod.type()->addMethod(new_method);
169267
new_method->setSchema(schema);
170268
}
@@ -180,7 +278,7 @@ torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
180278
auto new_g = std::make_shared<torch::jit::Graph>();
181279
AddEngineToGraph(new_mod, new_g, engine);
182280
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
183-
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
281+
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
184282
new_mod.type()->addMethod(new_method);
185283
new_method->setSchema(schema);
186284

core/compiler.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
#include <cuda_runtime.h>
44
#include <vector>
55
#include "core/conversion/conversion.h"
6+
#include "core/ir/ir.h"
7+
#include "core/partitioning/partitioning.h"
68
#include "torch/csrc/jit/api/module.h"
79

810
namespace trtorch {
911
namespace core {
1012

1113
struct CompileSpec {
12-
CompileSpec(std::vector<conversion::InputRange> input_ranges) : convert_info(std::move(input_ranges)) {}
14+
CompileSpec(std::vector<ir::InputRange> input_ranges) : convert_info(std::move(input_ranges)) {}
1315
conversion::ConversionInfo convert_info;
16+
partitioning::PartitionInfo partition_info;
1417
};
1518

1619
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);

core/conversion/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ cc_library(
2323
"//core/conversion/conversionctx",
2424
"//core/conversion/converters",
2525
"//core/conversion/evaluators",
26+
"//core/ir",
2627
"//core/util:prelude",
2728
] + select({
2829
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],

core/conversion/InterfaceTypes.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,55 +23,6 @@ GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vect
2323
return std::move(named_params);
2424
}
2525

26-
InputRange::InputRange(std::vector<int64_t> d) {
27-
if (d.size() > 5) {
28-
LOG_WARNING("Verify that this dim size is accepted");
29-
}
30-
31-
opt = util::toDims(d);
32-
min = util::toDims(d);
33-
max = util::toDims(d);
34-
input_shape = util::toDims(d);
35-
input_is_dynamic = false;
36-
}
37-
38-
InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape) {
39-
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
40-
LOG_WARNING("Verify that this dim size is accepted");
41-
}
42-
43-
std::set<size_t> sizes;
44-
sizes.insert(min_shape.size());
45-
sizes.insert(opt_shape.size());
46-
sizes.insert(max_shape.size());
47-
48-
if (sizes.size() != 1) {
49-
LOG_ERROR(
50-
"Expected all input sizes have the same dimensions, but found dimensions: min("
51-
<< min_shape.size() << "), opt(" << opt_shape.size() << "), max(" << max_shape.size() << ")");
52-
}
53-
54-
min = util::toDims(min_shape);
55-
opt = util::toDims(opt_shape);
56-
max = util::toDims(max_shape);
57-
58-
std::vector<int64_t> dyn_shape;
59-
for (size_t i = 0; i < opt_shape.size(); i++) {
60-
std::set<uint64_t> dim;
61-
dim.insert(min_shape[i]);
62-
dim.insert(opt_shape[i]);
63-
dim.insert(max_shape[i]);
64-
if (dim.size() != 1) {
65-
dyn_shape.push_back(-1);
66-
input_is_dynamic = true;
67-
} else {
68-
dyn_shape.push_back(opt_shape[i]);
69-
}
70-
}
71-
72-
input_shape = util::toDims(dyn_shape);
73-
}
74-
7526
} // namespace conversion
7627
} // namespace core
7728
} // namespace trtorch

core/conversion/conversion.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
118118
<< "please report this error to https://www.github.com/NVIDIA/TRTorch/issues");
119119
}
120120

121-
void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<InputRange>& input_dims) {
121+
void AddInputs(
122+
ConversionCtx* ctx,
123+
at::ArrayRef<const torch::jit::Value*> inputs,
124+
std::vector<ir::InputRange>& input_dims) {
122125
std::vector<const torch::jit::Value*> input_tensors;
123126
for (auto in : inputs) {
124127
// Disregarding inputs that are not tensors

core/conversion/conversion.h

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,17 @@
44

55
#include "NvInfer.h"
66
#include "core/conversion/conversionctx/ConversionCtx.h"
7+
#include "core/ir/ir.h"
78
#include "torch/csrc/jit/ir/ir.h"
89

910
namespace trtorch {
1011
namespace core {
1112
namespace conversion {
1213

13-
struct InputRange {
14-
nvinfer1::Dims min;
15-
nvinfer1::Dims max;
16-
nvinfer1::Dims opt;
17-
nvinfer1::Dims input_shape;
18-
bool input_is_dynamic = false;
19-
// Should we restrict to unsigned?
20-
InputRange(std::vector<int64_t> d);
21-
InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
22-
};
23-
2414
struct ConversionInfo {
25-
std::vector<InputRange> input_ranges;
15+
std::vector<ir::InputRange> input_ranges;
2616
BuilderSettings engine_settings;
27-
ConversionInfo(std::vector<InputRange> input_ranges)
17+
ConversionInfo(std::vector<ir::InputRange> input_ranges)
2818
: input_ranges(std::move(input_ranges)), engine_settings(BuilderSettings()) {}
2919
};
3020

core/conversion/evaluators/aten.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ auto aten_registrations TRTORCH_UNUSED =
450450
if (args.at(n->input(0)).IValue()->isInt()) {
451451
auto a = args.at(n->input(0)).unwrapToInt();
452452
auto b = args.at(n->input(1)).unwrapToInt();
453-
return std::floor(a / b);
453+
return static_cast<int>(std::floor(a / b));
454454
} else if (args.at(n->input(0)).IValue()->isDouble()) {
455455
auto a = args.at(n->input(0)).unwrapToDouble();
456456
auto b = args.at(n->input(1)).unwrapToDouble();

core/ir/BUILD

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
config_setting(
4+
name = "use_pre_cxx11_abi",
5+
values = {
6+
"define": "abi=pre_cxx11_abi",
7+
}
8+
)
9+
10+
cc_library(
11+
name = "ir",
12+
hdrs = [
13+
"ir.h"
14+
],
15+
srcs = [
16+
"InputRange.cpp",
17+
],
18+
deps = [
19+
"@tensorrt//:nvinfer",
20+
"//core/util:prelude",
21+
] + select({
22+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
23+
"//conditions:default": ["@libtorch//:libtorch"],
24+
}),
25+
)
26+
27+
load("@rules_pkg//:pkg.bzl", "pkg_tar")
28+
29+
pkg_tar(
30+
name = "include",
31+
package_dir = "core/ir/",
32+
srcs = [
33+
"ir.h",
34+
],
35+
)

0 commit comments

Comments
 (0)