Skip to content

Commit aab80bf

Browse files
authored
Merge pull request #126 from NVIDIA/python_abi_option
Handle Dropout_ and ABI fixes for C++ and Python
2 parents 3f54833 + fdbd7d2 commit aab80bf

File tree

4 files changed

+57
-16
lines changed

4 files changed

+57
-16
lines changed

core/conversion/conversion_blacklist.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
2222
"prim::GetAttr",
2323
"prim::CallMethod",
2424
"prim::Drop",
25-
"aten:dropout",
25+
"aten::dropout",
26+
"aten::dropout_"
2627
};
2728
return nonconvertable_nodes;
2829
}

core/lowering/passes/remove_dropout.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
2020
remove_dropout.RegisterRewritePattern(
2121
dropout_pattern, no_dropout_pattern);
2222
remove_dropout.runOnGraph(graph);
23+
24+
std::string dropout_inplace_pattern = R"IR(
25+
graph(%input, %4, %5):
26+
%6 = aten::dropout_(%input, %4, %5)
27+
return (%6))IR";
28+
std::string no_dropout_inplace_pattern = R"IR(
29+
graph(%input, %4, %5):
30+
return (%input))IR";
31+
32+
torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
33+
remove_dropout_inplace_pattern.RegisterRewritePattern(
34+
dropout_inplace_pattern, no_dropout_inplace_pattern);
35+
remove_dropout_inplace_pattern.runOnGraph(graph);
36+
2337
LOG_GRAPH("Post remove dropout: " << *graph);
2438
}
2539

cpp/trtorchc/BUILD

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
11
package(default_visibility = ["//visibility:public"])
22

3+
config_setting(
4+
name = "use_pre_cxx11_abi",
5+
values = {
6+
"define": "abi=pre_cxx11_abi",
7+
}
8+
)
9+
310
cc_binary(
411
name = "trtorchc",
512
srcs = [
613
"main.cpp"
714
],
815
deps = [
9-
"@libtorch//:libtorch",
10-
"@libtorch//:caffe2",
1116
"//third_party/args",
1217
"//cpp/api:trtorch"
13-
],
14-
)
18+
] + select({
19+
":use_pre_cxx11_abi": [
20+
"@libtorch_pre_cxx11_abi//:libtorch",
21+
"@libtorch_pre_cxx11_abi//:caffe2",
22+
],
23+
"//conditions:default": [
24+
"@libtorch//:libtorch",
25+
"@libtorch//:caffe2",
26+
],
27+
}),
28+
)

py/setup.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818

1919
__version__ = '0.0.2'
2020

21-
def build_libtrtorch_pre_cxx11_abi(develop=True, use_dist_dir=True):
21+
CXX11_ABI = False
22+
23+
if "--use-cxx11-abi" in sys.argv:
24+
sys.argv.remove("--use-cxx11-abi")
25+
CXX11_ABI = True
26+
27+
def build_libtrtorch_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=False):
2228
cmd = ["/usr/bin/bazel", "build"]
2329
cmd.append("//cpp/api/lib:libtrtorch.so")
2430
if develop:
@@ -27,7 +33,10 @@ def build_libtrtorch_pre_cxx11_abi(develop=True, use_dist_dir=True):
2733
cmd.append("--compilation_mode=opt")
2834
if use_dist_dir:
2935
cmd.append("--distdir=third_party/dist_dir/x86_64-linux-gnu")
30-
cmd.append("--config=python")
36+
if not cxx11_abi:
37+
cmd.append("--config=python")
38+
else:
39+
print("using CXX11 ABI build")
3140

3241
print("building libtrtorch")
3342
status_code = subprocess.run(cmd).returncode
@@ -64,7 +73,8 @@ def finalize_options(self):
6473
develop.finalize_options(self)
6574

6675
def run(self):
67-
build_libtrtorch_pre_cxx11_abi(develop=True)
76+
global CXX11_ABI
77+
build_libtrtorch_pre_cxx11_abi(develop=True, cxx11_abi=CXX11_ABI)
6878
gen_version_file()
6979
copy_libtrtorch()
7080
develop.run(self)
@@ -80,7 +90,8 @@ def finalize_options(self):
8090
install.finalize_options(self)
8191

8292
def run(self):
83-
build_libtrtorch_pre_cxx11_abi(develop=False)
93+
global CXX11_ABI
94+
build_libtrtorch_pre_cxx11_abi(develop=False, cxx11_abi=CXX11_ABI)
8495
gen_version_file()
8596
copy_libtrtorch()
8697
install.run(self)
@@ -95,7 +106,8 @@ def finalize_options(self):
95106
bdist_wheel.finalize_options(self)
96107

97108
def run(self):
98-
build_libtrtorch_pre_cxx11_abi(develop=False)
109+
global CXX11_ABI
110+
build_libtrtorch_pre_cxx11_abi(develop=False, cxx11_abi=CXX11_ABI)
99111
gen_version_file()
100112
copy_libtrtorch()
101113
bdist_wheel.run(self)
@@ -138,15 +150,16 @@ def run(self):
138150
dir_path + "/../bazel-TRTorch/external/tensorrt/include",
139151
],
140152
extra_compile_args=[
141-
"-D_GLIBCXX_USE_CXX11_ABI=0",
142-
"-Wno-deprecated-declaration",
143-
],
153+
"-Wno-deprecated",
154+
"-Wno-deprecated-declarations",
155+
] + ["-D_GLIBCXX_USE_CXX11_ABI=1"] if CXX11_ABI else ["-D_GLIBCXX_USE_CXX11_ABI=0"],
144156
extra_link_args=[
145-
"-D_GLIBCXX_USE_CXX11_ABI=0"
157+
"-Wno-deprecated",
158+
"-Wno-deprecated-declarations",
146159
"-Wl,--no-as-needed",
147160
"-ltrtorch",
148161
"-Wl,-rpath,$ORIGIN/lib"
149-
],
162+
] + ["-D_GLIBCXX_USE_CXX11_ABI=1"] if CXX11_ABI else ["-D_GLIBCXX_USE_CXX11_ABI=0"],
150163
undef_macros=[ "NDEBUG" ]
151164
)
152165
]
@@ -178,7 +191,6 @@ def run(self):
178191
zip_safe=False,
179192
license="BSD",
180193
packages=find_packages(),
181-
platform="Linux",
182194
classifiers=[
183195
"Development Status :: 3 - Alpha",
184196
"Environment :: GPU :: NVIDIA CUDA",

0 commit comments

Comments
 (0)