Skip to content

[mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) #73522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,39 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
/// on a 2D slice. Otherwise, returns a failure.
FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);

/// Return true if `vectorType` is a contiguous slice of `memrefType`.
///
/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
/// checked (the other dims are not relevant). Note that for `vectorType` to be
/// a contiguous slice of `memrefType`, the trailing dims of the latter have
/// to be contiguous - this is checked by looking at the corresponding strides.
///
/// There might be some restriction on the leading dim of `VectorType`:
///
/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
/// of `memrefType` then the leading dim of `vectorType` can be
/// arbitrary.
///
/// Ex. 1.1 contiguous slice, perfect match
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
///
/// Case 2. If an "internal" dim of `vectorType` does not match the
/// corresponding trailing dim in `memrefType` then the remaining
/// leading dims of `vectorType` have to be 1 (the first non-matching
/// dim can be arbitrary).
///
/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);

} // namespace vector

/// Constructs a permutation map of invariant memref indices to vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,28 +487,6 @@ class TransferWriteDropUnitDimsPattern

} // namespace

/// Return true if the memref type has its inner dimension matching the given
/// shape. Otherwise return false.
static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
ArrayRef<int64_t> targetShape) {
auto shape = memrefType.getShape();
SmallVector<int64_t> strides;
int64_t offset;
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
return false;
if (strides.back() != 1)
return false;
strides.pop_back();
int64_t flatDim = 1;
for (auto [targetDim, memrefDim, memrefStride] :
llvm::reverse(llvm::zip(targetShape, shape, strides))) {
flatDim *= memrefDim;
if (flatDim != memrefStride || targetDim != memrefDim)
return false;
}
return true;
}

/// Creates a memref.collapse_shape collapsing all inner dimensions of the
/// input starting at `firstDimToCollapse`.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
Expand Down Expand Up @@ -568,9 +546,7 @@ class FlattenContiguousRowMajorTransferReadPattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
if (!hasMatchingInnerContigousShape(
sourceType,
vectorType.getShape().take_back(vectorType.getRank() - 1)))
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
Expand Down Expand Up @@ -628,9 +604,7 @@ class FlattenContiguousRowMajorTransferWritePattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
if (!hasMatchingInnerContigousShape(
sourceType,
vectorType.getShape().take_back(vectorType.getRank() - 1)))
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
Expand Down
44 changes: 44 additions & 0 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,47 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
// between parallel, reduction and possibly other cases.
return ratio.has_value();
}

bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
if (vectorType.isScalable())
return false;

ArrayRef<int64_t> vectorShape = vectorType.getShape();
auto vecRank = vectorType.getRank();

// Extract the trailing dims and strides of the input memref
auto memrefShape = memrefType.getShape().take_back(vecRank);
int64_t offset;
SmallVector<int64_t> stridesFull;
if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
return false;
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);

// Cond 1: A contiguous memref will always have a unit trailing stride.
if (strides.back() != 1)
return false;

// Cond 2: Strides of a contiguous memref have to match the flattened dims.
strides = strides.drop_back(1);
SmallVector<int64_t> flattenedDims;
for (size_t i = 1; i < memrefShape.size(); i++)
flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));

if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
return false;

// Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
// In the most basic case, all dims will match.
auto firstNonMatchingDim =
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
memrefShape.rbegin(), memrefShape.rend());
if (firstNonMatchingDim.first == vectorShape.rend())
return true;

// One non-matching dim is still fine, however the remaining leading dims of
// `vectorType` need to be 1.
SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
vectorShape.rend());

return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
}
96 changes: 82 additions & 14 deletions mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s

func.func @transfer_read_flattenable_with_offset(
func.func @transfer_read_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
Expand All @@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
return %v : vector<5x4x3x2xi8>
}

// CHECK-LABEL: func @transfer_read_flattenable_with_offset
// CHECK-LABEL: func @transfer_read_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
Expand All @@ -18,15 +18,53 @@ func.func @transfer_read_flattenable_with_offset(

// -----

func.func @transfer_write_flattenable_with_offset(
// The shape of the memref and the vector don't match, but the vector is a
// contiguous subset of the memref, so "flattenable".

func.func @transfer_read_dims_mismatch_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
return %v : vector<1x1x2x2xi8>
}

// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>

// -----

func.func @transfer_read_dims_mismatch_non_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
return %v : vector<2x1x2x2xi8>
}

// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast

// -----

func.func @transfer_write_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
return
}

// CHECK-LABEL: func @transfer_write_flattenable_with_offset
// CHECK-LABEL: func @transfer_write_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
Expand All @@ -35,16 +73,48 @@ func.func @transfer_write_flattenable_with_offset(

// -----

func.func @transfer_write_dims_mismatch_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
return
}

// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
// CHECK: return
// CHECK: }

// -----

func.func @transfer_write_dims_mismatch_non_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
return
}

// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast

// -----

func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
return
}

// CHECK-LABEL: func @transfer_write_0d
// CHECK-SAME: %[[ARG:.+]]: memref<i8>
// CHECK-SAME: %[[VEC:.+]]: vector<i8>
// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
// CHECK: return
// CHECK-LABEL: func.func @transfer_write_0d
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast

// -----

Expand All @@ -54,11 +124,9 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
return %0 : vector<i8>
}

// CHECK-LABEL: func @transfer_read_0d
// CHECK-SAME: %[[ARG:.+]]: memref<i8>
// CHECK: %[[CST:.+]] = arith.constant 0 : i8
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
// CHECK: return %[[READ]]
// CHECK-LABEL: func.func @transfer_read_0d
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast

// -----

Expand Down