Skip to content

Commit 42064f0

Browse files
committed
Add support for sigmoid and tanh
Signed-off-by: Junjie Bai <[email protected]>
1 parent 91a7416 commit 42064f0

File tree

5 files changed

+107
-54
lines changed

5 files changed

+107
-54
lines changed

core/conversion/converters/impl/activation.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,42 @@ namespace converters {
99
namespace impl {
1010
namespace {
1111

12-
bool relu(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
13-
auto in = args[0].ITensor();
12+
#define convert(act, trt_type) \
13+
bool act(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { \
14+
auto in = args[0].ITensor(); \
15+
\
16+
auto new_layer = \
17+
ctx->net->addActivation(*in, nvinfer1::ActivationType::trt_type); \
18+
TRTORCH_CHECK(new_layer, \
19+
"Unable to create " #act " layer from node: " << *n); \
20+
\
21+
new_layer->setName(util::node_info(n).c_str()); \
22+
auto out_value = n->outputs()[0]; \
23+
auto out_tensor = new_layer->getOutput(0); \
24+
out_tensor->setName(out_value->debugName().c_str()); \
25+
ctx->value_tensor_map[out_value] = out_tensor; \
26+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); \
27+
\
28+
return true; \
29+
} \
30+
\
31+
auto act##_registrations TRTORCH_UNUSED = \
32+
RegisterNodeConversionPatterns() \
33+
.pattern({"aten::" #act "(Tensor input) -> (Tensor)", \
34+
[](ConversionCtx *ctx, const torch::jit::Node *n, \
35+
args &args) -> bool { return act(ctx, n, args); }}) \
36+
.pattern({"aten::" #act "_(Tensor(a!) self) -> (Tensor(a!))", \
37+
[](ConversionCtx *ctx, const torch::jit::Node *n, \
38+
args &args) -> bool { return act(ctx, n, args); }});
1439

15-
auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kRELU);
16-
TRTORCH_CHECK(new_layer, "Unable to create ReLU layer from node: " << *n);
40+
convert(relu, kRELU);
41+
convert(sigmoid, kSIGMOID);
42+
convert(tanh, kTANH);
1743

18-
new_layer->setName(util::node_info(n).c_str());
19-
auto out_value = n->outputs()[0];
20-
auto out_tensor = new_layer->getOutput(0);
21-
out_tensor->setName(out_value->debugName().c_str());
22-
ctx->value_tensor_map[out_value] = out_tensor;
23-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
24-
25-
return true;
26-
}
27-
28-
auto relu_registrations = RegisterNodeConversionPatterns()
29-
.pattern({
30-
"aten::relu(Tensor input) -> (Tensor)",
31-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
32-
return relu(ctx, n, args);
33-
}
34-
}).pattern({
35-
"aten::relu_(Tensor(a!) self) -> (Tensor(a!))",
36-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
37-
return relu(ctx, n, args);
38-
}
39-
});
44+
#undef convert
4045
} // namespace
4146
} // namespace impl
4247
} // namespace converters
4348
} // namespace conversion
4449
} // namespace core
45-
} // trtorch
50+
} // namespace trtorch

core/util/macros.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,11 @@
5353
TRTORCH_THROW_ERROR("Expected " << #cond \
5454
<< " to be true but got false\n" << __VA_ARGS__); \
5555
}
56+
57+
58+
// suppress an unused variable.
59+
#if defined(_MSC_VER) && !defined(__clang__)
60+
#define TRTORCH_UNUSED __pragma(warning(suppress: 4100 4101))
61+
#else
62+
#define TRTORCH_UNUSED __attribute__((__unused__))
63+
#endif //_MSC_VER

tests/core/converters/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ converter_test(
55
)
66

77
converter_test(
8-
name = "test_relu"
8+
name = "test_activation"
99
)
1010

1111
converter_test(
@@ -36,7 +36,7 @@ test_suite(
3636
name = "test_converters",
3737
tests = [
3838
":test_softmax",
39-
":test_relu",
39+
":test_activation",
4040
":test_pooling",
4141
":test_unary",
4242
":test_linear",
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenReLUConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%3 : Tensor = aten::relu(%0)
11+
return (%3))IR";
12+
13+
auto g = std::make_shared<torch::jit::Graph>();
14+
torch::jit::script::parseIR(graph, &*g);
15+
16+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
17+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
18+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
19+
20+
in = at::clone(in);
21+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
22+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
23+
24+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
25+
}
26+
27+
TEST(Converters, ATenSigmoidConvertsCorrectly) {
28+
const auto graph = R"IR(
29+
graph(%0 : Tensor):
30+
%3 : Tensor = aten::sigmoid(%0)
31+
return (%3))IR";
32+
33+
auto g = std::make_shared<torch::jit::Graph>();
34+
torch::jit::script::parseIR(graph, &*g);
35+
36+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
37+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
38+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
39+
40+
in = at::clone(in);
41+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
42+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
43+
44+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
45+
}
46+
47+
TEST(Converters, ATenTanhConvertsCorrectly) {
48+
const auto graph = R"IR(
49+
graph(%0 : Tensor):
50+
%3 : Tensor = aten::tanh(%0)
51+
return (%3))IR";
52+
53+
auto g = std::make_shared<torch::jit::Graph>();
54+
torch::jit::script::parseIR(graph, &*g);
55+
56+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
57+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
58+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
59+
60+
in = at::clone(in);
61+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
62+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
63+
64+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
65+
}

tests/core/converters/test_relu.cpp

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)