Skip to content

Commit 09fa2ad

Browse files
committed
Merge branch 'master' into pyt2.0
2 parents 925c76b + 5fa6374 commit 09fa2ad

File tree

117 files changed

+363
-229
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

117 files changed

+363
-229
lines changed

core/partitioning/partitioninginfo/PartitioningInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ struct PartitioningInfo {
1717
std::vector<std::string> forced_fallback_operators;
1818
bool truncate_long_and_double;
1919
ir::Device target_device;
20+
bool cast_int8_inputs = false;
2021

2122
std::string getGPUDeviceString() const {
2223
return "cuda:" + std::to_string(target_device.gpu_id);

core/partitioning/shape_analysis.cpp

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,24 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
9999
return nullptr;
100100
}
101101

102-
torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
102+
torch::jit::Node* createCastNode(
103+
SegmentedBlock& seg_block,
104+
size_t index,
105+
bool is_input,
106+
at::ScalarType dtype,
107+
std::string device,
108+
bool force_create_node = false) {
103109
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
104110
auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index];
105111
torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value);
106112
auto g = seg_block.g();
107113
// if we can find upstream aten::to node, we use it's parameters for creating new cast node
108-
if (cast_node) {
114+
if (cast_node && !force_create_node) {
109115
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
110116
value_map.insert({cast_node->inputs()[0], cast_subgraph_value});
111117
if (!is_input) {
112118
// if this value is output, we need to cast it to int32
113-
auto const_val = g->insertConstant(3);
119+
auto const_val = g->insertConstant(dtype);
114120
if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
115121
value_map.insert({cast_node->inputs()[2], const_val});
116122
} else {
@@ -122,7 +128,7 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
122128
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
123129
} else {
124130
// if there is no explicit cast aten::to operation, we need to create a node
125-
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
131+
auto const_type = g->insertConstant(dtype);
126132
auto const_zero = g->insertConstant(0);
127133
const_zero->setType(torch::jit::BoolType::get());
128134
auto cuda = g->insertConstant(device);
@@ -222,27 +228,56 @@ void getSegmentsOutputByRunning(
222228

223229
auto target_device = partitioning_info.getGPUDeviceString();
224230

225-
// auto int64 <=> int32 conversion
226-
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) {
231+
// auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
232+
if (seg_block.target() == SegmentedBlock::kTorch) {
227233
// First, check if there is Int64 input
228234
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
229235
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
230236
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
231237
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
232-
if (t == at::kLong) {
238+
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
239+
LOG_DEBUG(
240+
"Detected graph Long tensor input type during shape analysis, "
241+
<< "inserting aten::to cast to Long to ensure this Torch block receives "
242+
<< "a Long-type tensor input.");
233243
// we add a cast operation to cast the type to Int64
234-
auto cast_node = createCastNode(seg_block, i, true, target_device);
244+
auto cast_node = createCastNode(seg_block, i, true, at::kLong, target_device);
245+
seg_block.g()->prependNode(cast_node);
246+
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
247+
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
248+
LOG_DEBUG(
249+
"Detected graph Byte tensor input type during shape analysis, "
250+
<< "inserting aten::to cast to Byte to ensure this Torch block receives "
251+
<< "a Byte-type tensor input.");
252+
// If the input has type Byte, ensure it is casted to the correct type
253+
auto cast_node = createCastNode(seg_block, i, true, at::kByte, target_device, /*force_create_node=*/true);
235254
seg_block.g()->prependNode(cast_node);
236255
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
237256
}
238257
}
239258
}
259+
240260
for (size_t i = 0; i < seg_block.outputs().size(); ++i) {
241261
if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) {
242262
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
243263
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
244-
if (t == at::kLong) {
245-
auto cast_node = createCastNode(seg_block, i, false, target_device);
264+
265+
// If the output has type Long and truncation was requested, insert truncate
266+
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
267+
LOG_DEBUG(
268+
"Detected graph Long tensor output type during shape analysis, "
269+
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
270+
<< "receives an Int-type tensor input.");
271+
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device);
272+
seg_block.g()->appendNode(cast_node);
273+
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
274+
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
275+
LOG_DEBUG(
276+
"Detected graph Byte tensor output type during shape analysis, "
277+
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
278+
<< "receives an Int-type tensor input.");
279+
// If the output has type Byte and casting was requested, insert Integer cast
280+
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device, /*force_create_node=*/true);
246281
seg_block.g()->appendNode(cast_node);
247282
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
248283
}
@@ -254,11 +289,13 @@ void getSegmentsOutputByRunning(
254289
std::vector<std::vector<int64_t>> input_shapes;
255290
std::vector<at::ScalarType> input_types;
256291
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
257-
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
292+
auto current_input = seg_block.raw_inputs()[i];
293+
294+
if (ivalues_maps[current_input].isTensor()) {
258295
// set the input_shape and data_type
259296
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
260297
// shape inference
261-
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
298+
auto cur_ivalue = ivalues_maps[current_input];
262299
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
263300

264301
if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
@@ -271,10 +308,16 @@ void getSegmentsOutputByRunning(
271308
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
272309
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
273310
}
311+
274312
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
275313
if (dtype == c10::nullopt) {
276314
TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype());
315+
} else if (dtype && dtype.value() == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs) {
316+
// Special case to ensure input IValues to TensorRT engine are not Int8 type if the
317+
// model itself is not quantized
318+
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
277319
}
320+
278321
if (cur_ivalue.toTensor().sizes().size() == 0) {
279322
// handle Scalar types, which has sizes of []
280323
input_shapes.push_back(util::toVec(util::toDims(c10::List<int64_t>({1}))));
@@ -297,6 +340,7 @@ void runShapeAnalysis(
297340
const ir::ShapeMode& shape_mode) {
298341
// register every segment's input shape, and it's running output IValues
299342
for (auto& seg_block : ctx->partitioned_blocks[block]) {
343+
LOG_GRAPH("Running shape analysis on block " << seg_block);
300344
torch::jit::ConstantPooling(seg_block.g());
301345
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings, shape_mode);
302346
}

core/util/trt_util.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
252252
{at::kHalf, nvinfer1::DataType::kHALF},
253253
{at::kInt, nvinfer1::DataType::kINT32},
254254
{at::kChar, nvinfer1::DataType::kINT8},
255+
{at::kByte, nvinfer1::DataType::kINT8},
255256
{at::kBool, nvinfer1::DataType::kBOOL}};
256257
return at_trt_type_map;
257258
}

cpp/src/compile_spec.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,11 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
167167
internal.convert_info.engine_settings.dla_local_dram_size = external.dla_local_dram_size;
168168
internal.convert_info.engine_settings.dla_global_dram_size = external.dla_global_dram_size;
169169

170+
internal.partitioning_info.cast_int8_inputs = true;
171+
170172
if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
171173
internal.convert_info.engine_settings.enabled_precisions.end()) {
174+
internal.partitioning_info.cast_int8_inputs = false;
172175
if (external.ptq_calibrator) {
173176
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
174177
} else {

docs/_cpp_api/classtorch__tensorrt_1_1DataType.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class DataType &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Class DataType &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1Device_1_1DeviceType.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class Device::DeviceType &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Class Device::DeviceType &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1TensorFormat.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class TensorFormat &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Class TensorFormat &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8CacheCalibrator.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Template Class Int8CacheCalibrator &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Template Class Int8CacheCalibrator &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8Calibrator.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Template Class Int8Calibrator &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Template Class Int8Calibrator &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define STR &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Define STR &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a282fd3c0b1c3a215148ae372070e1268.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_PATCH_VERSION &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Define TORCH_TENSORRT_PATCH_VERSION &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a31398a6d4d27e28817afb0f0139e909e.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_MAJOR_VERSION &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Define TORCH_TENSORRT_MAJOR_VERSION &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a35703561b26b1a9d2738ad7d58b27827.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_MINOR_VERSION &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Define TORCH_TENSORRT_MINOR_VERSION &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1abd1465eb38256d3f22cc1426b23d516b.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCHTRT_API &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Define TORCHTRT_API &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define XSTR &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Define XSTR &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1ad19939408f7be171a74a89928b36eb59.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCHTRT_HIDDEN &mdash; Torch-TensorRT v1.4.0dev0+b638e78 documentation</title>
13+
<title>Define TORCHTRT_HIDDEN &mdash; Torch-TensorRT v1.4.0dev0+544654f documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b638e78
218+
v1.4.0dev0+544654f
219219
</div>
220220

221221

0 commit comments

Comments
 (0)