Skip to content

Commit 5df2536

Browse files
committed
fix: Bug fixes, code refactor and rebase with master
Signed-off-by: Dheeraj Peri <[email protected]>
2 parents 28da0e8 + 75e86e8 commit 5df2536

File tree

83 files changed

+2559
-456
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+2559
-456
lines changed

.github/workflows/docgen.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ jobs:
1717
username: $GITHUB_ACTOR
1818
password: ${{secrets.GITHUB_TOKEN}}
1919
steps:
20+
- name: Reclaim space
21+
run: |
22+
rm -rf /usr/share/dotnet
23+
rm -rf /opt/ghc
24+
rm -rf "/usr/local/share/boost"
2025
- uses: actions/checkout@v2
2126
with:
2227
ref: ${{github.head_ref}}

core/compiler.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ void AddEngineToGraph(
3232
torch::jit::script::Module mod,
3333
std::shared_ptr<torch::jit::Graph>& g,
3434
const std::string& serialized_engine,
35+
runtime::CudaDevice& device_info,
3536
std::string engine_id = "",
3637
bool fallback = false) {
37-
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine);
38+
auto engine_ptr =
39+
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine, device_info);
3840
// Get required metadata about the engine out
3941
auto num_io = engine_ptr->num_io;
4042
auto name = engine_ptr->name;
@@ -265,7 +267,9 @@ GraphAndMapping ConstructFallbackGraph(
265267
convert_cfg.input_ranges = input_ranges;
266268
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
267269
auto temp_g = std::make_shared<torch::jit::Graph>();
268-
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id.str(), true);
270+
auto device_spec = convert_cfg.engine_settings.device;
271+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
272+
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
269273

270274
seg_block.update_graph(temp_g);
271275
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
@@ -302,15 +306,15 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
302306
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
303307
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
304308
for (const torch::jit::script::Method& method : mod.get_methods()) {
305-
// Don't convert hidden methods
306-
if (method.name().rfind("_", 0)) {
309+
// Compile only forward methods. forward method contains the entire graph.
310+
if (method.name().compare("forward") == 0) {
307311
auto new_g = std::make_shared<torch::jit::Graph>();
308312
auto graph_and_parameters = lowering::Lower(mod, method.name());
309313

310314
auto g = graph_and_parameters.first;
311315
auto params = graph_and_parameters.second;
312316
auto named_params = conversion::get_named_params(g->inputs(), params);
313-
LOG_INFO(*g << "(LoweringGraph)\n");
317+
LOG_INFO("(LoweredGraph)\n" << *g);
314318

315319
std::unordered_map<torch::jit::Value*, ir::InputRange> input_ranges;
316320
for (size_t i = 0; i < g->inputs().size(); ++i) {
@@ -319,7 +323,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
319323
auto input_ivalues_map = partitioning::generateRandomInputs(input_ranges);
320324
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
321325
new_g = graph_and_mapping.first;
322-
LOG_INFO(*new_g << "(FallbackGraph)\n");
326+
LOG_INFO("(FallbackGraph)\n" << *new_g);
323327

324328
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
325329
// module
@@ -349,11 +353,13 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
349353
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
350354
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
351355
for (const torch::jit::script::Method& method : mod.get_methods()) {
352-
// Don't convert hidden methods
353-
if (method.name().rfind("_", 0)) {
356+
// Compile only forward methods. forward method contains the entire graph.
357+
if (method.name().compare("forward") == 0) {
354358
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
355359
auto new_g = std::make_shared<torch::jit::Graph>();
356-
AddEngineToGraph(new_mod, new_g, engine);
360+
auto device_spec = cfg.convert_info.engine_settings.device;
361+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
362+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
357363
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
358364
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
359365
new_mod.type()->addMethod(new_method);
@@ -364,12 +370,12 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
364370
return new_mod;
365371
}
366372

367-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
373+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device) {
368374
std::ostringstream engine_id;
369375
engine_id << reinterpret_cast<const int*>(&engine);
370376
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
371377
auto new_g = std::make_shared<torch::jit::Graph>();
372-
AddEngineToGraph(new_mod, new_g, engine);
378+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
373379
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
374380
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
375381
new_mod.type()->addMethod(new_method);

core/compiler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "core/conversion/conversion.h"
66
#include "core/ir/ir.h"
77
#include "core/partitioning/partitioning.h"
8+
#include "core/runtime/runtime.h"
89
#include "torch/csrc/jit/api/module.h"
910

1011
namespace trtorch {
@@ -22,7 +23,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
2223

2324
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2425

25-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine);
26+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device);
2627

2728
void set_device(const int gpu_id);
2829

core/conversion/conversion.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,15 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
4545
if (result) {
4646
// WARN: If the converter returns None then should pass through
4747
// but if repeated dep this section will get called each time
48-
ctx->evaluated_value_map[eval_in] = std::move(result.value());
49-
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
48+
auto val = result.value();
49+
if (val.isCustomClass()) {
50+
auto cont = val.toCustomClass<TensorContainer>();
51+
ctx->AssociateValueAndTensor(eval_in, cont->tensor());
52+
eval_args[eval_in] = ctx->value_tensor_map[eval_in];
53+
} else {
54+
ctx->AssociateValueAndIValue(eval_in, val);
55+
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
56+
}
5057
}
5158
} else {
5259
TRTORCH_THROW_ERROR(
@@ -374,6 +381,11 @@ void ConvertBlockToNetDef(
374381
} else {
375382
TRTORCH_THROW_ERROR("Unsupported return type for evaluated node");
376383
}
384+
} else if (eval.value().isCustomClass()) {
385+
auto container = eval.value().toCustomClass<TensorContainer>();
386+
auto tensor = container->tensor();
387+
LOG_DEBUG(ctx->logger, "Found the value to be an ITensor of shape: " << tensor->getDimensions());
388+
ctx->AssociateValueAndTensor(n->output(0), tensor);
377389
} else if (!eval.value().isTensor()) {
378390
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
379391
ctx->AssociateValueAndIValue(n->output(0), eval.value());

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,15 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
2323
<< "\n Max Workspace Size: " << s.workspace_size;
2424

2525
if (s.max_batch_size != 0) {
26-
os << "\n Max Batch Size: " << s.max_batch_size;
26+
os << "\n Max Batch Size: " << s.max_batch_size;
2727
} else {
28-
os << "\n Max Batch Size: Not set";
28+
os << "\n Max Batch Size: Not set";
2929
}
3030

3131
os << "\n Device Type: " << s.device.device_type \
3232
<< "\n GPU ID: " << s.device.gpu_id;
33-
if (s.device.device_type == nvinfer1::DeviceType::kDLA)
34-
{
35-
os << "\n DLACore: " << s.device.dla_core;
33+
if (s.device.device_type == nvinfer1::DeviceType::kDLA) {
34+
os << "\n DLACore: " << s.device.dla_core;
3635
}
3736
os << "\n Engine Capability: " << s.capability \
3837
<< "\n Calibrator Created: " << (s.calibrator != nullptr);
@@ -146,6 +145,9 @@ torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Val
146145

147146
std::string ConversionCtx::SerializeEngine() {
148147
auto engine = builder->buildEngineWithConfig(*net, *cfg);
148+
if (!engine) {
149+
TRTORCH_THROW_ERROR("Building TensorRT engine failed");
150+
}
149151
auto serialized_engine = engine->serialize();
150152
engine->destroy();
151153
auto engine_str = std::string((const char*)serialized_engine->data(), serialized_engine->size());

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ cc_library(
3535
"impl/batch_norm.cpp",
3636
"impl/concat.cpp",
3737
"impl/constant.cpp",
38+
"impl/constant_pad.cpp",
3839
"impl/conv_deconv.cpp",
3940
"impl/cumsum.cpp",
4041
"impl/element_wise.cpp",

core/conversion/converters/converter_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ nvinfer1::ITensor* addUnpadding(
5353
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer");
5454
shuffle_layer->setReshapeDimensions(newDims);
5555
shuffle_layer->setZeroIsPlaceholder(use_zeros);
56-
shuffle_layer->setName((util::node_info(n) + " [Reshape to " + util::toStr(newDims)).c_str() + ']');
56+
shuffle_layer->setName((util::node_info(n) + " [Reshape to " + util::toStr(newDims) + "]").c_str());
5757
return shuffle_layer->getOutput(0);
5858
} else {
5959
return tensor;
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#include <ATen/ATen.h>
2+
#include <vector>
3+
#include "NvInfer.h"
4+
#include "core/conversion/converters/converters.h"
5+
#include "core/util/prelude.h"
6+
#include "torch/torch.h"
7+
8+
namespace trtorch {
9+
namespace core {
10+
namespace conversion {
11+
namespace converters {
12+
namespace impl {
13+
namespace {
14+
15+
auto constant_pad_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
16+
{"aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)",
17+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18+
auto in = args[0].ITensor();
19+
auto inDims = in->getDimensions();
20+
int64_t inRank = inDims.nbDims;
21+
auto padding = args[1].unwrapToIntList().vec();
22+
int64_t padSize = padding.size();
23+
auto value = args[2].unwrapToScalar().to<float>();
24+
25+
TRTORCH_CHECK(padSize % 2 == 0, "Length of pad must be even but instead it equals " << padSize);
26+
27+
int64_t l_pad = padSize / 2;
28+
TRTORCH_CHECK(
29+
inRank >= (int64_t)l_pad,
30+
"Length of pad should be no more than twice the number of "
31+
"dimensions of the input. Pad length is "
32+
<< padSize << "while the input has " << inRank << "dimensions.");
33+
34+
// TODO negative padding. When the pad is negative, we need to crop the image.
35+
36+
std::vector<nvinfer1::ITensor*> tensors_vec;
37+
// input: (N, C, D_in, H_in, W_in).
38+
// padding: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
39+
// When axis is inRank - 1, making W_out = W_in + padding_left + padding_right.
40+
// When axis is inRank - 2, making H_out = H_in + padding_top + padding_bottom.
41+
// When axis is inRank - 3, making D_out = D_in + padding_front + padding_back.
42+
for (int64_t i = 0; i < l_pad; i++) {
43+
int64_t axis = inRank - (i + 1); // axis = {inRank - 1, inRank - 2, inRank - 3}
44+
int64_t padding_index = i * 2;
45+
46+
if (padding[padding_index] > 0) { // left/top/front padding value
47+
tensors_vec.clear();
48+
if (ctx->input_is_dynamic) {
49+
at::Tensor left_indices = torch::tensor({0}, torch::kInt32);
50+
auto indicesTensor = tensor_to_const(ctx, left_indices);
51+
auto left_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
52+
auto left_gather_out = left_gather_layer->getOutput(0);
53+
54+
// fill the left_gather_out with value
55+
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
56+
auto shape_gather_out = ctx->net->addShape(*left_gather_out)->getOutput(0);
57+
fill_layer->setInput(0, *shape_gather_out);
58+
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
59+
auto valueTensor = tensor_to_const(ctx, value_tensor);
60+
fill_layer->setInput(1, *valueTensor);
61+
at::Tensor delta_tensor = torch::zeros(inRank);
62+
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
63+
fill_layer->setInput(2, *deltaTensor);
64+
auto padTensor = fill_layer->getOutput(0);
65+
66+
for (int i = 0; i < padding[padding_index]; i++) {
67+
tensors_vec.push_back(padTensor);
68+
}
69+
} else {
70+
inDims.d[axis] = padding[padding_index];
71+
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
72+
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
73+
auto valueTensor = tensor_to_const(ctx, value_tensor);
74+
fill_layer->setInput(1, *valueTensor);
75+
at::Tensor delta_tensor = torch::zeros(inRank);
76+
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
77+
fill_layer->setInput(2, *deltaTensor);
78+
auto padTensor = fill_layer->getOutput(0);
79+
80+
tensors_vec.push_back(padTensor);
81+
}
82+
83+
tensors_vec.push_back(in);
84+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
85+
concat_layer->setAxis(axis);
86+
in = concat_layer->getOutput(0);
87+
inDims = in->getDimensions();
88+
}
89+
90+
if (padding[padding_index + 1] > 0) { // right/bottom/back padding value
91+
tensors_vec.clear();
92+
tensors_vec.push_back(in);
93+
94+
nvinfer1::ITensor* indicesTensor = NULL;
95+
if (inDims.d[axis] == -1) {
96+
auto shapeTensor = ctx->net->addShape(*in)->getOutput(0);
97+
at::Tensor dimValue = torch::tensor({axis}, torch::kInt32);
98+
auto dimTensor = tensor_to_const(ctx, dimValue);
99+
indicesTensor = ctx->net->addGather(*shapeTensor, *dimTensor, 0)->getOutput(0);
100+
auto oneTensor = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
101+
indicesTensor = ctx->net->addElementWise(*indicesTensor, *oneTensor, nvinfer1::ElementWiseOperation::kSUB)
102+
->getOutput(0);
103+
} else {
104+
auto indices = torch::tensor({inDims.d[axis] - 1}, torch::kInt32);
105+
indicesTensor = tensor_to_const(ctx, indices);
106+
}
107+
auto right_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
108+
auto right_gather_out = right_gather_layer->getOutput(0);
109+
110+
if (ctx->input_is_dynamic) {
111+
// fill the right_gather_out with value
112+
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
113+
auto shape_gather_out = ctx->net->addShape(*right_gather_out)->getOutput(0);
114+
fill_layer->setInput(0, *shape_gather_out);
115+
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
116+
auto valueTensor = tensor_to_const(ctx, value_tensor);
117+
fill_layer->setInput(1, *valueTensor);
118+
at::Tensor delta_tensor = torch::zeros(inRank);
119+
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
120+
fill_layer->setInput(2, *deltaTensor);
121+
auto padTensor = fill_layer->getOutput(0);
122+
123+
for (int i = 0; i < padding[padding_index + 1]; i++) {
124+
tensors_vec.push_back(padTensor);
125+
}
126+
} else {
127+
inDims.d[axis] = padding[padding_index + 1];
128+
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
129+
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
130+
auto valueTensor = tensor_to_const(ctx, value_tensor);
131+
fill_layer->setInput(1, *valueTensor);
132+
at::Tensor delta_tensor = torch::zeros(inRank);
133+
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
134+
fill_layer->setInput(2, *deltaTensor);
135+
auto padTensor = fill_layer->getOutput(0);
136+
137+
tensors_vec.push_back(padTensor);
138+
}
139+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
140+
concat_layer->setAxis(axis);
141+
in = concat_layer->getOutput(0);
142+
inDims = in->getDimensions();
143+
}
144+
}
145+
146+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
147+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
148+
return true;
149+
}});
150+
151+
} // namespace
152+
} // namespace impl
153+
} // namespace converters
154+
} // namespace conversion
155+
} // namespace core
156+
} // namespace trtorch

core/conversion/converters/impl/layer_norm.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,31 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
117117
}
118118

119119
auto power = Weights(ctx, at::ones(expand_size));
120-
auto scale_nd = ctx->net->addScaleNd(
121-
*div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1);
122-
scale_nd->setName((util::node_info(n) + "_scale_nd").c_str());
123-
auto scale_nd_out = scale_nd->getOutput(0);
124120

125-
ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out);
121+
auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0);
122+
auto scale_l = add_elementwise(
123+
ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str());
124+
125+
auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0);
126+
auto shift_l = add_elementwise(
127+
ctx,
128+
nvinfer1::ElementWiseOperation::kSUM,
129+
scale_l->getOutput(0),
130+
beta_tensor,
131+
(util::node_info(n) + "_shift").c_str());
132+
133+
auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0);
134+
auto power_l = add_elementwise(
135+
ctx,
136+
nvinfer1::ElementWiseOperation::kPOW,
137+
shift_l->getOutput(0),
138+
power_tensor,
139+
(util::node_info(n) + "_power").c_str());
140+
141+
power_l->setName((util::node_info(n) + "_scale_nd").c_str());
142+
auto power_l_out = power_l->getOutput(0);
143+
144+
ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out);
126145
return true;
127146
}});
128147

0 commit comments

Comments
 (0)