Skip to content

Commit 11af59a

Browse files
authored
Merge pull request #971 from NVIDIA/torch_tensorrt_1.1.0
Torch-TensorRT 1.1.0
2 parents c395c21 + 365dabe commit 11af59a

File tree

10 files changed

+51
-57
lines changed

10 files changed

+51
-57
lines changed

.bazelversion

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
4.2.1
1+
5.1.1

core/lowering/passes/linear_to_addmm.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11

22
#include <torch/csrc/jit/runtime/operator.h>
3+
#include "core/util/prelude.h"
4+
#include "torch/csrc/jit/api/function_impl.h"
35
#include "torch/csrc/jit/ir/alias_analysis.h"
46
#include "torch/csrc/jit/jit_log.h"
57
#include "torch/csrc/jit/passes/constant_propagation.h"
68
#include "torch/csrc/jit/passes/dead_code_elimination.h"
79
#include "torch/csrc/jit/passes/guard_elimination.h"
810
#include "torch/csrc/jit/passes/peephole.h"
9-
#include "torch/csrc/jit/runtime/graph_executor.h"
10-
#include "torch/csrc/jit/api/function_impl.h"
1111
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
12-
#include "core/util/prelude.h"
12+
#include "torch/csrc/jit/runtime/graph_executor.h"
1313

1414
namespace torch_tensorrt {
1515
namespace core {
@@ -34,7 +34,8 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
3434
continue;
3535
} else {
3636
torch::jit::WithInsertPoint guard(*it);
37-
std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction(decompose_funcs.get_function("linear")).graph();;
37+
std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction(decompose_funcs.get_function("linear")).graph();
38+
;
3839
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
3940
new_output->setType(it->output()->type());
4041
it->output()->replaceAllUsesWith(new_output);

core/lowering/passes/reduce_gelu.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
1212
%out : Tensor = aten::gelu(%x)
1313
return (%out))IR";
1414

15-
// This gelu_approximate_pattern schema exists in 21.11, 21.12, 22.01 containers of pytorch. These container versions use
16-
// an unmerged PR in pytorch : https://github.com/pytorch/pytorch/pull/61439. We reduce this to regular Gelu.
15+
// This gelu_approximate_pattern schema exists in 21.11, 21.12, 22.01 containers of pytorch. These container versions
16+
// use an unmerged PR in pytorch : https://github.com/pytorch/pytorch/pull/61439. We reduce this to regular Gelu.
1717
std::string gelu_approximate_pattern = R"IR(
1818
graph(%x : Tensor, %approx):
1919
%out : Tensor = aten::gelu(%x, %approx)
@@ -64,7 +64,8 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
6464
map_gelu_to_pointwise_ops.runOnGraph(graph);
6565

6666
torch::jit::SubgraphRewriter map_gelu_approximate_to_pointwise_ops;
67-
map_gelu_approximate_to_pointwise_ops.RegisterRewritePattern(gelu_approximate_pattern, gelu_reduce_multi_input_pattern);
67+
map_gelu_approximate_to_pointwise_ops.RegisterRewritePattern(
68+
gelu_approximate_pattern, gelu_reduce_multi_input_pattern);
6869
map_gelu_approximate_to_pointwise_ops.runOnGraph(graph);
6970

7071
LOG_GRAPH("Post lowering of [aten::gelu] -> " << *graph);

core/partitioning/partitioning.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,11 @@ std::unordered_map<torch::jit::Value*, usage_info> getInputUsageCounts(
231231
return usage_counts;
232232
}
233233

234-
std::unordered_map<size_t, std::list<SegmentedBlock>::iterator>
235-
getIdxtoIterMap(std::list<SegmentedBlock> &segmented_blocks_list) {
234+
std::unordered_map<size_t, std::list<SegmentedBlock>::iterator> getIdxtoIterMap(
235+
std::list<SegmentedBlock>& segmented_blocks_list) {
236236
std::unordered_map<size_t, std::list<SegmentedBlock>::iterator> idx_to_iter;
237237
auto iter = segmented_blocks_list.begin();
238-
for (int i = 0; i < segmented_blocks_list.size(); ++i, ++iter) {
238+
for (uint64_t i = 0; i < segmented_blocks_list.size(); ++i, ++iter) {
239239
idx_to_iter[i] = iter;
240240
}
241241
return idx_to_iter;
@@ -283,22 +283,24 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
283283
}
284284

285285
void resolveTensorListInputBlocks(PartitionedGraph& segmented_blocks) {
286-
// usage_counts is a map with key as non-tensor/tensorlist inputs and value as the idx of segmented block which produces/contains it.
287-
auto usage_counts = getInputUsageCounts(
288-
segmented_blocks, [](torch::jit::Value* input) -> bool { return isTensorList(input); });
286+
// usage_counts is a map with key as non-tensor/tensorlist inputs and value as the idx of segmented block which
287+
// produces/contains it.
288+
auto usage_counts =
289+
getInputUsageCounts(segmented_blocks, [](torch::jit::Value* input) -> bool { return isTensorList(input); });
289290

290291
// Get idx of the segblock to its iterator mapping
291292
std::list<SegmentedBlock> segmented_blocks_list(segmented_blocks.cbegin(), segmented_blocks.cend());
292293
auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list);
293294

294295
std::unordered_set<int> updated_segments;
295296
// we need to re-segment TensorRT segments whose inputs are TensorLists
296-
for (auto &use : usage_counts) {
297+
for (auto& use : usage_counts) {
297298
auto use_info = use.second;
298299
// For a particular tensorlist input, traverse through all ids of segmented blocks whose target is TensorRT
299300
for (auto i : use_info.tensorrt_use_id) {
300301
if (!updated_segments.count(i)) {
301-
// tensorlistinput_to_segblock is a mapping from {tensorlist input : segmented block which produced this tensorlist input}
302+
// tensorlistinput_to_segblock is a mapping from {tensorlist input : segmented block which produced this
303+
// tensorlist input}
302304
std::unordered_map<torch::jit::Value*, SegmentedBlock> tensorlistinput_to_segblock;
303305
for (auto input : segmented_blocks[i].raw_inputs()) {
304306
if (isTensorList(input)) {
@@ -308,18 +310,20 @@ void resolveTensorListInputBlocks(PartitionedGraph& segmented_blocks) {
308310

309311
// For each tensorlist input in tensorlistinput_to_segblock, get the node which actually uses this input.
310312
// Once we retrieve the node, we remove it from the current TensorRT segmented_blocks[i]. This node should be
311-
// added to block that generated/produced (can be obtained via produce_id) this tensorlist input in the first place.
313+
// added to block that generated/produced (can be obtained via produce_id) this tensorlist input in the first
314+
// place.
312315
auto seg_blocks = segmentBlocksWithTensorListInputs(segmented_blocks[i], tensorlistinput_to_segblock);
313316
auto append_blocks = seg_blocks.first;
314317
auto trt_block = seg_blocks.second;
315-
// Remove the current TensorRT seg_block and replace it with new TRT block (non empty) which has the node that uses tensorlist input removed.
318+
// Remove the current TensorRT seg_block and replace it with new TRT block (non empty) which has the node that
319+
// uses tensorlist input removed.
316320
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
317321
if (trt_block.raw_nodes().size() > 0) {
318322
segmented_blocks_list.insert(next_iter, trt_block);
319323
}
320324

321325
// append blocks' nodes to the producer seg_block
322-
for (auto append_block: append_blocks) {
326+
for (auto append_block : append_blocks) {
323327
auto input = append_block.first; // corresponds to the tensorlist input
324328
auto block = append_block.second;
325329
// append nodes to segmented_blocks_list

core/partitioning/shape_analysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#include <ATen/ATen.h>
21
#include "core/partitioning/shape_analysis.h"
2+
#include <ATen/ATen.h>
33
#include "core/util/prelude.h"
44
#include "torch/csrc/jit/api/module.h"
55
#include "torch/csrc/jit/passes/constant_pooling.h"

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ RUN rm -rf /opt/pytorch/torch_tensorrt /usr/bin/bazel
99

1010
ARG ARCH="x86_64"
1111
ARG TARGETARCH="amd64"
12-
ARG BAZEL_VERSION=4.2.1
12+
ARG BAZEL_VERSION=5.1.1
1313

1414
RUN [[ "$TARGETARCH" == "amd64" ]] && ARCH="x86_64" || ARCH="${TARGETARCH}" \
1515
&& wget -q https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-linux-${ARCH} -O /usr/bin/bazel \

docker/Dockerfile.docs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ FROM nvcr.io/nvidia/tensorrt:22.01-py3
33
RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key add -
44
RUN echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list
55

6-
RUN apt-get update && apt-get install -y bazel-4.2.1 clang-format-9 libjpeg9 libjpeg9-dev
7-
RUN ln -s /usr/bin/bazel-4.2.1 /usr/bin/bazel
6+
RUN apt-get update && apt-get install -y bazel-5.1.1 clang-format-9 libjpeg9 libjpeg9-dev
7+
RUN ln -s /usr/bin/bazel-5.1.1 /usr/bin/bazel
88
RUN ln -s $(which clang-format-9) /usr/bin/clang-format
99

1010
# Workaround for bazel expecting both static and shared versions, we only use shared libraries inside container

docker/WORKSPACE.docs

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,32 @@ workspace(name = "Torch-TensorRT")
33
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
44
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
55

6-
git_repository(
6+
http_archive(
77
name = "rules_python",
8-
remote = "https://github.com/bazelbuild/rules_python.git",
9-
commit = "4fcc24fd8a850bdab2ef2e078b1de337eea751a6",
10-
shallow_since = "1589292086 -0400"
8+
sha256 = "778197e26c5fbeb07ac2a2c5ae405b30f6cb7ad1f5510ea6fdac03bded96cc6f",
9+
url = "https://github.com/bazelbuild/rules_python/releases/download/0.2.0/rules_python-0.2.0.tar.gz",
1110
)
1211

13-
load("@rules_python//python:repositories.bzl", "py_repositories")
14-
py_repositories()
15-
16-
load("@rules_python//python:pip.bzl", "pip_repositories", "pip3_import")
17-
pip_repositories()
12+
load("@rules_python//python:pip.bzl", "pip_install")
1813

1914
http_archive(
2015
name = "rules_pkg",
21-
url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.4/rules_pkg-0.2.4.tar.gz",
22-
sha256 = "4ba8f4ab0ff85f2484287ab06c0d871dcb31cc54d439457d28fd4ae14b18450a",
16+
sha256 = "038f1caa773a7e35b3663865ffb003169c6a71dc995e39bf4815792f385d837d",
17+
urls = [
18+
"https://mirror.bazel.build/github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
19+
"https://github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
20+
],
2321
)
2422

2523
load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
24+
2625
rules_pkg_dependencies()
2726

2827
git_repository(
2928
name = "googletest",
30-
remote = "https://github.com/google/googletest",
3129
commit = "703bd9caab50b139428cea1aaff9974ebee5742e",
32-
shallow_since = "1570114335 -0400"
30+
remote = "https://github.com/google/googletest",
31+
shallow_since = "1570114335 -0400",
3332
)
3433

3534
# CUDA should be installed on the system locally
@@ -52,17 +51,17 @@ new_local_repository(
5251
http_archive(
5352
name = "libtorch",
5453
build_file = "@//third_party/libtorch:BUILD",
55-
sha256 = "190e963e739d5f7c2dcf94b3994de8fcd335706a4ebb333812ea7d8c841beb06",
54+
sha256 = "8d9e829ce9478db4f35bdb7943308cf02e8a2f58cf9bb10f742462c1d57bf287",
5655
strip_prefix = "libtorch",
57-
urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.10.0%2Bcu113.zip"],
56+
urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.11.0%2Bcu113.zip"],
5857
)
5958

6059
http_archive(
6160
name = "libtorch_pre_cxx11_abi",
6261
build_file = "@//third_party/libtorch:BUILD",
63-
sha256 = "0996a6a4ea8bbc1137b4fb0476eeca25b5efd8ed38955218dec1b73929090053",
62+
sha256 = "90159ecce3ff451f3ef3f657493b6c7c96759c3b74bbd70c1695f2ea2f81e1ad",
6463
strip_prefix = "libtorch",
65-
urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-shared-with-deps-1.10.0%2Bcu113.zip"],
64+
urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-shared-with-deps-1.11.0%2Bcu113.zip"],
6665
)
6766

6867
####################################################################################
@@ -84,18 +83,7 @@ new_local_repository(
8483
#########################################################################
8584
# Testing Dependencies (optional - comment out on aarch64)
8685
#########################################################################
87-
pip3_import(
88-
name = "torch_tensorrt_py_deps",
89-
requirements = "//py:requirements.txt"
90-
)
91-
92-
load("@torch_tensorrt_py_deps//:requirements.bzl", "pip_install")
93-
pip_install()
94-
95-
pip3_import(
96-
name = "py_test_deps",
97-
requirements = "//tests/py:requirements.txt"
98-
)
99-
100-
load("@py_test_deps//:requirements.bzl", "pip_install")
101-
pip_install()
86+
pip_install(
87+
name = "pylinter_deps",
88+
requirements = "//tools/linter:requirements.txt",
89+
)

docker/setup_nox.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ set -e
77
post=${1:-""}
88

99
# fetch bazel executable
10-
BAZEL_VERSION=4.2.1
10+
BAZEL_VERSION=5.1.1
1111
ARCH=$(uname -m)
1212
if [[ "$ARCH" == "aarch64" ]]; then ARCH="arm64"; fi
1313
wget -q https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-linux-${ARCH} -O /usr/bin/bazel

tests/util/util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

3+
#include <ATen/ATen.h>
34
#include <string>
45
#include <vector>
5-
#include <ATen/ATen.h>
66
#include "ATen/Tensor.h"
77
#include "core/ir/ir.h"
88
#include "core/util/prelude.h"

0 commit comments

Comments
 (0)