@@ -1794,7 +1794,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1794
1794
if (!batchDimMap.empty ()) {
1795
1795
int64_t lhsIndex = batchDimMap[0 ].first ;
1796
1796
int64_t rhsIndex = batchDimMap[0 ].second ;
1797
- rewriter.replaceOp (op, lowerParallel (op, lhsIndex, rhsIndex, rewriter));
1797
+ auto newOp = lowerParallel (op, lhsIndex, rhsIndex, rewriter);
1798
+ if (failed (newOp))
1799
+ return failure ();
1800
+ rewriter.replaceOp (op, newOp.value ());
1798
1801
return success ();
1799
1802
}
1800
1803
@@ -1812,8 +1815,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1812
1815
VectorType lhsType = op.getLhsType ();
1813
1816
for (int64_t lhsIndex = 0 , e = lhsType.getRank (); lhsIndex < e; ++lhsIndex) {
1814
1817
if (lhsContractingDimSet.count (lhsIndex) == 0 ) {
1815
- rewriter.replaceOp (
1816
- op, lowerParallel (op, lhsIndex, /* rhsIndex=*/ -1 , rewriter));
1818
+ auto newOp = lowerParallel (op, lhsIndex, /* rhsIndex=*/ -1 , rewriter);
1819
+ if (failed (newOp))
1820
+ return failure ();
1821
+ rewriter.replaceOp (op, newOp.value ());
1817
1822
return success ();
1818
1823
}
1819
1824
}
@@ -1822,26 +1827,33 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1822
1827
VectorType rhsType = op.getRhsType ();
1823
1828
for (int64_t rhsIndex = 0 , e = rhsType.getRank (); rhsIndex < e; ++rhsIndex) {
1824
1829
if (rhsContractingDimSet.count (rhsIndex) == 0 ) {
1825
- rewriter.replaceOp (
1826
- op, lowerParallel (op, /* lhsIndex=*/ -1 , rhsIndex, rewriter));
1830
+ auto newOp = lowerParallel (op, /* lhsIndex=*/ -1 , rhsIndex, rewriter);
1831
+ if (failed (newOp))
1832
+ return failure ();
1833
+ rewriter.replaceOp (op, newOp.value ());
1827
1834
return success ();
1828
1835
}
1829
1836
}
1830
1837
1831
1838
// Lower the first remaining reduction dimension.
1832
1839
if (!contractingDimMap.empty ()) {
1833
- rewriter.replaceOp (op, lowerReduction (op, rewriter));
1840
+ auto newOp = lowerReduction (op, rewriter);
1841
+ if (failed (newOp))
1842
+ return failure ();
1843
+ rewriter.replaceOp (op, newOp.value ());
1834
1844
return success ();
1835
1845
}
1836
1846
1837
1847
return failure ();
1838
1848
}
1839
1849
1840
1850
// Lower one parallel dimension.
1851
+ // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1841
1852
// TODO: consider reusing existing contract unrolling
1842
- Value ContractionOpLowering::lowerParallel (vector::ContractionOp op,
1843
- int64_t lhsIndex, int64_t rhsIndex,
1844
- PatternRewriter &rewriter) const {
1853
+ FailureOr<Value>
1854
+ ContractionOpLowering::lowerParallel (vector::ContractionOp op, int64_t lhsIndex,
1855
+ int64_t rhsIndex,
1856
+ PatternRewriter &rewriter) const {
1845
1857
VectorType lhsType = op.getLhsType ();
1846
1858
VectorType rhsType = op.getRhsType ();
1847
1859
VectorType resType = op.getResultType ().cast <VectorType>();
@@ -1851,18 +1863,34 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1851
1863
int64_t dimSize = -1 ;
1852
1864
if (lhsIndex >= 0 ) {
1853
1865
iterIndex = iMap[0 ].getDimPosition (lhsIndex);
1854
- assert ((rhsIndex < 0 || iterIndex == iMap[1 ].getDimPosition (rhsIndex)) &&
1855
- " parallel index should be free in LHS or batch in LHS/RHS" );
1866
+ if (rhsIndex >= 0 && iterIndex != iMap[1 ].getDimPosition (rhsIndex))
1867
+ return rewriter.notifyMatchFailure (op, [&](Diagnostic &diag) {
1868
+ diag << " expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1869
+ << " to map to the same dimension" ;
1870
+ });
1856
1871
dimSize = lhsType.getDimSize (lhsIndex);
1857
- } else {
1858
- assert (rhsIndex >= 0 && " missing parallel index" );
1872
+ } else if (rhsIndex >= 0 ) {
1859
1873
iterIndex = iMap[1 ].getDimPosition (rhsIndex);
1860
1874
dimSize = rhsType.getDimSize (rhsIndex);
1861
1875
}
1862
- assert (iterIndex >= 0 && " parallel index not listed in operand mapping" );
1863
- Optional<int64_t > lookup = getResultIndex (iMap[2 ], iterIndex);
1864
- assert (lookup.has_value () && " parallel index not listed in reduction" );
1865
- int64_t resIndex = lookup.getValue ();
1876
+ if (iterIndex < 0 )
1877
+ return rewriter.notifyMatchFailure (op, [&](Diagnostic &diag) {
1878
+ diag << " expected either lhsIndex=" << lhsIndex
1879
+ << " or rhsIndex=" << rhsIndex << " to be nonnegative" ;
1880
+ });
1881
+ // getValueOr(-1) means that we tolerate a dimension not appearing
1882
+ // in the result map. That can't happen for actual parallel iterators, but
1883
+ // the caller ContractionOpLowering::matchAndRewrite is currently calling
1884
+ // lowerParallel also for the case of unit-size reduction dims appearing only
1885
+ // on one of LHS or RHS, not both. At the moment, such cases are created by
1886
+ // CastAwayContractionLeadingOneDim, so we need to either support that or
1887
+ // modify that pattern.
1888
+ int64_t resIndex = getResultIndex (iMap[2 ], iterIndex).getValueOr (-1 );
1889
+ if (resIndex == -1 && dimSize != 1 )
1890
+ return rewriter.notifyMatchFailure (op, [&](Diagnostic &diag) {
1891
+ diag << " expected the dimension for iterIndex=" << iterIndex
1892
+ << " to either appear in the result map, or to be a unit dimension" ;
1893
+ });
1866
1894
// Construct new iterator types and affine map array attribute.
1867
1895
std::array<AffineMap, 3 > lowIndexingMaps = {
1868
1896
adjustMap (iMap[0 ], iterIndex, rewriter),
@@ -1888,33 +1916,49 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1888
1916
}
1889
1917
1890
1918
// Lower one reduction dimension.
1891
- Value ContractionOpLowering::lowerReduction (vector::ContractionOp op,
1892
- PatternRewriter &rewriter) const {
1919
+ FailureOr<Value>
1920
+ ContractionOpLowering::lowerReduction (vector::ContractionOp op,
1921
+ PatternRewriter &rewriter) const {
1893
1922
auto loc = op.getLoc ();
1894
1923
VectorType lhsType = op.getLhsType ();
1895
1924
VectorType rhsType = op.getRhsType ();
1896
1925
Type resType = op.getResultType ();
1897
- assert (!resType.isa <VectorType>());
1926
+ if (resType.isa <VectorType>())
1927
+ return rewriter.notifyMatchFailure (op,
1928
+ " did not expect a VectorType result" );
1898
1929
bool isInt = resType.isa <IntegerType>();
1899
1930
// Use iterator index 0.
1900
1931
int64_t iterIndex = 0 ;
1901
1932
SmallVector<AffineMap, 4 > iMap = op.getIndexingMaps ();
1902
1933
Optional<int64_t > lookupLhs = getResultIndex (iMap[0 ], iterIndex);
1903
1934
Optional<int64_t > lookupRhs = getResultIndex (iMap[1 ], iterIndex);
1904
- assert (lookupLhs.has_value () && " missing LHS parallel index" );
1905
- assert (lookupRhs.has_value () && " missing RHS parallel index" );
1935
+ if (!lookupLhs.hasValue ())
1936
+ return rewriter.notifyMatchFailure (op, [&](Diagnostic &diag) {
1937
+ diag << " expected iterIndex=" << iterIndex << " to map to a LHS dimension" ;
1938
+ });
1939
+ if (!lookupRhs.hasValue ())
1940
+ return rewriter.notifyMatchFailure (op, [&](Diagnostic &diag) {
1941
+ diag << " expected iterIndex=" << iterIndex << " to map to a RHS dimension" ;
1942
+ });
1906
1943
int64_t lhsIndex = lookupLhs.getValue ();
1907
1944
int64_t rhsIndex = lookupRhs.getValue ();
1908
1945
int64_t dimSize = lhsType.getDimSize (lhsIndex);
1909
- assert (dimSize == rhsType.getDimSize (rhsIndex) && " corrupt shape" );
1946
+ if (dimSize != rhsType.getDimSize (rhsIndex))
1947
+ return rewriter.notifyMatchFailure (op, [&](Diagnostic &diag) {
1948
+ diag << " expect LHS dimension " << lhsIndex
1949
+ << " to have the same size as RHS dimension " << rhsIndex;
1950
+ });
1910
1951
// Base case.
1911
1952
if (lhsType.getRank () == 1 ) {
1912
- assert (rhsType.getRank () == 1 && " corrupt contraction" );
1953
+ if (rhsType.getRank () != 1 )
1954
+ return rewriter.notifyMatchFailure (
1955
+ op, " When LHS has rank 1, expected also RHS to have rank 1" );
1913
1956
Value m = createMul (loc, op.getLhs (), op.getRhs (), isInt, rewriter);
1914
1957
auto kind = vector::CombiningKind::ADD;
1915
1958
if (auto acc = op.getAcc ())
1916
- return rewriter.create <vector::ReductionOp>(loc, kind, m, acc);
1917
- return rewriter.create <vector::ReductionOp>(loc, kind, m);
1959
+ return rewriter.create <vector::ReductionOp>(loc, kind, m, acc)
1960
+ .getResult ();
1961
+ return rewriter.create <vector::ReductionOp>(loc, kind, m).getResult ();
1918
1962
}
1919
1963
// Construct new iterator types and affine map array attribute.
1920
1964
std::array<AffineMap, 3 > lowIndexingMaps = {
0 commit comments