Skip to content

Commit 3f54833

Browse files
authored
Merge pull request #124 from andi4191/anuragd/aarch64-jetpack-support
(//bazel): Native compilation support for NVIDIA Jetson AGX platform
2 parents e1fa232 + 7986a8a commit 3f54833

File tree

9 files changed

+225
-56
lines changed

9 files changed

+225
-56
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
6363
| Platform | Support |
6464
| -------- | ------- |
6565
| Linux AMD64 / GPU | **Supported** |
66-
| Linux aarch64 / GPU | **Planned/Possible with Native Compiation but untested** |
67-
| Linux aarch64 / DLA | **Planned/Possible with Native Compilation but untested** |
66+
| Linux aarch64 / GPU | **Native Compilation Supported on JetPack-4.4** |
67+
| Linux aarch64 / DLA | **Native Compilation Supported on JetPack-4.4 but untested** |
6868
| Windows / GPU | - |
6969
| Linux ppc64le / GPU | - |
7070

71+
> Note: Refer NVIDIA NGC container(https://ngc.nvidia.com/catalog/containers/nvidia:l4t-pytorch) for PyTorch libraries on JetPack.
72+
7173
### Dependencies
7274

7375
- Bazel 3.2.0
@@ -171,6 +173,11 @@ bazel build //:libtrtorch --compilation_mode opt
171173
bazel build //:libtrtorch --compilation_mode=dbg
172174
```
173175

176+
### Native compilation on NVIDIA Jetson AGX
177+
``` shell
178+
bazel build //:libtrtorch --distdir third_party/distdir/aarch64-linux-gnu
179+
```
180+
174181
A tarball with the include files and library can then be found in bazel-bin
175182

176183
### Running TRTorch on a JIT Graph

WORKSPACE

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,30 @@ http_archive(
2525
load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
2626
rules_pkg_dependencies()
2727

28+
git_repository(
29+
name = "googletest",
30+
remote = "https://github.com/google/googletest",
31+
commit = "703bd9caab50b139428cea1aaff9974ebee5742e",
32+
shallow_since = "1570114335 -0400"
33+
)
34+
2835
# CUDA should be installed on the system locally
2936
new_local_repository(
3037
name = "cuda",
31-
path = "/usr/local/cuda-10.2/targets/x86_64-linux/",
38+
path = "/usr/local/cuda-10.2/",
3239
build_file = "@//third_party/cuda:BUILD",
3340
)
3441

35-
http_archive(
36-
name = "libtorch_pre_cxx11_abi",
37-
build_file = "@//third_party/libtorch:BUILD",
38-
strip_prefix = "libtorch",
39-
sha256 = "ea8de17c5f70015583f3a7a43c7a5cdf91a1d4bd19a6a7bc11f074ef6cd69e27",
40-
urls = ["https://download.pytorch.org/libtorch/cu102/libtorch-shared-with-deps-1.5.0.zip"],
42+
new_local_repository(
43+
name = "cublas",
44+
path = "/usr",
45+
build_file = "@//third_party/cublas:BUILD",
4146
)
4247

48+
#############################################################################################################
49+
# Tarballs and fetched dependencies (default - use in cases when building from precompiled bin and tarballs)
50+
#############################################################################################################
51+
4352
http_archive(
4453
name = "libtorch",
4554
build_file = "@//third_party/libtorch:BUILD",
@@ -48,23 +57,18 @@ http_archive(
4857
sha256 = "0efdd4e709ab11088fa75f0501c19b0e294404231442bab1d1fb953924feb6b5"
4958
)
5059

51-
pip3_import(
52-
name = "trtorch_py_deps",
53-
requirements = "//py:requirements.txt"
54-
)
55-
56-
load("@trtorch_py_deps//:requirements.bzl", "pip_install")
57-
pip_install()
58-
59-
pip3_import(
60-
name = "py_test_deps",
61-
requirements = "//tests/py:requirements.txt"
60+
http_archive(
61+
name = "libtorch_pre_cxx11_abi",
62+
build_file = "@//third_party/libtorch:BUILD",
63+
strip_prefix = "libtorch",
64+
sha256 = "ea8de17c5f70015583f3a7a43c7a5cdf91a1d4bd19a6a7bc11f074ef6cd69e27",
65+
urls = ["https://download.pytorch.org/libtorch/cu102/libtorch-shared-with-deps-1.5.0.zip"],
6266
)
6367

64-
load("@py_test_deps//:requirements.bzl", "pip_install")
65-
pip_install()
68+
# Download these tarballs manually from the NVIDIA website
69+
# Either place them in the distdir directory in third_party and use the --distdir flag
70+
# or modify the urls to "file:///<PATH TO TARBALL>/<TARBALL NAME>.tar.gz
6671

67-
# Downloaded distributions to use with --distdir
6872
http_archive(
6973
name = "cudnn",
7074
urls = ["https://developer.nvidia.com/compute/machine-learning/cudnn/secure/7.6.5.32/Production/10.2_20191118/cudnn-10.2-linux-x64-v7.6.5.32.tgz"],
@@ -81,22 +85,57 @@ http_archive(
8185
strip_prefix = "TensorRT-7.0.0.11"
8286
)
8387

84-
## Locally installed dependencies
85-
# new_local_repository(
88+
####################################################################################
89+
# Locally installed dependencies (use in cases of custom dependencies or aarch64)
90+
####################################################################################
91+
92+
# NOTE: In the case you are using just the pre-cxx11-abi path or just the cxx11 abi path
93+
# with your local libtorch, just point deps at the same path to satisfy bazel.
94+
95+
# NOTE: NVIDIA's aarch64 PyTorch (python) wheel file uses the CXX11 ABI unlike PyTorch's standard
96+
# x86_64 python distribution. If using NVIDIA's version just point to the root of the package
97+
# for both versions here and do not use --config=pre-cxx11-abi
98+
99+
#new_local_repository(
100+
# name = "libtorch",
101+
# path = "/usr/local/lib/python3.6/dist-packages/torch",
102+
# build_file = "third_party/libtorch/BUILD"
103+
#)
104+
105+
#new_local_repository(
106+
# name = "libtorch_pre_cxx11_abi",
107+
# path = "/usr/local/lib/python3.6/dist-packages/torch",
108+
# build_file = "third_party/libtorch/BUILD"
109+
#)
110+
111+
#new_local_repository(
86112
# name = "cudnn",
87113
# path = "/usr/",
88114
# build_file = "@//third_party/cudnn/local:BUILD"
89115
#)
90116

91-
# new_local_repository(
117+
#new_local_repository(
92118
# name = "tensorrt",
93119
# path = "/usr/",
94120
# build_file = "@//third_party/tensorrt/local:BUILD"
95121
#)
96122

97-
git_repository(
98-
name = "googletest",
99-
remote = "https://github.com/google/googletest",
100-
commit = "703bd9caab50b139428cea1aaff9974ebee5742e",
101-
shallow_since = "1570114335 -0400"
123+
#########################################################################
124+
# Testing Dependencies (optional - comment out on aarch64)
125+
#########################################################################
126+
pip3_import(
127+
name = "trtorch_py_deps",
128+
requirements = "//py:requirements.txt"
102129
)
130+
131+
load("@trtorch_py_deps//:requirements.bzl", "pip_install")
132+
pip_install()
133+
134+
pip3_import(
135+
name = "py_test_deps",
136+
requirements = "//tests/py:requirements.txt"
137+
)
138+
139+
load("@py_test_deps//:requirements.bzl", "pip_install")
140+
pip_install()
141+

tests/core/converters/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,3 @@ test_suite(
8686
":test_stack"
8787
]
8888
)
89-
90-

tests/util/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ cc_library(
2323
"//core/conversion",
2424
"//core/util:prelude",
2525
"//cpp/api:trtorch",
26+
"@tensorrt//:nvinfer"
2627
] + select({
2728
":use_pre_cxx11_abi": [
2829
"@libtorch_pre_cxx11_abi//:libtorch",

third_party/cublas/BUILD

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
config_setting(
4+
name = "aarch64_linux",
5+
constraint_values = [
6+
"@platforms//cpu:aarch64",
7+
"@platforms//os:linux",
8+
],
9+
)
10+
11+
cc_library(
12+
name = "cublas_headers",
13+
hdrs = ["include/cublas.h"] + glob(["include/cublas+.h"]),
14+
includes = ["include/"],
15+
visibility = ["//visibility:private"],
16+
)
17+
18+
cc_import(
19+
name = "cublas_lib",
20+
shared_library = select({
21+
":aarch64_linux": "lib/aarch64-linux-gnu/libcublas.so.10",
22+
"//conditions:default": "lib/x86_64-linux-gnu/libcublas.so.10",
23+
}),
24+
visibility = ["//visibility:private"],
25+
)
26+
27+
cc_library(
28+
name = "cublas",
29+
visibility = ["//visibility:public"],
30+
deps = [
31+
"cublas_headers",
32+
"cublas_lib",
33+
],
34+
)

third_party/cuda/BUILD

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
11
package(default_visibility = ["//visibility:public"])
22

3+
config_setting(
4+
name = "aarch64_linux",
5+
constraint_values = [
6+
"@platforms//cpu:aarch64",
7+
"@platforms//os:linux",
8+
],
9+
)
10+
311
cc_library(
412
name = "cudart",
5-
srcs = glob([
6-
"lib/**/libcudart.so",
7-
]),
13+
srcs = select({
14+
":aarch64_linux": [
15+
"targets/aarch64-linux/lib/libcudart.so",
16+
],
17+
"//conditions:default": [
18+
"targets/x86_64-linux/lib/libcudart.so",
19+
],
20+
}),
821
hdrs = glob([
922
"include/**/*.h",
1023
"include/**/*.hpp",
@@ -15,16 +28,26 @@ cc_library(
1528

1629
cc_library(
1730
name = "nvToolsExt",
18-
srcs = glob([
19-
"lib/**/libnvToolsExt.so.1"
20-
])
31+
srcs = select({
32+
":aarch64_linux": [
33+
"targets/aarch64-linux/lib/libnvToolsExt.so.1",
34+
],
35+
"//conditions:default": [
36+
"targets/x86_64-linux/lib/libnvToolsExt.so.1",
37+
],
38+
}),
2139
)
2240

2341
cc_library(
2442
name = "cuda",
25-
srcs = glob([
26-
"lib/**/*libcuda.so",
27-
]),
43+
srcs = select({
44+
":aarch64_linux": glob([
45+
"targets/aarch64-linux/lib/**/lib*.so",
46+
]),
47+
"//conditions:default": glob([
48+
"targets/x86_64-linux/lib/**/lib*.so",
49+
]),
50+
}),
2851
hdrs = glob([
2952
"include/**/*.h",
3053
"include/**/*.hpp",

third_party/cudnn/local/BUILD

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
package(default_visibility = ["//visibility:public"])
22

3+
config_setting(
4+
name = "aarch64_linux",
5+
constraint_values = [
6+
"@platforms//cpu:aarch64",
7+
"@platforms//os:linux",
8+
],
9+
)
10+
311
cc_library(
412
name = "cudnn_headers",
513
hdrs = ["include/cudnn.h"] + glob(["include/cudnn+.h"]),
@@ -9,7 +17,10 @@ cc_library(
917

1018
cc_import(
1119
name = "cudnn_lib",
12-
shared_library = "lib/x86_64-linux-gnu/libcudnn.so.7.6.5",
20+
shared_library = select({
21+
":aarch64_linux": "lib/aarch64-linux-gnu/libcudnn.so",
22+
"//conditions:default": "lib/x86_64-linux-gnu/libcudnn.so.7.6.5",
23+
}),
1324
visibility = ["//visibility:private"],
1425
)
1526

0 commit comments

Comments
 (0)