Skip to content

Commit f80fd27

Browse files
committed
feat: support aten.any related converters in dynamo
1 parent 4b608f0 commit f80fd27

File tree

3 files changed

+244
-0
lines changed

3 files changed

+244
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2605,6 +2605,26 @@ def aten_ops_remainder(
26052605
)
26062606

26072607

2608+
@dynamo_tensorrt_converter(torch.ops.aten.any.default)
2609+
@dynamo_tensorrt_converter(torch.ops.aten.any.dim)
2610+
@dynamo_tensorrt_converter(torch.ops.aten.any.dims)
2611+
def aten_ops_any(
2612+
ctx: ConversionContext,
2613+
target: Target,
2614+
args: Tuple[Argument, ...],
2615+
kwargs: Dict[str, Argument],
2616+
name: str,
2617+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2618+
return impl.reduce.any(
2619+
ctx,
2620+
target,
2621+
SourceIR.ATEN,
2622+
name,
2623+
args[0],
2624+
args_bounds_check(args, 1, replacement=None),
2625+
args_bounds_check(args, 2, replacement=False),
2626+
)
2627+
26082628
@dynamo_tensorrt_converter(torch.ops.aten._pdist_forward.default)
26092629
@enforce_tensor_types(
26102630
{

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tensorrt as trt
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion import impl
67
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
78
from torch_tensorrt.dynamo.conversion.converter_utils import (
89
cast_trt_tensor,
@@ -208,3 +209,32 @@ def mean(
208209
)
209210
set_layer_name(layer, target, name, source_ir)
210211
return layer.get_output(0)
212+
213+
214+
def any(
215+
ctx: ConversionContext,
216+
target: Target,
217+
source_ir: Optional[SourceIR],
218+
name: str,
219+
input_val: TRTTensor,
220+
dim: Union[int, Optional[Sequence[int]]] = None,
221+
keepdim: bool = False,
222+
) -> TRTTensor:
223+
if (isinstance(input_val, TRTTensor)) and (input_val.dtype == trt.bool):
224+
input_val = cast_trt_tensor(ctx, input_val, trt.int32, f"{name}_cast")
225+
226+
abs_out = impl.unary.abs(
227+
ctx,
228+
target,
229+
source_ir,
230+
f"{name}_abs",
231+
input_val,
232+
)
233+
if dim is None:
234+
dim = []
235+
elif isinstance(dim, int):
236+
dim = [dim]
237+
238+
max_out = amax(ctx, target, source_ir, f"{name}_amax", abs_out, dim, keepdim)
239+
240+
return cast_trt_tensor(ctx, max_out, trt.bool, f"{name}_cast_to_bool")
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestAnyConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(
13+
"3d",
14+
(3, 2, 4),
15+
),
16+
(
17+
"4d",
18+
(2, 3, 4, 5),
19+
),
20+
("5d", (6, 7, 5, 4, 5)),
21+
]
22+
)
23+
def test_any_default_float_dtype(self, _, input_shape):
24+
class Any(nn.Module):
25+
def forward(self, x):
26+
return torch.ops.aten.any.default(x)
27+
28+
inputs = [torch.randn(*input_shape)]
29+
self.run_test(Any(), inputs, output_dtypes=[torch.bool])
30+
31+
@parameterized.expand(
32+
[
33+
((3, 2, 4), 1, True),
34+
((2, 3, 4, 5), 3, True),
35+
((2, 3, 4, 5), 2, False),
36+
((6, 7, 5, 4, 5), 4, False),
37+
((1, 5, 2, 1), -1, True),
38+
]
39+
)
40+
def test_any_dim_float_dtype(self, input_shape, dim, keep_dims):
41+
class AnyDim(nn.Module):
42+
def forward(self, x):
43+
return torch.ops.aten.any.dim(x, dim, keep_dims)
44+
45+
inputs = [torch.randn(*input_shape)]
46+
self.run_test(AnyDim(), inputs, output_dtypes=[torch.bool])
47+
48+
@parameterized.expand(
49+
[
50+
((3, 2, 4), [1], True),
51+
((2, 1, 4, 5), [0, 3], True),
52+
((2, 3, 4, 5), [0, 1, 2, 3], False),
53+
((6, 7, 5, 4, 5), [1, 3, 4], False),
54+
]
55+
)
56+
def test_any_dims_tuple_float_dtype(self, input_shape, dims, keep_dims):
57+
class AnyDims(nn.Module):
58+
def forward(self, x):
59+
return torch.ops.aten.any.dims(x, dims, keep_dims)
60+
61+
inputs = [torch.randn(*input_shape)]
62+
self.run_test(AnyDims(), inputs, output_dtypes=[torch.bool])
63+
64+
@parameterized.expand(
65+
[
66+
((3, 2, 4), torch.int, 0, 5),
67+
((2, 3, 4, 5), torch.int, -10, 10),
68+
((2, 3, 4, 5), torch.int32, -5, 0),
69+
((6, 7, 5, 4, 5), torch.int32, -5, 5),
70+
((1, 5, 2, 1), torch.int32, -5, 5),
71+
]
72+
)
73+
def test_any_default_int_dtype(self, input_shape, dtype, low, high):
74+
class Any(nn.Module):
75+
def forward(self, x):
76+
return torch.ops.aten.any.default(x)
77+
78+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
79+
self.run_test(
80+
Any(),
81+
inputs,
82+
output_dtypes=[torch.bool],
83+
)
84+
85+
@parameterized.expand(
86+
[
87+
((3, 2, 4), 1, True, torch.int, 0, 5),
88+
((2, 3, 4, 5), 3, True, torch.int, -10, 10),
89+
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
90+
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
91+
((1, 5, 2, 1), -4, False, torch.int32, -5, 5),
92+
]
93+
)
94+
def test_any_dim_int_dtype(self, input_shape, dim, keep_dims, dtype, low, high):
95+
class AnyDim(nn.Module):
96+
def forward(self, x):
97+
return torch.ops.aten.any.dim(x, dim, keep_dims)
98+
99+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
100+
self.run_test(
101+
AnyDim(),
102+
inputs,
103+
output_dtypes=[torch.bool],
104+
)
105+
106+
@parameterized.expand(
107+
[
108+
((3, 2, 4), [1], True, torch.int, 0, 5),
109+
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
110+
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
111+
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
112+
((1, 5, 2, 1), [-3, -1], False, torch.int32, -5, 5),
113+
]
114+
)
115+
def test_any_dims_tuple_int_dtype(
116+
self, input_shape, dims, keep_dims, dtype, low, high
117+
):
118+
class AnyDims(nn.Module):
119+
def forward(self, x):
120+
return torch.ops.aten.any.dims(x, dims, keep_dims)
121+
122+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
123+
self.run_test(
124+
AnyDims(),
125+
inputs,
126+
output_dtypes=[torch.bool],
127+
)
128+
129+
@parameterized.expand(
130+
[
131+
((2, 3, 4), torch.int, -5, 0),
132+
((6, 7, 5, 4, 5), torch.int, -5, 5),
133+
((1, 5, 2, 1), torch.int, -5, 5),
134+
]
135+
)
136+
def test_any_default_bool_dtype(self, input_shape, dtype, low, high):
137+
class Any(nn.Module):
138+
def forward(self, x):
139+
return torch.ops.aten.any.default(x)
140+
141+
inputs = [torch.randint(low, high, input_shape, dtype=dtype).bool()]
142+
self.run_test(
143+
Any(),
144+
inputs,
145+
output_dtypes=[torch.bool],
146+
)
147+
148+
@parameterized.expand(
149+
[
150+
((3, 2, 4), 1, True, torch.int, 0, 5),
151+
((2, 3, 4, 5), 3, True, torch.int, -10, 10),
152+
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
153+
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
154+
((1, 5, 2, 1), -4, False, torch.int32, -5, 5),
155+
]
156+
)
157+
def test_any_dim_bool_dtype(self, input_shape, dim, keep_dims, dtype, low, high):
158+
class AnyDim(nn.Module):
159+
def forward(self, x):
160+
return torch.ops.aten.any.dim(x, dim, keep_dims)
161+
162+
inputs = [torch.randint(low, high, input_shape, dtype=dtype).bool()]
163+
self.run_test(
164+
AnyDim(),
165+
inputs,
166+
output_dtypes=[torch.bool],
167+
)
168+
169+
@parameterized.expand(
170+
[
171+
((3, 2, 4), [1], True, torch.int, 0, 5),
172+
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
173+
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
174+
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
175+
((1, 5, 2, 1), [-3, -1], False, torch.int32, -5, 5),
176+
]
177+
)
178+
def test_any_dims_tuple_bool_dtype(
179+
self, input_shape, dims, keep_dims, dtype, low, high
180+
):
181+
class AnyDims(nn.Module):
182+
def forward(self, x):
183+
return torch.ops.aten.any.dims(x, dims, keep_dims)
184+
185+
inputs = [torch.randint(low, high, input_shape, dtype=dtype).bool()]
186+
self.run_test(
187+
AnyDims(),
188+
inputs,
189+
output_dtypes=[torch.bool],
190+
)
191+
192+
193+
if __name__ == "__main__":
194+
run_tests()

0 commit comments

Comments
 (0)