Skip to content

Commit a187bb9

Browse files
committed
feat: support many elementwise dynamo converters
add output_dtypes in test
1 parent 0e5a497 commit a187bb9

16 files changed

+973
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,3 +843,256 @@ def aten_ops_isinf(
843843
name,
844844
args[0],
845845
)
846+
847+
848+
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
849+
def aten_ops_add(
850+
network: TRTNetwork,
851+
target: Target,
852+
args: Tuple[Argument, ...],
853+
kwargs: Dict[str, Argument],
854+
name: str,
855+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
856+
return impl.elementwise.add(
857+
network,
858+
target,
859+
SourceIR.ATEN,
860+
name,
861+
args[0],
862+
args[1],
863+
)
864+
865+
866+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
867+
def aten_ops_mul(
868+
network: TRTNetwork,
869+
target: Target,
870+
args: Tuple[Argument, ...],
871+
kwargs: Dict[str, Argument],
872+
name: str,
873+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
874+
return impl.elementwise.mul(
875+
network,
876+
target,
877+
SourceIR.ATEN,
878+
name,
879+
args[0],
880+
args[1],
881+
)
882+
883+
884+
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
885+
def aten_ops_max(
886+
network: TRTNetwork,
887+
target: Target,
888+
args: Tuple[Argument, ...],
889+
kwargs: Dict[str, Argument],
890+
name: str,
891+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
892+
return impl.elementwise.max(
893+
network,
894+
target,
895+
SourceIR.ATEN,
896+
name,
897+
args[0],
898+
args[1],
899+
)
900+
901+
902+
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
903+
def aten_ops_min(
904+
network: TRTNetwork,
905+
target: Target,
906+
args: Tuple[Argument, ...],
907+
kwargs: Dict[str, Argument],
908+
name: str,
909+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
910+
return impl.elementwise.min(
911+
network,
912+
target,
913+
SourceIR.ATEN,
914+
name,
915+
args[0],
916+
args[1],
917+
)
918+
919+
920+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
921+
def aten_ops_sub(
922+
network: TRTNetwork,
923+
target: Target,
924+
args: Tuple[Argument, ...],
925+
kwargs: Dict[str, Argument],
926+
name: str,
927+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
928+
return impl.elementwise.sub(
929+
network,
930+
target,
931+
SourceIR.ATEN,
932+
name,
933+
args[0],
934+
args[1],
935+
)
936+
937+
938+
# TODO: keep this or line 54...?
939+
# @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
940+
# def aten_ops_div(
941+
# network: TRTNetwork,
942+
# target: Target,
943+
# args: Tuple[Argument, ...],
944+
# kwargs: Dict[str, Argument],
945+
# name: str,
946+
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
947+
# return impl.elementwise.div(
948+
# network,
949+
# target,
950+
# SourceIR.ATEN,
951+
# name,
952+
# args[0],
953+
# args[1],
954+
# )
955+
956+
957+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
958+
def aten_ops_pow(
959+
network: TRTNetwork,
960+
target: Target,
961+
args: Tuple[Argument, ...],
962+
kwargs: Dict[str, Argument],
963+
name: str,
964+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
965+
return impl.elementwise.pow(
966+
network,
967+
target,
968+
SourceIR.ATEN,
969+
name,
970+
args[0],
971+
args[1],
972+
)
973+
974+
975+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
976+
def aten_ops_floor_div(
977+
network: TRTNetwork,
978+
target: Target,
979+
args: Tuple[Argument, ...],
980+
kwargs: Dict[str, Argument],
981+
name: str,
982+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
983+
return impl.elementwise.floor_divide(
984+
network,
985+
target,
986+
SourceIR.ATEN,
987+
name,
988+
args[0],
989+
args[1],
990+
)
991+
992+
993+
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
994+
def aten_ops_logical_and(
995+
network: TRTNetwork,
996+
target: Target,
997+
args: Tuple[Argument, ...],
998+
kwargs: Dict[str, Argument],
999+
name: str,
1000+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1001+
return impl.elementwise.logical_and(
1002+
network,
1003+
target,
1004+
SourceIR.ATEN,
1005+
name,
1006+
args[0],
1007+
args[1],
1008+
)
1009+
1010+
1011+
@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
1012+
def aten_ops_logical_or(
1013+
network: TRTNetwork,
1014+
target: Target,
1015+
args: Tuple[Argument, ...],
1016+
kwargs: Dict[str, Argument],
1017+
name: str,
1018+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1019+
return impl.elementwise.logical_or(
1020+
network,
1021+
target,
1022+
SourceIR.ATEN,
1023+
name,
1024+
args[0],
1025+
args[1],
1026+
)
1027+
1028+
1029+
@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
1030+
def aten_ops_logical_xor(
1031+
network: TRTNetwork,
1032+
target: Target,
1033+
args: Tuple[Argument, ...],
1034+
kwargs: Dict[str, Argument],
1035+
name: str,
1036+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1037+
return impl.elementwise.logical_xor(
1038+
network,
1039+
target,
1040+
SourceIR.ATEN,
1041+
name,
1042+
args[0],
1043+
args[1],
1044+
)
1045+
1046+
1047+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1048+
def aten_ops_equal(
1049+
network: TRTNetwork,
1050+
target: Target,
1051+
args: Tuple[Argument, ...],
1052+
kwargs: Dict[str, Argument],
1053+
name: str,
1054+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1055+
return impl.elementwise.eq(
1056+
network,
1057+
target,
1058+
SourceIR.ATEN,
1059+
name,
1060+
args[0],
1061+
args[1],
1062+
)
1063+
1064+
1065+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1066+
def aten_ops_greater(
1067+
network: TRTNetwork,
1068+
target: Target,
1069+
args: Tuple[Argument, ...],
1070+
kwargs: Dict[str, Argument],
1071+
name: str,
1072+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1073+
return impl.elementwise.gt(
1074+
network,
1075+
target,
1076+
SourceIR.ATEN,
1077+
name,
1078+
args[0],
1079+
args[1],
1080+
)
1081+
1082+
1083+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1084+
def aten_ops_less(
1085+
network: TRTNetwork,
1086+
target: Target,
1087+
args: Tuple[Argument, ...],
1088+
kwargs: Dict[str, Argument],
1089+
name: str,
1090+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1091+
return impl.elementwise.lt(
1092+
network,
1093+
target,
1094+
SourceIR.ATEN,
1095+
name,
1096+
args[0],
1097+
args[1],
1098+
)

0 commit comments

Comments
 (0)