@@ -754,12 +754,12 @@ def aten_ops_cumsum(
754
754
)
755
755
756
756
757
- @dynamo_tensorrt_converter (torch .ops .aten .tile .default ) # type: ignore[misc]
757
+ @dynamo_tensorrt_converter (torch .ops .aten .tile .default )
758
758
@enforce_tensor_types (
759
759
{
760
760
0 : (TRTTensor ,),
761
761
}
762
- ) # type: ignore[misc]
762
+ )
763
763
def aten_ops_tile (
764
764
ctx : ConversionContext ,
765
765
target : Target ,
@@ -777,7 +777,7 @@ def aten_ops_tile(
777
777
)
778
778
779
779
780
- @dynamo_tensorrt_converter (torch .ops .aten .permute .default ) # type: ignore[misc]
780
+ @dynamo_tensorrt_converter (torch .ops .aten .permute .default )
781
781
@enforce_tensor_types (
782
782
{
783
783
0 : (TRTTensor ,),
@@ -1702,29 +1702,63 @@ def aten_ops_logical_xor(
1702
1702
1703
1703
1704
1704
def bitwise_type_validator (node : Node ) -> bool :
1705
- targets = [
1705
+ supported_type = [torch .bool , bool ]
1706
+
1707
+ tensor_targets = [
1706
1708
torch .ops .aten .bitwise_and .Tensor ,
1707
1709
torch .ops .aten .bitwise_or .Tensor ,
1708
1710
torch .ops .aten .bitwise_xor .Tensor ,
1709
1711
]
1710
- if node .target not in targets :
1711
- return False
1712
+ scalar_targets = [
1713
+ torch .ops .aten .bitwise_and .Scalar ,
1714
+ torch .ops .aten .bitwise_or .Scalar ,
1715
+ torch .ops .aten .bitwise_xor .Scalar ,
1716
+ ]
1717
+ scalar_tensor_targets = [
1718
+ torch .ops .aten .bitwise_and .Scalar_Tensor ,
1719
+ torch .ops .aten .bitwise_or .Scalar_Tensor ,
1720
+ torch .ops .aten .bitwise_xor .Scalar_Tensor ,
1721
+ ]
1712
1722
1713
- lhs_val = node .args [0 ]
1714
- rhs_val = node .args [1 ]
1715
- lhs_meta = lhs_val .meta .get ("tensor_meta" )
1716
- rhs_meta = rhs_val .meta .get ("tensor_meta" )
1723
+ if node .target in tensor_targets :
1724
+ lhs_val = node .args [0 ]
1725
+ rhs_val = node .args [1 ]
1726
+ lhs_meta = lhs_val .meta .get ("tensor_meta" )
1727
+ rhs_meta = rhs_val .meta .get ("tensor_meta" )
1728
+ if lhs_meta is None or rhs_meta is None :
1729
+ return False
1730
+ return lhs_meta .dtype in supported_type and rhs_meta .dtype in supported_type
1717
1731
1718
- if lhs_meta is None or rhs_meta is None :
1719
- return False
1732
+ elif node .target in scalar_targets :
1733
+ lhs_val = node .args [0 ]
1734
+ rhs_val = node .args [1 ]
1735
+ lhs_meta = lhs_val .meta .get ("tensor_meta" )
1736
+ if lhs_meta is None :
1737
+ return False
1738
+ return lhs_meta .dtype in supported_type and isinstance (rhs_val , bool )
1720
1739
1721
- supported_type = [torch .bool , bool ]
1722
- return lhs_meta .dtype in supported_type and rhs_meta .dtype in supported_type
1740
+ elif node .target in scalar_tensor_targets :
1741
+ lhs_val = node .args [0 ]
1742
+ rhs_val = node .args [1 ]
1743
+ rhs_meta = rhs_val .meta .get ("tensor_meta" )
1744
+ if rhs_meta is None :
1745
+ return False
1746
+ return isinstance (lhs_val , bool ) and rhs_meta .dtype in supported_type
1747
+
1748
+ else :
1749
+ return False
1723
1750
1724
1751
1725
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1726
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Scalar ) # type: ignore[misc]
1727
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Scalar_Tensor ) # type: ignore[misc]
1752
+ @dynamo_tensorrt_converter (
1753
+ torch .ops .aten .bitwise_and .Tensor , capability_validator = bitwise_type_validator
1754
+ )
1755
+ @dynamo_tensorrt_converter (
1756
+ torch .ops .aten .bitwise_and .Scalar , capability_validator = bitwise_type_validator
1757
+ )
1758
+ @dynamo_tensorrt_converter (
1759
+ torch .ops .aten .bitwise_and .Scalar_Tensor ,
1760
+ capability_validator = bitwise_type_validator ,
1761
+ )
1728
1762
def aten_ops_bitwise_and (
1729
1763
ctx : ConversionContext ,
1730
1764
target : Target ,
@@ -1742,9 +1776,15 @@ def aten_ops_bitwise_and(
1742
1776
)
1743
1777
1744
1778
1745
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1746
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Scalar ) # type: ignore[misc]
1747
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Scalar_Tensor ) # type: ignore[misc]
1779
+ @dynamo_tensorrt_converter (
1780
+ torch .ops .aten .bitwise_or .Tensor , capability_validator = bitwise_type_validator
1781
+ )
1782
+ @dynamo_tensorrt_converter (
1783
+ torch .ops .aten .bitwise_or .Scalar , capability_validator = bitwise_type_validator
1784
+ )
1785
+ @dynamo_tensorrt_converter (
1786
+ torch .ops .aten .bitwise_or .Scalar_Tensor , capability_validator = bitwise_type_validator
1787
+ )
1748
1788
def aten_ops_bitwise_or (
1749
1789
ctx : ConversionContext ,
1750
1790
target : Target ,
@@ -1762,9 +1802,16 @@ def aten_ops_bitwise_or(
1762
1802
)
1763
1803
1764
1804
1765
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1766
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Scalar ) # type: ignore[misc]
1767
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Scalar_Tensor ) # type: ignore[misc]
1805
+ @dynamo_tensorrt_converter (
1806
+ torch .ops .aten .bitwise_xor .Tensor , capability_validator = bitwise_type_validator
1807
+ )
1808
+ @dynamo_tensorrt_converter (
1809
+ torch .ops .aten .bitwise_xor .Scalar , capability_validator = bitwise_type_validator
1810
+ )
1811
+ @dynamo_tensorrt_converter (
1812
+ torch .ops .aten .bitwise_xor .Scalar_Tensor ,
1813
+ capability_validator = bitwise_type_validator ,
1814
+ )
1768
1815
def aten_ops_bitwise_xor (
1769
1816
ctx : ConversionContext ,
1770
1817
target : Target ,
@@ -1793,12 +1840,14 @@ def bitwise_not_type_validator(node: Node) -> bool:
1793
1840
return val_meta .dtype in supported_type
1794
1841
1795
1842
1796
- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_not .default , capability_validator = bitwise_not_type_validator ) # type: ignore[misc]
1843
+ @dynamo_tensorrt_converter (
1844
+ torch .ops .aten .bitwise_not .default , capability_validator = bitwise_not_type_validator
1845
+ )
1797
1846
@enforce_tensor_types (
1798
1847
{
1799
1848
0 : (TRTTensor ,),
1800
1849
}
1801
- ) # type: ignore[misc]
1850
+ )
1802
1851
def aten_ops_bitwise_not (
1803
1852
ctx : ConversionContext ,
1804
1853
target : Target ,
@@ -1815,13 +1864,13 @@ def aten_ops_bitwise_not(
1815
1864
)
1816
1865
1817
1866
1818
- @dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor ) # type: ignore[misc]
1819
- @dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar ) # type: ignore[misc]
1867
+ @dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor )
1868
+ @dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar )
1820
1869
@enforce_tensor_types (
1821
1870
{
1822
1871
0 : (TRTTensor ,),
1823
1872
}
1824
- ) # type: ignore[misc]
1873
+ )
1825
1874
def aten_ops_eq (
1826
1875
ctx : ConversionContext ,
1827
1876
target : Target ,
@@ -1839,13 +1888,13 @@ def aten_ops_eq(
1839
1888
)
1840
1889
1841
1890
1842
- @dynamo_tensorrt_converter (torch .ops .aten .ne .Tensor ) # type: ignore[misc]
1843
- @dynamo_tensorrt_converter (torch .ops .aten .ne .Scalar ) # type: ignore[misc]
1891
+ @dynamo_tensorrt_converter (torch .ops .aten .ne .Tensor )
1892
+ @dynamo_tensorrt_converter (torch .ops .aten .ne .Scalar )
1844
1893
@enforce_tensor_types (
1845
1894
{
1846
1895
0 : (TRTTensor ,),
1847
1896
}
1848
- ) # type: ignore[misc]
1897
+ )
1849
1898
def aten_ops_ne (
1850
1899
ctx : ConversionContext ,
1851
1900
target : Target ,
@@ -1863,13 +1912,13 @@ def aten_ops_ne(
1863
1912
)
1864
1913
1865
1914
1866
- @dynamo_tensorrt_converter (torch .ops .aten .gt .Tensor ) # type: ignore[misc]
1867
- @dynamo_tensorrt_converter (torch .ops .aten .gt .Scalar ) # type: ignore[misc]
1915
+ @dynamo_tensorrt_converter (torch .ops .aten .gt .Tensor )
1916
+ @dynamo_tensorrt_converter (torch .ops .aten .gt .Scalar )
1868
1917
@enforce_tensor_types (
1869
1918
{
1870
1919
0 : (TRTTensor ,),
1871
1920
}
1872
- ) # type: ignore[misc]
1921
+ )
1873
1922
def aten_ops_gt (
1874
1923
ctx : ConversionContext ,
1875
1924
target : Target ,
@@ -1887,13 +1936,13 @@ def aten_ops_gt(
1887
1936
)
1888
1937
1889
1938
1890
- @dynamo_tensorrt_converter (torch .ops .aten .ge .Tensor ) # type: ignore[misc]
1891
- @dynamo_tensorrt_converter (torch .ops .aten .ge .Scalar ) # type: ignore[misc]
1939
+ @dynamo_tensorrt_converter (torch .ops .aten .ge .Tensor )
1940
+ @dynamo_tensorrt_converter (torch .ops .aten .ge .Scalar )
1892
1941
@enforce_tensor_types (
1893
1942
{
1894
1943
0 : (TRTTensor ,),
1895
1944
}
1896
- ) # type: ignore[misc]
1945
+ )
1897
1946
def aten_ops_ge (
1898
1947
ctx : ConversionContext ,
1899
1948
target : Target ,
@@ -1911,13 +1960,13 @@ def aten_ops_ge(
1911
1960
)
1912
1961
1913
1962
1914
- @dynamo_tensorrt_converter (torch .ops .aten .lt .Tensor ) # type: ignore[misc]
1915
- @dynamo_tensorrt_converter (torch .ops .aten .lt .Scalar ) # type: ignore[misc]
1963
+ @dynamo_tensorrt_converter (torch .ops .aten .lt .Tensor )
1964
+ @dynamo_tensorrt_converter (torch .ops .aten .lt .Scalar )
1916
1965
@enforce_tensor_types (
1917
1966
{
1918
1967
0 : (TRTTensor ,),
1919
1968
}
1920
- ) # type: ignore[misc]
1969
+ )
1921
1970
def aten_ops_lt (
1922
1971
ctx : ConversionContext ,
1923
1972
target : Target ,
@@ -1935,13 +1984,13 @@ def aten_ops_lt(
1935
1984
)
1936
1985
1937
1986
1938
- @dynamo_tensorrt_converter (torch .ops .aten .le .Tensor ) # type: ignore[misc]
1939
- @dynamo_tensorrt_converter (torch .ops .aten .le .Scalar ) # type: ignore[misc]
1987
+ @dynamo_tensorrt_converter (torch .ops .aten .le .Tensor )
1988
+ @dynamo_tensorrt_converter (torch .ops .aten .le .Scalar )
1940
1989
@enforce_tensor_types (
1941
1990
{
1942
1991
0 : (TRTTensor ,),
1943
1992
}
1944
- ) # type: ignore[misc]
1993
+ )
1945
1994
def aten_ops_le (
1946
1995
ctx : ConversionContext ,
1947
1996
target : Target ,
@@ -2191,14 +2240,14 @@ def aten_ops_argmax(
2191
2240
)
2192
2241
2193
2242
2194
- @dynamo_tensorrt_converter (torch .ops .aten .addmm .default ) # type: ignore[misc]
2243
+ @dynamo_tensorrt_converter (torch .ops .aten .addmm .default )
2195
2244
@enforce_tensor_types (
2196
2245
{
2197
2246
0 : (TRTTensor ,),
2198
2247
1 : (np .ndarray , torch .Tensor , TRTTensor ),
2199
2248
2 : (np .ndarray , torch .Tensor , TRTTensor ),
2200
2249
}
2201
- ) # type: ignore[misc]
2250
+ )
2202
2251
def aten_ops_addmm (
2203
2252
ctx : ConversionContext ,
2204
2253
target : Target ,
0 commit comments