Skip to content

Commit 8698045

Browse files
committed
fix: Adapt torch JIT pass for removeDropout
- Adapt JIT pass to remove dropout to accommodate multiple dropout schemas - Include additional test cases to verify new removal code
1 parent ae8b569 commit 8698045

File tree

2 files changed

+122
-3
lines changed

2 files changed

+122
-3
lines changed

core/lowering/passes/remove_dropout.cpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
#include "torch/csrc/jit/passes/remove_dropout.h"
2-
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
1+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
32

43
#include "core/util/prelude.h"
54

@@ -8,8 +7,52 @@ namespace core {
87
namespace lowering {
98
namespace passes {
109

10+
// Schemas for dropout variants
11+
const std::unordered_set<c10::Symbol> DropoutNodeKinds = {
12+
c10::Symbol::fromQualString("aten::dropout"),
13+
c10::Symbol::fromQualString("aten::dropout_"),
14+
c10::Symbol::fromQualString("aten::feature_dropout"),
15+
c10::Symbol::fromQualString("aten::feature_dropout_"),
16+
c10::Symbol::fromQualString("aten::feature_alpha_dropout"),
17+
c10::Symbol::fromQualString("aten::feature_alpha_dropout_"),
18+
};
19+
20+
void removeDropoutInBlock(torch::jit::Block* block) {
21+
/*
22+
Function adapted from:
23+
torch/csrc/jit/passes/remove_dropout.cpp
24+
25+
Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added
26+
*/
27+
std::vector<torch::jit::Node*> dropout_nodes_to_remove;
28+
29+
for (auto node : block->nodes()) {
30+
// Remove dropout for each member block within a node
31+
for (auto block : node->blocks()) {
32+
removeDropoutInBlock(block);
33+
}
34+
35+
// For each node having a dropout-variant Schema, remove the node
36+
if (DropoutNodeKinds.find(node->kind()) != DropoutNodeKinds.end()) {
37+
// Extract input and output tensors of dropout operator
38+
auto input_value = node->inputs()[0];
39+
auto output_value = node->outputs()[0];
40+
41+
output_value->replaceAllUsesWith(input_value);
42+
dropout_nodes_to_remove.push_back(node);
43+
}
44+
}
45+
46+
// Delete dropout nodes
47+
for (auto del_node : dropout_nodes_to_remove) {
48+
del_node->destroy();
49+
}
50+
}
51+
1152
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
12-
torch::jit::removeDropout(graph);
53+
// Remove all instances of dropout variants from graph
54+
removeDropoutInBlock(graph->block());
55+
torch::jit::EliminateDeadCode(graph);
1356
LOG_GRAPH("Post remove dropout: " << *graph);
1457
}
1558

tests/core/lowering/test_remove_dropout_pass.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,79 @@ TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) {
132132

133133
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
134134
}
135+
136+
TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
137+
std::string source_graph = R"IR(
138+
graph(%x.1):
139+
%3 : float = prim::Constant[value=0.5]()
140+
%4 : bool = prim::Constant[value=0]()
141+
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
142+
%11 : Tensor = aten::relu(%y.1)
143+
return (%11))IR";
144+
std::string target_graph = R"IR(
145+
graph(%x.1):
146+
%11 : Tensor = aten::relu(%x.1)
147+
return (%11))IR";
148+
149+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
150+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
151+
auto sg = std::make_shared<torch::jit::Graph>();
152+
torch::jit::parseIR(source_graph, sg.get());
153+
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);
154+
155+
auto tg = std::make_shared<torch::jit::Graph>();
156+
torch::jit::parseIR(target_graph, tg.get());
157+
158+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
159+
}
160+
161+
TEST(LoweringPasses, RemoveFeatureAlphaDropoutNestedLowersCorrectly) {
162+
std::string source_graph = R"IR(
163+
graph(%x.1):
164+
%3 : float = prim::Constant[value=0.5]()
165+
%4 : bool = prim::Constant[value=0]()
166+
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
167+
%z.1 : Tensor = aten::feature_alpha_dropout(%y.1, %3, %4)
168+
%12 : Tensor = aten::relu(%z.1)
169+
return (%12))IR";
170+
std::string target_graph = R"IR(
171+
graph(%x.1):
172+
%11 : Tensor = aten::relu(%x.1)
173+
return (%11))IR";
174+
175+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
176+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
177+
auto sg = std::make_shared<torch::jit::Graph>();
178+
torch::jit::parseIR(source_graph, sg.get());
179+
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);
180+
181+
auto tg = std::make_shared<torch::jit::Graph>();
182+
torch::jit::parseIR(target_graph, tg.get());
183+
184+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
185+
}
186+
187+
TEST(LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) {
188+
std::string source_graph = R"IR(
189+
graph(%x.1):
190+
%3 : float = prim::Constant[value=0.5]()
191+
%4 : bool = prim::Constant[value=0]()
192+
%y.1 : Tensor = aten::feature_alpha_dropout_(%x.1, %3, %4)
193+
%11 : Tensor = aten::relu(%y.1)
194+
return (%11))IR";
195+
std::string target_graph = R"IR(
196+
graph(%x.1):
197+
%11 : Tensor = aten::relu(%x.1)
198+
return (%11))IR";
199+
200+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
201+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
202+
auto sg = std::make_shared<torch::jit::Graph>();
203+
torch::jit::parseIR(source_graph, sg.get());
204+
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);
205+
206+
auto tg = std::make_shared<torch::jit::Graph>();
207+
torch::jit::parseIR(target_graph, tg.get());
208+
209+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
210+
}

0 commit comments

Comments
 (0)