Skip to content

Commit 3da4947

Browse files
authored
Merge pull request #18 from bddppq/simple-activations
Add support for sigmoid and tanh
2 parents 31e3964 + 42064f0 commit 3da4947

File tree

5 files changed

+107
-54
lines changed

5 files changed

+107
-54
lines changed

core/conversion/converters/impl/activation.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,42 @@ namespace converters {
88
namespace impl {
99
namespace {
1010

11-
bool relu(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
12-
auto in = args[0].ITensor();
11+
#define convert(act, trt_type) \
12+
bool act(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { \
13+
auto in = args[0].ITensor(); \
14+
\
15+
auto new_layer = \
16+
ctx->net->addActivation(*in, nvinfer1::ActivationType::trt_type); \
17+
TRTORCH_CHECK(new_layer, \
18+
"Unable to create " #act " layer from node: " << *n); \
19+
\
20+
new_layer->setName(util::node_info(n).c_str()); \
21+
auto out_value = n->outputs()[0]; \
22+
auto out_tensor = new_layer->getOutput(0); \
23+
out_tensor->setName(out_value->debugName().c_str()); \
24+
ctx->value_tensor_map[out_value] = out_tensor; \
25+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); \
26+
\
27+
return true; \
28+
} \
29+
\
30+
auto act##_registrations TRTORCH_UNUSED = \
31+
RegisterNodeConversionPatterns() \
32+
.pattern({"aten::" #act "(Tensor input) -> (Tensor)", \
33+
[](ConversionCtx *ctx, const torch::jit::Node *n, \
34+
args &args) -> bool { return act(ctx, n, args); }}) \
35+
.pattern({"aten::" #act "_(Tensor(a!) self) -> (Tensor(a!))", \
36+
[](ConversionCtx *ctx, const torch::jit::Node *n, \
37+
args &args) -> bool { return act(ctx, n, args); }});
1338

14-
auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kRELU);
15-
TRTORCH_CHECK(new_layer, "Unable to create ReLU layer from node: " << *n);
39+
convert(relu, kRELU);
40+
convert(sigmoid, kSIGMOID);
41+
convert(tanh, kTANH);
1642

17-
new_layer->setName(util::node_info(n).c_str());
18-
auto out_value = n->outputs()[0];
19-
auto out_tensor = new_layer->getOutput(0);
20-
out_tensor->setName(out_value->debugName().c_str());
21-
ctx->value_tensor_map[out_value] = out_tensor;
22-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
23-
24-
return true;
25-
}
26-
27-
auto relu_registrations = RegisterNodeConversionPatterns()
28-
.pattern({
29-
"aten::relu(Tensor input) -> (Tensor)",
30-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
31-
return relu(ctx, n, args);
32-
}
33-
}).pattern({
34-
"aten::relu_(Tensor(a!) self) -> (Tensor(a!))",
35-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
36-
return relu(ctx, n, args);
37-
}
38-
});
43+
#undef convert
3944
} // namespace
4045
} // namespace impl
4146
} // namespace converters
4247
} // namespace conversion
4348
} // namespace core
44-
} // trtorch
49+
} // namespace trtorch

core/util/macros.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,11 @@
5353
TRTORCH_THROW_ERROR("Expected " << #cond \
5454
<< " to be true but got false\n" << __VA_ARGS__); \
5555
}
56+
57+
58+
// suppress an unused variable.
59+
#if defined(_MSC_VER) && !defined(__clang__)
60+
#define TRTORCH_UNUSED __pragma(warning(suppress: 4100 4101))
61+
#else
62+
#define TRTORCH_UNUSED __attribute__((__unused__))
63+
#endif //_MSC_VER

tests/core/converters/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ converter_test(
55
)
66

77
converter_test(
8-
name = "test_relu"
8+
name = "test_activation"
99
)
1010

1111
converter_test(
@@ -36,7 +36,7 @@ test_suite(
3636
name = "test_converters",
3737
tests = [
3838
":test_softmax",
39-
":test_relu",
39+
":test_activation",
4040
":test_pooling",
4141
":test_unary",
4242
":test_linear",
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenReLUConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%3 : Tensor = aten::relu(%0)
11+
return (%3))IR";
12+
13+
auto g = std::make_shared<torch::jit::Graph>();
14+
torch::jit::script::parseIR(graph, &*g);
15+
16+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
17+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
18+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
19+
20+
in = at::clone(in);
21+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
22+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
23+
24+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
25+
}
26+
27+
TEST(Converters, ATenSigmoidConvertsCorrectly) {
28+
const auto graph = R"IR(
29+
graph(%0 : Tensor):
30+
%3 : Tensor = aten::sigmoid(%0)
31+
return (%3))IR";
32+
33+
auto g = std::make_shared<torch::jit::Graph>();
34+
torch::jit::script::parseIR(graph, &*g);
35+
36+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
37+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
38+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
39+
40+
in = at::clone(in);
41+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
42+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
43+
44+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
45+
}
46+
47+
TEST(Converters, ATenTanhConvertsCorrectly) {
48+
const auto graph = R"IR(
49+
graph(%0 : Tensor):
50+
%3 : Tensor = aten::tanh(%0)
51+
return (%3))IR";
52+
53+
auto g = std::make_shared<torch::jit::Graph>();
54+
torch::jit::script::parseIR(graph, &*g);
55+
56+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
57+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
58+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
59+
60+
in = at::clone(in);
61+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
62+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
63+
64+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
65+
}

tests/core/converters/test_relu.cpp

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)