Skip to content

Commit 181082b

Browse files
committed
Expand coverage of unary ops
Added support for: aten::cos aten::acos aten::cosh aten::sin aten::asin aten::sinh aten::tan aten::atan aten::abs aten::floor aten::reciprocal aten::ceil aten::sqrt aten::exp aten::neg Signed-off-by: Junjie Bai <[email protected]>
1 parent 8300d7c commit 181082b

File tree

2 files changed

+86
-32
lines changed

2 files changed

+86
-32
lines changed

core/conversion/converters/impl/unary.cpp

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,53 @@ namespace converters {
88
namespace impl {
99
namespace {
1010

11-
auto unary_registrations = RegisterNodeConversionPatterns()
12-
.pattern({
13-
"aten::log(Tensor self) -> Tensor",
14-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15-
auto in = args[0].ITensor();
16-
auto log = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kLOG);
11+
#define convert(unary, trt_type) \
12+
auto unary##_registrations TRTORCH_UNUSED = \
13+
RegisterNodeConversionPatterns().pattern( \
14+
{"aten::" #unary "(Tensor self) -> Tensor", \
15+
[](ConversionCtx *ctx, const torch::jit::Node *n, \
16+
args &args) -> bool { \
17+
auto in = args[0].ITensor(); \
18+
auto unary = \
19+
ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \
20+
\
21+
TRTORCH_CHECK( \
22+
unary, \
23+
"Unable to create " #unary " layer from node: " << *n); \
24+
\
25+
unary->setName(util::node_info(n).c_str()); \
26+
auto out_value = n->outputs()[0]; \
27+
auto out_tensor = unary->getOutput(0); \
28+
out_tensor->setName(out_value->debugName().c_str()); \
29+
ctx->value_tensor_map[out_value] = out_tensor; \
30+
LOG_DEBUG( \
31+
"Output tensor shape: " << out_tensor->getDimensions()); \
32+
\
33+
return true; \
34+
}});
1735

18-
TRTORCH_CHECK(log, "Unable to create log layer from node: " << *n);
36+
convert(cos, kCOS);
37+
convert(acos, kACOS);
38+
convert(cosh, kCOSH);
39+
convert(sin, kSIN);
40+
convert(asin, kASIN);
41+
convert(sinh, kSINH);
42+
convert(tan, kTAN);
43+
convert(atan, kATAN);
44+
convert(abs, kABS);
45+
convert(floor, kFLOOR);
46+
convert(reciprocal, kRECIP);
47+
convert(log, kLOG);
48+
convert(ceil, kCEIL);
49+
convert(sqrt, kSQRT);
50+
convert(exp, kEXP);
51+
convert(neg, kNEG);
52+
53+
#undef convert
1954

20-
log->setName(util::node_info(n).c_str());
21-
auto out_value = n->outputs()[0];
22-
auto out_tensor = log->getOutput(0);
23-
out_tensor->setName(out_value->debugName().c_str());
24-
ctx->value_tensor_map[out_value] = out_tensor;
25-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
26-
27-
return true;
28-
}
29-
});
3055
} // namespace
3156
} // namespace impl
3257
} // namespace converters
3358
} // namespace conversion
3459
} // namespace core
35-
} // namespace trtorch
60+
} // namespace trtorch

tests/core/converters/test_unary.cpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,51 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7-
TEST(Converters, ATenLogConvertsCorrectly) {
8-
const auto graph = R"IR(
7+
namespace {
8+
std::string gen_test_graph(const std::string &unary) {
9+
return R"IR(
910
graph(%0 : Tensor):
10-
%3 : Tensor = aten::log(%0)
11+
%3 : Tensor = aten::)IR" +
12+
unary + R"IR((%0)
1113
return (%3))IR";
14+
}
15+
} // namespace
1216

13-
auto g = std::make_shared<torch::jit::Graph>();
14-
torch::jit::script::parseIR(graph, &*g);
15-
16-
auto in = at::randint(1, 5, {5}, {at::kCUDA});
17-
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
18-
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
17+
#define test_unary(unary, name) \
18+
TEST(Converters, ATen##name##ConvertsCorrectly) { \
19+
const auto graph = gen_test_graph(#unary); \
20+
\
21+
auto g = std::make_shared<torch::jit::Graph>(); \
22+
torch::jit::script::parseIR(graph, &*g); \
23+
\
24+
auto in = at::empty({10}, {at::kCUDA}).uniform_(0, 0.5); \
25+
auto params = \
26+
trtorch::core::conversion::get_named_params(g->inputs(), {}); \
27+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); \
28+
\
29+
in = at::clone(in); \
30+
params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \
31+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); \
32+
\
33+
ASSERT_TRUE( \
34+
trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); \
35+
}
1936

20-
in = at::clone(in);
21-
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
22-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
37+
test_unary(cos, Cos);
38+
test_unary(acos, Acos);
39+
test_unary(cosh, Cosh);
40+
test_unary(sin, Sin);
41+
test_unary(asin, Asin);
42+
test_unary(sinh, Sinh);
43+
test_unary(tan, Tan);
44+
test_unary(atan, Atan);
45+
test_unary(abs, Abs);
46+
test_unary(floor, Floor);
47+
test_unary(reciprocal, Reciprocal);
48+
test_unary(log, Log);
49+
test_unary(ceil, Ceil);
50+
test_unary(sqrt, Sqrt);
51+
test_unary(exp, Exp);
52+
test_unary(neg, Neg);
2353

24-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
25-
}
54+
#undef test_unary

0 commit comments

Comments
 (0)