Skip to content

Commit bfeb194

Browse files
committed
feat: Add converter for aten::sign unary op (#1391)
* feat: Add converter for sign unary operator - Add sign operator - Update test cases to test op - Ensure tests cover both int and float cases with negative and positive sign - Ensure tests cover cases where elements equal zero * Remove round unary from PR
1 parent 7ac12bd commit bfeb194

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

core/conversion/converters/impl/unary.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ convert(sqrt, kSQRT);
9595
convert(exp, kEXP);
9696
convert(neg, kNEG);
9797
convert(erf, kERF);
98+
convert(sign, kSIGN);
9899
convert(asinh, kASINH);
99100
convert(acosh, kACOSH);
100101
convert(atanh, kATANH);

tests/core/conversion/converters/test_unary.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,40 @@ TEST(Converters, ATenReciprocalIntConvertsCorrectly) {
4747
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));
4848
}
4949

50+
TEST(Converters, ATenSignConvertsCorrectly) {
51+
const auto graph = gen_test_graph("sign");
52+
auto g = std::make_shared<torch::jit::Graph>();
53+
torch::jit::parseIR(graph, g.get());
54+
55+
// Resize range to [-10, 10] to span negative values
56+
auto in = -20 * at::rand({2, 3, 5, 5}, {at::kCUDA}) + 10;
57+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
58+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
59+
60+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
61+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
62+
63+
ASSERT_TRUE(
64+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
65+
}
66+
67+
TEST(Converters, ATenSignConvertsZerosCorrectly) {
68+
const auto graph = gen_test_graph("sign");
69+
auto g = std::make_shared<torch::jit::Graph>();
70+
torch::jit::parseIR(graph, g.get());
71+
72+
// Resize range to [-1, 1] to span negative values, cast to int to include zero
73+
auto in = (-2 * at::rand({7, 3, 1, 5}, {at::kCUDA}) + 1).to(torch::kInt32);
74+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
75+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
76+
77+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
78+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
79+
80+
ASSERT_TRUE(
81+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
82+
}
83+
5084
#define test_unary(unary, name) \
5185
TEST(Converters, ATen##name##ConvertsCorrectly) { \
5286
const auto graph = gen_test_graph(#unary); \

0 commit comments

Comments
 (0)