@@ -290,7 +290,6 @@ auto element_wise_registrations TRTORCH_UNUSED =
290
290
}})
291
291
.pattern({" aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
292
292
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
293
- // TODO: Remove with functionalization
294
293
auto self = args[0 ].ITensorOrFreeze (ctx);
295
294
auto other = args[1 ].ITensorOrFreeze (ctx);
296
295
auto equal = add_elementwise (
@@ -354,7 +353,6 @@ auto element_wise_registrations TRTORCH_UNUSED =
354
353
}})
355
354
.pattern({" aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)" ,
356
355
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
357
- // TODO: Remove with functionalization
358
356
auto self = args[0 ].ITensorOrFreeze (ctx);
359
357
auto exponent = args[1 ].ITensorOrFreeze (ctx);
360
358
auto pow =
@@ -369,7 +367,6 @@ auto element_wise_registrations TRTORCH_UNUSED =
369
367
}})
370
368
.pattern({" aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)" ,
371
369
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
372
- // TODO: Remove with functionalization
373
370
auto self = args[0 ].ITensorOrFreeze (ctx);
374
371
auto exponentScalar = args[1 ].unwrapToScalar ().to <float >();
375
372
auto exponent = tensor_to_const (ctx, torch::tensor ({exponentScalar}));
@@ -380,6 +377,181 @@ auto element_wise_registrations TRTORCH_UNUSED =
380
377
pow->setName (util::node_info (n).c_str ());
381
378
auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], pow->getOutput (0 ));
382
379
380
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
381
+ return true ;
382
+ }})
383
+ .pattern({" aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
384
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
385
+ auto self = args[0 ].ITensorOrFreeze (ctx);
386
+ auto other = args[1 ].ITensorOrFreeze (ctx);
387
+ auto gt =
388
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n));
389
+ TRTORCH_CHECK (gt, " Unable to create greater layer from node: " << *n);
390
+
391
+ gt->setName (util::node_info (n).c_str ());
392
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gt->getOutput (0 ));
393
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
394
+ return true ;
395
+ }})
396
+ .pattern({" aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
397
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
398
+ auto self = args[0 ].ITensorOrFreeze (ctx);
399
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
400
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
401
+ auto gt =
402
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n));
403
+ TRTORCH_CHECK (gt, " Unable to create greater layer from node: " << *n);
404
+
405
+ gt->setName (util::node_info (n).c_str ());
406
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gt->getOutput (0 ));
407
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
408
+ return true ;
409
+ }})
410
+ .pattern({" aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
411
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
412
+ auto self = args[0 ].ITensorOrFreeze (ctx);
413
+ auto other = args[1 ].ITensorOrFreeze (ctx);
414
+ auto lt =
415
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n));
416
+ TRTORCH_CHECK (lt, " Unable to create less layer from node: " << *n);
417
+
418
+ lt->setName (util::node_info (n).c_str ());
419
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], lt->getOutput (0 ));
420
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
421
+ return true ;
422
+ }})
423
+ .pattern({" aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
424
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
425
+ auto self = args[0 ].ITensorOrFreeze (ctx);
426
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
427
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
428
+ auto lt =
429
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n));
430
+ TRTORCH_CHECK (lt, " Unable to create less layer from node: " << *n);
431
+
432
+ lt->setName (util::node_info (n).c_str ());
433
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], lt->getOutput (0 ));
434
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
435
+ return true ;
436
+ }})
437
+ .pattern({" aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
438
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
439
+ auto self = args[0 ].ITensorOrFreeze (ctx);
440
+ auto other = args[1 ].ITensorOrFreeze (ctx);
441
+ auto eq =
442
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n));
443
+ TRTORCH_CHECK (eq, " Unable to create equal layer from node: " << *n);
444
+
445
+ eq->setName (util::node_info (n).c_str ());
446
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], eq->getOutput (0 ));
447
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
448
+ return true ;
449
+ }})
450
+ .pattern({" aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
451
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
452
+ auto self = args[0 ].ITensorOrFreeze (ctx);
453
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
454
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
455
+ auto eq =
456
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n));
457
+ TRTORCH_CHECK (eq, " Unable to create equal layer from node: " << *n);
458
+
459
+ eq->setName (util::node_info (n).c_str ());
460
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], eq->getOutput (0 ));
461
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
462
+ return true ;
463
+ }})
464
+ .pattern({" aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
465
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
466
+ auto self = args[0 ].ITensorOrFreeze (ctx);
467
+ auto other = args[1 ].ITensorOrFreeze (ctx);
468
+
469
+ auto greater = add_elementwise (
470
+ ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n) + " _greater" );
471
+ TRTORCH_CHECK (greater, " Unable to create Greater layer from node: " << *n);
472
+
473
+ auto equal = add_elementwise (
474
+ ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n) + " _equal" );
475
+ TRTORCH_CHECK (equal, " Unable to create Equal layer from node: " << *n);
476
+
477
+ auto or_op = ctx->net ->addElementWise (
478
+ *greater->getOutput (0 ), *equal->getOutput (0 ), nvinfer1::ElementWiseOperation::kOR );
479
+
480
+ TRTORCH_CHECK (or_op, " Unable to create Or layer from node: " << *n);
481
+ or_op->setName (util::node_info (n).c_str ());
482
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], or_op->getOutput (0 ));
483
+
484
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
485
+ return true ;
486
+ }})
487
+ .pattern({" aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
488
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
489
+ auto self = args[0 ].ITensorOrFreeze (ctx);
490
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
491
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
492
+
493
+ auto greater = add_elementwise (
494
+ ctx, nvinfer1::ElementWiseOperation::kGREATER , self, other, util::node_info (n) + " _greater" );
495
+ TRTORCH_CHECK (greater, " Unable to create Greater layer from node: " << *n);
496
+
497
+ auto equal = add_elementwise (
498
+ ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n) + " _equal" );
499
+ TRTORCH_CHECK (equal, " Unable to create Equal layer from node: " << *n);
500
+
501
+ auto or_op = ctx->net ->addElementWise (
502
+ *greater->getOutput (0 ), *equal->getOutput (0 ), nvinfer1::ElementWiseOperation::kOR );
503
+
504
+ TRTORCH_CHECK (or_op, " Unable to create Or layer from node: " << *n);
505
+ or_op->setName (util::node_info (n).c_str ());
506
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], or_op->getOutput (0 ));
507
+
508
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
509
+ return true ;
510
+ }})
511
+ .pattern({" aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)" ,
512
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
513
+ auto self = args[0 ].ITensorOrFreeze (ctx);
514
+ auto other = args[1 ].ITensorOrFreeze (ctx);
515
+
516
+ auto less = add_elementwise (
517
+ ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n) + " _less" );
518
+ TRTORCH_CHECK (less, " Unable to create Less layer from node: " << *n);
519
+
520
+ auto equal = add_elementwise (
521
+ ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n) + " _equal" );
522
+ TRTORCH_CHECK (equal, " Unable to create Equal layer from node: " << *n);
523
+
524
+ auto or_op = ctx->net ->addElementWise (
525
+ *less->getOutput (0 ), *equal->getOutput (0 ), nvinfer1::ElementWiseOperation::kOR );
526
+
527
+ TRTORCH_CHECK (or_op, " Unable to create Or layer from node: " << *n);
528
+ or_op->setName (util::node_info (n).c_str ());
529
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], or_op->getOutput (0 ));
530
+
531
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
532
+ return true ;
533
+ }})
534
+ .pattern({" aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
535
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
536
+ auto self = args[0 ].ITensorOrFreeze (ctx);
537
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
538
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
539
+
540
+ auto less = add_elementwise (
541
+ ctx, nvinfer1::ElementWiseOperation::kLESS , self, other, util::node_info (n) + " _less" );
542
+ TRTORCH_CHECK (less, " Unable to create Less layer from node: " << *n);
543
+
544
+ auto equal = add_elementwise (
545
+ ctx, nvinfer1::ElementWiseOperation::kEQUAL , self, other, util::node_info (n) + " _equal" );
546
+ TRTORCH_CHECK (equal, " Unable to create Equal layer from node: " << *n);
547
+
548
+ auto or_op = ctx->net ->addElementWise (
549
+ *less->getOutput (0 ), *equal->getOutput (0 ), nvinfer1::ElementWiseOperation::kOR );
550
+
551
+ TRTORCH_CHECK (or_op, " Unable to create Or layer from node: " << *n);
552
+ or_op->setName (util::node_info (n).c_str ());
553
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], or_op->getOutput (0 ));
554
+
383
555
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
384
556
return true ;
385
557
}});
0 commit comments