Skip to content

Commit efb3230

Browse files
committed
Support aten::hardswish using unpack_hardswish pass
Signed-off-by: itsliupeng <[email protected]>
1 parent eb39f9c commit efb3230

File tree

6 files changed

+117
-0
lines changed

6 files changed

+117
-0
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void LowerBlock(torch::jit::Block* b) {
2525
}
2626

2727
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
28+
passes::UnpackHardSwish(g);
2829
torch::jit::EliminateRedundantGuards(g);
2930
torch::jit::RemoveListMutation(g);
3031
torch::jit::RemoveTensorMutation(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
"unpack_addmm.cpp",
2525
"unpack_batch_norm.cpp",
2626
"unpack_log_softmax.cpp",
27+
"unpack_hardswish.cpp"
2728
],
2829
hdrs = [
2930
"passes.h",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
2121
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
2222
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
2323
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
24+
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
2425

2526
} // namespace passes
2627
} // namespace lowering
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string hardswish_pattern = R"IR(
12+
graph(%input):
13+
%result = aten::hardswish(%input)
14+
return (%result))IR";
15+
16+
std::string hardswish_pattern_inplace = R"IR(
17+
graph(%input):
18+
%result = aten::hardswish_(%input)
19+
return (%result))IR";
20+
21+
std::string new_pattern = R"IR(
22+
graph(%input):
23+
%1 : Scalar = prim::Constant[value=3.]()
24+
%2 : Scalar = prim::Constant[value=1.]()
25+
%3 = aten::add(%input, %1, %2)
26+
%4 : Scalar = prim::Constant[value=0.]()
27+
%5 : Scalar = prim::Constant[value=6.]()
28+
%6 = aten::hardtanh(%3, %4, %5)
29+
%7 = aten::div(%6, %5)
30+
%8 = aten::mul(%input, %7)
31+
return (%8))IR";
32+
33+
torch::jit::SubgraphRewriter rewriter;
34+
rewriter.RegisterRewritePattern(hardswish_pattern, new_pattern);
35+
rewriter.RegisterRewritePattern(hardswish_pattern_inplace, new_pattern);
36+
rewriter.runOnGraph(graph);
37+
38+
LOG_GRAPH("Post unpack hardswish: " << *graph);
39+
}
40+
41+
} // namespace passes
42+
} // namespace lowering
43+
} // namespace core
44+
} // namespace trtorch

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ lowering_test(
3535
name = "test_silu_to_sigmoid_multiplication",
3636
)
3737

38+
lowering_test(
39+
name = "test_unpack_hardswish",
40+
)
41+
3842
test_suite(
3943
name = "lowering_tests",
4044
tests = [
@@ -44,5 +48,6 @@ test_suite(
4448
":test_remove_detach_pass",
4549
":test_remove_dropout_pass",
4650
":test_remove_to",
51+
":test_unpack_hardswish"
4752
],
4853
)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, UnpackHardSwish) {
10+
std::string source_graph = R"IR(
11+
graph(%input):
12+
%result = aten::hardswish(%input)
13+
return (%result))IR";
14+
15+
std::string target_graph = R"IR(
16+
graph(%input):
17+
%1 : Scalar = prim::Constant[value=3.]()
18+
%2 : Scalar = prim::Constant[value=1.]()
19+
%3 = aten::add(%input, %1, %2)
20+
%4 : Scalar = prim::Constant[value=0.]()
21+
%5 : Scalar = prim::Constant[value=6.]()
22+
%6 = aten::hardtanh(%3, %4, %5)
23+
%7 = aten::div(%6, %5)
24+
%8 = aten::mul(%input, %7)
25+
return (%8))IR";
26+
27+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
28+
auto sg = std::make_shared<torch::jit::Graph>();
29+
torch::jit::parseIR(source_graph, &*sg);
30+
trtorch::core::lowering::passes::UnpackHardSwish(sg);
31+
32+
auto tg = std::make_shared<torch::jit::Graph>();
33+
torch::jit::parseIR(target_graph, &*tg);
34+
35+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
36+
}
37+
38+
TEST(LoweringPasses, UnpackHardInplaceSwish) {
39+
std::string source_graph = R"IR(
40+
graph(%input):
41+
%result = aten::hardswish_(%input)
42+
return (%result))IR";
43+
44+
std::string target_graph = R"IR(
45+
graph(%input):
46+
%1 : Scalar = prim::Constant[value=3.]()
47+
%2 : Scalar = prim::Constant[value=1.]()
48+
%3 = aten::add(%input, %1, %2)
49+
%4 : Scalar = prim::Constant[value=0.]()
50+
%5 : Scalar = prim::Constant[value=6.]()
51+
%6 = aten::hardtanh(%3, %4, %5)
52+
%7 = aten::div(%6, %5)
53+
%8 = aten::mul(%input, %7)
54+
return (%8))IR";
55+
56+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
57+
auto sg = std::make_shared<torch::jit::Graph>();
58+
torch::jit::parseIR(source_graph, &*sg);
59+
trtorch::core::lowering::passes::UnpackHardSwish(sg);
60+
61+
auto tg = std::make_shared<torch::jit::Graph>();
62+
torch::jit::parseIR(target_graph, &*tg);
63+
64+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
65+
}

0 commit comments

Comments
 (0)