Skip to content

Commit 391af52

Browse files
committed
feat(hardtanh): Adds support for the the hard tanh operator
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 73bfd4c commit 391af52

File tree

2 files changed

+98
-7
lines changed

2 files changed

+98
-7
lines changed

core/conversion/converters/impl/activation.cpp

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ namespace {
2828
auto act##_registrations TRTORCH_UNUSED = \
2929
RegisterNodeConversionPatterns() \
3030
.pattern({"aten::" #act "(Tensor input) -> (Tensor)", \
31-
[](ConversionCtx *ctx, const torch::jit::Node *n, \
32-
args &args) -> bool { return act(ctx, n, args); }}) \
31+
[](ConversionCtx* ctx, const torch::jit::Node* n, \
32+
args& args) -> bool { return act(ctx, n, args); }}) \
3333
.pattern({"aten::" #act "_(Tensor(a!) self) -> (Tensor(a!))", \
34-
[](ConversionCtx *ctx, const torch::jit::Node *n, \
35-
args &args) -> bool { return act(ctx, n, args); }});
34+
[](ConversionCtx* ctx, const torch::jit::Node* n, \
35+
args& args) -> bool { return act(ctx, n, args); }});
3636

3737
//TODO: remove support for conversion of implace operators and move to the functionalization pass
3838

@@ -41,6 +41,51 @@ convert(sigmoid, kSIGMOID);
4141
convert(tanh, kTANH);
4242

4343
#undef convert
44+
45+
auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns()
46+
.pattern({
47+
"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)",
48+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
49+
auto in = args[0].ITensor();
50+
auto min = args[1].unwrapToDouble();
51+
auto max = args[2].unwrapToDouble();
52+
53+
auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kCLIP);
54+
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::hardtanh");
55+
56+
new_layer->setAlpha(min);
57+
new_layer->setBeta(max);
58+
59+
new_layer->setName(util::node_info(n).c_str());
60+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
61+
62+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
63+
return true;
64+
}
65+
}).pattern({
66+
//TODO: Remove after functionalization
67+
"aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))",
68+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
69+
auto in = args[0].ITensor();
70+
auto min = args[1].unwrapToDouble();
71+
auto max = args[2].unwrapToDouble();
72+
73+
auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kCLIP);
74+
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::hardtanh");
75+
76+
new_layer->setAlpha(min);
77+
new_layer->setBeta(max);
78+
79+
new_layer->setName(util::node_info(n).c_str());
80+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
81+
82+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
83+
return true;
84+
}
85+
});
86+
87+
88+
4489
} // namespace
4590
} // namespace impl
4691
} // namespace converters

tests/core/converters/test_activation.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ TEST(Converters, ATenReLUConvertsCorrectly) {
2121
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
2222
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
2323

24-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
24+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
2525
}
2626

2727
TEST(Converters, ATenSigmoidConvertsCorrectly) {
@@ -41,7 +41,7 @@ TEST(Converters, ATenSigmoidConvertsCorrectly) {
4141
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
4242
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
4343

44-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
44+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
4545
}
4646

4747
TEST(Converters, ATenTanhConvertsCorrectly) {
@@ -61,5 +61,51 @@ TEST(Converters, ATenTanhConvertsCorrectly) {
6161
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
6262
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
6363

64-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
64+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
6565
}
66+
67+
//TODO: Seems like the IR parser is not handling negative numbers well, need to follow up with the PyTorch Team
68+
// TEST(Converters, ATenHardTanhConvertsCorrectly) {
69+
// const auto graph = R"IR(
70+
// graph(%0 : Tensor):
71+
// %1 : float = prim::Constant[value=-1.0]()
72+
// %2 : float = prim::Constant[value=1.0]()
73+
// %3 : Tensor = aten::hardtanh(%0, %1, %2)
74+
// return (%3))IR";
75+
76+
// auto g = std::make_shared<torch::jit::Graph>();
77+
// torch::jit::script::parseIR(graph, &*g);
78+
79+
// auto in = at::randint(-5, 5, {5}, {at::kCUDA});
80+
// auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
81+
// auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
82+
83+
// in = at::clone(in);
84+
// params = trtorch::core::conversion::get_named_params(g->inputs(), {});
85+
// auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
86+
87+
// ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
88+
// }
89+
90+
TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) {
91+
const auto graph = R"IR(
92+
graph(%0 : Tensor):
93+
%1 : float = prim::Constant[value=0.0]()
94+
%2 : float = prim::Constant[value=6.0]()
95+
%3 : Tensor = aten::hardtanh(%0, %1, %2)
96+
return (%3))IR";
97+
98+
auto g = std::make_shared<torch::jit::Graph>();
99+
torch::jit::script::parseIR(graph, &*g);
100+
101+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
102+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
103+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
104+
105+
in = at::clone(in);
106+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
107+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
108+
109+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
110+
}
111+

0 commit comments

Comments
 (0)