Skip to content

Commit e9906b7

Browse files
authored
Merge pull request #500 from itsliupeng/master
Support aten::hardswish using unpack_hardswish pass
2 parents 74b1023 + 643c938 commit e9906b7

File tree

6 files changed

+139
-0
lines changed

6 files changed

+139
-0
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void LowerBlock(torch::jit::Block* b) {
2525
}
2626

2727
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
28+
passes::UnpackHardSwish(g);
2829
torch::jit::EliminateRedundantGuards(g);
2930
torch::jit::RemoveListMutation(g);
3031
torch::jit::RemoveTensorMutation(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
"unpack_addmm.cpp",
2525
"unpack_batch_norm.cpp",
2626
"unpack_log_softmax.cpp",
27+
"unpack_hardswish.cpp"
2728
],
2829
hdrs = [
2930
"passes.h",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
2121
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
2222
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
2323
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
24+
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
2425

2526
} // namespace passes
2627
} // namespace lowering
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string hardswish_pattern = R"IR(
12+
graph(%input):
13+
%result = aten::hardswish(%input)
14+
return (%result))IR";
15+
16+
std::string hardswish_pattern_inplace = R"IR(
17+
graph(%input):
18+
%result = aten::hardswish_(%input)
19+
return (%result))IR";
20+
21+
std::string new_pattern = R"IR(
22+
graph(%input):
23+
%1 : Scalar = prim::Constant[value=3.]()
24+
%2 : Scalar = prim::Constant[value=1.]()
25+
%3 = aten::add(%input, %1, %2)
26+
%4 : Scalar = prim::Constant[value=0.]()
27+
%5 : Scalar = prim::Constant[value=6.]()
28+
%6 = aten::hardtanh(%3, %4, %5)
29+
%7 = aten::div(%6, %5)
30+
%8 = aten::mul(%input, %7)
31+
return (%8))IR";
32+
33+
torch::jit::SubgraphRewriter rewriter;
34+
rewriter.RegisterRewritePattern(hardswish_pattern, new_pattern);
35+
rewriter.RegisterRewritePattern(hardswish_pattern_inplace, new_pattern);
36+
rewriter.runOnGraph(graph);
37+
38+
LOG_GRAPH("Post unpack hardswish: " << *graph);
39+
}
40+
41+
} // namespace passes
42+
} // namespace lowering
43+
} // namespace core
44+
} // namespace trtorch

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ lowering_test(
3535
name = "test_silu_to_sigmoid_multiplication",
3636
)
3737

38+
lowering_test(
39+
name = "test_unpack_hardswish",
40+
)
41+
3842
test_suite(
3943
name = "lowering_tests",
4044
tests = [
@@ -44,5 +48,6 @@ test_suite(
4448
":test_remove_detach_pass",
4549
":test_remove_dropout_pass",
4650
":test_remove_to",
51+
":test_unpack_hardswish"
4752
],
4853
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, UnpackHardSwish) {
10+
std::string source_graph = R"IR(
11+
graph(%input):
12+
%result = aten::hardswish(%input)
13+
return (%result))IR";
14+
15+
std::string target_graph = R"IR(
16+
graph(%input):
17+
%1 : Scalar = prim::Constant[value=3.]()
18+
%2 : Scalar = prim::Constant[value=1.]()
19+
%3 = aten::add(%input, %1, %2)
20+
%4 : Scalar = prim::Constant[value=0.]()
21+
%5 : Scalar = prim::Constant[value=6.]()
22+
%6 = aten::hardtanh(%3, %4, %5)
23+
%7 = aten::div(%6, %5)
24+
%8 = aten::mul(%input, %7)
25+
return (%8))IR";
26+
27+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
28+
auto sg = std::make_shared<torch::jit::Graph>();
29+
torch::jit::parseIR(source_graph, &*sg);
30+
31+
auto in = at::rand({10, 100}, {at::kCUDA});
32+
auto sg_params = trtorch::core::conversion::get_named_params(sg->inputs(), {});
33+
auto sg_results = trtorch::tests::util::RunGraph(sg, sg_params, {in});
34+
35+
trtorch::core::lowering::passes::UnpackHardSwish(sg);
36+
37+
auto tg = std::make_shared<torch::jit::Graph>();
38+
torch::jit::parseIR(target_graph, &*tg);
39+
40+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
41+
42+
in = at::clone(in);
43+
auto tg_params = trtorch::core::conversion::get_named_params(tg->inputs(), {});
44+
auto tg_results = trtorch::tests::util::RunGraph(tg, tg_params, {in});
45+
46+
ASSERT_TRUE(trtorch::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6));
47+
}
48+
49+
TEST(LoweringPasses, UnpackHardInplaceSwish) {
50+
std::string source_graph = R"IR(
51+
graph(%input):
52+
%result = aten::hardswish_(%input)
53+
return (%result))IR";
54+
55+
std::string target_graph = R"IR(
56+
graph(%input):
57+
%1 : Scalar = prim::Constant[value=3.]()
58+
%2 : Scalar = prim::Constant[value=1.]()
59+
%3 = aten::add(%input, %1, %2)
60+
%4 : Scalar = prim::Constant[value=0.]()
61+
%5 : Scalar = prim::Constant[value=6.]()
62+
%6 = aten::hardtanh(%3, %4, %5)
63+
%7 = aten::div(%6, %5)
64+
%8 = aten::mul(%input, %7)
65+
return (%8))IR";
66+
67+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
68+
auto sg = std::make_shared<torch::jit::Graph>();
69+
torch::jit::parseIR(source_graph, &*sg);
70+
71+
auto in = at::rand({10, 100}, {at::kCUDA});
72+
auto sg_params = trtorch::core::conversion::get_named_params(sg->inputs(), {});
73+
auto sg_results = trtorch::tests::util::RunGraph(sg, sg_params, {in});
74+
75+
trtorch::core::lowering::passes::UnpackHardSwish(sg);
76+
77+
auto tg = std::make_shared<torch::jit::Graph>();
78+
torch::jit::parseIR(target_graph, &*tg);
79+
80+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
81+
82+
in = at::clone(in);
83+
auto tg_params = trtorch::core::conversion::get_named_params(tg->inputs(), {});
84+
auto tg_results = trtorch::tests::util::RunGraph(tg, tg_params, {in});
85+
86+
ASSERT_TRUE(trtorch::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6));
87+
}

0 commit comments

Comments
 (0)