Skip to content

Commit a09e8fb

Browse files
authored
Merge pull request #24 from bddppq/unary
Expand coverage of unary ops
2 parents fe48049 + 92a98b7 commit a09e8fb

File tree

6 files changed

+86
-144
lines changed

6 files changed

+86
-144
lines changed

core/conversion/converters/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ cc_library(
1616
"impl/element_wise.cpp",
1717
"impl/linear.cpp",
1818
"impl/pooling.cpp",
19-
"impl/scale.cpp",
2019
"impl/softmax.cpp",
2120
"impl/unary.cpp",
2221
],

core/conversion/converters/impl/scale.cpp

Lines changed: 0 additions & 81 deletions
This file was deleted.

core/conversion/converters/impl/unary.cpp

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,53 @@ namespace converters {
88
namespace impl {
99
namespace {
1010

11-
auto unary_registrations = RegisterNodeConversionPatterns()
12-
.pattern({
13-
"aten::log(Tensor self) -> Tensor",
14-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15-
auto in = args[0].ITensor();
16-
auto log = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kLOG);
11+
#define convert(unary, trt_type) \
12+
auto unary##_registrations TRTORCH_UNUSED = \
13+
RegisterNodeConversionPatterns().pattern( \
14+
{"aten::" #unary "(Tensor self) -> Tensor", \
15+
[](ConversionCtx *ctx, const torch::jit::Node *n, \
16+
args &args) -> bool { \
17+
auto in = args[0].ITensor(); \
18+
auto unary = \
19+
ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \
20+
\
21+
TRTORCH_CHECK( \
22+
unary, \
23+
"Unable to create " #unary " layer from node: " << *n); \
24+
\
25+
unary->setName(util::node_info(n).c_str()); \
26+
auto out_value = n->outputs()[0]; \
27+
auto out_tensor = unary->getOutput(0); \
28+
out_tensor->setName(out_value->debugName().c_str()); \
29+
ctx->value_tensor_map[out_value] = out_tensor; \
30+
LOG_DEBUG( \
31+
"Output tensor shape: " << out_tensor->getDimensions()); \
32+
\
33+
return true; \
34+
}});
1735

18-
TRTORCH_CHECK(log, "Unable to create log layer from node: " << *n);
36+
convert(cos, kCOS);
37+
convert(acos, kACOS);
38+
convert(cosh, kCOSH);
39+
convert(sin, kSIN);
40+
convert(asin, kASIN);
41+
convert(sinh, kSINH);
42+
convert(tan, kTAN);
43+
convert(atan, kATAN);
44+
convert(abs, kABS);
45+
convert(floor, kFLOOR);
46+
convert(reciprocal, kRECIP);
47+
convert(log, kLOG);
48+
convert(ceil, kCEIL);
49+
convert(sqrt, kSQRT);
50+
convert(exp, kEXP);
51+
convert(neg, kNEG);
52+
53+
#undef convert
1954

20-
log->setName(util::node_info(n).c_str());
21-
auto out_value = n->outputs()[0];
22-
auto out_tensor = log->getOutput(0);
23-
out_tensor->setName(out_value->debugName().c_str());
24-
ctx->value_tensor_map[out_value] = out_tensor;
25-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
26-
27-
return true;
28-
}
29-
});
3055
} // namespace
3156
} // namespace impl
3257
} // namespace converters
3358
} // namespace conversion
3459
} // namespace core
35-
} // namespace trtorch
60+
} // namespace trtorch

tests/core/converters/BUILD

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ converter_test(
2828
name = "test_conv"
2929
)
3030

31-
converter_test(
32-
name = "test_scale"
33-
)
34-
3531
test_suite(
3632
name = "test_converters",
3733
tests = [
@@ -42,7 +38,6 @@ test_suite(
4238
":test_linear",
4339
":test_element_wise",
4440
":test_conv",
45-
":test_scale"
4641
]
4742
)
4843

tests/core/converters/test_scale.cpp

Lines changed: 0 additions & 25 deletions
This file was deleted.

tests/core/converters/test_unary.cpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,51 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7-
TEST(Converters, ATenLogConvertsCorrectly) {
8-
const auto graph = R"IR(
7+
namespace {
8+
std::string gen_test_graph(const std::string &unary) {
9+
return R"IR(
910
graph(%0 : Tensor):
10-
%3 : Tensor = aten::log(%0)
11+
%3 : Tensor = aten::)IR" +
12+
unary + R"IR((%0)
1113
return (%3))IR";
14+
}
15+
} // namespace
1216

13-
auto g = std::make_shared<torch::jit::Graph>();
14-
torch::jit::script::parseIR(graph, &*g);
15-
16-
auto in = at::randint(1, 5, {5}, {at::kCUDA});
17-
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
18-
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
17+
#define test_unary(unary, name) \
18+
TEST(Converters, ATen##name##ConvertsCorrectly) { \
19+
const auto graph = gen_test_graph(#unary); \
20+
\
21+
auto g = std::make_shared<torch::jit::Graph>(); \
22+
torch::jit::script::parseIR(graph, &*g); \
23+
\
24+
auto in = at::empty({10}, {at::kCUDA}).uniform_(0, 0.5); \
25+
auto params = \
26+
trtorch::core::conversion::get_named_params(g->inputs(), {}); \
27+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); \
28+
\
29+
in = at::clone(in); \
30+
params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \
31+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); \
32+
\
33+
ASSERT_TRUE( \
34+
trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); \
35+
}
1936

20-
in = at::clone(in);
21-
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
22-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
37+
test_unary(cos, Cos);
38+
test_unary(acos, Acos);
39+
test_unary(cosh, Cosh);
40+
test_unary(sin, Sin);
41+
test_unary(asin, Asin);
42+
test_unary(sinh, Sinh);
43+
test_unary(tan, Tan);
44+
test_unary(atan, Atan);
45+
test_unary(abs, Abs);
46+
test_unary(floor, Floor);
47+
test_unary(reciprocal, Reciprocal);
48+
test_unary(log, Log);
49+
test_unary(ceil, Ceil);
50+
test_unary(sqrt, Sqrt);
51+
test_unary(exp, Exp);
52+
test_unary(neg, Neg);
2353

24-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
25-
}
54+
#undef test_unary

0 commit comments

Comments
 (0)