Skip to content

Commit ae8b569

Browse files
committed
fix: Replace dropout lowering pass
- Remove existing dropout removal lowering pass implementation due to bug - Use Torch JIT dropout removal lowering pass to resolve bug where nested dropouts resulted in invalid graph - Existing removal process left artifacts in graph which caused an internal assertion error - Add regression test to catch nested dropout bug - Update tests to remove testing for `feature_alpha_dropout` and `feature_alpha_dropout_`, which are not removed by the JIT lowering pass and can be added in later
1 parent bdf6ad1 commit ae8b569

File tree

2 files changed

+13
-114
lines changed

2 files changed

+13
-114
lines changed

core/lowering/passes/remove_dropout.cpp

Lines changed: 2 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "torch/csrc/jit/passes/remove_dropout.h"
12
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
23

34
#include "core/util/prelude.h"
@@ -8,85 +9,7 @@ namespace lowering {
89
namespace passes {
910

1011
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
11-
std::string dropout_pattern = R"IR(
12-
graph(%input, %4, %5):
13-
%6 = aten::dropout(%input, %4, %5)
14-
return (%6))IR";
15-
std::string no_dropout_pattern = R"IR(
16-
graph(%input, %4, %5):
17-
return (%input))IR";
18-
19-
torch::jit::SubgraphRewriter remove_dropout;
20-
remove_dropout.RegisterRewritePattern(dropout_pattern, no_dropout_pattern);
21-
remove_dropout.runOnGraph(graph);
22-
23-
std::string dropout_inplace_pattern = R"IR(
24-
graph(%input, %4, %5):
25-
%6 = aten::dropout_(%input, %4, %5)
26-
return (%6))IR";
27-
std::string no_dropout_inplace_pattern = R"IR(
28-
graph(%input, %4, %5):
29-
return (%input))IR";
30-
31-
torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
32-
remove_dropout_inplace_pattern.RegisterRewritePattern(dropout_inplace_pattern, no_dropout_inplace_pattern);
33-
remove_dropout_inplace_pattern.runOnGraph(graph);
34-
35-
// remove feature_dropout
36-
std::string feature_dropout_pattern = R"IR(
37-
graph(%input, %4, %5):
38-
%6 = aten::feature_dropout(%input, %4, %5)
39-
return (%6))IR";
40-
std::string no_feature_dropout_pattern = R"IR(
41-
graph(%input, %4, %5):
42-
return (%input))IR";
43-
44-
torch::jit::SubgraphRewriter remove_feature_dropout_pattern;
45-
remove_feature_dropout_pattern.RegisterRewritePattern(feature_dropout_pattern, no_feature_dropout_pattern);
46-
remove_feature_dropout_pattern.runOnGraph(graph);
47-
48-
// remove feature_dropout inplace
49-
std::string feature_dropout_inplace_pattern = R"IR(
50-
graph(%input, %4, %5):
51-
%6 = aten::feature_dropout_(%input, %4, %5)
52-
return (%6))IR";
53-
std::string no_feature_dropout_inplace_pattern = R"IR(
54-
graph(%input, %4, %5):
55-
return (%input))IR";
56-
57-
torch::jit::SubgraphRewriter remove_feature_dropout_inplace_pattern;
58-
remove_feature_dropout_inplace_pattern.RegisterRewritePattern(
59-
feature_dropout_inplace_pattern, no_feature_dropout_inplace_pattern);
60-
remove_feature_dropout_inplace_pattern.runOnGraph(graph);
61-
62-
// remove feature_alpha_dropout
63-
std::string feature_alpha_dropout_pattern = R"IR(
64-
graph(%input, %4, %5):
65-
%6 = aten::feature_alpha_dropout(%input, %4, %5)
66-
return (%6))IR";
67-
std::string no_feature_alpha_dropout_pattern = R"IR(
68-
graph(%input, %4, %5):
69-
return (%input))IR";
70-
71-
torch::jit::SubgraphRewriter remove_feature_alpha_dropout_pattern;
72-
remove_feature_alpha_dropout_pattern.RegisterRewritePattern(
73-
feature_alpha_dropout_pattern, no_feature_alpha_dropout_pattern);
74-
remove_feature_alpha_dropout_pattern.runOnGraph(graph);
75-
76-
// remove feature_alpha_dropout inplace
77-
std::string feature_alpha_dropout_inplace_pattern = R"IR(
78-
graph(%input, %4, %5):
79-
%6 = aten::feature_alpha_dropout_(%input, %4, %5)
80-
return (%6))IR";
81-
std::string no_feature_alpha_dropout_inplace_pattern = R"IR(
82-
graph(%input, %4, %5):
83-
return (%input))IR";
84-
85-
torch::jit::SubgraphRewriter remove_feature_alpha_dropout_inplace_pattern;
86-
remove_feature_alpha_dropout_inplace_pattern.RegisterRewritePattern(
87-
feature_alpha_dropout_inplace_pattern, no_feature_alpha_dropout_inplace_pattern);
88-
remove_feature_alpha_dropout_inplace_pattern.runOnGraph(graph);
89-
12+
torch::jit::removeDropout(graph);
9013
LOG_GRAPH("Post remove dropout: " << *graph);
9114
}
9215

tests/core/lowering/test_remove_dropout_pass.cpp

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,39 +32,15 @@ TEST(LoweringPasses, RemoveDropoutLowersCorrectly) {
3232
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
3333
}
3434

35-
TEST(LoweringPasses, RemoveDropoutInplaceLowersCorrectly) {
36-
std::string source_graph = R"IR(
37-
graph(%x.1):
38-
%3 : float = prim::Constant[value=0.5]()
39-
%4 : bool = prim::Constant[value=0]()
40-
%y.1 : Tensor = aten::dropout_(%x.1, %3, %4)
41-
%11 : Tensor = aten::relu(%y.1)
42-
return (%11))IR";
43-
std::string target_graph = R"IR(
44-
graph(%x.1):
45-
%11 : Tensor = aten::relu(%x.1)
46-
return (%11))IR";
47-
48-
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
49-
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
50-
auto sg = std::make_shared<torch::jit::Graph>();
51-
torch::jit::parseIR(source_graph, sg.get());
52-
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);
53-
54-
auto tg = std::make_shared<torch::jit::Graph>();
55-
torch::jit::parseIR(target_graph, tg.get());
56-
57-
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
58-
}
59-
60-
TEST(LoweringPasses, RemoveFeatureDropoutLowersCorrectly) {
35+
TEST(LoweringPasses, RemoveDropoutNestedLowersCorrectly) {
6136
std::string source_graph = R"IR(
6237
graph(%x.1):
6338
%3 : float = prim::Constant[value=0.5]()
6439
%4 : bool = prim::Constant[value=0]()
65-
%y.1 : Tensor = aten::feature_dropout(%x.1, %3, %4)
66-
%11 : Tensor = aten::relu(%y.1)
67-
return (%11))IR";
40+
%y.1 : Tensor = aten::dropout(%x.1, %3, %4)
41+
%z.1 : Tensor = aten::dropout(%y.1, %3, %4)
42+
%12 : Tensor = aten::relu(%z.1)
43+
return (%12))IR";
6844
std::string target_graph = R"IR(
6945
graph(%x.1):
7046
%11 : Tensor = aten::relu(%x.1)
@@ -82,12 +58,12 @@ TEST(LoweringPasses, RemoveFeatureDropoutLowersCorrectly) {
8258
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
8359
}
8460

85-
TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) {
61+
TEST(LoweringPasses, RemoveDropoutInplaceLowersCorrectly) {
8662
std::string source_graph = R"IR(
8763
graph(%x.1):
8864
%3 : float = prim::Constant[value=0.5]()
8965
%4 : bool = prim::Constant[value=0]()
90-
%y.1 : Tensor = aten::feature_dropout_(%x.1, %3, %4)
66+
%y.1 : Tensor = aten::dropout_(%x.1, %3, %4)
9167
%11 : Tensor = aten::relu(%y.1)
9268
return (%11))IR";
9369
std::string target_graph = R"IR(
@@ -107,12 +83,12 @@ TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) {
10783
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
10884
}
10985

110-
TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
86+
TEST(LoweringPasses, RemoveFeatureDropoutLowersCorrectly) {
11187
std::string source_graph = R"IR(
11288
graph(%x.1):
11389
%3 : float = prim::Constant[value=0.5]()
11490
%4 : bool = prim::Constant[value=0]()
115-
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
91+
%y.1 : Tensor = aten::feature_dropout(%x.1, %3, %4)
11692
%11 : Tensor = aten::relu(%y.1)
11793
return (%11))IR";
11894
std::string target_graph = R"IR(
@@ -132,12 +108,12 @@ TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
132108
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
133109
}
134110

135-
TEST(LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) {
111+
TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) {
136112
std::string source_graph = R"IR(
137113
graph(%x.1):
138114
%3 : float = prim::Constant[value=0.5]()
139115
%4 : bool = prim::Constant[value=0]()
140-
%y.1 : Tensor = aten::feature_alpha_dropout_(%x.1, %3, %4)
116+
%y.1 : Tensor = aten::feature_dropout_(%x.1, %3, %4)
141117
%11 : Tensor = aten::relu(%y.1)
142118
return (%11))IR";
143119
std::string target_graph = R"IR(

0 commit comments

Comments
 (0)