Skip to content

Commit 643c938

Browse files
committed
add functional test
Signed-off-by: itsliupeng <[email protected]>
1 parent efb3230 commit 643c938

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/core/lowering/test_unpack_hardswish.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,23 @@ TEST(LoweringPasses, UnpackHardSwish) {
2727
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
2828
auto sg = std::make_shared<torch::jit::Graph>();
2929
torch::jit::parseIR(source_graph, &*sg);
30+
31+
auto in = at::rand({10, 100}, {at::kCUDA});
32+
auto sg_params = trtorch::core::conversion::get_named_params(sg->inputs(), {});
33+
auto sg_results = trtorch::tests::util::RunGraph(sg, sg_params, {in});
34+
3035
trtorch::core::lowering::passes::UnpackHardSwish(sg);
3136

3237
auto tg = std::make_shared<torch::jit::Graph>();
3338
torch::jit::parseIR(target_graph, &*tg);
3439

3540
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
41+
42+
in = at::clone(in);
43+
auto tg_params = trtorch::core::conversion::get_named_params(tg->inputs(), {});
44+
auto tg_results = trtorch::tests::util::RunGraph(tg, tg_params, {in});
45+
46+
ASSERT_TRUE(trtorch::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6));
3647
}
3748

3849
TEST(LoweringPasses, UnpackHardInplaceSwish) {
@@ -56,10 +67,21 @@ TEST(LoweringPasses, UnpackHardInplaceSwish) {
5667
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
5768
auto sg = std::make_shared<torch::jit::Graph>();
5869
torch::jit::parseIR(source_graph, &*sg);
70+
71+
auto in = at::rand({10, 100}, {at::kCUDA});
72+
auto sg_params = trtorch::core::conversion::get_named_params(sg->inputs(), {});
73+
auto sg_results = trtorch::tests::util::RunGraph(sg, sg_params, {in});
74+
5975
trtorch::core::lowering::passes::UnpackHardSwish(sg);
6076

6177
auto tg = std::make_shared<torch::jit::Graph>();
6278
torch::jit::parseIR(target_graph, &*tg);
6379

6480
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
81+
82+
in = at::clone(in);
83+
auto tg_params = trtorch::core::conversion::get_named_params(tg->inputs(), {});
84+
auto tg_results = trtorch::tests::util::RunGraph(tg, tg_params, {in});
85+
86+
ASSERT_TRUE(trtorch::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6));
6587
}

0 commit comments

Comments
 (0)