Skip to content

[MLIR] Fix incorrect slice contiguity inference in vector::isContiguousSlice #142422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 23, 2025
58 changes: 31 additions & 27 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,37 +47,41 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
/// on a 2D slice. Otherwise, returns a failure.
FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);

/// Return true if `vectorType` is a contiguous slice of `memrefType`.
/// Return true if `vectorType` is a contiguous slice of `memrefType`,
/// in the sense that it can be read/written from/to a contiguous area
/// of the memref.
///
/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
/// checked (the other dims are not relevant). Note that for `vectorType` to be
/// a contiguous slice of `memrefType`, the trailing dims of the latter have
/// to be contiguous - this is checked by looking at the corresponding strides.
/// The leading unit dimensions of the vector type are ignored as they
/// are not relevant to the result. Let N be the number of the vector
/// dimensions after ignoring a leading sequence of unit ones.
///
/// There might be some restriction on the leading dim of `VectorType`:
/// For `vectorType` to be a contiguous slice of `memrefType`
/// a) the N trailing dimensions of `memrefType` must be contiguous, and
/// b) the N-1 trailing dimensions of `vectorType` and `memrefType`
/// must match.
///
/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
/// of `memrefType` then the leading dim of `vectorType` can be
/// arbitrary.
///
/// Ex. 1.1 contiguous slice, perfect match
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
///
/// Case 2. If an "internal" dim of `vectorType` does not match the
/// corresponding trailing dim in `memrefType` then the remaining
/// leading dims of `vectorType` have to be 1 (the first non-matching
/// dim can be arbitrary).
/// Examples:
///
/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
/// Ex.1 contiguous slice, perfect match
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
/// Ex.2 contiguous slice, the leading dim does not match (2 != 4)
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
/// Ex.3 non-contiguous slice, 2 != 3
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
/// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
/// 2 != 3 (allowed)
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex.5. contiguous slice, leading two unit dims of the vector ignored,
/// 2 != 3 (allowed)
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
/// Ex.7 contiguous slice, memref needs to be contiguous only in the last
/// dimension
/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
/// Ex.8 non-contiguous slice, memref needs to be contiguous in the last
/// two dimensions, and it isn't
/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);

/// Returns an iterator for all positions in the leading dimensions of `vType`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
}

namespace {

/// Rewrites contiguous row-major vector.transfer_read ops by inserting
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
Expand Down Expand Up @@ -630,7 +629,11 @@ class FlattenContiguousRowMajorTransferReadPattern
if (transferReadOp.getMask())
return failure();

int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// Determine the first memref dimension to collapse - just enough so we can
// read a flattened vector.
int64_t firstDimToCollapse =
sourceType.getRank() -
vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();

// 1. Collapse the source memref
Value collapsedSource =
Expand Down Expand Up @@ -722,7 +725,11 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();

int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// Determine the first memref dimension to collapse - just enough so we can
// read a flattened vector.
int64_t firstDimToCollapse =
sourceType.getRank() -
vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();

// 1. Collapse the source memref
Value collapsedSource =
Expand Down
25 changes: 8 additions & 17 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
if (vectorType.isScalable())
return false;

ArrayRef<int64_t> vectorShape = vectorType.getShape();
auto vecRank = vectorType.getRank();
// Ignore a leading sequence of adjacent unit dimensions in the vector.
ArrayRef<int64_t> vectorShape =
vectorType.getShape().drop_while([](auto v) { return v == 1; });
auto vecRank = vectorShape.size();

if (!memrefType.areTrailingDimsContiguous(vecRank))
return false;

// Extract the trailing dims and strides of the input memref
// Extract the trailing dims of the input memref
auto memrefShape = memrefType.getShape().take_back(vecRank);

// Compare the dims of `vectorType` against `memrefType` (in reverse).
// In the most basic case, all dims will match.
auto firstNonMatchingDim =
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
memrefShape.rbegin(), memrefShape.rend());
if (firstNonMatchingDim.first == vectorShape.rend())
return true;

// One non-matching dim is still fine, however the remaining leading dims of
// `vectorType` need to be 1.
SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
vectorShape.rend());

return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
// Compare the dims of `vectorType` against `memrefType`.
// All of the dimensions, except the first must match.
return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
}

std::optional<StaticTileOffsetRange>
Expand Down
Loading
Loading