Skip to content

Commit 5c88118

Browse files
committed
feat: support more elementwise and unary converters
1 parent 52b89ed commit 5c88118

File tree

7 files changed

+292
-0
lines changed

7 files changed

+292
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,120 @@ def aten_ops_logical_xor(
16471647
)
16481648

16491649

1650+
def bitwise_type_validator(node: Node) -> bool:
1651+
targets = [
1652+
torch.ops.aten.bitwise_and.Tensor,
1653+
torch.ops.aten.bitwise_or.Tensor,
1654+
torch.ops.aten.bitwise_xor.Tensor,
1655+
]
1656+
if node.target not in targets:
1657+
return False
1658+
1659+
lhs_val = node.args[0]
1660+
rhs_val = node.args[1]
1661+
lhs_meta = lhs_val.meta.get("tensor_meta")
1662+
rhs_meta = rhs_val.meta.get("tensor_meta")
1663+
1664+
if lhs_meta is None or rhs_meta is None:
1665+
return False
1666+
1667+
supported_type = [torch.bool, bool]
1668+
return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type
1669+
1670+
1671+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
1672+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar) # type: ignore[misc]
1673+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar_Tensor) # type: ignore[misc]
1674+
def aten_ops_bitwise_and(
1675+
ctx: ConversionContext,
1676+
target: Target,
1677+
args: Tuple[Argument, ...],
1678+
kwargs: Dict[str, Argument],
1679+
name: str,
1680+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1681+
return impl.elementwise.bitwise_and(
1682+
ctx,
1683+
target,
1684+
SourceIR.ATEN,
1685+
name,
1686+
args[0],
1687+
args[1],
1688+
)
1689+
1690+
1691+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
1692+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar) # type: ignore[misc]
1693+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar_Tensor) # type: ignore[misc]
1694+
def aten_ops_bitwise_or(
1695+
ctx: ConversionContext,
1696+
target: Target,
1697+
args: Tuple[Argument, ...],
1698+
kwargs: Dict[str, Argument],
1699+
name: str,
1700+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1701+
return impl.elementwise.bitwise_or(
1702+
ctx,
1703+
target,
1704+
SourceIR.ATEN,
1705+
name,
1706+
args[0],
1707+
args[1],
1708+
)
1709+
1710+
1711+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
1712+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar) # type: ignore[misc]
1713+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar_Tensor) # type: ignore[misc]
1714+
def aten_ops_bitwise_xor(
1715+
ctx: ConversionContext,
1716+
target: Target,
1717+
args: Tuple[Argument, ...],
1718+
kwargs: Dict[str, Argument],
1719+
name: str,
1720+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1721+
return impl.elementwise.bitwise_xor(
1722+
ctx,
1723+
target,
1724+
SourceIR.ATEN,
1725+
name,
1726+
args[0],
1727+
args[1],
1728+
)
1729+
1730+
1731+
def bitwise_not_type_validator(node: Node) -> bool:
1732+
val = node.args[0]
1733+
val_meta = val.meta.get("tensor_meta")
1734+
1735+
if val_meta is None:
1736+
return False
1737+
1738+
supported_type = [torch.bool, bool]
1739+
return val_meta.dtype in supported_type
1740+
1741+
1742+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator) # type: ignore[misc]
1743+
@enforce_tensor_types(
1744+
{
1745+
0: (TRTTensor,),
1746+
}
1747+
) # type: ignore[misc]
1748+
def aten_ops_bitwise_not(
1749+
ctx: ConversionContext,
1750+
target: Target,
1751+
args: Tuple[Argument, ...],
1752+
kwargs: Dict[str, Argument],
1753+
name: str,
1754+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1755+
return impl.unary.bitwise_not(
1756+
ctx,
1757+
target,
1758+
SourceIR.ATEN,
1759+
name,
1760+
args[0],
1761+
)
1762+
1763+
16501764
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
16511765
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
16521766
@enforce_tensor_types(

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,39 @@ def logical_xor(
423423
)
424424

425425

426+
def bitwise_and(
427+
ctx: ConversionContext,
428+
target: Target,
429+
source_ir: Optional[SourceIR],
430+
name: str,
431+
lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
432+
rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
433+
) -> TRTTensor:
434+
return logical_and(ctx, target, source_ir, f"{name}_logical_and", lhs_val, rhs_val)
435+
436+
437+
def bitwise_or(
438+
ctx: ConversionContext,
439+
target: Target,
440+
source_ir: Optional[SourceIR],
441+
name: str,
442+
lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
443+
rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
444+
) -> TRTTensor:
445+
return logical_or(ctx, target, source_ir, f"{name}_logical_or", lhs_val, rhs_val)
446+
447+
448+
def bitwise_xor(
449+
ctx: ConversionContext,
450+
target: Target,
451+
source_ir: Optional[SourceIR],
452+
name: str,
453+
lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
454+
rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
455+
) -> TRTTensor:
456+
return logical_xor(ctx, target, source_ir, f"{name}_logical_xor", lhs_val, rhs_val)
457+
458+
426459
def eq(
427460
ctx: ConversionContext,
428461
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
import tensorrt as trt
4+
import torch_tensorrt.dynamo.conversion.impl as impl
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
67
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -336,6 +337,18 @@ def logical_not(
336337
)
337338

338339

340+
def bitwise_not(
341+
ctx: ConversionContext,
342+
target: Target,
343+
source_ir: Optional[SourceIR],
344+
name: str,
345+
input_val: TRTTensor,
346+
) -> TRTTensor:
347+
return impl.unary.logical_not(
348+
ctx, target, source_ir, f"{name}_logical_not", input_val
349+
)
350+
351+
339352
def sign(
340353
ctx: ConversionContext,
341354
target: Target,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBitwiseAndConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_bitwise_and_tensor(self, _, shape):
17+
class bitwise_and(nn.Module):
18+
def forward(self, lhs_val, rhs_val):
19+
return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)
20+
21+
inputs = [
22+
torch.randint(0, 2, shape, dtype=bool),
23+
torch.randint(0, 2, shape, dtype=bool),
24+
]
25+
self.run_test(
26+
bitwise_and(),
27+
inputs,
28+
enable_passes=True,
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBitwiseNotConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_bitwise_not_tensor(self, _, shape):
17+
class bitwise_not(nn.Module):
18+
def forward(self, val):
19+
return torch.ops.aten.bitwise_not.default(val)
20+
21+
inputs = [
22+
torch.randint(0, 2, shape, dtype=torch.bool),
23+
]
24+
self.run_test(
25+
bitwise_not(),
26+
inputs,
27+
enable_passes=True,
28+
output_dtypes=[torch.bool],
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBitwiseOrConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_bitwise_or_tensor(self, _, shape):
17+
class bitwise_or(nn.Module):
18+
def forward(self, lhs_val, rhs_val):
19+
return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val)
20+
21+
inputs = [
22+
torch.randint(0, 2, shape, dtype=bool),
23+
torch.randint(0, 2, shape, dtype=bool),
24+
]
25+
self.run_test(
26+
bitwise_or(),
27+
inputs,
28+
enable_passes=True,
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBitwiseXorConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_bitwise_xor_tensor(self, _, shape):
17+
class bitwise_xor(nn.Module):
18+
def forward(self, lhs_val, rhs_val):
19+
return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val)
20+
21+
inputs = [
22+
torch.randint(0, 2, shape, dtype=bool),
23+
torch.randint(0, 2, shape, dtype=bool),
24+
]
25+
self.run_test(
26+
bitwise_xor(),
27+
inputs,
28+
enable_passes=True,
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()

0 commit comments

Comments
 (0)