Skip to content

Commit e7a469d

Browse files
committed
fix: Adapt torch JIT pass for removeDropout
- Adapt JIT pass to remove dropout to accommodate multiple dropout schemas - Include additional test cases to verify new removal code
1 parent ae8b569 commit e7a469d

File tree

2 files changed

+122
-4
lines changed

2 files changed

+122
-4
lines changed

core/lowering/passes/remove_dropout.cpp

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,57 @@
1-
#include "torch/csrc/jit/passes/remove_dropout.h"
2-
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
3-
41
#include "core/util/prelude.h"
2+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
53

64
namespace torch_tensorrt {
75
namespace core {
86
namespace lowering {
97
namespace passes {
108

9+
// Schemas for dropout variants
10+
const std::unordered_set<c10::Symbol> DropoutNodeKinds = {
11+
c10::Symbol::fromQualString("aten::dropout"),
12+
c10::Symbol::fromQualString("aten::dropout_"),
13+
c10::Symbol::fromQualString("aten::feature_dropout"),
14+
c10::Symbol::fromQualString("aten::feature_dropout_"),
15+
c10::Symbol::fromQualString("aten::feature_alpha_dropout"),
16+
c10::Symbol::fromQualString("aten::feature_alpha_dropout_"),
17+
};
18+
19+
void removeDropoutInBlock(torch::jit::Block* block) {
20+
/*
21+
Function adapted from:
22+
torch/csrc/jit/passes/remove_dropout.cpp
23+
24+
Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added
25+
*/
26+
std::vector<torch::jit::Node*> dropout_nodes_to_remove;
27+
28+
for (auto node : block->nodes()) {
29+
// Remove dropout for each member block within a node
30+
for (auto block : node->blocks()) {
31+
removeDropoutInBlock(block);
32+
}
33+
34+
// For each node having a dropout-variant Schema, remove the node
35+
if (DropoutNodeKinds.find(node->kind()) != DropoutNodeKinds.end()) {
36+
// Extract input and output tensors of dropout operator
37+
auto input_value = node->inputs()[0];
38+
auto output_value = node->outputs()[0];
39+
40+
output_value->replaceAllUsesWith(input_value);
41+
dropout_nodes_to_remove.push_back(node);
42+
}
43+
}
44+
45+
// Delete dropout nodes
46+
for (auto del_node : dropout_nodes_to_remove) {
47+
del_node->destroy();
48+
}
49+
}
50+
1151
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
12-
torch::jit::removeDropout(graph);
52+
// Remove all instances of dropout variants from graph
53+
removeDropoutInBlock(graph->block());
54+
torch::jit::EliminateDeadCode(graph);
1355
LOG_GRAPH("Post remove dropout: " << *graph);
1456
}
1557

tests/core/lowering/test_remove_dropout_pass.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,79 @@ TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) {
132132

133133
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
134134
}
135+
136+
TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
137+
std::string source_graph = R"IR(
138+
graph(%x.1):
139+
%3 : float = prim::Constant[value=0.5]()
140+
%4 : bool = prim::Constant[value=0]()
141+
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
142+
%11 : Tensor = aten::relu(%y.1)
143+
return (%11))IR";
144+
std::string target_graph = R"IR(
145+
graph(%x.1):
146+
%11 : Tensor = aten::relu(%x.1)
147+
return (%11))IR";
148+
149+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
150+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
151+
auto sg = std::make_shared<torch::jit::Graph>();
152+
torch::jit::parseIR(source_graph, sg.get());
153+
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);
154+
155+
auto tg = std::make_shared<torch::jit::Graph>();
156+
torch::jit::parseIR(target_graph, tg.get());
157+
158+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
159+
}
160+
161+
TEST(LoweringPasses, RemoveFeatureAlphaDropoutNestedLowersCorrectly) {
162+
std::string source_graph = R"IR(
163+
graph(%x.1):
164+
%3 : float = prim::Constant[value=0.5]()
165+
%4 : bool = prim::Constant[value=0]()
166+
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
167+
%z.1 : Tensor = aten::feature_alpha_dropout(%y.1, %3, %4)
168+
%12 : Tensor = aten::relu(%z.1)
169+
return (%12))IR";
170+
std::string target_graph = R"IR(
171+
graph(%x.1):
172+
%11 : Tensor = aten::relu(%x.1)
173+
return (%11))IR";
174+
175+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
176+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
177+
auto sg = std::make_shared<torch::jit::Graph>();
178+
torch::jit::parseIR(source_graph, sg.get());
179+
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);
180+
181+
auto tg = std::make_shared<torch::jit::Graph>();
182+
torch::jit::parseIR(target_graph, tg.get());
183+
184+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
185+
}
186+
187+
TEST(LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) {
188+
std::string source_graph = R"IR(
189+
graph(%x.1):
190+
%3 : float = prim::Constant[value=0.5]()
191+
%4 : bool = prim::Constant[value=0]()
192+
%y.1 : Tensor = aten::feature_alpha_dropout_(%x.1, %3, %4)
193+
%11 : Tensor = aten::relu(%y.1)
194+
return (%11))IR";
195+
std::string target_graph = R"IR(
196+
graph(%x.1):
197+
%11 : Tensor = aten::relu(%x.1)
198+
return (%11))IR";
199+
200+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
201+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
202+
auto sg = std::make_shared<torch::jit::Graph>();
203+
torch::jit::parseIR(source_graph, sg.get());
204+
torch_tensorrt::core::lowering::passes::RemoveDropout(sg);
205+
206+
auto tg = std::make_shared<torch::jit::Graph>();
207+
torch::jit::parseIR(target_graph, tg.get());
208+
209+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
210+
}

0 commit comments

Comments
 (0)