@@ -132,3 +132,79 @@ TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) {
132
132
133
133
ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
134
134
}
135
+
136
+ TEST (LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
137
+ std::string source_graph = R"IR(
138
+ graph(%x.1):
139
+ %3 : float = prim::Constant[value=0.5]()
140
+ %4 : bool = prim::Constant[value=0]()
141
+ %y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
142
+ %11 : Tensor = aten::relu(%y.1)
143
+ return (%11))IR" ;
144
+ std::string target_graph = R"IR(
145
+ graph(%x.1):
146
+ %11 : Tensor = aten::relu(%x.1)
147
+ return (%11))IR" ;
148
+
149
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
150
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
151
+ auto sg = std::make_shared<torch::jit::Graph>();
152
+ torch::jit::parseIR (source_graph, sg.get ());
153
+ torch_tensorrt::core::lowering::passes::RemoveDropout (sg);
154
+
155
+ auto tg = std::make_shared<torch::jit::Graph>();
156
+ torch::jit::parseIR (target_graph, tg.get ());
157
+
158
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
159
+ }
160
+
161
+ TEST (LoweringPasses, RemoveFeatureAlphaDropoutNestedLowersCorrectly) {
162
+ std::string source_graph = R"IR(
163
+ graph(%x.1):
164
+ %3 : float = prim::Constant[value=0.5]()
165
+ %4 : bool = prim::Constant[value=0]()
166
+ %y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
167
+ %z.1 : Tensor = aten::feature_alpha_dropout(%y.1, %3, %4)
168
+ %12 : Tensor = aten::relu(%z.1)
169
+ return (%12))IR" ;
170
+ std::string target_graph = R"IR(
171
+ graph(%x.1):
172
+ %11 : Tensor = aten::relu(%x.1)
173
+ return (%11))IR" ;
174
+
175
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
176
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
177
+ auto sg = std::make_shared<torch::jit::Graph>();
178
+ torch::jit::parseIR (source_graph, sg.get ());
179
+ torch_tensorrt::core::lowering::passes::RemoveDropout (sg);
180
+
181
+ auto tg = std::make_shared<torch::jit::Graph>();
182
+ torch::jit::parseIR (target_graph, tg.get ());
183
+
184
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
185
+ }
186
+
187
+ TEST (LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) {
188
+ std::string source_graph = R"IR(
189
+ graph(%x.1):
190
+ %3 : float = prim::Constant[value=0.5]()
191
+ %4 : bool = prim::Constant[value=0]()
192
+ %y.1 : Tensor = aten::feature_alpha_dropout_(%x.1, %3, %4)
193
+ %11 : Tensor = aten::relu(%y.1)
194
+ return (%11))IR" ;
195
+ std::string target_graph = R"IR(
196
+ graph(%x.1):
197
+ %11 : Tensor = aten::relu(%x.1)
198
+ return (%11))IR" ;
199
+
200
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
201
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
202
+ auto sg = std::make_shared<torch::jit::Graph>();
203
+ torch::jit::parseIR (source_graph, sg.get ());
204
+ torch_tensorrt::core::lowering::passes::RemoveDropout (sg);
205
+
206
+ auto tg = std::make_shared<torch::jit::Graph>();
207
+ torch::jit::parseIR (target_graph, tg.get ());
208
+
209
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
210
+ }
0 commit comments