@@ -27,12 +27,23 @@ TEST(LoweringPasses, UnpackHardSwish) {
27
27
trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
28
28
auto sg = std::make_shared<torch::jit::Graph>();
29
29
torch::jit::parseIR (source_graph, &*sg);
30
+
31
+ auto in = at::rand ({10 , 100 }, {at::kCUDA });
32
+ auto sg_params = trtorch::core::conversion::get_named_params (sg->inputs (), {});
33
+ auto sg_results = trtorch::tests::util::RunGraph (sg, sg_params, {in});
34
+
30
35
trtorch::core::lowering::passes::UnpackHardSwish (sg);
31
36
32
37
auto tg = std::make_shared<torch::jit::Graph>();
33
38
torch::jit::parseIR (target_graph, &*tg);
34
39
35
40
ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
41
+
42
+ in = at::clone (in);
43
+ auto tg_params = trtorch::core::conversion::get_named_params (tg->inputs (), {});
44
+ auto tg_results = trtorch::tests::util::RunGraph (tg, tg_params, {in});
45
+
46
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (sg_results[0 ], tg_results[0 ], 2e-6 ));
36
47
}
37
48
38
49
TEST (LoweringPasses, UnpackHardInplaceSwish) {
@@ -56,10 +67,21 @@ TEST(LoweringPasses, UnpackHardInplaceSwish) {
56
67
trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
57
68
auto sg = std::make_shared<torch::jit::Graph>();
58
69
torch::jit::parseIR (source_graph, &*sg);
70
+
71
+ auto in = at::rand ({10 , 100 }, {at::kCUDA });
72
+ auto sg_params = trtorch::core::conversion::get_named_params (sg->inputs (), {});
73
+ auto sg_results = trtorch::tests::util::RunGraph (sg, sg_params, {in});
74
+
59
75
trtorch::core::lowering::passes::UnpackHardSwish (sg);
60
76
61
77
auto tg = std::make_shared<torch::jit::Graph>();
62
78
torch::jit::parseIR (target_graph, &*tg);
63
79
64
80
ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
81
+
82
+ in = at::clone (in);
83
+ auto tg_params = trtorch::core::conversion::get_named_params (tg->inputs (), {});
84
+ auto tg_results = trtorch::tests::util::RunGraph (tg, tg_params, {in});
85
+
86
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (sg_results[0 ], tg_results[0 ], 2e-6 ));
65
87
}
0 commit comments