Skip to content

Commit da66673

Browse files
committed
feat: support many elementwise dynamo converters
add output_dtypes in test add util func and fix bugs
1 parent 178f4f3 commit da66673

18 files changed

+1014
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,259 @@ def aten_ops_isinf(
845845
)
846846

847847

848+
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
849+
def aten_ops_add(
850+
network: TRTNetwork,
851+
target: Target,
852+
args: Tuple[Argument, ...],
853+
kwargs: Dict[str, Argument],
854+
name: str,
855+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
856+
return impl.elementwise.add(
857+
network,
858+
target,
859+
SourceIR.ATEN,
860+
name,
861+
args[0],
862+
args[1],
863+
)
864+
865+
866+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
867+
def aten_ops_mul(
868+
network: TRTNetwork,
869+
target: Target,
870+
args: Tuple[Argument, ...],
871+
kwargs: Dict[str, Argument],
872+
name: str,
873+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
874+
return impl.elementwise.mul(
875+
network,
876+
target,
877+
SourceIR.ATEN,
878+
name,
879+
args[0],
880+
args[1],
881+
)
882+
883+
884+
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
885+
def aten_ops_max(
886+
network: TRTNetwork,
887+
target: Target,
888+
args: Tuple[Argument, ...],
889+
kwargs: Dict[str, Argument],
890+
name: str,
891+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
892+
return impl.elementwise.max(
893+
network,
894+
target,
895+
SourceIR.ATEN,
896+
name,
897+
args[0],
898+
args[1],
899+
)
900+
901+
902+
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
903+
def aten_ops_min(
904+
network: TRTNetwork,
905+
target: Target,
906+
args: Tuple[Argument, ...],
907+
kwargs: Dict[str, Argument],
908+
name: str,
909+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
910+
return impl.elementwise.min(
911+
network,
912+
target,
913+
SourceIR.ATEN,
914+
name,
915+
args[0],
916+
args[1],
917+
)
918+
919+
920+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
921+
def aten_ops_sub(
922+
network: TRTNetwork,
923+
target: Target,
924+
args: Tuple[Argument, ...],
925+
kwargs: Dict[str, Argument],
926+
name: str,
927+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
928+
return impl.elementwise.sub(
929+
network,
930+
target,
931+
SourceIR.ATEN,
932+
name,
933+
args[0],
934+
args[1],
935+
)
936+
937+
938+
# TODO: keep this or line 54...?
939+
# @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
940+
# def aten_ops_div(
941+
# network: TRTNetwork,
942+
# target: Target,
943+
# args: Tuple[Argument, ...],
944+
# kwargs: Dict[str, Argument],
945+
# name: str,
946+
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
947+
# return impl.elementwise.div(
948+
# network,
949+
# target,
950+
# SourceIR.ATEN,
951+
# name,
952+
# args[0],
953+
# args[1],
954+
# )
955+
956+
957+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
958+
def aten_ops_pow(
959+
network: TRTNetwork,
960+
target: Target,
961+
args: Tuple[Argument, ...],
962+
kwargs: Dict[str, Argument],
963+
name: str,
964+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
965+
return impl.elementwise.pow(
966+
network,
967+
target,
968+
SourceIR.ATEN,
969+
name,
970+
args[0],
971+
args[1],
972+
)
973+
974+
975+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
976+
def aten_ops_floor_div(
977+
network: TRTNetwork,
978+
target: Target,
979+
args: Tuple[Argument, ...],
980+
kwargs: Dict[str, Argument],
981+
name: str,
982+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
983+
return impl.elementwise.floor_divide(
984+
network,
985+
target,
986+
SourceIR.ATEN,
987+
name,
988+
args[0],
989+
args[1],
990+
)
991+
992+
993+
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
994+
def aten_ops_logical_and(
995+
network: TRTNetwork,
996+
target: Target,
997+
args: Tuple[Argument, ...],
998+
kwargs: Dict[str, Argument],
999+
name: str,
1000+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1001+
return impl.elementwise.logical_and(
1002+
network,
1003+
target,
1004+
SourceIR.ATEN,
1005+
name,
1006+
args[0],
1007+
args[1],
1008+
)
1009+
1010+
1011+
@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
1012+
def aten_ops_logical_or(
1013+
network: TRTNetwork,
1014+
target: Target,
1015+
args: Tuple[Argument, ...],
1016+
kwargs: Dict[str, Argument],
1017+
name: str,
1018+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1019+
return impl.elementwise.logical_or(
1020+
network,
1021+
target,
1022+
SourceIR.ATEN,
1023+
name,
1024+
args[0],
1025+
args[1],
1026+
)
1027+
1028+
1029+
@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
1030+
def aten_ops_logical_xor(
1031+
network: TRTNetwork,
1032+
target: Target,
1033+
args: Tuple[Argument, ...],
1034+
kwargs: Dict[str, Argument],
1035+
name: str,
1036+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1037+
return impl.elementwise.logical_xor(
1038+
network,
1039+
target,
1040+
SourceIR.ATEN,
1041+
name,
1042+
args[0],
1043+
args[1],
1044+
)
1045+
1046+
1047+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1048+
def aten_ops_equal(
1049+
network: TRTNetwork,
1050+
target: Target,
1051+
args: Tuple[Argument, ...],
1052+
kwargs: Dict[str, Argument],
1053+
name: str,
1054+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1055+
return impl.elementwise.eq(
1056+
network,
1057+
target,
1058+
SourceIR.ATEN,
1059+
name,
1060+
args[0],
1061+
args[1],
1062+
)
1063+
1064+
1065+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1066+
def aten_ops_greater(
1067+
network: TRTNetwork,
1068+
target: Target,
1069+
args: Tuple[Argument, ...],
1070+
kwargs: Dict[str, Argument],
1071+
name: str,
1072+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1073+
return impl.elementwise.gt(
1074+
network,
1075+
target,
1076+
SourceIR.ATEN,
1077+
name,
1078+
args[0],
1079+
args[1],
1080+
)
1081+
1082+
1083+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1084+
def aten_ops_less(
1085+
network: TRTNetwork,
1086+
target: Target,
1087+
args: Tuple[Argument, ...],
1088+
kwargs: Dict[str, Argument],
1089+
name: str,
1090+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1091+
return impl.elementwise.lt(
1092+
network,
1093+
target,
1094+
SourceIR.ATEN,
1095+
name,
1096+
args[0],
1097+
args[1],
1098+
)
1099+
1100+
8481101
def conv_param_validator(conv_node: Node) -> bool:
8491102
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))
8501103

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,17 @@ def extend_attr_to_tuple(
188188
if isinstance(val, list):
189189
val = tuple(val)
190190
return val
191+
192+
193+
def trt_cast_int_to_float(network: TRTNetwork, name: str, tensor: TRTTensor):
194+
if tensor.dtype == trt.int8 or tensor.dtype == trt.int32:
195+
return cast_trt_tensor(network, tensor, trt.float32, name)
196+
197+
return tensor
198+
199+
200+
def trt_cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor):
201+
if tensor.dtype != trt.bool:
202+
return cast_trt_tensor(network, tensor, trt.bool, name)
203+
204+
return tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, Callable, Optional, Union
44

5+
import numpy as np
56
import tensorrt as trt
67
import torch
78
from torch.fx.node import Target
@@ -24,12 +25,30 @@ def get_python_op_from_trt_elementwise_op(
2425
return operator.add
2526
elif trt_op == trt.ElementWiseOperation.PROD:
2627
return operator.mul
28+
elif trt_op == trt.ElementWiseOperation.MAX:
29+
return lambda a, b: max(a, b)
30+
elif trt_op == trt.ElementWiseOperation.MIN:
31+
return lambda a, b: min(a, b)
2732
elif trt_op == trt.ElementWiseOperation.SUB:
2833
return operator.sub
2934
elif trt_op == trt.ElementWiseOperation.DIV:
3035
return operator.truediv
36+
elif trt_op == trt.ElementWiseOperation.POW:
37+
return operator.pow
3138
elif trt_op == trt.ElementWiseOperation.FLOOR_DIV:
3239
return operator.floordiv
40+
elif trt_op == trt.ElementWiseOperation.AND:
41+
return lambda a, b: a and b
42+
elif trt_op == trt.ElementWiseOperation.OR:
43+
return lambda a, b: a or b
44+
elif trt_op == trt.ElementWiseOperation.XOR:
45+
return lambda a, b: (a or b) and not (a and b)
46+
elif trt_op == trt.ElementWiseOperation.EQUAL:
47+
return operator.eq
48+
elif trt_op == trt.ElementWiseOperation.GREATER:
49+
return operator.gt
50+
elif trt_op == trt.ElementWiseOperation.LESS:
51+
return operator.lt
3352
else:
3453
raise RuntimeError(f"{trt_op} is not supported yet!")
3554

@@ -75,10 +94,10 @@ def convert_binary_elementwise(
7594
is_rhs_trt_tensor = False
7695

7796
if isinstance(lhs_val, TRTTensor):
78-
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH)
97+
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY)
7998
is_lhs_trt_tensor = True
8099
if isinstance(rhs_val, TRTTensor):
81-
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH)
100+
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY)
82101
is_rhs_trt_tensor = True
83102

84103
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
@@ -103,9 +122,9 @@ def convert_binary_elementwise(
103122
# dtype but we don't have a way to detect whether it makes sense for the
104123
# scalar to be float or half. Hence we go with the lhs dtype.
105124
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
106-
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
125+
rhs_val = np.array([rhs_val], dtype=lhs_dtype)
107126
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
108-
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)
127+
lhs_val = np.array([lhs_val], dtype=rhs_dtype)
109128

110129
# When lhs is scalar, and rhs has shape [1,], then currently the assert
111130
# will fail because lhs shape has fewer dimensions than rhs shape. This

0 commit comments

Comments
 (0)