Skip to content

Commit 7d1d984

Browse files
committed
chore: rebase with master
Signed-off-by: Dheeraj Peri <[email protected]>
2 parents 14de59a + 6f73a23 commit 7d1d984

File tree

122 files changed

+431
-272
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+431
-272
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
114114
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
115115

116116
- Bazel 5.2.0
117-
- Libtorch 1.14.0.dev20221114 (built with CUDA 11.7)
117+
- Libtorch 1.14.0.dev20221205 (built with CUDA 11.7)
118118
- CUDA 11.7
119119
- cuDNN 8.5.0
120120
- TensorRT 8.5.1.7

core/conversion/converters/impl/select.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,22 @@ auto select_registrations TORCHTRT_UNUSED =
736736
{"aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)",
737737
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
738738
auto condition = args[0].ITensorOrFreeze(ctx);
739+
auto condition_nbDims = condition->getDimensions().nbDims;
739740
auto x = args[1].ITensorOrFreeze(ctx);
741+
auto x_nbDims = x->getDimensions().nbDims;
740742
auto y = args[2].ITensorOrFreeze(ctx);
743+
auto y_nbDims = y->getDimensions().nbDims;
744+
745+
// Get maximum rank of all input tensors
746+
auto max_nbDims = std::max(condition_nbDims, std::max(x_nbDims, y_nbDims));
747+
748+
// TensorRT requires all inputs to Select layers to have the same rank, so for each
749+
// tensor input, ensure that its rank is equal to the maximum number of dimensions
750+
// If not, left-pad the tensor dimension with 1s until the max rank is achieved
751+
condition =
752+
addPadding(ctx, n, condition, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
753+
x = addPadding(ctx, n, x, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
754+
y = addPadding(ctx, n, y, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
741755

742756
auto layer = ctx->net->addSelect(*condition, *x, *y);
743757

core/lowering/lowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct LowerInfo {
2020
std::vector<std::string> forced_fallback_modules;
2121
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
2222

23-
std::string getGPUDeviceString() {
23+
std::string getGPUDeviceString() const {
2424
return "cuda:" + std::to_string(target_device.gpu_id);
2525
};
2626
};

core/partitioning/partitioninginfo/PartitioningInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ struct PartitioningInfo {
1616
uint64_t min_block_size = 1;
1717
std::vector<std::string> forced_fallback_operators;
1818
bool truncate_long_and_double;
19+
ir::Device target_device;
20+
21+
std::string getGPUDeviceString() const {
22+
return "cuda:" + std::to_string(target_device.gpu_id);
23+
};
1924
};
2025

2126
std::ostream& operator<<(std::ostream& os, const PartitioningInfo& s);

core/partitioning/shape_analysis.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
9999
return nullptr;
100100
}
101101

102-
torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input) {
102+
torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
103103
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
104104
auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index];
105105
torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value);
@@ -125,8 +125,11 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
125125
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
126126
auto const_zero = g->insertConstant(0);
127127
const_zero->setType(torch::jit::BoolType::get());
128+
auto cuda = g->insertConstant(device);
129+
cuda->setType(torch::jit::DeviceObjType::get());
128130
auto none_val = g->insertNode(g->createNone())->output();
129-
cast_node = g->create(torch::jit::aten::to, {cast_subgraph_value, const_type, const_zero, const_zero, none_val});
131+
cast_node =
132+
g->create(torch::jit::aten::to, {cast_subgraph_value, cuda, const_type, const_zero, const_zero, none_val});
130133
}
131134
return cast_node;
132135
}
@@ -217,6 +220,8 @@ void getSegmentsOutputByRunning(
217220
ivalues_maps[output] = jit_results[idx++];
218221
}
219222

223+
auto target_device = partitioning_info.getGPUDeviceString();
224+
220225
// auto int64 <=> int32 conversion
221226
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) {
222227
// First, check if there is Int64 input
@@ -226,7 +231,7 @@ void getSegmentsOutputByRunning(
226231
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
227232
if (t == at::kLong) {
228233
// we add a cast operation to cast the type to Int64
229-
auto cast_node = createCastNode(seg_block, i, true);
234+
auto cast_node = createCastNode(seg_block, i, true, target_device);
230235
seg_block.g()->prependNode(cast_node);
231236
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
232237
}
@@ -237,7 +242,7 @@ void getSegmentsOutputByRunning(
237242
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
238243
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
239244
if (t == at::kLong) {
240-
auto cast_node = createCastNode(seg_block, i, false);
245+
auto cast_node = createCastNode(seg_block, i, false, target_device);
241246
seg_block.g()->appendNode(cast_node);
242247
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
243248
}

cpp/src/compile_spec.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
111111
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
112112
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
113113
internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;
114+
internal.partitioning_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;
114115

115116
TORCHTRT_CHECK(
116117
!(external.require_full_compilation && (external.torch_executed_ops.size() > 0)),
@@ -132,11 +133,13 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
132133
case Device::DeviceType::kDLA:
133134
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kDLA;
134135
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kDLA;
136+
internal.partitioning_info.target_device.device_type = nvinfer1::DeviceType::kDLA;
135137
break;
136138
case Device::DeviceType::kGPU:
137139
default:
138140
internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kGPU;
139141
internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kGPU;
142+
internal.partitioning_info.target_device.device_type = nvinfer1::DeviceType::kGPU;
140143
}
141144

142145
switch (external.capability) {
@@ -155,6 +158,9 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
155158
internal.convert_info.engine_settings.device.dla_core = external.device.dla_core;
156159
internal.lower_info.target_device.gpu_id = external.device.gpu_id;
157160
internal.lower_info.target_device.dla_core = external.device.dla_core;
161+
internal.partitioning_info.target_device.gpu_id = external.device.gpu_id;
162+
internal.partitioning_info.target_device.dla_core = external.device.dla_core;
163+
158164
internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
159165
internal.convert_info.engine_settings.workspace_size = external.workspace_size;
160166
internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size;

docs/_cpp_api/classtorch__tensorrt_1_1DataType.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class DataType &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Class DataType &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1Device_1_1DeviceType.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class Device::DeviceType &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Class Device::DeviceType &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1TensorFormat.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class TensorFormat &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Class TensorFormat &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8CacheCalibrator.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Template Class Int8CacheCalibrator &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Template Class Int8CacheCalibrator &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8Calibrator.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Template Class Int8Calibrator &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Template Class Int8Calibrator &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a18d295a837ac71add5578860b55e5502.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define STR &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Define STR &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a282fd3c0b1c3a215148ae372070e1268.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_PATCH_VERSION &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Define TORCH_TENSORRT_PATCH_VERSION &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a31398a6d4d27e28817afb0f0139e909e.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_MAJOR_VERSION &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Define TORCH_TENSORRT_MAJOR_VERSION &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1a35703561b26b1a9d2738ad7d58b27827.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_MINOR_VERSION &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Define TORCH_TENSORRT_MINOR_VERSION &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1abd1465eb38256d3f22cc1426b23d516b.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCHTRT_API &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Define TORCHTRT_API &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1abe87b341f562fd1cf40b7672e4d759da.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define XSTR &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Define XSTR &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1ad19939408f7be171a74a89928b36eb59.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCHTRT_HIDDEN &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Define TORCHTRT_HIDDEN &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/define_macros_8h_1adad592a7b1b7eed529cdf6acd584c883.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Define TORCH_TENSORRT_VERSION &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Define TORCH_TENSORRT_VERSION &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -215,7 +215,7 @@
215215

216216

217217
<div class="version">
218-
v1.4.0dev0+b7ceedf
218+
v1.4.0dev0+af39c65
219219
</div>
220220

221221

docs/_cpp_api/dir_cpp.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Directory cpp &mdash; Torch-TensorRT v1.4.0dev0+b7ceedf documentation</title>
13+
<title>Directory cpp &mdash; Torch-TensorRT v1.4.0dev0+af39c65 documentation</title>
1414

1515

1616

@@ -213,7 +213,7 @@
213213

214214

215215
<div class="version">
216-
v1.4.0dev0+b7ceedf
216+
v1.4.0dev0+af39c65
217217
</div>
218218

219219

0 commit comments

Comments
 (0)