Skip to content

Commit f5b7b31

Browse files
chohk88peri044
authored andcommitted
feat: support aten.isnan converter (#2711)
1 parent a9a6272 commit f5b7b31

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,23 @@ def aten_ops_isinf(
15331533
)
15341534

15351535

1536+
@dynamo_tensorrt_converter(torch.ops.aten.isnan.default)
1537+
def aten_ops_isnan(
1538+
ctx: ConversionContext,
1539+
target: Target,
1540+
args: Tuple[Argument, ...],
1541+
kwargs: Dict[str, Argument],
1542+
name: str,
1543+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1544+
return impl.unary.isnan(
1545+
ctx,
1546+
target,
1547+
SourceIR.ATEN,
1548+
name,
1549+
args[0],
1550+
)
1551+
1552+
15361553
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
15371554
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
15381555
def aten_ops_add(

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,3 +508,23 @@ def scalar_tensor(
508508
identity_layer = ctx.net.add_identity(tensor)
509509
set_layer_name(identity_layer, target, name, source_ir)
510510
return identity_layer.get_output(0)
511+
512+
513+
def isnan(
514+
ctx: ConversionContext,
515+
target: Target,
516+
source_ir: Optional[SourceIR],
517+
name: str,
518+
input: TRTTensor,
519+
) -> TRTTensor:
520+
# False for NaN elements since NaN is not equal to anything, including itself.
521+
equality_result = impl.elementwise.eq(
522+
ctx, target, source_ir, f"{name}_eq_nan", input, input
523+
)
524+
525+
# Invert equality_result to get a mask where NaN values are marked as True.
526+
nan_values_mask = logical_not(
527+
ctx, target, source_ir, f"{name}_logical_not", equality_result
528+
)
529+
530+
return nan_values_mask
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestIsNanConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(
13+
torch.tensor(
14+
[
15+
1.23,
16+
float("nan"),
17+
-4.56,
18+
float("inf"),
19+
float("-inf"),
20+
-100.0,
21+
float("nan"),
22+
0.13,
23+
-0.13,
24+
3.14159265,
25+
]
26+
),
27+
),
28+
]
29+
)
30+
def test_isnan_float(self, data):
31+
class isnan(nn.Module):
32+
def forward(self, input):
33+
return torch.ops.aten.isnan.default(input)
34+
35+
inputs = [data]
36+
self.run_test(
37+
isnan(),
38+
inputs,
39+
output_dtypes=[torch.bool],
40+
)
41+
42+
@parameterized.expand(
43+
[
44+
(torch.full((2, 2), float("nan"), dtype=torch.float32),),
45+
(torch.full((3, 10, 5), float("nan"), dtype=torch.float32),),
46+
(torch.randn((5, 10, 5), dtype=torch.float32),),
47+
]
48+
)
49+
def test_isnan_dim(self, data):
50+
class isnan(nn.Module):
51+
def forward(self, input):
52+
return torch.ops.aten.isnan.default(input)
53+
54+
inputs = [data]
55+
self.run_test(
56+
isnan(),
57+
inputs,
58+
output_dtypes=[torch.bool],
59+
)
60+
61+
@parameterized.expand(
62+
[
63+
((10,), torch.int, 0, 5),
64+
((1, 20), torch.int32, -10, 10),
65+
((2, 3, 4), torch.int, -5, 5),
66+
]
67+
)
68+
def test_isnan_int(self, input_shape, dtype, low, high):
69+
class isnan(nn.Module):
70+
def forward(self, input):
71+
return torch.ops.aten.isnan.default(input)
72+
73+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
74+
self.run_test(
75+
isnan(),
76+
inputs,
77+
output_dtypes=[torch.bool],
78+
)
79+
80+
81+
if __name__ == "__main__":
82+
run_tests()

0 commit comments

Comments
 (0)