1
+ #include < string>
2
+ #include " core/compiler.h"
3
+ #include " core/lowering/passes/passes.h"
4
+ #include " gtest/gtest.h"
5
+ #include " tests/util/util.h"
6
+ #include " torch/csrc/jit/ir/irparser.h"
7
+ #include " torch/csrc/jit/ir/subgraph_matcher.h"
8
+
9
+ TEST (LoweringPasses, UnpackHardSwish) {
10
+ std::string source_graph = R"IR(
11
+ graph(%input):
12
+ %result = aten::hardswish(%input)
13
+ return (%result))IR" ;
14
+
15
+ std::string target_graph = R"IR(
16
+ graph(%input):
17
+ %1 : Scalar = prim::Constant[value=3.]()
18
+ %2 : Scalar = prim::Constant[value=1.]()
19
+ %3 = aten::add(%input, %1, %2)
20
+ %4 : Scalar = prim::Constant[value=0.]()
21
+ %5 : Scalar = prim::Constant[value=6.]()
22
+ %6 = aten::hardtanh(%3, %4, %5)
23
+ %7 = aten::div(%6, %5)
24
+ %8 = aten::mul(%input, %7)
25
+ return (%8))IR" ;
26
+
27
+ trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
28
+ auto sg = std::make_shared<torch::jit::Graph>();
29
+ torch::jit::parseIR (source_graph, &*sg);
30
+
31
+ auto in = at::rand ({10 , 100 }, {at::kCUDA });
32
+ auto sg_params = trtorch::core::conversion::get_named_params (sg->inputs (), {});
33
+ auto sg_results = trtorch::tests::util::RunGraph (sg, sg_params, {in});
34
+
35
+ trtorch::core::lowering::passes::UnpackHardSwish (sg);
36
+
37
+ auto tg = std::make_shared<torch::jit::Graph>();
38
+ torch::jit::parseIR (target_graph, &*tg);
39
+
40
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
41
+
42
+ in = at::clone (in);
43
+ auto tg_params = trtorch::core::conversion::get_named_params (tg->inputs (), {});
44
+ auto tg_results = trtorch::tests::util::RunGraph (tg, tg_params, {in});
45
+
46
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (sg_results[0 ], tg_results[0 ], 2e-6 ));
47
+ }
48
+
49
+ TEST (LoweringPasses, UnpackHardInplaceSwish) {
50
+ std::string source_graph = R"IR(
51
+ graph(%input):
52
+ %result = aten::hardswish_(%input)
53
+ return (%result))IR" ;
54
+
55
+ std::string target_graph = R"IR(
56
+ graph(%input):
57
+ %1 : Scalar = prim::Constant[value=3.]()
58
+ %2 : Scalar = prim::Constant[value=1.]()
59
+ %3 = aten::add(%input, %1, %2)
60
+ %4 : Scalar = prim::Constant[value=0.]()
61
+ %5 : Scalar = prim::Constant[value=6.]()
62
+ %6 = aten::hardtanh(%3, %4, %5)
63
+ %7 = aten::div(%6, %5)
64
+ %8 = aten::mul(%input, %7)
65
+ return (%8))IR" ;
66
+
67
+ trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
68
+ auto sg = std::make_shared<torch::jit::Graph>();
69
+ torch::jit::parseIR (source_graph, &*sg);
70
+
71
+ auto in = at::rand ({10 , 100 }, {at::kCUDA });
72
+ auto sg_params = trtorch::core::conversion::get_named_params (sg->inputs (), {});
73
+ auto sg_results = trtorch::tests::util::RunGraph (sg, sg_params, {in});
74
+
75
+ trtorch::core::lowering::passes::UnpackHardSwish (sg);
76
+
77
+ auto tg = std::make_shared<torch::jit::Graph>();
78
+ torch::jit::parseIR (target_graph, &*tg);
79
+
80
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
81
+
82
+ in = at::clone (in);
83
+ auto tg_params = trtorch::core::conversion::get_named_params (tg->inputs (), {});
84
+ auto tg_results = trtorch::tests::util::RunGraph (tg, tg_params, {in});
85
+
86
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (sg_results[0 ], tg_results[0 ], 2e-6 ));
87
+ }
0 commit comments