Skip to content

Commit 6870a50

Browse files
committed
lowerParallel is also called on unit-size, one-sided reduction dims
See: https://gist.github.com/bjacob/d8be8ec7e70ed0be4b3a5794ced2a7e8 Differential Revision: https://reviews.llvm.org/D129096
1 parent 3968936 commit 6870a50

File tree

3 files changed

+103
-30
lines changed

3 files changed

+103
-30
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -527,11 +527,12 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
527527
vector::VectorTransformsOptions vectorTransformOptions;
528528
FilterConstraintType filter;
529529
// Lower one parallel dimension.
530-
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
531-
int64_t rhsIndex, PatternRewriter &rewriter) const;
530+
FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
531+
int64_t rhsIndex,
532+
PatternRewriter &rewriter) const;
532533
// Lower one reduction dimension.
533-
Value lowerReduction(vector::ContractionOp op,
534-
PatternRewriter &rewriter) const;
534+
FailureOr<Value> lowerReduction(vector::ContractionOp op,
535+
PatternRewriter &rewriter) const;
535536
};
536537

537538
} // namespace vector

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
17941794
if (!batchDimMap.empty()) {
17951795
int64_t lhsIndex = batchDimMap[0].first;
17961796
int64_t rhsIndex = batchDimMap[0].second;
1797-
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1797+
auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
1798+
if (failed(newOp))
1799+
return failure();
1800+
rewriter.replaceOp(op, newOp.value());
17981801
return success();
17991802
}
18001803

@@ -1812,8 +1815,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
18121815
VectorType lhsType = op.getLhsType();
18131816
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
18141817
if (lhsContractingDimSet.count(lhsIndex) == 0) {
1815-
rewriter.replaceOp(
1816-
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1818+
auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
1819+
if (failed(newOp))
1820+
return failure();
1821+
rewriter.replaceOp(op, newOp.value());
18171822
return success();
18181823
}
18191824
}
@@ -1822,26 +1827,33 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
18221827
VectorType rhsType = op.getRhsType();
18231828
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
18241829
if (rhsContractingDimSet.count(rhsIndex) == 0) {
1825-
rewriter.replaceOp(
1826-
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1830+
auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
1831+
if (failed(newOp))
1832+
return failure();
1833+
rewriter.replaceOp(op, newOp.value());
18271834
return success();
18281835
}
18291836
}
18301837

18311838
// Lower the first remaining reduction dimension.
18321839
if (!contractingDimMap.empty()) {
1833-
rewriter.replaceOp(op, lowerReduction(op, rewriter));
1840+
auto newOp = lowerReduction(op, rewriter);
1841+
if (failed(newOp))
1842+
return failure();
1843+
rewriter.replaceOp(op, newOp.value());
18341844
return success();
18351845
}
18361846

18371847
return failure();
18381848
}
18391849

18401850
// Lower one parallel dimension.
1851+
// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
18411852
// TODO: consider reusing existing contract unrolling
1842-
Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1843-
int64_t lhsIndex, int64_t rhsIndex,
1844-
PatternRewriter &rewriter) const {
1853+
FailureOr<Value>
1854+
ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
1855+
int64_t rhsIndex,
1856+
PatternRewriter &rewriter) const {
18451857
VectorType lhsType = op.getLhsType();
18461858
VectorType rhsType = op.getRhsType();
18471859
VectorType resType = op.getResultType().cast<VectorType>();
@@ -1851,18 +1863,34 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
18511863
int64_t dimSize = -1;
18521864
if (lhsIndex >= 0) {
18531865
iterIndex = iMap[0].getDimPosition(lhsIndex);
1854-
assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1855-
"parallel index should be free in LHS or batch in LHS/RHS");
1866+
if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1867+
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1868+
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1869+
<< " to map to the same dimension";
1870+
});
18561871
dimSize = lhsType.getDimSize(lhsIndex);
1857-
} else {
1858-
assert(rhsIndex >= 0 && "missing parallel index");
1872+
} else if (rhsIndex >= 0) {
18591873
iterIndex = iMap[1].getDimPosition(rhsIndex);
18601874
dimSize = rhsType.getDimSize(rhsIndex);
18611875
}
1862-
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1863-
Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1864-
assert(lookup.has_value() && "parallel index not listed in reduction");
1865-
int64_t resIndex = lookup.getValue();
1876+
if (iterIndex < 0)
1877+
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1878+
diag << "expected either lhsIndex=" << lhsIndex
1879+
<< " or rhsIndex=" << rhsIndex << " to be nonnegative";
1880+
});
1881+
// getValueOr(-1) means that we tolerate a dimension not appearing
1882+
// in the result map. That can't happen for actual parallel iterators, but
1883+
// the caller ContractionOpLowering::matchAndRewrite is currently calling
1884+
// lowerParallel also for the case of unit-size reduction dims appearing only
1885+
// on one of LHS or RHS, not both. At the moment, such cases are created by
1886+
// CastAwayContractionLeadingOneDim, so we need to either support that or
1887+
// modify that pattern.
1888+
int64_t resIndex = getResultIndex(iMap[2], iterIndex).getValueOr(-1);
1889+
if (resIndex == -1 && dimSize != 1)
1890+
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1891+
diag << "expected the dimension for iterIndex=" << iterIndex
1892+
<< " to either appear in the result map, or to be a unit dimension";
1893+
});
18661894
// Construct new iterator types and affine map array attribute.
18671895
std::array<AffineMap, 3> lowIndexingMaps = {
18681896
adjustMap(iMap[0], iterIndex, rewriter),
@@ -1888,33 +1916,49 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
18881916
}
18891917

18901918
// Lower one reduction dimension.
1891-
Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1892-
PatternRewriter &rewriter) const {
1919+
FailureOr<Value>
1920+
ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1921+
PatternRewriter &rewriter) const {
18931922
auto loc = op.getLoc();
18941923
VectorType lhsType = op.getLhsType();
18951924
VectorType rhsType = op.getRhsType();
18961925
Type resType = op.getResultType();
1897-
assert(!resType.isa<VectorType>());
1926+
if (resType.isa<VectorType>())
1927+
return rewriter.notifyMatchFailure(op,
1928+
"did not expect a VectorType result");
18981929
bool isInt = resType.isa<IntegerType>();
18991930
// Use iterator index 0.
19001931
int64_t iterIndex = 0;
19011932
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
19021933
Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
19031934
Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1904-
assert(lookupLhs.has_value() && "missing LHS parallel index");
1905-
assert(lookupRhs.has_value() && "missing RHS parallel index");
1935+
if (!lookupLhs.hasValue())
1936+
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1937+
diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1938+
});
1939+
if (!lookupRhs.hasValue())
1940+
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1941+
diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1942+
});
19061943
int64_t lhsIndex = lookupLhs.getValue();
19071944
int64_t rhsIndex = lookupRhs.getValue();
19081945
int64_t dimSize = lhsType.getDimSize(lhsIndex);
1909-
assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1946+
if (dimSize != rhsType.getDimSize(rhsIndex))
1947+
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1948+
diag << "expect LHS dimension " << lhsIndex
1949+
<< " to have the same size as RHS dimension " << rhsIndex;
1950+
});
19101951
// Base case.
19111952
if (lhsType.getRank() == 1) {
1912-
assert(rhsType.getRank() == 1 && "corrupt contraction");
1953+
if (rhsType.getRank() != 1)
1954+
return rewriter.notifyMatchFailure(
1955+
op, "When LHS has rank 1, expected also RHS to have rank 1");
19131956
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
19141957
auto kind = vector::CombiningKind::ADD;
19151958
if (auto acc = op.getAcc())
1916-
return rewriter.create<vector::ReductionOp>(loc, kind, m, acc);
1917-
return rewriter.create<vector::ReductionOp>(loc, kind, m);
1959+
return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1960+
.getResult();
1961+
return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
19181962
}
19191963
// Construct new iterator types and affine map array attribute.
19201964
std::array<AffineMap, 3> lowIndexingMaps = {

mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,34 @@ func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x
858858
return %0 : vector<2x1x7xi1>
859859
}
860860

861+
// CHECK-LABEL: @contract_one_sided_unit_reduction_dim
862+
// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>)
863+
// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32>
864+
// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32>
865+
// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32>
866+
// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32>
867+
// CHECK: %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32
868+
// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32>
869+
// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32>
870+
// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32>
871+
// CHECK: %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32
872+
// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32>
873+
// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32>
874+
// CHECK: return %[[S]] : vector<2xi32>
875+
876+
func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> {
877+
%res = vector.contract {
878+
indexing_maps = [
879+
affine_map<(d0, d1, d2) -> (d0, d2)>,
880+
affine_map<(d0, d1, d2) -> (d1, d2)>,
881+
affine_map<(d0, d1, d2) -> (d1)>
882+
],
883+
iterator_types = ["reduction", "parallel", "reduction"],
884+
kind = #vector.kind<add>
885+
} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32>
886+
return %res : vector<2xi32>
887+
}
888+
861889
#matmat_accesses_0 = [
862890
affine_map<(m, n, k) -> (m, k)>,
863891
affine_map<(m, n, k) -> (k, n)>,

0 commit comments

Comments
 (0)