Skip to content

Commit fe37a86

Browse files
peri044narendasan
andauthored
feat: Upgrade stack in release/1.3 branch (#1485)
Signed-off-by: Dheeraj Peri <[email protected]> Co-authored-by: Naren Dasan <[email protected]>
1 parent 4611b1e commit fe37a86

File tree

18 files changed

+328
-286
lines changed

18 files changed

+328
-286
lines changed

.circleci/config.yml

Lines changed: 180 additions & 174 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
113113
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
114114

115115
- Bazel 5.2.0
116-
- Libtorch 1.12.1 (built with CUDA 11.6)
117-
- CUDA 11.6
116+
- Libtorch 1.13.0 (built with CUDA 11.7)
117+
- CUDA 11.7
118118
- cuDNN 8.4.1
119-
- TensorRT 8.4.3.1
119+
- TensorRT 8.5.1.7
120120

121121
## Prebuilt Binaries and Wheel files
122122

WORKSPACE

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ local_repository(
4141
new_local_repository(
4242
name = "cuda",
4343
build_file = "@//third_party/cuda:BUILD",
44-
path = "/usr/local/cuda-11.6/",
44+
path = "/usr/local/cuda-11.7/",
4545
)
4646

4747
new_local_repository(
@@ -56,17 +56,17 @@ new_local_repository(
5656
http_archive(
5757
name = "libtorch",
5858
build_file = "@//third_party/libtorch:BUILD",
59-
sha256 = "486106ddc5b5ad532f030f447940a571b924da821b9534d25c0cef5503cdfaea",
59+
sha256 = "0a013dceedb252f4965b666a2ad772d962135597db5889bd5d43644697c17dbc",
6060
strip_prefix = "libtorch",
61-
urls = ["https://download.pytorch.org/libtorch/nightly/cu116/libtorch-cxx11-abi-shared-with-deps-1.13.0.dev20220921%2Bcu116.zip"],
61+
urls = ["https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip"],
6262
)
6363

6464
http_archive(
6565
name = "libtorch_pre_cxx11_abi",
6666
build_file = "@//third_party/libtorch:BUILD",
67-
sha256 = "b304ebf26effcbbefcec99134bcfb0127c499306343fbe2e2cd127213448a4a6",
67+
sha256 = "cdbd43985ad9d5886793d5dc455d665cf3fd4b4617ef1094479678ff210ed0af",
6868
strip_prefix = "libtorch",
69-
urls = ["https://download.pytorch.org/libtorch/nightly/cu116/libtorch-shared-with-deps-1.13.0.dev20220921%2Bcu116.zip"],
69+
urls = ["https://download.pytorch.org/libtorch/cu117/libtorch-shared-with-deps-1.13.0%2Bcu117.zip"],
7070
)
7171

7272
# Download these tarballs manually from the NVIDIA website

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,20 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
102102
}
103103

104104
auto w = Weights(ctx, args[1].unwrapToTensor());
105+
// TODO: Remove this when conv3d with kernel size=1 bug is fixed.
106+
// Github issue: https://github.com/pytorch/TensorRT/issues/1445
107+
bool is_kernel_size_one = true;
108+
bool is_3d_kernel = w.kernel_shape.nbDims == 3;
109+
for (int64_t i = 0; i < w.kernel_shape.nbDims; i++) {
110+
if (w.kernel_shape.d[i] != 1.0f) {
111+
is_kernel_size_one = false;
112+
}
113+
}
114+
if (is_kernel_size_one && is_3d_kernel) {
115+
LOG_WARNING(
116+
"Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \
117+
Github issue: https://github.com/pytorch/TensorRT/issues/1445. Other conv variants do not have this issue.");
118+
}
105119
auto dims = in->getDimensions();
106120
auto orig_dims = dims;
107121
LOG_DEBUG("Input dims: " << orig_dims);

core/conversion/converters/impl/einsum.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace impl {
1212
namespace {
1313

1414
auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
15-
{"aten::einsum(str equation, Tensor[] tensors) -> (Tensor)",
15+
{"aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> (Tensor)",
1616
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1717
// Extract equation and list of tensors
1818
auto equation = args[0].unwrapToString();

core/runtime/TRTEngine.cpp

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ TRTEngine::TRTEngine(
6868
uint64_t inputs = 0;
6969
uint64_t outputs = 0;
7070

71-
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
72-
std::string bind_name = cuda_engine->getBindingName(x);
71+
for (int64_t trt_idx = 0; trt_idx < cuda_engine->getNbIOTensors(); trt_idx++) {
72+
std::string bind_name = cuda_engine->getIOTensorName(trt_idx);
7373
LOG_DEBUG("Binding name: " << bind_name);
7474
auto delim = bind_name.find(".");
7575
if (delim == std::string::npos) {
@@ -80,46 +80,45 @@ TRTEngine::TRTEngine(
8080
<< bind_name
8181
<< "\nEnsure module was compiled with Torch-TensorRT.ts or follows Torch-TensorRT Runtime conventions");
8282
}
83-
8483
std::string idx_s = bind_name.substr(delim + 1);
85-
uint64_t idx = static_cast<uint64_t>(std::stoi(idx_s));
84+
uint64_t pyt_idx = static_cast<uint64_t>(std::stoi(idx_s));
8685

87-
if (cuda_engine->bindingIsInput(x)) {
86+
if (cuda_engine->getTensorIOMode(bind_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) {
8887
inputs++;
89-
in_binding_map[x] = idx;
90-
LOG_DEBUG("TRT Binding: " << x << ": PYT Input: " << idx);
88+
in_binding_map[trt_idx] = pyt_idx;
89+
LOG_DEBUG("TRT Binding index: " << trt_idx << "corresponds to PYT Input index: " << pyt_idx);
9190
} else {
9291
outputs++;
93-
out_binding_map[x] = idx;
94-
LOG_DEBUG("TRT Binding: " << x << ": PYT Output: " << idx);
92+
out_binding_map[trt_idx] = pyt_idx;
93+
LOG_DEBUG("TRT Binding index: " << trt_idx << "corresponds to PYT Output: " << pyt_idx);
9594
}
9695
}
9796

9897
num_io = std::make_pair(inputs, outputs);
9998
in_binding_names.resize(inputs);
10099
out_binding_names.resize(outputs);
101-
102-
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
103-
std::string bind_name = cuda_engine->getBindingName(x);
104-
if (cuda_engine->bindingIsInput(x)) {
100+
for (int64_t x = 0; x < cuda_engine->getNbIOTensors(); x++) {
101+
std::string bind_name = cuda_engine->getIOTensorName(x);
102+
if (cuda_engine->getTensorIOMode(bind_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) {
105103
in_binding_names[in_binding_map.at(x)] = bind_name;
106104
} else {
107105
out_binding_names[out_binding_map.at(x)] = bind_name;
108106
}
109107
}
110108
} else {
111-
uint64_t inputs = _in_binding_names.size();
112-
in_binding_names.resize(inputs);
113-
for (size_t pyt_idx = 0; pyt_idx < inputs; pyt_idx++) {
109+
uint64_t inputs_size = _in_binding_names.size();
110+
in_binding_names.resize(inputs_size);
111+
for (size_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) {
114112
auto binding_name = _in_binding_names[pyt_idx];
115113
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
116-
TORCHTRT_CHECK((trt_idx >= 0), "Could not find a TensorRT engine binding for input named " << binding_name);
114+
std::string engine_binded_name = cuda_engine->getIOTensorName(pyt_idx);
117115
TORCHTRT_CHECK(
118-
cuda_engine->bindingIsInput(trt_idx),
116+
(binding_name == engine_binded_name),
117+
"Could not find a TensorRT engine binding for input named " << binding_name);
118+
TORCHTRT_CHECK(
119+
(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
119120
"Binding " << binding_name << " specified as input but found as output in TensorRT engine");
120-
LOG_DEBUG(
121-
"Input binding name: " << binding_name << " (trt binding idx: " << trt_idx << ", "
122-
<< "pyt arg idx: " << pyt_idx << ")");
121+
LOG_DEBUG("Input binding name: " << binding_name << "pyt arg idx: " << pyt_idx << ")");
123122
in_binding_map[trt_idx] = pyt_idx;
124123
in_binding_names[pyt_idx] = _in_binding_names[pyt_idx];
125124
}
@@ -129,30 +128,31 @@ TRTEngine::TRTEngine(
129128
for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) {
130129
auto binding_name = _out_binding_names[pyt_idx];
131130
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
132-
TORCHTRT_CHECK((trt_idx >= 0), "Could not find a TensorRT engine binding for output named " << binding_name);
131+
std::string engine_binded_name = cuda_engine->getIOTensorName(inputs_size + pyt_idx);
132+
TORCHTRT_CHECK(
133+
(binding_name == engine_binded_name),
134+
"Could not find a TensorRT engine binding for output named " << binding_name);
133135
TORCHTRT_CHECK(
134-
!cuda_engine->bindingIsInput(trt_idx),
136+
!(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
135137
"Binding " << binding_name << " specified as output but found as input in TensorRT engine");
136-
LOG_DEBUG(
137-
"Output binding name: " << binding_name << " (trt binding idx: " << trt_idx << ", "
138-
<< "pyt return idx: " << pyt_idx << ")");
138+
LOG_DEBUG("Output binding name: " << binding_name << "pyt return idx: " << inputs_size + pyt_idx << ")");
139139
out_binding_map[trt_idx] = pyt_idx;
140140
out_binding_names[pyt_idx] = binding_name;
141141
}
142-
num_io = std::make_pair(inputs, outputs);
142+
num_io = std::make_pair(inputs_size, outputs);
143143
}
144144

145-
#ifndef NDEBUG
146-
this->enable_profiling();
147-
#endif
145+
#ifndef NDEBUG
146+
this->enable_profiling();
147+
#endif
148148
LOG_DEBUG(*this);
149149
}
150150

151151
TRTEngine::~TRTEngine() {
152+
rt.reset();
152153
trt_engine_profiler.reset();
153154
exec_ctx.reset();
154155
cuda_engine.reset();
155-
rt.reset();
156156
}
157157

158158
void TRTEngine::disable_profiling() {
@@ -164,7 +164,7 @@ void TRTEngine::disable_profiling() {
164164
}
165165

166166
void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) {
167-
auto inspector = cuda_engine->createEngineInspector();
167+
auto inspector = make_trt(cuda_engine->createEngineInspector());
168168
std::ofstream f(path);
169169
f << std::string(inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON));
170170
f.close();
@@ -208,23 +208,23 @@ std::string TRTEngine::to_str() const {
208208
std::stringstream ss;
209209
ss << "Torch-TensorRT TensorRT Engine:" << std::endl;
210210
ss << " Name: " << name << std::endl;
211-
ss << " Bindings: {" << std::endl;
212-
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
213-
if (cuda_engine->bindingIsInput(x)) {
214-
const uint64_t pyt_idx = in_binding_map.at(x);
215-
ss << " (" << x << ": " << in_binding_names.at(pyt_idx) << ") Input: [" << std::endl;
216-
ss << " pytorch arg idx: " << pyt_idx << std::endl;
217-
ss << " shape: " << exec_ctx->getBindingDimensions(x) << std::endl;
218-
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(x)) << std::endl;
219-
ss << " ]" << std::endl;
220-
} else {
221-
const uint64_t pyt_idx = out_binding_map.at(x);
222-
ss << " (" << x << ": " << out_binding_names.at(pyt_idx) << ") Output: [" << std::endl;
223-
ss << " pytorch return idx: " << pyt_idx << std::endl;
224-
ss << " shape: " << exec_ctx->getBindingDimensions(x) << std::endl;
225-
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(x)) << std::endl;
226-
ss << " ]" << std::endl;
227-
}
211+
ss << " Inputs: [" << std::endl;
212+
for (uint64_t i = 0; i < num_io.first; i++) {
213+
ss << " id: " << i << std::endl;
214+
ss << " shape: " << exec_ctx->getTensorShape(std::string("input_" + str(i)).c_str()) << std::endl;
215+
ss << " dtype: "
216+
<< util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(std::string("input_" + str(i)).c_str()))
217+
<< std::endl;
218+
}
219+
ss << " ]" << std::endl;
220+
ss << " Outputs: [" << std::endl;
221+
for (uint64_t o = 0; o < num_io.second; o++) {
222+
ss << " id: " << o << std::endl;
223+
ss << " shape: " << exec_ctx->getTensorShape(std::string("output_" + str(o)).c_str()) << std::endl;
224+
ss << " dtype: "
225+
<< util::TRTDataTypeToScalarType(
226+
exec_ctx->getEngine().getTensorDataType(std::string("output_" + str(o)).c_str()))
227+
<< std::endl;
228228
}
229229
ss << " }" << std::endl;
230230
ss << " Device: " << device_info << std::endl;

core/runtime/execute_engine.cpp

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -121,36 +121,30 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
121121
}
122122
}
123123

124-
std::vector<void*> gpu_handles;
125-
std::vector<at::Tensor> contig_inputs{};
126124
{
127125
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
128126
if (compiled_engine->profile_execution) {
129127
input_profiler_guard =
130128
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
131129
}
132-
133-
contig_inputs.reserve(inputs.size());
134-
135130
for (size_t i = 0; i < inputs.size(); i++) {
136-
uint64_t pyt_idx = compiled_engine->in_binding_map[i];
131+
std::string name = compiled_engine->in_binding_names[i];
137132
TORCHTRT_CHECK(
138-
inputs[pyt_idx].is_cuda(),
139-
"Expected input tensors to have device cuda, found device " << inputs[pyt_idx].device());
140-
auto expected_type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getBindingDataType(i));
133+
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
134+
auto expected_type =
135+
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
141136
TORCHTRT_CHECK(
142-
inputs[pyt_idx].dtype() == expected_type,
143-
"Expected input tensors to have type " << expected_type << ", found type " << inputs[pyt_idx].dtype());
144-
auto dims = core::util::toDimsPad(inputs[pyt_idx].sizes(), 1);
137+
inputs[i].dtype() == expected_type,
138+
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
139+
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
145140
auto shape = core::util::toVec(dims);
146-
contig_inputs.push_back(inputs[pyt_idx].view(shape).contiguous());
147-
LOG_DEBUG("Input shape: " << dims);
148-
compiled_engine->exec_ctx->setBindingDimensions(i, dims);
149-
gpu_handles.push_back(contig_inputs.back().data_ptr());
141+
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
142+
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
143+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
150144
}
145+
151146
TORCHTRT_CHECK(
152-
compiled_engine->exec_ctx->allInputDimensionsSpecified(),
153-
"Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
147+
compiled_engine->exec_ctx->allInputShapesSpecified(), "Not enough inputs provided (runtime.RunCudaEngine)");
154148
}
155149

156150
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
@@ -163,26 +157,27 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
163157

164158
for (size_t o = inputs.size(); o < (compiled_engine->num_io.first + compiled_engine->num_io.second); o++) {
165159
uint64_t pyt_idx = compiled_engine->out_binding_map[o];
166-
auto out_shape = compiled_engine->exec_ctx->getBindingDimensions(o);
167-
LOG_DEBUG("Output shape: " << out_shape);
160+
std::string name = compiled_engine->out_binding_names[pyt_idx];
161+
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
162+
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);
168163
auto dims = core::util::toVec(out_shape);
169-
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getBindingDataType(o));
164+
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
170165
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
171-
gpu_handles.push_back(outputs[pyt_idx].data_ptr());
166+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr());
172167
}
173168
}
169+
174170
{
175171
std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard;
176172
if (compiled_engine->profile_execution) {
177173
enqueue_profiler_guard =
178174
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
179175
}
180-
181176
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index());
182177

183178
// nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it.
184179
std::unique_lock<std::mutex> lock(compiled_engine->mu);
185-
compiled_engine->exec_ctx->enqueueV2(gpu_handles.data(), stream, nullptr);
180+
compiled_engine->exec_ctx->enqueueV3(stream);
186181
if (compiled_engine->profile_execution) {
187182
LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler);
188183
dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler);

docker/WORKSPACE.docker

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,17 @@ new_local_repository(
5151
http_archive(
5252
name = "libtorch",
5353
build_file = "@//third_party/libtorch:BUILD",
54-
sha256 = "8d9e829ce9478db4f35bdb7943308cf02e8a2f58cf9bb10f742462c1d57bf287",
54+
sha256 = "0a013dceedb252f4965b666a2ad772d962135597db5889bd5d43644697c17dbc",
5555
strip_prefix = "libtorch",
56-
urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.11.0%2Bcu113.zip"],
56+
urls = ["https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip"],
5757
)
5858

5959
http_archive(
6060
name = "libtorch_pre_cxx11_abi",
6161
build_file = "@//third_party/libtorch:BUILD",
62-
sha256 = "90159ecce3ff451f3ef3f657493b6c7c96759c3b74bbd70c1695f2ea2f81e1ad",
62+
sha256 = "cdbd43985ad9d5886793d5dc455d665cf3fd4b4617ef1094479678ff210ed0af",
6363
strip_prefix = "libtorch",
64-
urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-shared-with-deps-1.11.0%2Bcu113.zip"],
64+
urls = ["https://download.pytorch.org/libtorch/cu117/libtorch-shared-with-deps-1.13.0%2Bcu117.zip"],
6565
)
6666

6767
####################################################################################

py/requirements.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
numpy
22
pybind11==2.6.2
3-
--extra-index-url https://download.pytorch.org/whl/nightly/cu116
4-
torch==1.13.0.dev20220921+cu116
5-
torchvision==0.14.0.dev20220921+cu116
3+
torch==1.13.0
4+
torchvision==0.14.0
65
--extra-index-url https://pypi.ngc.nvidia.com
7-
nvidia-tensorrt==8.4.3.1
6+
tensorrt==8.5.1.7

py/torch_tensorrt/fx/converters/convolution.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import tensorrt as trt
44
import torch
5+
import logging
56

67
from ..converter_registry import tensorrt_converter
78

@@ -12,6 +13,8 @@
1213
to_numpy,
1314
)
1415

16+
logger = logging.getLogger(__name__)
17+
1518

1619
def common_conv(network, mod, dimension, input_val, layer_name, is_quantized):
1720
if mod.padding_mode != "zeros":
@@ -139,6 +142,18 @@ def conv3d(network, submod, args, kwargs, layer_name):
139142
# args/kwargs should have already been normalized to kwargs
140143
assert len(args) == 0
141144
input_val = kwargs["input"]
145+
# TODO: Remove this warning when https://github.com/pytorch/TensorRT/issues/1445 is fixed
146+
kernel = to_numpy(submod.weight)
147+
kernel_size_one = True
148+
if len(kernel.shape) == 5:
149+
for filter_size in kernel.shape[2:]:
150+
if filter_size != 1:
151+
kernel_size_one = False
152+
if kernel_size_one:
153+
logger.warn(
154+
"Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \
155+
Github issue: https://github.com/pytorch/TensorRT/issues/1445. Other conv variants do not have this issue."
156+
)
142157

143158
if not isinstance(input_val, trt.tensorrt.ITensor):
144159
raise RuntimeError(

0 commit comments

Comments
 (0)