@@ -21,7 +21,7 @@ TEST(Converters, ATenReLUConvertsCorrectly) {
21
21
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
22
22
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
23
23
24
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
24
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
25
25
}
26
26
27
27
TEST (Converters, ATenSigmoidConvertsCorrectly) {
@@ -41,7 +41,7 @@ TEST(Converters, ATenSigmoidConvertsCorrectly) {
41
41
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
42
42
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
43
43
44
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
44
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
45
45
}
46
46
47
47
TEST (Converters, ATenTanhConvertsCorrectly) {
@@ -61,5 +61,51 @@ TEST(Converters, ATenTanhConvertsCorrectly) {
61
61
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
62
62
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
63
63
64
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
64
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
65
65
}
66
+
67
+ // TODO: Seems like the IR parser is not handling negative numbers well, need to follow up with the PyTorch Team
68
+ // TEST(Converters, ATenHardTanhConvertsCorrectly) {
69
+ // const auto graph = R"IR(
70
+ // graph(%0 : Tensor):
71
+ // %1 : float = prim::Constant[value=-1.0]()
72
+ // %2 : float = prim::Constant[value=1.0]()
73
+ // %3 : Tensor = aten::hardtanh(%0, %1, %2)
74
+ // return (%3))IR";
75
+
76
+ // auto g = std::make_shared<torch::jit::Graph>();
77
+ // torch::jit::script::parseIR(graph, &*g);
78
+
79
+ // auto in = at::randint(-5, 5, {5}, {at::kCUDA});
80
+ // auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
81
+ // auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
82
+
83
+ // in = at::clone(in);
84
+ // params = trtorch::core::conversion::get_named_params(g->inputs(), {});
85
+ // auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
86
+
87
+ // ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
88
+ // }
89
+
90
+ TEST (Converters, ATenHardTanhCustomRangeConvertsCorrectly) {
91
+ const auto graph = R"IR(
92
+ graph(%0 : Tensor):
93
+ %1 : float = prim::Constant[value=0.0]()
94
+ %2 : float = prim::Constant[value=6.0]()
95
+ %3 : Tensor = aten::hardtanh(%0, %1, %2)
96
+ return (%3))IR" ;
97
+
98
+ auto g = std::make_shared<torch::jit::Graph>();
99
+ torch::jit::script::parseIR (graph, &*g);
100
+
101
+ auto in = at::randint (-5 , 5 , {5 }, {at::kCUDA });
102
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
103
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
104
+
105
+ in = at::clone (in);
106
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
107
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
108
+
109
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
110
+ }
111
+
0 commit comments