Skip to content

Commit 9a054ce

Browse files
committed
feat: support many elementwise dynamo converters
1 parent e58f319 commit 9a054ce

16 files changed

+974
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,256 @@ def aten_ops_expand(
438438
args[0],
439439
args[1],
440440
)
441+
442+
443+
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
444+
def aten_ops_add(
445+
network: TRTNetwork,
446+
target: Target,
447+
args: Tuple[Argument, ...],
448+
kwargs: Dict[str, Argument],
449+
name: str,
450+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
451+
return impl.elementwise.add(
452+
network,
453+
target,
454+
SourceIR.ATEN,
455+
name,
456+
args[0],
457+
args[1],
458+
)
459+
460+
461+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
462+
def aten_ops_mul(
463+
network: TRTNetwork,
464+
target: Target,
465+
args: Tuple[Argument, ...],
466+
kwargs: Dict[str, Argument],
467+
name: str,
468+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
469+
return impl.elementwise.mul(
470+
network,
471+
target,
472+
SourceIR.ATEN,
473+
name,
474+
args[0],
475+
args[1],
476+
)
477+
478+
479+
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
480+
def aten_ops_max(
481+
network: TRTNetwork,
482+
target: Target,
483+
args: Tuple[Argument, ...],
484+
kwargs: Dict[str, Argument],
485+
name: str,
486+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
487+
return impl.elementwise.max(
488+
network,
489+
target,
490+
SourceIR.ATEN,
491+
name,
492+
args[0],
493+
args[1],
494+
)
495+
496+
497+
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
498+
def aten_ops_min(
499+
network: TRTNetwork,
500+
target: Target,
501+
args: Tuple[Argument, ...],
502+
kwargs: Dict[str, Argument],
503+
name: str,
504+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
505+
return impl.elementwise.min(
506+
network,
507+
target,
508+
SourceIR.ATEN,
509+
name,
510+
args[0],
511+
args[1],
512+
)
513+
514+
515+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
516+
def aten_ops_sub(
517+
network: TRTNetwork,
518+
target: Target,
519+
args: Tuple[Argument, ...],
520+
kwargs: Dict[str, Argument],
521+
name: str,
522+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
523+
return impl.elementwise.sub(
524+
network,
525+
target,
526+
SourceIR.ATEN,
527+
name,
528+
args[0],
529+
args[1],
530+
)
531+
532+
533+
# TODO: keep this or line 54...?
534+
# @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
535+
# def aten_ops_div(
536+
# network: TRTNetwork,
537+
# target: Target,
538+
# args: Tuple[Argument, ...],
539+
# kwargs: Dict[str, Argument],
540+
# name: str,
541+
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
542+
# return impl.elementwise.div(
543+
# network,
544+
# target,
545+
# SourceIR.ATEN,
546+
# name,
547+
# args[0],
548+
# args[1],
549+
# )
550+
551+
552+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
553+
def aten_ops_pow(
554+
network: TRTNetwork,
555+
target: Target,
556+
args: Tuple[Argument, ...],
557+
kwargs: Dict[str, Argument],
558+
name: str,
559+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
560+
return impl.elementwise.pow(
561+
network,
562+
target,
563+
SourceIR.ATEN,
564+
name,
565+
args[0],
566+
args[1],
567+
)
568+
569+
570+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
571+
def aten_ops_floor_div(
572+
network: TRTNetwork,
573+
target: Target,
574+
args: Tuple[Argument, ...],
575+
kwargs: Dict[str, Argument],
576+
name: str,
577+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
578+
return impl.elementwise.floor_divide(
579+
network,
580+
target,
581+
SourceIR.ATEN,
582+
name,
583+
args[0],
584+
args[1],
585+
)
586+
587+
588+
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
589+
def aten_ops_logical_and(
590+
network: TRTNetwork,
591+
target: Target,
592+
args: Tuple[Argument, ...],
593+
kwargs: Dict[str, Argument],
594+
name: str,
595+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
596+
return impl.elementwise.logical_and(
597+
network,
598+
target,
599+
SourceIR.ATEN,
600+
name,
601+
args[0],
602+
args[1],
603+
)
604+
605+
606+
@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
607+
def aten_ops_logical_or(
608+
network: TRTNetwork,
609+
target: Target,
610+
args: Tuple[Argument, ...],
611+
kwargs: Dict[str, Argument],
612+
name: str,
613+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
614+
return impl.elementwise.logical_or(
615+
network,
616+
target,
617+
SourceIR.ATEN,
618+
name,
619+
args[0],
620+
args[1],
621+
)
622+
623+
624+
@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
625+
def aten_ops_logical_xor(
626+
network: TRTNetwork,
627+
target: Target,
628+
args: Tuple[Argument, ...],
629+
kwargs: Dict[str, Argument],
630+
name: str,
631+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
632+
return impl.elementwise.logical_xor(
633+
network,
634+
target,
635+
SourceIR.ATEN,
636+
name,
637+
args[0],
638+
args[1],
639+
)
640+
641+
642+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
643+
def aten_ops_equal(
644+
network: TRTNetwork,
645+
target: Target,
646+
args: Tuple[Argument, ...],
647+
kwargs: Dict[str, Argument],
648+
name: str,
649+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
650+
return impl.elementwise.eq(
651+
network,
652+
target,
653+
SourceIR.ATEN,
654+
name,
655+
args[0],
656+
args[1],
657+
)
658+
659+
660+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
661+
def aten_ops_greater(
662+
network: TRTNetwork,
663+
target: Target,
664+
args: Tuple[Argument, ...],
665+
kwargs: Dict[str, Argument],
666+
name: str,
667+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
668+
return impl.elementwise.gt(
669+
network,
670+
target,
671+
SourceIR.ATEN,
672+
name,
673+
args[0],
674+
args[1],
675+
)
676+
677+
678+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
679+
def aten_ops_less(
680+
network: TRTNetwork,
681+
target: Target,
682+
args: Tuple[Argument, ...],
683+
kwargs: Dict[str, Argument],
684+
name: str,
685+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
686+
return impl.elementwise.lt(
687+
network,
688+
target,
689+
SourceIR.ATEN,
690+
name,
691+
args[0],
692+
args[1],
693+
)

0 commit comments

Comments
 (0)