-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) #73522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) #73522
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesUpdates "flatten vector" patterns to support more cases, namely Ops that %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0] ... :
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8> Currently, this %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] :
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
into memref<120xi8, strided<[1], offset: ?>>
%0 = vector.transfer_read %collapse_shape[%c0] ... :
memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
%1 = vector.shape_cast %0 : vector<4xi8> to vector<1x1x2x2xi8>
Full diff: https://github.com/llvm/llvm-project/pull/73522.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..f04ca9f5c71e0dd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -487,26 +487,76 @@ class TransferWriteDropUnitDimsPattern
} // namespace
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
- ArrayRef<int64_t> targetShape) {
- auto shape = memrefType.getShape();
- SmallVector<int64_t> strides;
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`.
+///
+/// There are two cases:
+///
+/// 1. The trailing dimensions of `memrefType` match the dimensions of
+/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
+/// not matter in this case):
+///
+/// vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+/// vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
+/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
+/// first dim of `vectorType` that does not match can be arbitrary, but the
+/// remaining leading dims have to be 1:
+///
+/// vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+/// vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// In both cases `memrefType` has to be contiguous (this is checked by looking
+/// at strides).
+///
+/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
+/// TODO: Update
+static bool isContiguousSlice(MemRefType memrefType,
+ VectorType vectorType) {
+
+ ArrayRef<int64_t> targetShape = vectorType.getShape();
+ auto targetShapeTrailingDims = targetShape.drop_front(1);
+
+ // Not used
int64_t offset;
+ SmallVector<int64_t> strides;
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
return false;
+
+ // Non-unit stride in the trailing dimension means that this is memref is
+ // not contiguous.
if (strides.back() != 1)
return false;
- strides.pop_back();
+
+ // Do all but the leading dim of `vectorType` and the trailing dims of
+ // `memrefType` match?
+ bool allTrailingDimsMatch = true;
+
+ // The trailing dimension of `memrefType` after collapsing/flattening the
+ // current dim. This will be a product of the leading dims, hence initialising
+ // to 1.
int64_t flatDim = 1;
- for (auto [targetDim, memrefDim, memrefStride] :
- llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+ strides.pop_back();
+ for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
+ targetShapeTrailingDims, memrefType.getShape(), strides))) {
flatDim *= memrefDim;
- if (flatDim != memrefStride || targetDim != memrefDim)
+ // If the memref stride does not match the flattened dim, then this is
+ // memref is not contiguous.
+ if (flatDim != memrefStride)
+ return false;
+
+ // If a non-matching dim was found, then the remaining dims of `VectorType`
+ // should be 1.
+ if (!allTrailingDimsMatch && (targetDim != 1))
return false;
+
+ allTrailingDimsMatch = (targetDim == memrefDim);
}
- return true;
+
+ return allTrailingDimsMatch ? true : (targetShape[0] == 1);
}
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
@@ -568,9 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- if (!hasMatchingInnerContigousShape(
- sourceType,
- vectorType.getShape().take_back(vectorType.getRank() - 1)))
+ if (!isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
@@ -628,9 +676,7 @@ class FlattenContiguousRowMajorTransferWritePattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- if (!hasMatchingInnerContigousShape(
- sourceType,
- vectorType.getShape().take_back(vectorType.getRank() - 1)))
+ if (!isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae62a5ba43d055a..08ce837be93ffd3 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
-func.func @transfer_read_flattenable_with_offset(
+func.func @transfer_read_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
return %v : vector<5x4x3x2xi8>
}
-// CHECK-LABEL: func @transfer_read_flattenable_with_offset
+// CHECK-LABEL: func @transfer_read_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,7 +18,44 @@ func.func @transfer_read_flattenable_with_offset(
// -----
-func.func @transfer_write_flattenable_with_offset(
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable".
+
+func.func @transfer_read_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+ return %v : vector<1x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
+// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
+// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
+func.func @transfer_write_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -26,7 +63,7 @@ func.func @transfer_write_flattenable_with_offset(
return
}
-// CHECK-LABEL: func @transfer_write_flattenable_with_offset
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +72,46 @@ func.func @transfer_write_flattenable_with_offset(
// -----
+func.func @transfer_write_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
+// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @transfer_write_dims_mismatch_non_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
return
}
-// CHECK-LABEL: func @transfer_write_0d
-// CHECK-SAME: %[[ARG:.+]]: memref<i8>
-// CHECK-SAME: %[[VEC:.+]]: vector<i8>
-// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
-// CHECK: return
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// -----
@@ -54,11 +121,8 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
return %0 : vector<i8>
}
-// CHECK-LABEL: func @transfer_read_0d
-// CHECK-SAME: %[[ARG:.+]]: memref<i8>
-// CHECK: %[[CST:.+]] = arith.constant 0 : i8
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
-// CHECK: return %[[READ]]
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// -----
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Updates "flatten vector" patterns to support more cases, namely Ops that read/write vectors with leading unit dims. For example: ```mlir %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0] ... : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8> ``` Currently, this `vector.transfer_read` would not be flattened. With this change, it will be transformed as follows: ```mlir %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> %0 = vector.transfer_read %collapse_shape[%c0] ... : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8> %1 = vector.shape_cast %0 : vector<4xi8> to vector<1x1x2x2xi8> ``` `hasMatchingInnerContigousShape` is generalised and renamed as `isContiguousSlice` to better match the updated functionality. A few test names are updated to better highlight what case is being exercised.
c20b0d8
to
902ccc3
Compare
QQ before I dig into the review: we have some patterns to remove unit dims from xfer ops. Have you tried running those before this flattening step? I think we have a pass in IREE that is applying all these simplifications before trying to flatten. In general, we should aim for removing all these unit dims and avoid the complexity they introduce. |
You are probably referring to ("rank reducing patterns"): llvm-project/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir Lines 197 to 204 in 79b0330
I have tried that and it didn't trigger for my example. But I need to dig deeper to understand why (and am happy to extend that as part of this work). However, note that my actual goal is #73523 (which builds on top of this one). I might be able to enable #73523 by updating "rank reducing patterns" instead, but this change feels beneficial on its own regardless. Most of my changes are comments (to help me understand what's going on) and tests to verify the functionality.
Hm, my example from IREE that I am trying to "fix" (there's a lot going on here and the pattern updated here is just one element of the puzzle): func.func @original(%0: memref<1x1080x1962x2xi32>, %1: memref<1x43x2xi32>, %2: memref<1x1080x1920x2xi32>, %z: index, %y: index, %x: index) {
%cst = arith.constant dense<0> : vector<1x4x2xi32>
%c43 = arith.constant 43 : index
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c60 = arith.constant 60 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%subview = memref.subview %2[0, %z, %y, %z] [1, 60, 64, 2] [1, 1, 1, 1] : memref<1x1080x1920x2xi32> to memref<1x60x64x2xi32, strided<[4147200, 3840, 2, 1], offset: ?>>
%subview_0 = memref.subview %0[0, %z, %y, %z] [1, 60, 106, 2] [1, 1, 1, 1] : memref<1x1080x1962x2xi32> to memref<1x60x106x2xi32, strided<[4237920, 3924, 2, 1], offset: ?>>
scf.for %arg0 = %c0 to %c60 step %c1 {
scf.for %arg1 = %c0 to %c64 step %c4 {
%subview_1 = memref.subview %subview[0, %arg0, %arg1, 0] [1, 1, 4, 2] [1, 1, 1, 1] : memref<1x60x64x2xi32, strided<[4147200, 3840, 2, 1], offset: ?>> to memref<1x1x4x2xi32, strided<[4147200, 3840, 2, 1], offset: ?>>
%6 = scf.for %arg2 = %c0 to %c43 step %c1 iter_args(%arg3 = %cst) -> (vector<1x4x2xi32>) {
%8 = arith.addi %arg2, %arg1 : index
%9 = vector.transfer_read %subview_0[%c0, %arg0, %8, %c0], %c0_i32 {in_bounds = [true, true]} : memref<1x60x106x2xi32, strided<[4237920, 3924, 2, 1], offset: ?>>, vector<4x2xi32>
%10 = vector.transfer_read %1[%c0, %arg2, %c0], %c0_i32 {in_bounds = [true]} : memref<1x43x2xi32>, vector<2xi32>
%11 = vector.broadcast %10 : vector<2xi32> to vector<1x4x2xi32>
%12 = vector.shape_cast %9 : vector<4x2xi32> to vector<8xi32>
%13 = vector.shape_cast %11 : vector<1x4x2xi32> to vector<8xi32>
%14 = arith.muli %12, %13 : vector<8xi32>
%15 = vector.shape_cast %arg3 : vector<1x4x2xi32> to vector<8xi32>
%16 = arith.addi %14, %15 : vector<8xi32>
%17 = vector.shape_cast %16 : vector<8xi32> to vector<1x4x2xi32>
scf.yield %17 : vector<1x4x2xi32>
}
%7 = vector.extract %6[0] : vector<4x2xi32> from vector<1x4x2xi32>
%subview_2 = memref.subview %subview_1[0, 0, 0, 0] [1, 1, 4, 2] [1, 1, 1, 1] : memref<1x1x4x2xi32, strided<[4147200, 3840, 2, 1], offset: ?>> to memref<4x2xi32, affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 + s0)>>
vector.transfer_write %7, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x2xi32>, memref<4x2xi32, affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 + s0)>>
}
}
return
} So the rank reducing patterns are failing here too. TBH, I'm trying to solve multiple issues here. This is an attempt to reduce the problem space.
Agreed, that's part of the plan. But removing unit dims is unlikely to be sufficient the fold away these That's basically where the example above comes from. In general, I know that I will need multiple things to fix this 😂 . |
3109a42
to
9a3c60b
Compare
I've updated the comments - hopefully that helps clarify the logic. There's quite a lot and this feels similar to the situation in the vectoriser where I tried to distinguish between "contiguous" and "gather" loads. I think that we might be missing some abstraction here, but it's not obvious to me what that could be. Not yet. |
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me, approving conditioned on implementation refactoring.
Thanks for improving this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pushing on this!
Final refactor requested by Nicolas
c0bf526
to
d8e998b
Compare
Updates "flatten vector" patterns to support more cases, namely Ops that
read/write vectors with leading unit dims. For example:
Currently, the
vector.transfer_read
above would not be flattened. With thischange, it will be rewritten as follows:
hasMatchingInnerContigousShape
is generalised and renamed asisContiguousSlice
to better match the updated functionality. A fewtest names are updated to better highlight what case is being exercised.