@@ -32,39 +32,15 @@ TEST(LoweringPasses, RemoveDropoutLowersCorrectly) {
32
32
ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
33
33
}
34
34
35
- TEST (LoweringPasses, RemoveDropoutInplaceLowersCorrectly) {
36
- std::string source_graph = R"IR(
37
- graph(%x.1):
38
- %3 : float = prim::Constant[value=0.5]()
39
- %4 : bool = prim::Constant[value=0]()
40
- %y.1 : Tensor = aten::dropout_(%x.1, %3, %4)
41
- %11 : Tensor = aten::relu(%y.1)
42
- return (%11))IR" ;
43
- std::string target_graph = R"IR(
44
- graph(%x.1):
45
- %11 : Tensor = aten::relu(%x.1)
46
- return (%11))IR" ;
47
-
48
- torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
49
- torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
50
- auto sg = std::make_shared<torch::jit::Graph>();
51
- torch::jit::parseIR (source_graph, sg.get ());
52
- torch_tensorrt::core::lowering::passes::RemoveDropout (sg);
53
-
54
- auto tg = std::make_shared<torch::jit::Graph>();
55
- torch::jit::parseIR (target_graph, tg.get ());
56
-
57
- ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
58
- }
59
-
60
- TEST (LoweringPasses, RemoveFeatureDropoutLowersCorrectly) {
35
+ TEST (LoweringPasses, RemoveDropoutNestedLowersCorrectly) {
61
36
std::string source_graph = R"IR(
62
37
graph(%x.1):
63
38
%3 : float = prim::Constant[value=0.5]()
64
39
%4 : bool = prim::Constant[value=0]()
65
- %y.1 : Tensor = aten::feature_dropout(%x.1, %3, %4)
66
- %11 : Tensor = aten::relu(%y.1)
67
- return (%11))IR" ;
40
+ %y.1 : Tensor = aten::dropout(%x.1, %3, %4)
41
+ %z.1 : Tensor = aten::dropout(%y.1, %3, %4)
42
+ %12 : Tensor = aten::relu(%z.1)
43
+ return (%12))IR" ;
68
44
std::string target_graph = R"IR(
69
45
graph(%x.1):
70
46
%11 : Tensor = aten::relu(%x.1)
@@ -82,12 +58,12 @@ TEST(LoweringPasses, RemoveFeatureDropoutLowersCorrectly) {
82
58
ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
83
59
}
84
60
85
- TEST (LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly ) {
61
+ TEST (LoweringPasses, RemoveDropoutInplaceLowersCorrectly ) {
86
62
std::string source_graph = R"IR(
87
63
graph(%x.1):
88
64
%3 : float = prim::Constant[value=0.5]()
89
65
%4 : bool = prim::Constant[value=0]()
90
- %y.1 : Tensor = aten::feature_dropout_ (%x.1, %3, %4)
66
+ %y.1 : Tensor = aten::dropout_ (%x.1, %3, %4)
91
67
%11 : Tensor = aten::relu(%y.1)
92
68
return (%11))IR" ;
93
69
std::string target_graph = R"IR(
@@ -107,12 +83,12 @@ TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) {
107
83
ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
108
84
}
109
85
110
- TEST (LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly ) {
86
+ TEST (LoweringPasses, RemoveFeatureDropoutLowersCorrectly ) {
111
87
std::string source_graph = R"IR(
112
88
graph(%x.1):
113
89
%3 : float = prim::Constant[value=0.5]()
114
90
%4 : bool = prim::Constant[value=0]()
115
- %y.1 : Tensor = aten::feature_alpha_dropout (%x.1, %3, %4)
91
+ %y.1 : Tensor = aten::feature_dropout (%x.1, %3, %4)
116
92
%11 : Tensor = aten::relu(%y.1)
117
93
return (%11))IR" ;
118
94
std::string target_graph = R"IR(
@@ -132,12 +108,12 @@ TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
132
108
ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
133
109
}
134
110
135
- TEST (LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly ) {
111
+ TEST (LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly ) {
136
112
std::string source_graph = R"IR(
137
113
graph(%x.1):
138
114
%3 : float = prim::Constant[value=0.5]()
139
115
%4 : bool = prim::Constant[value=0]()
140
- %y.1 : Tensor = aten::feature_alpha_dropout_ (%x.1, %3, %4)
116
+ %y.1 : Tensor = aten::feature_dropout_ (%x.1, %3, %4)
141
117
%11 : Tensor = aten::relu(%y.1)
142
118
return (%11))IR" ;
143
119
std::string target_graph = R"IR(
0 commit comments