Skip to content

Commit 03f0ebb

Browse files
committed
feat: support ne, ge, and le converters
1 parent 7011809 commit 03f0ebb

File tree

8 files changed

+374
-57
lines changed

8 files changed

+374
-57
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,9 +1701,9 @@ def aten_ops_logical_xor(
17011701
)
17021702

17031703

1704-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1705-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
1706-
def aten_ops_equal(
1704+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
1705+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
1706+
def aten_ops_eq(
17071707
ctx: ConversionContext,
17081708
target: Target,
17091709
args: Tuple[Argument, ...],
@@ -1720,9 +1720,28 @@ def aten_ops_equal(
17201720
)
17211721

17221722

1723-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1724-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
1725-
def aten_ops_greater(
1723+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) # type: ignore[misc]
1724+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) # type: ignore[misc]
1725+
def aten_ops_ne(
1726+
ctx: ConversionContext,
1727+
target: Target,
1728+
args: Tuple[Argument, ...],
1729+
kwargs: Dict[str, Argument],
1730+
name: str,
1731+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1732+
return impl.elementwise.ne(
1733+
ctx,
1734+
target,
1735+
SourceIR.ATEN,
1736+
name,
1737+
args[0],
1738+
args[1],
1739+
)
1740+
1741+
1742+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc]
1743+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc]
1744+
def aten_ops_gt(
17261745
ctx: ConversionContext,
17271746
target: Target,
17281747
args: Tuple[Argument, ...],
@@ -1739,9 +1758,28 @@ def aten_ops_greater(
17391758
)
17401759

17411760

1742-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1743-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
1744-
def aten_ops_less(
1761+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) # type: ignore[misc]
1762+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) # type: ignore[misc]
1763+
def aten_ops_ge(
1764+
ctx: ConversionContext,
1765+
target: Target,
1766+
args: Tuple[Argument, ...],
1767+
kwargs: Dict[str, Argument],
1768+
name: str,
1769+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1770+
return impl.elementwise.ge(
1771+
ctx,
1772+
target,
1773+
SourceIR.ATEN,
1774+
name,
1775+
args[0],
1776+
args[1],
1777+
)
1778+
1779+
1780+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc]
1781+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc]
1782+
def aten_ops_lt(
17451783
ctx: ConversionContext,
17461784
target: Target,
17471785
args: Tuple[Argument, ...],
@@ -1758,6 +1796,25 @@ def aten_ops_less(
17581796
)
17591797

17601798

1799+
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) # type: ignore[misc]
1800+
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) # type: ignore[misc]
1801+
def aten_ops_le(
1802+
ctx: ConversionContext,
1803+
target: Target,
1804+
args: Tuple[Argument, ...],
1805+
kwargs: Dict[str, Argument],
1806+
name: str,
1807+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1808+
return impl.elementwise.le(
1809+
ctx,
1810+
target,
1811+
SourceIR.ATEN,
1812+
name,
1813+
args[0],
1814+
args[1],
1815+
)
1816+
1817+
17611818
def conv_param_validator(conv_node: Node) -> bool:
17621819
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
17631820

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import tensorrt as trt
5+
import torch_tensorrt.dynamo.conversion.impl as impl
56
from torch.fx.node import Target
67
from torch_tensorrt.dynamo._SourceIR import SourceIR
78
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -441,6 +442,23 @@ def eq(
441442
)
442443

443444

445+
def ne(
446+
ctx: ConversionContext,
447+
target: Target,
448+
source_ir: Optional[SourceIR],
449+
name: str,
450+
lhs_val: Union[TRTTensor, int, float],
451+
rhs_val: Union[TRTTensor, int, float],
452+
) -> TRTTensor:
453+
return impl.unary.logical_not(
454+
ctx,
455+
target,
456+
source_ir,
457+
f"{name}_logical_not",
458+
eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val),
459+
)
460+
461+
444462
def gt(
445463
ctx: ConversionContext,
446464
target: Target,
@@ -460,6 +478,24 @@ def gt(
460478
)
461479

462480

481+
def ge(
482+
ctx: ConversionContext,
483+
target: Target,
484+
source_ir: Optional[SourceIR],
485+
name: str,
486+
lhs_val: Union[TRTTensor, int, float],
487+
rhs_val: Union[TRTTensor, int, float],
488+
) -> TRTTensor:
489+
return logical_or(
490+
ctx,
491+
target,
492+
source_ir,
493+
name,
494+
gt(ctx, target, source_ir, f"{name}_gt", lhs_val, rhs_val),
495+
eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val),
496+
)
497+
498+
463499
def lt(
464500
ctx: ConversionContext,
465501
target: Target,
@@ -477,3 +513,21 @@ def lt(
477513
lhs_val,
478514
rhs_val,
479515
)
516+
517+
518+
def le(
519+
ctx: ConversionContext,
520+
target: Target,
521+
source_ir: Optional[SourceIR],
522+
name: str,
523+
lhs_val: Union[TRTTensor, int, float],
524+
rhs_val: Union[TRTTensor, int, float],
525+
) -> TRTTensor:
526+
return logical_or(
527+
ctx,
528+
target,
529+
source_ir,
530+
name,
531+
lt(ctx, target, source_ir, f"{name}_lt", lhs_val, rhs_val),
532+
eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val),
533+
)

tests/py/dynamo/conversion/test_equal_aten.py renamed to tests/py/dynamo/conversion/test_eq_aten.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,65 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5-
from torch_tensorrt import Input
65

76
from .harness import DispatchTestCase
87

98

109
class TestEqualConverter(DispatchTestCase):
1110
@parameterized.expand(
1211
[
13-
("2d", (2, 1)),
14-
("3d", (2, 1, 2)),
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
1514
]
1615
)
17-
def test_equal_tensor(self, _, shape):
18-
class equal(nn.Module):
16+
def test_eq_tensor(self, _, shape):
17+
class eq(nn.Module):
1918
def forward(self, lhs_val, rhs_val):
2019
return torch.ops.aten.eq.Tensor(lhs_val, rhs_val)
2120

22-
inputs = [torch.randn(shape), torch.randn(shape)]
21+
inputs = [
22+
torch.randint(0, 3, shape, dtype=torch.int32),
23+
torch.randint(0, 3, shape, dtype=torch.int32),
24+
]
2325
self.run_test(
24-
equal(),
26+
eq(),
2527
inputs,
2628
output_dtypes=[torch.bool],
2729
)
2830

2931
@parameterized.expand(
3032
[
31-
("2d", (2, 1), 1),
32-
("3d", (2, 1, 2), 2.0),
33+
("2d", (5, 3), 1),
34+
("3d", (5, 3, 2), 2.0),
3335
]
3436
)
35-
def test_equal_tensor_scalar(self, _, shape, scalar):
36-
class equal(nn.Module):
37+
def test_eq_tensor_scalar(self, _, shape, scalar):
38+
class eq(nn.Module):
3739
def forward(self, lhs_val):
3840
return torch.ops.aten.eq.Tensor(lhs_val, torch.tensor(scalar))
3941

40-
inputs = [torch.randn(shape)]
42+
inputs = [torch.randint(0, 3, shape, dtype=torch.int32)]
4143
self.run_test(
42-
equal(),
44+
eq(),
4345
inputs,
4446
output_dtypes=[torch.bool],
4547
)
4648

4749
@parameterized.expand(
4850
[
49-
("2d", (2, 1), 1),
50-
("3d", (2, 1, 2), 2.0),
51+
("2d", (5, 3), 1),
52+
("3d", (5, 3, 2), 2.0),
5153
]
5254
)
53-
def test_equal_scalar(self, _, shape, scalar):
54-
class equal(nn.Module):
55+
def test_eq_scalar(self, _, shape, scalar):
56+
class eq(nn.Module):
5557
def forward(self, lhs_val):
5658
return torch.ops.aten.eq.Scalar(lhs_val, scalar)
5759

58-
inputs = [torch.randn(shape)]
60+
inputs = [torch.randint(0, 3, shape, dtype=torch.int32)]
5961
self.run_test(
60-
equal(),
62+
eq(),
6163
inputs,
62-
# expected_ops={torch.ops.aten.eq.Scalar},
6364
output_dtypes=[torch.bool],
6465
)
6566

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestGtConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_ge_tensor(self, _, shape):
17+
class ge(nn.Module):
18+
def forward(self, lhs_val, rhs_val):
19+
return torch.ops.aten.ge.Tensor(lhs_val, rhs_val)
20+
21+
inputs = [
22+
torch.randint(0, 3, shape, dtype=torch.int32),
23+
torch.randint(0, 3, shape, dtype=torch.int32),
24+
]
25+
self.run_test(
26+
ge(),
27+
inputs,
28+
output_dtypes=[torch.bool],
29+
)
30+
31+
@parameterized.expand(
32+
[
33+
("2d", (5, 3), 1),
34+
("3d", (5, 3, 2), 2.0),
35+
]
36+
)
37+
def test_ge_tensor_scalar(self, _, shape, scalar):
38+
class ge(nn.Module):
39+
def forward(self, lhs_val):
40+
return torch.ops.aten.ge.Tensor(lhs_val, torch.tensor(scalar))
41+
42+
inputs = [torch.randint(0, 3, shape, dtype=torch.int32)]
43+
self.run_test(
44+
ge(),
45+
inputs,
46+
output_dtypes=[torch.bool],
47+
)
48+
49+
@parameterized.expand(
50+
[
51+
("2d", (5, 3), 1),
52+
("3d", (5, 3, 2), 2.0),
53+
]
54+
)
55+
def test_ge_scalar(self, _, shape, scalar):
56+
class ge(nn.Module):
57+
def forward(self, lhs_val):
58+
return torch.ops.aten.ge.Scalar(lhs_val, scalar)
59+
60+
inputs = [torch.randint(0, 3, shape, dtype=torch.int32)]
61+
self.run_test(
62+
ge(),
63+
inputs,
64+
output_dtypes=[torch.bool],
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
run_tests()

0 commit comments

Comments
 (0)