Skip to content

Commit 3736a7e

Browse files
authored
Merge pull request #1705 from mfeliz-cruise/michael.feliz/aten_logical_not
[feat] Add converter support for aten::logical_not
2 parents 4cd892b + d3cc1ac commit 3736a7e

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

core/conversion/converters/impl/unary.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().
3434
return true;
3535
}});
3636

37+
auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
38+
{"aten::logical_not(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
39+
auto in = args[0].ITensorOrFreeze(ctx);
40+
if (in->getType() != nvinfer1::DataType::kBOOL) {
41+
// unary not layer only supports bool inputs
42+
in = castITensor(ctx, in, nvinfer1::DataType::kBOOL, util::node_info(n).c_str());
43+
}
44+
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT);
45+
TORCHTRT_CHECK(unary_layer, "Unable to create logical_not layer from node: " << *n);
46+
unary_layer->setName(util::node_info(n).c_str());
47+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
48+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
49+
return true;
50+
}});
51+
3752
#define convert(unary, trt_type) \
3853
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
3954
{"aten::" #unary "(Tensor self) -> Tensor", \

tests/core/conversion/converters/test_unary.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,21 @@ TEST(Converters, ATenSignConvertsZerosCorrectly) {
8181
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
8282
}
8383

84+
TEST(Converters, ATenLogicalNotBoolConvertsCorrectly) {
85+
const auto graph = gen_test_graph("logical_not");
86+
auto g = std::make_shared<torch::jit::Graph>();
87+
torch::jit::parseIR(graph, g.get());
88+
auto in = at::randint(0, 2, {7, 3, 1, 5}, {at::kCUDA}).to(torch::kBool);
89+
90+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
91+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
92+
93+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
94+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
95+
96+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
97+
}
98+
8499
#define test_unary(unary, name) \
85100
TEST(Converters, ATen##name##ConvertsCorrectly) { \
86101
const auto graph = gen_test_graph(#unary); \
@@ -122,5 +137,6 @@ test_unary(erf, Erf);
122137
test_unary(asinh, Asinh);
123138
test_unary(acosh, Acosh);
124139
test_unary(atanh, Atanh);
140+
test_unary(logical_not, LogicalNot);
125141

126142
#undef test_unary

0 commit comments

Comments
 (0)