Skip to content

Commit 6ef3b21

Browse files
committed
feat: support more elementwise and unary converters
1 parent ac1e7db commit 6ef3b21

File tree

7 files changed

+292
-0
lines changed

7 files changed

+292
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,6 +1701,120 @@ def aten_ops_logical_xor(
17011701
)
17021702

17031703

1704+
def bitwise_type_validator(node: Node) -> bool:
1705+
targets = [
1706+
torch.ops.aten.bitwise_and.Tensor,
1707+
torch.ops.aten.bitwise_or.Tensor,
1708+
torch.ops.aten.bitwise_xor.Tensor,
1709+
]
1710+
if node.target not in targets:
1711+
return False
1712+
1713+
lhs_val = node.args[0]
1714+
rhs_val = node.args[1]
1715+
lhs_meta = lhs_val.meta.get("tensor_meta")
1716+
rhs_meta = rhs_val.meta.get("tensor_meta")
1717+
1718+
if lhs_meta is None or rhs_meta is None:
1719+
return False
1720+
1721+
supported_type = [torch.bool, bool]
1722+
return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type
1723+
1724+
1725+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
1726+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar) # type: ignore[misc]
1727+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar_Tensor) # type: ignore[misc]
1728+
def aten_ops_bitwise_and(
1729+
ctx: ConversionContext,
1730+
target: Target,
1731+
args: Tuple[Argument, ...],
1732+
kwargs: Dict[str, Argument],
1733+
name: str,
1734+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1735+
return impl.elementwise.bitwise_and(
1736+
ctx,
1737+
target,
1738+
SourceIR.ATEN,
1739+
name,
1740+
args[0],
1741+
args[1],
1742+
)
1743+
1744+
1745+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
1746+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar) # type: ignore[misc]
1747+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar_Tensor) # type: ignore[misc]
1748+
def aten_ops_bitwise_or(
1749+
ctx: ConversionContext,
1750+
target: Target,
1751+
args: Tuple[Argument, ...],
1752+
kwargs: Dict[str, Argument],
1753+
name: str,
1754+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1755+
return impl.elementwise.bitwise_or(
1756+
ctx,
1757+
target,
1758+
SourceIR.ATEN,
1759+
name,
1760+
args[0],
1761+
args[1],
1762+
)
1763+
1764+
1765+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
1766+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar) # type: ignore[misc]
1767+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar_Tensor) # type: ignore[misc]
1768+
def aten_ops_bitwise_xor(
1769+
ctx: ConversionContext,
1770+
target: Target,
1771+
args: Tuple[Argument, ...],
1772+
kwargs: Dict[str, Argument],
1773+
name: str,
1774+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1775+
return impl.elementwise.bitwise_xor(
1776+
ctx,
1777+
target,
1778+
SourceIR.ATEN,
1779+
name,
1780+
args[0],
1781+
args[1],
1782+
)
1783+
1784+
1785+
def bitwise_not_type_validator(node: Node) -> bool:
1786+
val = node.args[0]
1787+
val_meta = val.meta.get("tensor_meta")
1788+
1789+
if val_meta is None:
1790+
return False
1791+
1792+
supported_type = [torch.bool, bool]
1793+
return val_meta.dtype in supported_type
1794+
1795+
1796+
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator) # type: ignore[misc]
1797+
@enforce_tensor_types(
1798+
{
1799+
0: (TRTTensor,),
1800+
}
1801+
) # type: ignore[misc]
1802+
def aten_ops_bitwise_not(
1803+
ctx: ConversionContext,
1804+
target: Target,
1805+
args: Tuple[Argument, ...],
1806+
kwargs: Dict[str, Argument],
1807+
name: str,
1808+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1809+
return impl.unary.bitwise_not(
1810+
ctx,
1811+
target,
1812+
SourceIR.ATEN,
1813+
name,
1814+
args[0],
1815+
)
1816+
1817+
17041818
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
17051819
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
17061820
@enforce_tensor_types(

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,39 @@ def logical_xor(
423423
)
424424

425425

426+
def bitwise_and(
427+
ctx: ConversionContext,
428+
target: Target,
429+
source_ir: Optional[SourceIR],
430+
name: str,
431+
lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
432+
rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
433+
) -> TRTTensor:
434+
return logical_and(ctx, target, source_ir, f"{name}_logical_and", lhs_val, rhs_val)
435+
436+
437+
def bitwise_or(
438+
ctx: ConversionContext,
439+
target: Target,
440+
source_ir: Optional[SourceIR],
441+
name: str,
442+
lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
443+
rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
444+
) -> TRTTensor:
445+
return logical_or(ctx, target, source_ir, f"{name}_logical_or", lhs_val, rhs_val)
446+
447+
448+
def bitwise_xor(
449+
ctx: ConversionContext,
450+
target: Target,
451+
source_ir: Optional[SourceIR],
452+
name: str,
453+
lhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
454+
rhs_val: Union[TRTTensor, int, bool, Sequence[Union[int, bool]]],
455+
) -> TRTTensor:
456+
return logical_xor(ctx, target, source_ir, f"{name}_logical_xor", lhs_val, rhs_val)
457+
458+
426459
def eq(
427460
ctx: ConversionContext,
428461
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
import tensorrt as trt
4+
import torch_tensorrt.dynamo.conversion.impl as impl
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
67
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -336,6 +337,18 @@ def logical_not(
336337
)
337338

338339

340+
def bitwise_not(
341+
ctx: ConversionContext,
342+
target: Target,
343+
source_ir: Optional[SourceIR],
344+
name: str,
345+
input_val: TRTTensor,
346+
) -> TRTTensor:
347+
return impl.unary.logical_not(
348+
ctx, target, source_ir, f"{name}_logical_not", input_val
349+
)
350+
351+
339352
def sign(
340353
ctx: ConversionContext,
341354
target: Target,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBitwiseAndConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_bitwise_and_tensor(self, _, shape):
17+
class bitwise_and(nn.Module):
18+
def forward(self, lhs_val, rhs_val):
19+
return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)
20+
21+
inputs = [
22+
torch.randint(0, 2, shape, dtype=bool),
23+
torch.randint(0, 2, shape, dtype=bool),
24+
]
25+
self.run_test(
26+
bitwise_and(),
27+
inputs,
28+
enable_passes=True,
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBitwiseNotConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_bitwise_not_tensor(self, _, shape):
17+
class bitwise_not(nn.Module):
18+
def forward(self, val):
19+
return torch.ops.aten.bitwise_not.default(val)
20+
21+
inputs = [
22+
torch.randint(0, 2, shape, dtype=torch.bool),
23+
]
24+
self.run_test(
25+
bitwise_not(),
26+
inputs,
27+
enable_passes=True,
28+
output_dtypes=[torch.bool],
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBitwiseOrConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_bitwise_or_tensor(self, _, shape):
17+
class bitwise_or(nn.Module):
18+
def forward(self, lhs_val, rhs_val):
19+
return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val)
20+
21+
inputs = [
22+
torch.randint(0, 2, shape, dtype=bool),
23+
torch.randint(0, 2, shape, dtype=bool),
24+
]
25+
self.run_test(
26+
bitwise_or(),
27+
inputs,
28+
enable_passes=True,
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBitwiseXorConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d", (5, 3)),
13+
("3d", (5, 3, 2)),
14+
]
15+
)
16+
def test_bitwise_xor_tensor(self, _, shape):
17+
class bitwise_xor(nn.Module):
18+
def forward(self, lhs_val, rhs_val):
19+
return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val)
20+
21+
inputs = [
22+
torch.randint(0, 2, shape, dtype=bool),
23+
torch.randint(0, 2, shape, dtype=bool),
24+
]
25+
self.run_test(
26+
bitwise_xor(),
27+
inputs,
28+
enable_passes=True,
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()

0 commit comments

Comments
 (0)