Skip to content

Commit eb480ee

Browse files
committed
feat: support ne, ge, and le converters
1 parent 59a4910 commit eb480ee

File tree

8 files changed

+374
-57
lines changed

8 files changed

+374
-57
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,9 +1647,9 @@ def aten_ops_logical_xor(
16471647
)
16481648

16491649

1650-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1651-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
1652-
def aten_ops_equal(
1650+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
1651+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
1652+
def aten_ops_eq(
16531653
ctx: ConversionContext,
16541654
target: Target,
16551655
args: Tuple[Argument, ...],
@@ -1666,9 +1666,28 @@ def aten_ops_equal(
16661666
)
16671667

16681668

1669-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1670-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
1671-
def aten_ops_greater(
1669+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) # type: ignore[misc]
1670+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) # type: ignore[misc]
1671+
def aten_ops_ne(
1672+
ctx: ConversionContext,
1673+
target: Target,
1674+
args: Tuple[Argument, ...],
1675+
kwargs: Dict[str, Argument],
1676+
name: str,
1677+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1678+
return impl.elementwise.ne(
1679+
ctx,
1680+
target,
1681+
SourceIR.ATEN,
1682+
name,
1683+
args[0],
1684+
args[1],
1685+
)
1686+
1687+
1688+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc]
1689+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc]
1690+
def aten_ops_gt(
16721691
ctx: ConversionContext,
16731692
target: Target,
16741693
args: Tuple[Argument, ...],
@@ -1685,9 +1704,28 @@ def aten_ops_greater(
16851704
)
16861705

16871706

1688-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1689-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
1690-
def aten_ops_less(
1707+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) # type: ignore[misc]
1708+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) # type: ignore[misc]
1709+
def aten_ops_ge(
1710+
ctx: ConversionContext,
1711+
target: Target,
1712+
args: Tuple[Argument, ...],
1713+
kwargs: Dict[str, Argument],
1714+
name: str,
1715+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1716+
return impl.elementwise.ge(
1717+
ctx,
1718+
target,
1719+
SourceIR.ATEN,
1720+
name,
1721+
args[0],
1722+
args[1],
1723+
)
1724+
1725+
1726+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc]
1727+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc]
1728+
def aten_ops_lt(
16911729
ctx: ConversionContext,
16921730
target: Target,
16931731
args: Tuple[Argument, ...],
@@ -1704,6 +1742,25 @@ def aten_ops_less(
17041742
)
17051743

17061744

1745+
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) # type: ignore[misc]
1746+
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) # type: ignore[misc]
1747+
def aten_ops_le(
1748+
ctx: ConversionContext,
1749+
target: Target,
1750+
args: Tuple[Argument, ...],
1751+
kwargs: Dict[str, Argument],
1752+
name: str,
1753+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1754+
return impl.elementwise.le(
1755+
ctx,
1756+
target,
1757+
SourceIR.ATEN,
1758+
name,
1759+
args[0],
1760+
args[1],
1761+
)
1762+
1763+
17071764
def conv_param_validator(conv_node: Node) -> bool:
17081765
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
17091766

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import tensorrt as trt
5+
import torch_tensorrt.dynamo.conversion.impl as impl
56
from torch.fx.node import Target
67
from torch_tensorrt.dynamo._SourceIR import SourceIR
78
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -441,6 +442,23 @@ def eq(
441442
)
442443

443444

445+
def ne(
446+
ctx: ConversionContext,
447+
target: Target,
448+
source_ir: Optional[SourceIR],
449+
name: str,
450+
lhs_val: Union[TRTTensor, int, float],
451+
rhs_val: Union[TRTTensor, int, float],
452+
) -> TRTTensor:
453+
return impl.unary.logical_not(
454+
ctx,
455+
target,
456+
source_ir,
457+
f"{name}_logical_not",
458+
eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val),
459+
)
460+
461+
444462
def gt(
445463
ctx: ConversionContext,
446464
target: Target,
@@ -460,6 +478,24 @@ def gt(
460478
)
461479

462480

481+
def ge(
482+
ctx: ConversionContext,
483+
target: Target,
484+
source_ir: Optional[SourceIR],
485+
name: str,
486+
lhs_val: Union[TRTTensor, int, float],
487+
rhs_val: Union[TRTTensor, int, float],
488+
) -> TRTTensor:
489+
return logical_or(
490+
ctx,
491+
target,
492+
source_ir,
493+
name,
494+
gt(ctx, target, source_ir, f"{name}_gt", lhs_val, rhs_val),
495+
eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val),
496+
)
497+
498+
463499
def lt(
464500
ctx: ConversionContext,
465501
target: Target,
@@ -477,3 +513,21 @@ def lt(
477513
lhs_val,
478514
rhs_val,
479515
)
516+
517+
518+
def le(
519+
ctx: ConversionContext,
520+
target: Target,
521+
source_ir: Optional[SourceIR],
522+
name: str,
523+
lhs_val: Union[TRTTensor, int, float],
524+
rhs_val: Union[TRTTensor, int, float],
525+
) -> TRTTensor:
526+
return logical_or(
527+
ctx,
528+
target,
529+
source_ir,
530+
name,
531+
lt(ctx, target, source_ir, f"{name}_lt", lhs_val, rhs_val),
532+
eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val),
533+
)

tests/py/dynamo/conversion/test_equal_aten.py renamed to tests/py/dynamo/conversion/test_eq_aten.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,65 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5-
from torch_tensorrt import Input
65

76
from .harness import DispatchTestCase
87

98

109
class TestEqualConverter(DispatchTestCase):
1110
@parameterized.expand(
1211
[
13-
("2d", (2, 1)),
14-
("3d", (2, 1, 2)),
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
1514
]
1615
)
17-
def test_equal_tensor(self, _, shape):
18-
class equal(nn.Module):
16+
def test_eq_tensor(self, _, shape):
17+
class eq(nn.Module):
1918
def forward(self, lhs_val, rhs_val):
2019
return torch.ops.aten.eq.Tensor(lhs_val, rhs_val)
2120

22-
inputs = [torch.randn(shape), torch.randn(shape)]
21+
inputs = [
22+
torch.randint(0, 3, shape, dtype=torch.int32),
23+
torch.randint(0, 3, shape, dtype=torch.int32),
24+
]
2325
self.run_test(
24-
equal(),
26+
eq(),
2527
inputs,
2628
output_dtypes=[torch.bool],
2729
)
2830

2931
@parameterized.expand(
3032
[
31-
("2d", (2, 1), 1),
32-
("3d", (2, 1, 2), 2.0),
33+
("2d", (5, 3), 1),
34+
("3d", (5, 3, 2), 2.0),
3335
]
3436
)
35-
def test_equal_tensor_scalar(self, _, shape, scalar):
36-
class equal(nn.Module):
37+
def test_eq_tensor_scalar(self, _, shape, scalar):
38+
class eq(nn.Module):
3739
def forward(self, lhs_val):
3840
return torch.ops.aten.eq.Tensor(lhs_val, torch.tensor(scalar))
3941

40-
inputs = [torch.randn(shape)]
42+
inputs = [torch.randint(0, 3, shape, dtype=torch.int32)]
4143
self.run_test(
42-
equal(),
44+
eq(),
4345
inputs,
4446
output_dtypes=[torch.bool],
4547
)
4648

4749
@parameterized.expand(
4850
[
49-
("2d", (2, 1), 1),
50-
("3d", (2, 1, 2), 2.0),
51+
("2d", (5, 3), 1),
52+
("3d", (5, 3, 2), 2.0),
5153
]
5254
)
53-
def test_equal_scalar(self, _, shape, scalar):
54-
class equal(nn.Module):
55+
def test_eq_scalar(self, _, shape, scalar):
56+
class eq(nn.Module):
5557
def forward(self, lhs_val):
5658
return torch.ops.aten.eq.Scalar(lhs_val, scalar)
5759

58-
inputs = [torch.randn(shape)]
60+
inputs = [torch.randint(0, 3, shape, dtype=torch.int32)]
5961
self.run_test(
60-
equal(),
62+
eq(),
6163
inputs,
62-
# expected_ops={torch.ops.aten.eq.Scalar},
6364
output_dtypes=[torch.bool],
6465
)
6566

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestGtConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_ge_tensor(self, _, shape):
17+
class ge(nn.Module):
18+
def forward(self, lhs_val, rhs_val):
19+
return torch.ops.aten.ge.Tensor(lhs_val, rhs_val)
20+
21+
inputs = [
22+
torch.randint(0, 3, shape, dtype=torch.int32),
23+
torch.randint(0, 3, shape, dtype=torch.int32),
24+
]
25+
self.run_test(
26+
ge(),
27+
inputs,
28+
output_dtypes=[torch.bool],
29+
)
30+
31+
@parameterized.expand(
32+
[
33+
("2d", (5, 3), 1),
34+
("3d", (5, 3, 2), 2.0),
35+
]
36+
)
37+
def test_ge_tensor_scalar(self, _, shape, scalar):
38+
class ge(nn.Module):
39+
def forward(self, lhs_val):
40+
return torch.ops.aten.ge.Tensor(lhs_val, torch.tensor(scalar))
41+
42+
inputs = [torch.randint(0, 3, shape, dtype=torch.int32)]
43+
self.run_test(
44+
ge(),
45+
inputs,
46+
output_dtypes=[torch.bool],
47+
)
48+
49+
@parameterized.expand(
50+
[
51+
("2d", (5, 3), 1),
52+
("3d", (5, 3, 2), 2.0),
53+
]
54+
)
55+
def test_ge_scalar(self, _, shape, scalar):
56+
class ge(nn.Module):
57+
def forward(self, lhs_val):
58+
return torch.ops.aten.ge.Scalar(lhs_val, scalar)
59+
60+
inputs = [torch.randint(0, 3, shape, dtype=torch.int32)]
61+
self.run_test(
62+
ge(),
63+
inputs,
64+
output_dtypes=[torch.bool],
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
run_tests()

0 commit comments

Comments
 (0)