Skip to content

Upgrade stack in release/1.3 branch #1485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
e2cf026
feat: Upgrade CUDA to 11.7
peri044 Oct 25, 2022
beea264
chore: Update circle CI config.yaml
peri044 Nov 1, 2022
e3db145
chore: Fix install-cuda part
peri044 Nov 1, 2022
7464b0b
chore: fix circle ci config
peri044 Nov 1, 2022
ff3c488
chore: modify cudnn version
peri044 Nov 1, 2022
fc8c186
chore: minor fix
peri044 Nov 1, 2022
04700aa
chore: minor fix
peri044 Nov 1, 2022
97e81d5
chore: minor fix
peri044 Nov 1, 2022
91ac151
chore: minor fix
peri044 Nov 1, 2022
5e99485
chore: uninstall cuda-11.4
peri044 Nov 1, 2022
a2eee26
chore: minor fix
peri044 Nov 1, 2022
031eeef
chore: minor fix
peri044 Nov 1, 2022
551f35b
chore: minor fix
peri044 Nov 1, 2022
d65d489
chore: minor fix
peri044 Nov 1, 2022
9785845
chore: minor fix
peri044 Nov 1, 2022
7df7191
chore: minor fix
peri044 Nov 1, 2022
d74c8d2
chore: minor fix
peri044 Nov 1, 2022
4cebdff
chore: Minor fix
peri044 Nov 1, 2022
5a9b0b3
chore: minor fixes
peri044 Nov 1, 2022
02d9538
chore: fix cudnn conflicts
peri044 Nov 1, 2022
1ce8f29
chore: minor fix
peri044 Nov 2, 2022
42c1f94
feat: upgrade to TRT 8.5 GA
peri044 Nov 2, 2022
d9cf3f7
chore: minor fix for trt 8.5
peri044 Nov 2, 2022
b213e58
chore: Upgrade TRT and lint files
peri044 Nov 2, 2022
7705585
chore: Upgrade to 1.13 pyt release
peri044 Nov 2, 2022
d1ae7e4
chore: minor fix
peri044 Nov 3, 2022
6318e1d
chore: Update circle ci to add more fx tests
peri044 Nov 15, 2022
06a878a
chore: update config
peri044 Nov 15, 2022
a6ccefa
chore: minor fix
peri044 Nov 15, 2022
2506b58
chore: minor fix
peri044 Nov 15, 2022
d787aaf
chore: minor fix
peri044 Nov 15, 2022
b035fec
chore: minor fix
peri044 Nov 15, 2022
58f9b7d
chore: minor fix
peri044 Nov 15, 2022
0349ba3
chore: minor fix
peri044 Nov 15, 2022
7c7fbf2
chore: minor fix
peri044 Nov 15, 2022
ed753ac
chore: Fix threshold failure for stack converter
peri044 Nov 18, 2022
621b2fb
chore: linter fixes
peri044 Nov 21, 2022
c49f991
chore: Add FX test warning for conv3d
peri044 Nov 22, 2022
23e1131
chore: resolve merge conflicts
peri044 Nov 22, 2022
90b13a9
chore: remove debug statements from execute_engine
narendasan Nov 22, 2022
f11300b
chore: Fix destructor calls of cuda engine
peri044 Nov 22, 2022
ab9c236
chore: Resolve TRT engine merge conflicts
peri044 Nov 23, 2022
55f1842
chore: Linter fixes
peri044 Nov 23, 2022
83f5c17
chore: minor fixes
peri044 Nov 23, 2022
c14d41b
chore: Use in and out binding maps to preserve order
peri044 Nov 23, 2022
34f5742
chore: resolve merge conflicts
peri044 Nov 28, 2022
dbc9557
chore: add conv3d patch
peri044 Nov 18, 2022
acad32a
chore: Add changes from conv3d patch
peri044 Nov 22, 2022
7a0ee23
chore: disable test_permute_linear
peri044 Nov 29, 2022
fd7e535
chore: Remove dev versions of pytorch
peri044 Nov 29, 2022
c00292b
chore: Linter fixes
peri044 Nov 29, 2022
7f10168
chore: Address review comments
peri044 Nov 29, 2022
df452e5
chore: add unittest dep
peri044 Nov 29, 2022
d4f4717
chore: Linter fixes
peri044 Nov 29, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
354 changes: 180 additions & 174 deletions .circleci/config.yml

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.

- Bazel 5.2.0
- Libtorch 1.12.1 (built with CUDA 11.6)
- CUDA 11.6
- Libtorch 1.13.0 (built with CUDA 11.7)
- CUDA 11.7
- cuDNN 8.4.1
- TensorRT 8.4.3.1
- TensorRT 8.5.1.7

## Prebuilt Binaries and Wheel files

Expand Down
10 changes: 5 additions & 5 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ local_repository(
new_local_repository(
name = "cuda",
build_file = "@//third_party/cuda:BUILD",
path = "/usr/local/cuda-11.6/",
path = "/usr/local/cuda-11.7/",
)

new_local_repository(
Expand All @@ -56,17 +56,17 @@ new_local_repository(
http_archive(
name = "libtorch",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "486106ddc5b5ad532f030f447940a571b924da821b9534d25c0cef5503cdfaea",
sha256 = "0a013dceedb252f4965b666a2ad772d962135597db5889bd5d43644697c17dbc",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu116/libtorch-cxx11-abi-shared-with-deps-1.13.0.dev20220921%2Bcu116.zip"],
urls = ["https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip"],
)

http_archive(
name = "libtorch_pre_cxx11_abi",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "b304ebf26effcbbefcec99134bcfb0127c499306343fbe2e2cd127213448a4a6",
sha256 = "cdbd43985ad9d5886793d5dc455d665cf3fd4b4617ef1094479678ff210ed0af",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/nightly/cu116/libtorch-shared-with-deps-1.13.0.dev20220921%2Bcu116.zip"],
urls = ["https://download.pytorch.org/libtorch/cu117/libtorch-shared-with-deps-1.13.0%2Bcu117.zip"],
)

# Download these tarballs manually from the NVIDIA website
Expand Down
14 changes: 14 additions & 0 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
}

auto w = Weights(ctx, args[1].unwrapToTensor());
// TODO: Remove this when conv3d with kernel size=1 bug is fixed.
// Github issue: https://github.com/pytorch/TensorRT/issues/1445
bool is_kernel_size_one = true;
bool is_3d_kernel = w.kernel_shape.nbDims == 3;
for (int64_t i = 0; i < w.kernel_shape.nbDims; i++) {
if (w.kernel_shape.d[i] != 1.0f) {
is_kernel_size_one = false;
}
}
if (is_kernel_size_one && is_3d_kernel) {
LOG_WARNING(
"Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \
Github issue: https://github.com/pytorch/TensorRT/issues/1445. Other conv variants do not have this issue.");
}
auto dims = in->getDimensions();
auto orig_dims = dims;
LOG_DEBUG("Input dims: " << orig_dims);
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace impl {
namespace {

auto einsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::einsum(str equation, Tensor[] tensors) -> (Tensor)",
{"aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Extract equation and list of tensors
auto equation = args[0].unwrapToString();
Expand Down
98 changes: 49 additions & 49 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ TRTEngine::TRTEngine(
uint64_t inputs = 0;
uint64_t outputs = 0;

for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
std::string bind_name = cuda_engine->getBindingName(x);
for (int64_t trt_idx = 0; trt_idx < cuda_engine->getNbIOTensors(); trt_idx++) {
std::string bind_name = cuda_engine->getIOTensorName(trt_idx);
LOG_DEBUG("Binding name: " << bind_name);
auto delim = bind_name.find(".");
if (delim == std::string::npos) {
Expand All @@ -80,46 +80,45 @@ TRTEngine::TRTEngine(
<< bind_name
<< "\nEnsure module was compiled with Torch-TensorRT.ts or follows Torch-TensorRT Runtime conventions");
}

std::string idx_s = bind_name.substr(delim + 1);
uint64_t idx = static_cast<uint64_t>(std::stoi(idx_s));
uint64_t pyt_idx = static_cast<uint64_t>(std::stoi(idx_s));

if (cuda_engine->bindingIsInput(x)) {
if (cuda_engine->getTensorIOMode(bind_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) {
inputs++;
in_binding_map[x] = idx;
LOG_DEBUG("TRT Binding: " << x << ": PYT Input: " << idx);
in_binding_map[trt_idx] = pyt_idx;
LOG_DEBUG("TRT Binding index: " << trt_idx << "corresponds to PYT Input index: " << pyt_idx);
} else {
outputs++;
out_binding_map[x] = idx;
LOG_DEBUG("TRT Binding: " << x << ": PYT Output: " << idx);
out_binding_map[trt_idx] = pyt_idx;
LOG_DEBUG("TRT Binding index: " << trt_idx << "corresponds to PYT Output: " << pyt_idx);
}
}

num_io = std::make_pair(inputs, outputs);
in_binding_names.resize(inputs);
out_binding_names.resize(outputs);

for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
std::string bind_name = cuda_engine->getBindingName(x);
if (cuda_engine->bindingIsInput(x)) {
for (int64_t x = 0; x < cuda_engine->getNbIOTensors(); x++) {
std::string bind_name = cuda_engine->getIOTensorName(x);
if (cuda_engine->getTensorIOMode(bind_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) {
in_binding_names[in_binding_map.at(x)] = bind_name;
} else {
out_binding_names[out_binding_map.at(x)] = bind_name;
}
}
} else {
uint64_t inputs = _in_binding_names.size();
in_binding_names.resize(inputs);
for (size_t pyt_idx = 0; pyt_idx < inputs; pyt_idx++) {
uint64_t inputs_size = _in_binding_names.size();
in_binding_names.resize(inputs_size);
for (size_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) {
auto binding_name = _in_binding_names[pyt_idx];
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
TORCHTRT_CHECK((trt_idx >= 0), "Could not find a TensorRT engine binding for input named " << binding_name);
std::string engine_binded_name = cuda_engine->getIOTensorName(pyt_idx);
TORCHTRT_CHECK(
cuda_engine->bindingIsInput(trt_idx),
(binding_name == engine_binded_name),
"Could not find a TensorRT engine binding for input named " << binding_name);
TORCHTRT_CHECK(
(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
"Binding " << binding_name << " specified as input but found as output in TensorRT engine");
LOG_DEBUG(
"Input binding name: " << binding_name << " (trt binding idx: " << trt_idx << ", "
<< "pyt arg idx: " << pyt_idx << ")");
LOG_DEBUG("Input binding name: " << binding_name << "pyt arg idx: " << pyt_idx << ")");
in_binding_map[trt_idx] = pyt_idx;
in_binding_names[pyt_idx] = _in_binding_names[pyt_idx];
}
Expand All @@ -129,30 +128,31 @@ TRTEngine::TRTEngine(
for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) {
auto binding_name = _out_binding_names[pyt_idx];
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
TORCHTRT_CHECK((trt_idx >= 0), "Could not find a TensorRT engine binding for output named " << binding_name);
std::string engine_binded_name = cuda_engine->getIOTensorName(inputs_size + pyt_idx);
TORCHTRT_CHECK(
(binding_name == engine_binded_name),
"Could not find a TensorRT engine binding for output named " << binding_name);
TORCHTRT_CHECK(
!cuda_engine->bindingIsInput(trt_idx),
!(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
"Binding " << binding_name << " specified as output but found as input in TensorRT engine");
LOG_DEBUG(
"Output binding name: " << binding_name << " (trt binding idx: " << trt_idx << ", "
<< "pyt return idx: " << pyt_idx << ")");
LOG_DEBUG("Output binding name: " << binding_name << "pyt return idx: " << inputs_size + pyt_idx << ")");
out_binding_map[trt_idx] = pyt_idx;
out_binding_names[pyt_idx] = binding_name;
}
num_io = std::make_pair(inputs, outputs);
num_io = std::make_pair(inputs_size, outputs);
}

#ifndef NDEBUG
this->enable_profiling();
#endif
#ifndef NDEBUG
this->enable_profiling();
#endif
LOG_DEBUG(*this);
}

TRTEngine::~TRTEngine() {
rt.reset();
trt_engine_profiler.reset();
exec_ctx.reset();
cuda_engine.reset();
rt.reset();
}

void TRTEngine::disable_profiling() {
Expand All @@ -164,7 +164,7 @@ void TRTEngine::disable_profiling() {
}

void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) {
auto inspector = cuda_engine->createEngineInspector();
auto inspector = make_trt(cuda_engine->createEngineInspector());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note for later, should we be keeping this object around?

std::ofstream f(path);
f << std::string(inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON));
f.close();
Expand Down Expand Up @@ -208,23 +208,23 @@ std::string TRTEngine::to_str() const {
std::stringstream ss;
ss << "Torch-TensorRT TensorRT Engine:" << std::endl;
ss << " Name: " << name << std::endl;
ss << " Bindings: {" << std::endl;
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
if (cuda_engine->bindingIsInput(x)) {
const uint64_t pyt_idx = in_binding_map.at(x);
ss << " (" << x << ": " << in_binding_names.at(pyt_idx) << ") Input: [" << std::endl;
ss << " pytorch arg idx: " << pyt_idx << std::endl;
ss << " shape: " << exec_ctx->getBindingDimensions(x) << std::endl;
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(x)) << std::endl;
ss << " ]" << std::endl;
} else {
const uint64_t pyt_idx = out_binding_map.at(x);
ss << " (" << x << ": " << out_binding_names.at(pyt_idx) << ") Output: [" << std::endl;
ss << " pytorch return idx: " << pyt_idx << std::endl;
ss << " shape: " << exec_ctx->getBindingDimensions(x) << std::endl;
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(x)) << std::endl;
ss << " ]" << std::endl;
}
ss << " Inputs: [" << std::endl;
for (uint64_t i = 0; i < num_io.first; i++) {
ss << " id: " << i << std::endl;
ss << " shape: " << exec_ctx->getTensorShape(std::string("input_" + str(i)).c_str()) << std::endl;
ss << " dtype: "
<< util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(std::string("input_" + str(i)).c_str()))
<< std::endl;
}
ss << " ]" << std::endl;
ss << " Outputs: [" << std::endl;
for (uint64_t o = 0; o < num_io.second; o++) {
ss << " id: " << o << std::endl;
ss << " shape: " << exec_ctx->getTensorShape(std::string("output_" + str(o)).c_str()) << std::endl;
ss << " dtype: "
<< util::TRTDataTypeToScalarType(
exec_ctx->getEngine().getTensorDataType(std::string("output_" + str(o)).c_str()))
<< std::endl;
}
ss << " }" << std::endl;
ss << " Device: " << device_info << std::endl;
Expand Down
43 changes: 19 additions & 24 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,36 +121,30 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
}

std::vector<void*> gpu_handles;
std::vector<at::Tensor> contig_inputs{};
{
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
if (compiled_engine->profile_execution) {
input_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
}

contig_inputs.reserve(inputs.size());

for (size_t i = 0; i < inputs.size(); i++) {
uint64_t pyt_idx = compiled_engine->in_binding_map[i];
std::string name = compiled_engine->in_binding_names[i];
TORCHTRT_CHECK(
inputs[pyt_idx].is_cuda(),
"Expected input tensors to have device cuda, found device " << inputs[pyt_idx].device());
auto expected_type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getBindingDataType(i));
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
auto expected_type =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
TORCHTRT_CHECK(
inputs[pyt_idx].dtype() == expected_type,
"Expected input tensors to have type " << expected_type << ", found type " << inputs[pyt_idx].dtype());
auto dims = core::util::toDimsPad(inputs[pyt_idx].sizes(), 1);
inputs[i].dtype() == expected_type,
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
auto shape = core::util::toVec(dims);
contig_inputs.push_back(inputs[pyt_idx].view(shape).contiguous());
LOG_DEBUG("Input shape: " << dims);
compiled_engine->exec_ctx->setBindingDimensions(i, dims);
gpu_handles.push_back(contig_inputs.back().data_ptr());
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
}

TORCHTRT_CHECK(
compiled_engine->exec_ctx->allInputDimensionsSpecified(),
"Not enough inputs provided (torch.ops.tensorrt.execute_engine)");
compiled_engine->exec_ctx->allInputShapesSpecified(), "Not enough inputs provided (runtime.RunCudaEngine)");
}

std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
Expand All @@ -163,26 +157,27 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

for (size_t o = inputs.size(); o < (compiled_engine->num_io.first + compiled_engine->num_io.second); o++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here, use output_binding_names to pull out the TensorRT binding index / address

uint64_t pyt_idx = compiled_engine->out_binding_map[o];
auto out_shape = compiled_engine->exec_ctx->getBindingDimensions(o);
LOG_DEBUG("Output shape: " << out_shape);
std::string name = compiled_engine->out_binding_names[pyt_idx];
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);
auto dims = core::util::toVec(out_shape);
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getBindingDataType(o));
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
gpu_handles.push_back(outputs[pyt_idx].data_ptr());
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr());
}
}

{
std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard;
if (compiled_engine->profile_execution) {
enqueue_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
}

c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index());

// nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it.
std::unique_lock<std::mutex> lock(compiled_engine->mu);
compiled_engine->exec_ctx->enqueueV2(gpu_handles.data(), stream, nullptr);
compiled_engine->exec_ctx->enqueueV3(stream);
if (compiled_engine->profile_execution) {
LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler);
dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler);
Expand Down
8 changes: 4 additions & 4 deletions docker/WORKSPACE.docker
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ new_local_repository(
http_archive(
name = "libtorch",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "8d9e829ce9478db4f35bdb7943308cf02e8a2f58cf9bb10f742462c1d57bf287",
sha256 = "0a013dceedb252f4965b666a2ad772d962135597db5889bd5d43644697c17dbc",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.11.0%2Bcu113.zip"],
urls = ["https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip"],
)

http_archive(
name = "libtorch_pre_cxx11_abi",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "90159ecce3ff451f3ef3f657493b6c7c96759c3b74bbd70c1695f2ea2f81e1ad",
sha256 = "cdbd43985ad9d5886793d5dc455d665cf3fd4b4617ef1094479678ff210ed0af",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-shared-with-deps-1.11.0%2Bcu113.zip"],
urls = ["https://download.pytorch.org/libtorch/cu117/libtorch-shared-with-deps-1.13.0%2Bcu117.zip"],
)

####################################################################################
Expand Down
7 changes: 3 additions & 4 deletions py/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
numpy
pybind11==2.6.2
--extra-index-url https://download.pytorch.org/whl/nightly/cu116
torch==1.13.0.dev20220921+cu116
torchvision==0.14.0.dev20220921+cu116
torch==1.13.0
torchvision==0.14.0
--extra-index-url https://pypi.ngc.nvidia.com
nvidia-tensorrt==8.4.3.1
tensorrt==8.5.1.7
15 changes: 15 additions & 0 deletions py/torch_tensorrt/fx/converters/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import tensorrt as trt
import torch
import logging

from ..converter_registry import tensorrt_converter

Expand All @@ -12,6 +13,8 @@
to_numpy,
)

logger = logging.getLogger(__name__)


def common_conv(network, mod, dimension, input_val, layer_name, is_quantized):
if mod.padding_mode != "zeros":
Expand Down Expand Up @@ -139,6 +142,18 @@ def conv3d(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0
input_val = kwargs["input"]
# TODO: Remove this warning when https://github.com/pytorch/TensorRT/issues/1445 is fixed
kernel = to_numpy(submod.weight)
kernel_size_one = True
if len(kernel.shape) == 5:
for filter_size in kernel.shape[2:]:
if filter_size != 1:
kernel_size_one = False
if kernel_size_one:
logger.warn(
"Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \
Github issue: https://github.com/pytorch/TensorRT/issues/1445. Other conv variants do not have this issue."
)

if not isinstance(input_val, trt.tensorrt.ITensor):
raise RuntimeError(
Expand Down
Loading