Skip to content

Commit bedb959

Browse files
momchil-velikovAnthony Tran
authored andcommitted
[MLIR] Fix incorrect slice contiguity inference in vector::isContiguousSlice (llvm#142422)
Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such cases, only the trailing `n` dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write` ops. The patterns used to collapse a number of dimensions equal to the vector rank which missed some opportunities when the leading unit dimensions of the vector span non-contiguous dimensions of the memref. Now that the contiguity of the slice is determined correctly, there is a choice how many dimensions of the memref to collapse, ranging from a) the number of vector dimensions after ignoring the leading unit dimensions, up to b) the maximum number of contiguous memref dimensions This patch makes a choice to do minimal memref collapsing. The rationale behind this decision is that this way the least amount of information is discarded. (It follows that in some cases where the patterns used to trigger and collapse some memref dimensions, after this patch the patterns may collapse less dimensions).
1 parent 1a1d9b4 commit bedb959

File tree

4 files changed

+260
-111
lines changed

4 files changed

+260
-111
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,37 +47,41 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
4747
/// on a 2D slice. Otherwise, returns a failure.
4848
FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
4949

50-
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
50+
/// Return true if `vectorType` is a contiguous slice of `memrefType`,
51+
/// in the sense that it can be read/written from/to a contiguous area
52+
/// of the memref.
5153
///
52-
/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
53-
/// checked (the other dims are not relevant). Note that for `vectorType` to be
54-
/// a contiguous slice of `memrefType`, the trailing dims of the latter have
55-
/// to be contiguous - this is checked by looking at the corresponding strides.
54+
/// The leading unit dimensions of the vector type are ignored as they
55+
/// are not relevant to the result. Let N be the number of the vector
56+
/// dimensions after ignoring a leading sequence of unit ones.
5657
///
57-
/// There might be some restriction on the leading dim of `VectorType`:
58+
/// For `vectorType` to be a contiguous slice of `memrefType`
59+
/// a) the N trailing dimensions of `memrefType` must be contiguous, and
60+
/// b) the N-1 trailing dimensions of `vectorType` and `memrefType`
61+
/// must match.
5862
///
59-
/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
60-
/// of `memrefType` then the leading dim of `vectorType` can be
61-
/// arbitrary.
62-
///
63-
/// Ex. 1.1 contiguous slice, perfect match
64-
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
65-
/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
66-
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
67-
///
68-
/// Case 2. If an "internal" dim of `vectorType` does not match the
69-
/// corresponding trailing dim in `memrefType` then the remaining
70-
/// leading dims of `vectorType` have to be 1 (the first non-matching
71-
/// dim can be arbitrary).
63+
/// Examples:
7264
///
73-
/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
74-
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
75-
/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
76-
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
77-
/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
78-
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
79-
/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
80-
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
65+
/// Ex.1 contiguous slice, perfect match
66+
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
67+
/// Ex.2 contiguous slice, the leading dim does not match (2 != 4)
68+
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
69+
/// Ex.3 non-contiguous slice, 2 != 3
70+
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
71+
/// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
72+
/// 2 != 3 (allowed)
73+
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
74+
/// Ex.5. contiguous slice, leading two unit dims of the vector ignored,
75+
/// 2 != 3 (allowed)
76+
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
77+
/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
78+
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
79+
/// Ex.7 contiguous slice, memref needs to be contiguous only in the last
80+
/// dimension
81+
/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
82+
/// Ex.8 non-contiguous slice, memref needs to be contiguous in the last
83+
/// two dimensions, and it isn't
84+
/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
8185
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
8286

8387
/// Returns an iterator for all positions in the leading dimensions of `vType`

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,6 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
581581
}
582582

583583
namespace {
584-
585584
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
586585
/// memref.collapse_shape on the source so that the resulting
587586
/// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -630,7 +629,11 @@ class FlattenContiguousRowMajorTransferReadPattern
630629
if (transferReadOp.getMask())
631630
return failure();
632631

633-
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
632+
// Determine the first memref dimension to collapse - just enough so we can
633+
// read a flattened vector.
634+
int64_t firstDimToCollapse =
635+
sourceType.getRank() -
636+
vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
634637

635638
// 1. Collapse the source memref
636639
Value collapsedSource =
@@ -722,7 +725,11 @@ class FlattenContiguousRowMajorTransferWritePattern
722725
if (transferWriteOp.getMask())
723726
return failure();
724727

725-
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
728+
// Determine the first memref dimension to collapse - just enough so we can
729+
// read a flattened vector.
730+
int64_t firstDimToCollapse =
731+
sourceType.getRank() -
732+
vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
726733

727734
// 1. Collapse the source memref
728735
Value collapsedSource =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
258258
if (vectorType.isScalable())
259259
return false;
260260

261-
ArrayRef<int64_t> vectorShape = vectorType.getShape();
262-
auto vecRank = vectorType.getRank();
261+
// Ignore a leading sequence of adjacent unit dimensions in the vector.
262+
ArrayRef<int64_t> vectorShape =
263+
vectorType.getShape().drop_while([](auto v) { return v == 1; });
264+
auto vecRank = vectorShape.size();
263265

264266
if (!memrefType.areTrailingDimsContiguous(vecRank))
265267
return false;
266268

267-
// Extract the trailing dims and strides of the input memref
269+
// Extract the trailing dims of the input memref
268270
auto memrefShape = memrefType.getShape().take_back(vecRank);
269271

270-
// Compare the dims of `vectorType` against `memrefType` (in reverse).
271-
// In the most basic case, all dims will match.
272-
auto firstNonMatchingDim =
273-
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
274-
memrefShape.rbegin(), memrefShape.rend());
275-
if (firstNonMatchingDim.first == vectorShape.rend())
276-
return true;
277-
278-
// One non-matching dim is still fine, however the remaining leading dims of
279-
// `vectorType` need to be 1.
280-
SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
281-
vectorShape.rend());
282-
283-
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
272+
// Compare the dims of `vectorType` against `memrefType`.
273+
// All of the dimensions, except the first must match.
274+
return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
284275
}
285276

286277
std::optional<StaticTileOffsetRange>

0 commit comments

Comments
 (0)