Skip to content

Commit 3e1cc88

Browse files
authored
Merge pull request #300 from NVIDIA/ne
Add support for ne layer
2 parents 885439c + 5edf0d4 commit 3e1cc88

File tree

2 files changed

+90
-6
lines changed

2 files changed

+90
-6
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,70 @@ auto element_wise_registrations TRTORCH_UNUSED =
260260
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
261261
return true;
262262
}})
263+
.pattern({"aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)",
264+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
265+
// TODO: Remove with functionalization
266+
auto self = args[0].ITensorOrFreeze(ctx);
267+
auto other = args[1].ITensorOrFreeze(ctx);
268+
auto equal = add_elementwise(
269+
ctx,
270+
nvinfer1::ElementWiseOperation::kEQUAL,
271+
self,
272+
other,
273+
util::node_info(n) + std::string("is_equal"));
274+
TRTORCH_CHECK(equal, "Unable to create elementwise equal layer from node: " << *n);
275+
// XOR with ones negates and produces not_equal result
276+
auto options = torch::TensorOptions().dtype(torch::kFloat32);
277+
auto ones = at::full({1}, 1, {options});
278+
auto ones_tensor = tensor_to_const(ctx, ones);
279+
nvinfer1::IIdentityLayer* cast_layer = ctx->net->addIdentity(*ones_tensor);
280+
cast_layer->setOutputType(0, nvinfer1::DataType::kBOOL);
281+
282+
auto sub = add_elementwise(
283+
ctx,
284+
nvinfer1::ElementWiseOperation::kXOR,
285+
cast_layer->getOutput(0),
286+
equal->getOutput(0),
287+
util::node_info(n));
288+
TRTORCH_CHECK(sub, "Unable to create ne (not equal) layer from node: " << *n);
289+
290+
sub->setName(util::node_info(n).c_str());
291+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0));
292+
LOG_DEBUG("Not equal layer output tensor shape: " << out->getDimensions());
293+
return true;
294+
}})
295+
.pattern({"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)",
296+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
297+
auto self = args[0].ITensorOrFreeze(ctx);
298+
auto scalar = args[1].unwrapToScalar().to<float>();
299+
auto scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar}));
300+
auto equal = add_elementwise(
301+
ctx,
302+
nvinfer1::ElementWiseOperation::kEQUAL,
303+
self,
304+
scalar_tensor,
305+
util::node_info(n) + std::string("is_equal"));
306+
TRTORCH_CHECK(equal, "Unable to create elementwise equal layer from node: " << *n);
307+
// XOR with ones negates and produces not_equal result
308+
auto options = torch::TensorOptions().dtype(torch::kFloat32);
309+
auto ones = at::full({1}, 1, {options});
310+
auto ones_tensor = tensor_to_const(ctx, ones);
311+
nvinfer1::IIdentityLayer* cast_layer = ctx->net->addIdentity(*ones_tensor);
312+
cast_layer->setOutputType(0, nvinfer1::DataType::kBOOL);
313+
314+
auto sub = add_elementwise(
315+
ctx,
316+
nvinfer1::ElementWiseOperation::kXOR,
317+
cast_layer->getOutput(0),
318+
equal->getOutput(0),
319+
util::node_info(n));
320+
TRTORCH_CHECK(sub, "Unable to create ne (not equal) layer from node: " << *n);
321+
322+
sub->setName(util::node_info(n).c_str());
323+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0));
324+
LOG_DEBUG("Not equal layer output tensor shape: " << out->getDimensions());
325+
return true;
326+
}})
263327
.pattern({"aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)",
264328
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
265329
// TODO: Remove with functionalization

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ TEST(Converters, ATenAddImplicitWithAlphaConvertsCorrectly) {
7676
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
7777
}
7878

79+
TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
80+
const auto graph = R"IR(
81+
graph(%0 : Tensor):
82+
%2 : int = prim::Constant[value=1]()
83+
%scalar : float = prim::Constant[value=2.4]()
84+
%3 : Tensor = aten::add(%0, %scalar, %2)
85+
return (%3))IR";
86+
pointwise_test_helper(graph, true);
87+
}
88+
7989
TEST(Converters, ATenSubConvertsCorrectly) {
8090
const auto graph = R"IR(
8191
graph(%0 : Tensor, %1 : Tensor):
@@ -134,12 +144,22 @@ TEST(Converters, ATenPowScalarConvertsCorrectly) {
134144
pointwise_test_helper(graph, true);
135145
}
136146

137-
TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
147+
TEST(Converters, ATenNeTensorConvertsCorrectly) {
138148
const auto graph = R"IR(
139-
graph(%0 : Tensor):
140-
%2 : int = prim::Constant[value=1]()
141-
%scalar : float = prim::Constant[value=2.4]()
142-
%3 : Tensor = aten::add(%0, %scalar, %2)
149+
graph(%x.1 : Tensor,
150+
%y.1 : Tensor):
151+
%3 : Tensor = aten::ne(%x.1, %y.1)
143152
return (%3))IR";
144-
pointwise_test_helper(graph, true);
153+
pointwise_test_helper(graph, false, false, {3, 4}, {3, 4});
154+
pointwise_test_helper(graph, false, true, {3, 4}, {3, 4});
155+
}
156+
157+
TEST(Converters, ATenNeScalarConvertsCorrectly) {
158+
const auto graph = R"IR(
159+
graph(%x.1 : Tensor):
160+
%2 : int = prim::Constant[value=2]()
161+
%3 : Tensor = aten::ne(%x.1, %2)
162+
return (%3))IR";
163+
pointwise_test_helper(graph, true, false, {3, 4, 2});
164+
;
145165
}

0 commit comments

Comments
 (0)