Skip to content

Commit a4075ca

Browse files
committed
add overloads, update tests, and fix a bug
1 parent da66673 commit a4075ca

16 files changed

+537
-149
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 81 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
import logging
22
from typing import Any, Dict, Optional, Sequence, Tuple, Union
33

4-
import tensorrt as trt
54
import torch
65
from torch.fx.node import Argument, Node, Target
76
from torch_tensorrt.dynamo._SourceIR import SourceIR
87
from torch_tensorrt.dynamo.conversion import impl
9-
from torch_tensorrt.dynamo.conversion.converter_utils import (
10-
cast_int_int_div_trt_tensor,
11-
cast_trt_tensor,
12-
)
13-
from torch_tensorrt.fx.converters import acc_ops_converters
148
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
159

1610
from .converter_registry import dynamo_tensorrt_converter
@@ -48,58 +42,6 @@ def aten_ops_batch_norm(
4842
)
4943

5044

51-
@dynamo_tensorrt_converter(torch.ops.aten.div.default) # type: ignore[misc]
52-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
53-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc]
54-
def aten_ops_div(
55-
network: TRTNetwork,
56-
target: Target,
57-
args: Tuple[Argument, ...],
58-
kwargs: Dict[str, Argument],
59-
name: str,
60-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
61-
kwargs_new = {
62-
"input": args[0],
63-
"other": args[1],
64-
}
65-
# If both are TRTTensor, both are cast to float32
66-
if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor):
67-
kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor(
68-
network,
69-
kwargs_new["input"],
70-
kwargs_new["other"],
71-
name,
72-
)
73-
# If one is TRTTensor, it is cast to float32
74-
elif isinstance(args[0], TRTTensor) and (
75-
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
76-
):
77-
kwargs_new["input"] = cast_trt_tensor(
78-
network, kwargs_new["input"], trt.float32, name, target
79-
)
80-
elif isinstance(args[1], TRTTensor) and (
81-
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
82-
):
83-
kwargs_new["other"] = cast_trt_tensor(
84-
network, kwargs_new["other"], trt.float32, name, target
85-
)
86-
rounding_mode = kwargs.get("rounding_mode")
87-
if rounding_mode is None:
88-
return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name)
89-
elif rounding_mode == "floor":
90-
return acc_ops_converters.acc_ops_floor_div(
91-
network, target, None, kwargs_new, name
92-
)
93-
elif rounding_mode == "trunc":
94-
return impl.elementwise.trunc_div(
95-
network, target, SourceIR.ATEN, name, args[0], args[1]
96-
)
97-
else:
98-
raise RuntimeError(
99-
f"Target {target} does not support rounding mode {rounding_mode}"
100-
)
101-
102-
10345
def embedding_param_validator(embedding_node: Node) -> bool:
10446
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
10547
sparse = args_bounds_check(embedding_node.args, 4)
@@ -846,24 +788,39 @@ def aten_ops_isinf(
846788

847789

848790
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
791+
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
849792
def aten_ops_add(
850793
network: TRTNetwork,
851794
target: Target,
852795
args: Tuple[Argument, ...],
853796
kwargs: Dict[str, Argument],
854797
name: str,
855798
) -> Union[TRTTensor, Sequence[TRTTensor]]:
799+
other = args[1]
800+
alpha = kwargs.get("alpha", 1)
801+
802+
if alpha != 1:
803+
other = impl.elementwise.mul(
804+
network,
805+
target,
806+
SourceIR.ATEN,
807+
name,
808+
other,
809+
alpha,
810+
)
811+
856812
return impl.elementwise.add(
857813
network,
858814
target,
859815
SourceIR.ATEN,
860816
name,
861817
args[0],
862-
args[1],
818+
other,
863819
)
864820

865821

866822
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
823+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar)
867824
def aten_ops_mul(
868825
network: TRTNetwork,
869826
target: Target,
@@ -918,43 +875,86 @@ def aten_ops_min(
918875

919876

920877
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
878+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar)
921879
def aten_ops_sub(
922880
network: TRTNetwork,
923881
target: Target,
924882
args: Tuple[Argument, ...],
925883
kwargs: Dict[str, Argument],
926884
name: str,
927885
) -> Union[TRTTensor, Sequence[TRTTensor]]:
886+
other = args[1]
887+
alpha = kwargs.get("alpha", 1)
888+
889+
if alpha != 1:
890+
other = impl.elementwise.mul(
891+
network,
892+
target,
893+
SourceIR.ATEN,
894+
name,
895+
other,
896+
alpha,
897+
)
898+
928899
return impl.elementwise.sub(
929900
network,
930901
target,
931902
SourceIR.ATEN,
932903
name,
933904
args[0],
934-
args[1],
905+
other,
935906
)
936907

937908

938-
# TODO: keep this or line 54...?
939-
# @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
940-
# def aten_ops_div(
941-
# network: TRTNetwork,
942-
# target: Target,
943-
# args: Tuple[Argument, ...],
944-
# kwargs: Dict[str, Argument],
945-
# name: str,
946-
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
947-
# return impl.elementwise.div(
948-
# network,
949-
# target,
950-
# SourceIR.ATEN,
951-
# name,
952-
# args[0],
953-
# args[1],
954-
# )
909+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
910+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
911+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar)
912+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode)
913+
def aten_ops_div(
914+
network: TRTNetwork,
915+
target: Target,
916+
args: Tuple[Argument, ...],
917+
kwargs: Dict[str, Argument],
918+
name: str,
919+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
920+
rounding_mode = kwargs.get("rounding_mode")
921+
922+
if rounding_mode is None:
923+
return impl.elementwise.div(
924+
network,
925+
target,
926+
SourceIR.ATEN,
927+
name,
928+
args[0],
929+
args[1],
930+
)
931+
elif rounding_mode == "floor":
932+
return impl.elementwise.floor_divide(
933+
network,
934+
target,
935+
SourceIR.ATEN,
936+
name,
937+
args[0],
938+
args[1],
939+
)
940+
elif rounding_mode == "trunc":
941+
return impl.elementwise.trunc_div(
942+
network,
943+
target,
944+
SourceIR.ATEN,
945+
name,
946+
args[0],
947+
args[1],
948+
)
949+
else:
950+
raise RuntimeError(
951+
f"Target {target} does not support rounding mode {rounding_mode}"
952+
)
955953

956954

957955
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
956+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar)
957+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
958958
def aten_ops_pow(
959959
network: TRTNetwork,
960960
target: Target,
@@ -973,6 +973,7 @@ def aten_ops_pow(
973973

974974

975975
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
976+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar)
976977
def aten_ops_floor_div(
977978
network: TRTNetwork,
978979
target: Target,
@@ -1045,6 +1046,7 @@ def aten_ops_logical_xor(
10451046

10461047

10471048
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1049+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
10481050
def aten_ops_equal(
10491051
network: TRTNetwork,
10501052
target: Target,
@@ -1063,6 +1065,7 @@ def aten_ops_equal(
10631065

10641066

10651067
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1068+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
10661069
def aten_ops_greater(
10671070
network: TRTNetwork,
10681071
target: Target,
@@ -1081,6 +1084,7 @@ def aten_ops_greater(
10811084

10821085

10831086
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1087+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
10841088
def aten_ops_less(
10851089
network: TRTNetwork,
10861090
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def cast_trt_tensor(
9797

9898
if input_val.dtype != trt_dtype:
9999
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
100-
target_str = ConverterRegistry.qualified_name_or_str(target)
100+
target_str = ConverterRegistry.qualified_name_or_str(ConverterRegistry, target)
101101
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"
102102

103103
identity_layer = network.add_identity(input_val)

tests/py/dynamo/conversion/test_add_aten.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,79 @@
11
import torch
22
import torch.nn as nn
3-
from .harness import DispatchTestCase
43
from parameterized import parameterized
54
from torch.testing._internal.common_utils import run_tests
65
from torch_tensorrt import Input
76

7+
from .harness import DispatchTestCase
8+
89

910
class TestAddConverter(DispatchTestCase):
1011
@parameterized.expand(
1112
[
12-
("2d_dim_alpha", (2, 1), 2),
13-
("3d_dim_alpha", (2, 1, 2), 2),
13+
("2d", (2, 1)),
14+
("3d", (2, 1, 2)),
15+
]
16+
)
17+
def test_add_tensor(self, _, shape):
18+
class add(nn.Module):
19+
def forward(self, lhs_val, rhs_val):
20+
return torch.add(lhs_val, rhs_val)
21+
22+
inputs = [torch.randn(shape), torch.randn(shape)]
23+
self.run_test(
24+
add(),
25+
inputs,
26+
expected_ops={torch.ops.aten.add.Tensor},
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
("2d", (2, 1), 1),
32+
("3d", (2, 1, 2), 2.0),
1433
]
1534
)
16-
def test_add(self, _, x, alpha):
35+
def test_add_tensor_alpha(self, _, shape, alpha):
1736
class add(nn.Module):
1837
def forward(self, lhs_val, rhs_val):
19-
return lhs_val + rhs_val
38+
return torch.add(lhs_val, rhs_val, alpha=alpha)
39+
40+
inputs = [torch.randn(shape), torch.randn(shape)]
41+
self.run_test(
42+
add(),
43+
inputs,
44+
expected_ops={torch.ops.aten.add.Tensor},
45+
)
46+
47+
@parameterized.expand(
48+
[
49+
("2d", (2, 1), 1.0),
50+
("3d", (2, 1, 2), 2),
51+
]
52+
)
53+
def test_add_scalar(self, _, shape, scalar):
54+
class add(nn.Module):
55+
def forward(self, lhs_val):
56+
return torch.add(lhs_val, scalar)
57+
58+
inputs = [torch.randn(shape)]
59+
self.run_test(
60+
add(),
61+
inputs,
62+
expected_ops={torch.ops.aten.add.Tensor},
63+
)
64+
65+
@parameterized.expand(
66+
[
67+
("2d", (2, 1), 1.0, 1.0),
68+
("3d", (2, 1, 2), 2, 2),
69+
]
70+
)
71+
def test_add_scalar_alpha(self, _, shape, scalar, alpha):
72+
class add(nn.Module):
73+
def forward(self, lhs_val):
74+
return torch.add(lhs_val, scalar, alpha=alpha)
2075

21-
inputs = [torch.randn(x) + 1, torch.randn(x) + 1]
76+
inputs = [torch.randn(shape)]
2277
self.run_test(
2378
add(),
2479
inputs,

0 commit comments

Comments
 (0)