Skip to content

Commit 9081478

Browse files
authored
Merge pull request #304 from NVIDIA/logic_operators
Adding support for basic logic operators
2 parents 23f8e9d + a5491a5 commit 9081478

File tree

2 files changed

+260
-3
lines changed

2 files changed

+260
-3
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 175 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ auto element_wise_registrations TRTORCH_UNUSED =
290290
}})
291291
.pattern({"aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)",
292292
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
293-
// TODO: Remove with functionalization
294293
auto self = args[0].ITensorOrFreeze(ctx);
295294
auto other = args[1].ITensorOrFreeze(ctx);
296295
auto equal = add_elementwise(
@@ -354,7 +353,6 @@ auto element_wise_registrations TRTORCH_UNUSED =
354353
}})
355354
.pattern({"aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)",
356355
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
357-
// TODO: Remove with functionalization
358356
auto self = args[0].ITensorOrFreeze(ctx);
359357
auto exponent = args[1].ITensorOrFreeze(ctx);
360358
auto pow =
@@ -369,7 +367,6 @@ auto element_wise_registrations TRTORCH_UNUSED =
369367
}})
370368
.pattern({"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)",
371369
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
372-
// TODO: Remove with functionalization
373370
auto self = args[0].ITensorOrFreeze(ctx);
374371
auto exponentScalar = args[1].unwrapToScalar().to<float>();
375372
auto exponent = tensor_to_const(ctx, torch::tensor({exponentScalar}));
@@ -380,6 +377,181 @@ auto element_wise_registrations TRTORCH_UNUSED =
380377
pow->setName(util::node_info(n).c_str());
381378
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));
382379

380+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
381+
return true;
382+
}})
383+
.pattern({"aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)",
384+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
385+
auto self = args[0].ITensorOrFreeze(ctx);
386+
auto other = args[1].ITensorOrFreeze(ctx);
387+
auto gt =
388+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n));
389+
TRTORCH_CHECK(gt, "Unable to create greater layer from node: " << *n);
390+
391+
gt->setName(util::node_info(n).c_str());
392+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gt->getOutput(0));
393+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
394+
return true;
395+
}})
396+
.pattern({"aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)",
397+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
398+
auto self = args[0].ITensorOrFreeze(ctx);
399+
auto otherScalar = args[1].unwrapToScalar().to<float>();
400+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
401+
auto gt =
402+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n));
403+
TRTORCH_CHECK(gt, "Unable to create greater layer from node: " << *n);
404+
405+
gt->setName(util::node_info(n).c_str());
406+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gt->getOutput(0));
407+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
408+
return true;
409+
}})
410+
.pattern({"aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)",
411+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
412+
auto self = args[0].ITensorOrFreeze(ctx);
413+
auto other = args[1].ITensorOrFreeze(ctx);
414+
auto lt =
415+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n));
416+
TRTORCH_CHECK(lt, "Unable to create less layer from node: " << *n);
417+
418+
lt->setName(util::node_info(n).c_str());
419+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], lt->getOutput(0));
420+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
421+
return true;
422+
}})
423+
.pattern({"aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)",
424+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
425+
auto self = args[0].ITensorOrFreeze(ctx);
426+
auto otherScalar = args[1].unwrapToScalar().to<float>();
427+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
428+
auto lt =
429+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n));
430+
TRTORCH_CHECK(lt, "Unable to create less layer from node: " << *n);
431+
432+
lt->setName(util::node_info(n).c_str());
433+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], lt->getOutput(0));
434+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
435+
return true;
436+
}})
437+
.pattern({"aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)",
438+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
439+
auto self = args[0].ITensorOrFreeze(ctx);
440+
auto other = args[1].ITensorOrFreeze(ctx);
441+
auto eq =
442+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n));
443+
TRTORCH_CHECK(eq, "Unable to create equal layer from node: " << *n);
444+
445+
eq->setName(util::node_info(n).c_str());
446+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], eq->getOutput(0));
447+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
448+
return true;
449+
}})
450+
.pattern({"aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)",
451+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
452+
auto self = args[0].ITensorOrFreeze(ctx);
453+
auto otherScalar = args[1].unwrapToScalar().to<float>();
454+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
455+
auto eq =
456+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n));
457+
TRTORCH_CHECK(eq, "Unable to create equal layer from node: " << *n);
458+
459+
eq->setName(util::node_info(n).c_str());
460+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], eq->getOutput(0));
461+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
462+
return true;
463+
}})
464+
.pattern({"aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)",
465+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
466+
auto self = args[0].ITensorOrFreeze(ctx);
467+
auto other = args[1].ITensorOrFreeze(ctx);
468+
469+
auto greater = add_elementwise(
470+
ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n) + "_greater");
471+
TRTORCH_CHECK(greater, "Unable to create Greater layer from node: " << *n);
472+
473+
auto equal = add_elementwise(
474+
ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n) + "_equal");
475+
TRTORCH_CHECK(equal, "Unable to create Equal layer from node: " << *n);
476+
477+
auto or_op = ctx->net->addElementWise(
478+
*greater->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR);
479+
480+
TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n);
481+
or_op->setName(util::node_info(n).c_str());
482+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0));
483+
484+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
485+
return true;
486+
}})
487+
.pattern({"aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)",
488+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
489+
auto self = args[0].ITensorOrFreeze(ctx);
490+
auto otherScalar = args[1].unwrapToScalar().to<float>();
491+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
492+
493+
auto greater = add_elementwise(
494+
ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n) + "_greater");
495+
TRTORCH_CHECK(greater, "Unable to create Greater layer from node: " << *n);
496+
497+
auto equal = add_elementwise(
498+
ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n) + "_equal");
499+
TRTORCH_CHECK(equal, "Unable to create Equal layer from node: " << *n);
500+
501+
auto or_op = ctx->net->addElementWise(
502+
*greater->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR);
503+
504+
TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n);
505+
or_op->setName(util::node_info(n).c_str());
506+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0));
507+
508+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
509+
return true;
510+
}})
511+
.pattern({"aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)",
512+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
513+
auto self = args[0].ITensorOrFreeze(ctx);
514+
auto other = args[1].ITensorOrFreeze(ctx);
515+
516+
auto less = add_elementwise(
517+
ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n) + "_less");
518+
TRTORCH_CHECK(less, "Unable to create Less layer from node: " << *n);
519+
520+
auto equal = add_elementwise(
521+
ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n) + "_equal");
522+
TRTORCH_CHECK(equal, "Unable to create Equal layer from node: " << *n);
523+
524+
auto or_op = ctx->net->addElementWise(
525+
*less->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR);
526+
527+
TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n);
528+
or_op->setName(util::node_info(n).c_str());
529+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0));
530+
531+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
532+
return true;
533+
}})
534+
.pattern({"aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)",
535+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
536+
auto self = args[0].ITensorOrFreeze(ctx);
537+
auto otherScalar = args[1].unwrapToScalar().to<float>();
538+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
539+
540+
auto less = add_elementwise(
541+
ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n) + "_less");
542+
TRTORCH_CHECK(less, "Unable to create Less layer from node: " << *n);
543+
544+
auto equal = add_elementwise(
545+
ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n) + "_equal");
546+
TRTORCH_CHECK(equal, "Unable to create Equal layer from node: " << *n);
547+
548+
auto or_op = ctx->net->addElementWise(
549+
*less->getOutput(0), *equal->getOutput(0), nvinfer1::ElementWiseOperation::kOR);
550+
551+
TRTORCH_CHECK(or_op, "Unable to create Or layer from node: " << *n);
552+
or_op->setName(util::node_info(n).c_str());
553+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], or_op->getOutput(0));
554+
383555
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
384556
return true;
385557
}});

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,88 @@ TEST(Converters, ATenNeScalarConvertsCorrectly) {
172172
pointwise_test_helper(graph, true, false, {3, 4, 2});
173173
;
174174
}
175+
176+
TEST(Converters, ATenGreaterThanConvertsCorrectly) {
177+
const auto graph = R"IR(
178+
graph(%0 : Tensor, %1 : Tensor):
179+
%2 : Tensor = aten::gt(%0, %1)
180+
return (%2))IR";
181+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
182+
}
183+
184+
TEST(Converters, ATenGreaterThanScalarConvertsCorrectly) {
185+
const auto graph = R"IR(
186+
graph(%0 : Tensor):
187+
%scalar : float = prim::Constant[value=3]()
188+
%2 : Tensor = aten::gt(%0, %scalar)
189+
return (%2))IR";
190+
pointwise_test_helper(graph, true, false, {5, 5});
191+
}
192+
193+
TEST(Converters, ATenLessThanConvertsCorrectly) {
194+
const auto graph = R"IR(
195+
graph(%0 : Tensor, %1 : Tensor):
196+
%2 : Tensor = aten::lt(%0, %1)
197+
return (%2))IR";
198+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
199+
}
200+
201+
TEST(Converters, ATenLessThanScalarConvertsCorrectly) {
202+
const auto graph = R"IR(
203+
graph(%0 : Tensor):
204+
%scalar : float = prim::Constant[value=3]()
205+
%2 : Tensor = aten::lt(%0, %scalar)
206+
return (%2))IR";
207+
pointwise_test_helper(graph, true, false, {5, 5});
208+
}
209+
210+
TEST(Converters, ATenEqualConvertsCorrectly) {
211+
const auto graph = R"IR(
212+
graph(%0 : Tensor, %1 : Tensor):
213+
%2 : Tensor = aten::eq(%0, %1)
214+
return (%2))IR";
215+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
216+
}
217+
218+
TEST(Converters, ATenEqualScalarConvertsCorrectly) {
219+
const auto graph = R"IR(
220+
graph(%0 : Tensor):
221+
%scalar : float = prim::Constant[value=3]()
222+
%2 : Tensor = aten::eq(%0, %scalar)
223+
return (%2))IR";
224+
pointwise_test_helper(graph, true, false, {5, 5});
225+
}
226+
227+
TEST(Converters, ATenGEConvertsCorrectly) {
228+
const auto graph = R"IR(
229+
graph(%0 : Tensor, %1 : Tensor):
230+
%2 : Tensor = aten::ge(%0, %1)
231+
return (%2))IR";
232+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
233+
}
234+
235+
TEST(Converters, ATenGEScalarConvertsCorrectly) {
236+
const auto graph = R"IR(
237+
graph(%0 : Tensor):
238+
%scalar : float = prim::Constant[value=3]()
239+
%2 : Tensor = aten::ge(%0, %scalar)
240+
return (%2))IR";
241+
pointwise_test_helper(graph, true, false, {5, 5});
242+
}
243+
244+
TEST(Converters, ATenLEConvertsCorrectly) {
245+
const auto graph = R"IR(
246+
graph(%0 : Tensor, %1 : Tensor):
247+
%2 : Tensor = aten::le(%0, %1)
248+
return (%2))IR";
249+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5});
250+
}
251+
252+
TEST(Converters, ATenLEScalarConvertsCorrectly) {
253+
const auto graph = R"IR(
254+
graph(%0 : Tensor):
255+
%scalar : float = prim::Constant[value=3]()
256+
%2 : Tensor = aten::le(%0, %scalar)
257+
return (%2))IR";
258+
pointwise_test_helper(graph, true, false, {5, 5});
259+
}

0 commit comments

Comments
 (0)