Skip to content

Commit 1b077ca

Browse files
committed
fix bugs
1 parent 6ef3b21 commit 1b077ca

File tree

5 files changed

+234
-64
lines changed

5 files changed

+234
-64
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 94 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -754,12 +754,12 @@ def aten_ops_cumsum(
754754
)
755755

756756

757-
@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
757+
@dynamo_tensorrt_converter(torch.ops.aten.tile.default)
758758
@enforce_tensor_types(
759759
{
760760
0: (TRTTensor,),
761761
}
762-
) # type: ignore[misc]
762+
)
763763
def aten_ops_tile(
764764
ctx: ConversionContext,
765765
target: Target,
@@ -777,7 +777,7 @@ def aten_ops_tile(
777777
)
778778

779779

780-
@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
780+
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
781781
@enforce_tensor_types(
782782
{
783783
0: (TRTTensor,),
@@ -1702,29 +1702,63 @@ def aten_ops_logical_xor(
17021702

17031703

17041704
def bitwise_type_validator(node: Node) -> bool:
1705-
targets = [
1705+
supported_type = [torch.bool, bool]
1706+
1707+
tensor_targets = [
17061708
torch.ops.aten.bitwise_and.Tensor,
17071709
torch.ops.aten.bitwise_or.Tensor,
17081710
torch.ops.aten.bitwise_xor.Tensor,
17091711
]
1710-
if node.target not in targets:
1711-
return False
1712+
scalar_targets = [
1713+
torch.ops.aten.bitwise_and.Scalar,
1714+
torch.ops.aten.bitwise_or.Scalar,
1715+
torch.ops.aten.bitwise_xor.Scalar,
1716+
]
1717+
scalar_tensor_targets = [
1718+
torch.ops.aten.bitwise_and.Scalar_Tensor,
1719+
torch.ops.aten.bitwise_or.Scalar_Tensor,
1720+
torch.ops.aten.bitwise_xor.Scalar_Tensor,
1721+
]
17121722

1713-
lhs_val = node.args[0]
1714-
rhs_val = node.args[1]
1715-
lhs_meta = lhs_val.meta.get("tensor_meta")
1716-
rhs_meta = rhs_val.meta.get("tensor_meta")
1723+
if node.target in tensor_targets:
1724+
lhs_val = node.args[0]
1725+
rhs_val = node.args[1]
1726+
lhs_meta = lhs_val.meta.get("tensor_meta")
1727+
rhs_meta = rhs_val.meta.get("tensor_meta")
1728+
if lhs_meta is None or rhs_meta is None:
1729+
return False
1730+
return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type
17171731

1718-
if lhs_meta is None or rhs_meta is None:
1719-
return False
1732+
elif node.target in scalar_targets:
1733+
lhs_val = node.args[0]
1734+
rhs_val = node.args[1]
1735+
lhs_meta = lhs_val.meta.get("tensor_meta")
1736+
if lhs_meta is None:
1737+
return False
1738+
return lhs_meta.dtype in supported_type and isinstance(rhs_val, bool)
17201739

1721-
supported_type = [torch.bool, bool]
1722-
return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type
1740+
elif node.target in scalar_tensor_targets:
1741+
lhs_val = node.args[0]
1742+
rhs_val = node.args[1]
1743+
rhs_meta = rhs_val.meta.get("tensor_meta")
1744+
if rhs_meta is None:
1745+
return False
1746+
return isinstance(lhs_val, bool) and rhs_meta.dtype in supported_type
1747+
1748+
else:
1749+
return False
17231750

17241751

1725-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
1726-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar) # type: ignore[misc]
1727-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar_Tensor) # type: ignore[misc]
1752+
@dynamo_tensorrt_converter(
1753+
torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator
1754+
)
1755+
@dynamo_tensorrt_converter(
1756+
torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator
1757+
)
1758+
@dynamo_tensorrt_converter(
1759+
torch.ops.aten.bitwise_and.Scalar_Tensor,
1760+
capability_validator=bitwise_type_validator,
1761+
)
17281762
def aten_ops_bitwise_and(
17291763
ctx: ConversionContext,
17301764
target: Target,
@@ -1742,9 +1776,15 @@ def aten_ops_bitwise_and(
17421776
)
17431777

17441778

1745-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
1746-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar) # type: ignore[misc]
1747-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar_Tensor) # type: ignore[misc]
1779+
@dynamo_tensorrt_converter(
1780+
torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator
1781+
)
1782+
@dynamo_tensorrt_converter(
1783+
torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator
1784+
)
1785+
@dynamo_tensorrt_converter(
1786+
torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator
1787+
)
17481788
def aten_ops_bitwise_or(
17491789
ctx: ConversionContext,
17501790
target: Target,
@@ -1762,9 +1802,16 @@ def aten_ops_bitwise_or(
17621802
)
17631803

17641804

1765-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
1766-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar) # type: ignore[misc]
1767-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar_Tensor) # type: ignore[misc]
1805+
@dynamo_tensorrt_converter(
1806+
torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator
1807+
)
1808+
@dynamo_tensorrt_converter(
1809+
torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator
1810+
)
1811+
@dynamo_tensorrt_converter(
1812+
torch.ops.aten.bitwise_xor.Scalar_Tensor,
1813+
capability_validator=bitwise_type_validator,
1814+
)
17681815
def aten_ops_bitwise_xor(
17691816
ctx: ConversionContext,
17701817
target: Target,
@@ -1793,12 +1840,14 @@ def bitwise_not_type_validator(node: Node) -> bool:
17931840
return val_meta.dtype in supported_type
17941841

17951842

1796-
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator) # type: ignore[misc]
1843+
@dynamo_tensorrt_converter(
1844+
torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator
1845+
)
17971846
@enforce_tensor_types(
17981847
{
17991848
0: (TRTTensor,),
18001849
}
1801-
) # type: ignore[misc]
1850+
)
18021851
def aten_ops_bitwise_not(
18031852
ctx: ConversionContext,
18041853
target: Target,
@@ -1815,13 +1864,13 @@ def aten_ops_bitwise_not(
18151864
)
18161865

18171866

1818-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
1819-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
1867+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1868+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
18201869
@enforce_tensor_types(
18211870
{
18221871
0: (TRTTensor,),
18231872
}
1824-
) # type: ignore[misc]
1873+
)
18251874
def aten_ops_eq(
18261875
ctx: ConversionContext,
18271876
target: Target,
@@ -1839,13 +1888,13 @@ def aten_ops_eq(
18391888
)
18401889

18411890

1842-
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) # type: ignore[misc]
1843-
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) # type: ignore[misc]
1891+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor)
1892+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar)
18441893
@enforce_tensor_types(
18451894
{
18461895
0: (TRTTensor,),
18471896
}
1848-
) # type: ignore[misc]
1897+
)
18491898
def aten_ops_ne(
18501899
ctx: ConversionContext,
18511900
target: Target,
@@ -1863,13 +1912,13 @@ def aten_ops_ne(
18631912
)
18641913

18651914

1866-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc]
1867-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc]
1915+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1916+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
18681917
@enforce_tensor_types(
18691918
{
18701919
0: (TRTTensor,),
18711920
}
1872-
) # type: ignore[misc]
1921+
)
18731922
def aten_ops_gt(
18741923
ctx: ConversionContext,
18751924
target: Target,
@@ -1887,13 +1936,13 @@ def aten_ops_gt(
18871936
)
18881937

18891938

1890-
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) # type: ignore[misc]
1891-
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) # type: ignore[misc]
1939+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor)
1940+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar)
18921941
@enforce_tensor_types(
18931942
{
18941943
0: (TRTTensor,),
18951944
}
1896-
) # type: ignore[misc]
1945+
)
18971946
def aten_ops_ge(
18981947
ctx: ConversionContext,
18991948
target: Target,
@@ -1911,13 +1960,13 @@ def aten_ops_ge(
19111960
)
19121961

19131962

1914-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc]
1915-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc]
1963+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1964+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
19161965
@enforce_tensor_types(
19171966
{
19181967
0: (TRTTensor,),
19191968
}
1920-
) # type: ignore[misc]
1969+
)
19211970
def aten_ops_lt(
19221971
ctx: ConversionContext,
19231972
target: Target,
@@ -1935,13 +1984,13 @@ def aten_ops_lt(
19351984
)
19361985

19371986

1938-
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) # type: ignore[misc]
1939-
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) # type: ignore[misc]
1987+
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor)
1988+
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar)
19401989
@enforce_tensor_types(
19411990
{
19421991
0: (TRTTensor,),
19431992
}
1944-
) # type: ignore[misc]
1993+
)
19451994
def aten_ops_le(
19461995
ctx: ConversionContext,
19471996
target: Target,
@@ -2191,14 +2240,14 @@ def aten_ops_argmax(
21912240
)
21922241

21932242

2194-
@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc]
2243+
@dynamo_tensorrt_converter(torch.ops.aten.addmm.default)
21952244
@enforce_tensor_types(
21962245
{
21972246
0: (TRTTensor,),
21982247
1: (np.ndarray, torch.Tensor, TRTTensor),
21992248
2: (np.ndarray, torch.Tensor, TRTTensor),
22002249
}
2201-
) # type: ignore[misc]
2250+
)
22022251
def aten_ops_addmm(
22032252
ctx: ConversionContext,
22042253
target: Target,

0 commit comments

Comments
 (0)