-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Fix incorrect slice contiguity inference in vector::isContiguousSlice
#142422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Fix incorrect slice contiguity inference in vector::isContiguousSlice
#142422
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Momchil Velikov (momchil-velikov) ChangesPreviously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., This affects how This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. Full diff: https://github.com/llvm/llvm-project/pull/142422.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 99218f491ddef..a1c8ec2db056a 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -40,7 +40,8 @@ class ArrayAttr;
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
///
-/// `sizes` elements are asserted to be non-negative.
+/// `sizes` element `s0` is asserted to be kDynamic or non-negative.
+/// `sizes` elements `s1` to `sn` are asserted to be non-negative.
///
/// Return an empty vector if `sizes` is empty.
SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 6609b28d77b6c..ed06d7a029494 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -49,35 +49,37 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
///
-/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
-/// checked (the other dims are not relevant). Note that for `vectorType` to be
-/// a contiguous slice of `memrefType`, the trailing dims of the latter have
-/// to be contiguous - this is checked by looking at the corresponding strides.
+/// The leading unit dimensions of the vector type are ignored as they
+/// are not relevant to the result. Let N be the number of the vector
+/// dimensions after ignoring a leading sequence of unit ones.
///
-/// There might be some restriction on the leading dim of `VectorType`:
+/// For `vectorType` to be a contiguous slice of `memrefType`
+/// a) the N trailing dimensions of the latter must be contiguous, and
+/// b) the trailing N dimensions of `vectorType` and `memrefType`,
+/// except the first of them, must match.
///
-/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
-/// of `memrefType` then the leading dim of `vectorType` can be
-/// arbitrary.
-///
-/// Ex. 1.1 contiguous slice, perfect match
-/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
-/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
-///
-/// Case 2. If an "internal" dim of `vectorType` does not match the
-/// corresponding trailing dim in `memrefType` then the remaining
-/// leading dims of `vectorType` have to be 1 (the first non-matching
-/// dim can be arbitrary).
+/// Examples:
///
-/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
-/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
-/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
-/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
-/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+/// Ex.1 contiguous slice, perfect match
+/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
+/// Ex.2 contiguous slice, the leading dim does not match (2 != 4)
+/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
+/// Ex.3 non-contiguous slice, 2 != 3
+/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
+/// 2 != 3 (allowed)
+/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored,
+/// 2 != 3 (allowed)
+/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
+/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+/// Ex.7 contiguous slice, memref needs to be contiguous only on the last
+/// dimension
+/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
+/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last
+/// two dimensions, and it isn't
+/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
/// Returns an iterator for all positions in the leading dimensions of `vType`
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 8de77e2c3cb08..bb719d46a215a 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -69,8 +69,10 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
//===----------------------------------------------------------------------===//
SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
- assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
- "sizes must be nonnegative");
+ assert(sizes.size() == 0 ||
+ ((sizes[0] == ShapedType::kDynamic || sizes[0] >= 0) &&
+ llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
+ "sizes must be nonnegative");
int64_t unit = 1;
return ::computeSuffixProductImpl(sizes, unit);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 7dbb7a334fe62..709716365f825 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -630,7 +630,9 @@ class FlattenContiguousRowMajorTransferReadPattern
if (transferReadOp.getMask())
return failure();
- int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+ // Determinine the first memref dimension to collapse
+ int64_t firstDimToCollapse =
+ sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
// 1. Collapse the source memref
Value collapsedSource =
@@ -722,7 +724,9 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();
- int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+ // Determinine the first memref dimension to collapse
+ int64_t firstDimToCollapse =
+ sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
// 1. Collapse the source memref
Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..5f8f3b6adf9db 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
if (vectorType.isScalable())
return false;
- ArrayRef<int64_t> vectorShape = vectorType.getShape();
- auto vecRank = vectorType.getRank();
+ // Ignore a leading contiguous sequence of unit dimensions in the vector.
+ ArrayRef<int64_t> vectorShape =
+ vectorType.getShape().drop_while([](auto v) { return v == 1; });
+ auto vecRank = vectorShape.size();
if (!memrefType.areTrailingDimsContiguous(vecRank))
return false;
- // Extract the trailing dims and strides of the input memref
+ // Extract the trailing dims of the input memref
auto memrefShape = memrefType.getShape().take_back(vecRank);
- // Compare the dims of `vectorType` against `memrefType` (in reverse).
- // In the most basic case, all dims will match.
- auto firstNonMatchingDim =
- std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
- memrefShape.rbegin(), memrefShape.rend());
- if (firstNonMatchingDim.first == vectorShape.rend())
- return true;
-
- // One non-matching dim is still fine, however the remaining leading dims of
- // `vectorType` need to be 1.
- SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
- vectorShape.rend());
-
- return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
+ // Compare the dims of `vectorType` against `memrefType`.
+ // All of the dimensions, except the first must match.
+ return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
}
std::optional<StaticTileOffsetRange>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 5b2f2ab1f2cef..594f7ce371347 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -116,10 +116,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
-// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0, 1, 2, 3]]
+// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1032xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
@@ -170,16 +171,18 @@ func.func @transfer_read_leading_dynamic_dims(
return %res : vector<8x4xi8>
}
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME: [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME: [%[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I8]]
// CHECK-SAME: {in_bounds = [true]}
-// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK-SAME: : memref<?x?xi8, {{.+}}>, vector<32xi8>
// CHECK: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
// CHECK: return %[[RES]] : vector<8x4xi8>
@@ -210,13 +213,12 @@ func.func @transfer_read_dynamic_dim_to_flatten(
// CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]]
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
-// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME-LITERAL: [[0, 1, 2, 3]]
+// CHECK-SAME: memref<1x?x4x6xi32> into memref<?xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
-// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[COLLAPSED_IDX]]],
+// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<?xi32>, vector<12xi32>
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
@@ -397,11 +399,10 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>,
// CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
-// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<1x43x4x6xi32> into memref<1032xi32>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
+// CHECK: vector.transfer_write %[[SC]], %[[CS]][%[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1032xi32>
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
@@ -449,16 +450,18 @@ func.func @transfer_write_leading_dynamic_dims(
return
}
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
// CHECK-SAME: %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[ARG3]]]
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8>
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME: [%[[ARG2]], %[[COLLAPSED_IDX]]]
// CHECK-SAME: {in_bounds = [true]}
-// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+// CHECK-SAME: : vector<32xi8>, memref<?x?xi8, {{.+}}>
// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
// CHECK-128B: memref.collapse_shape
@@ -488,14 +491,13 @@ func.func @transfer_write_dynamic_to_flatten(
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
-// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME-LITERAL: [[0, 1, 2, 3]]
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<?xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
-// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[COLLAPSED_IDX]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<?xi32>
// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
@@ -573,8 +575,12 @@ func.func @negative_out_of_bound_transfer_read(
memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
return %res : vector<5x4x3x2xi8>
}
-// CHECK: func.func @negative_out_of_bound_transfer_read
-// CHECK-NOT: memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
+// CHECK-NOT: memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
// -----
@@ -585,5 +591,47 @@ func.func @negative_out_of_bound_transfer_write(
vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
return
}
-// CHECK: func.func @negative_out_of_bound_transfer_write
-// CHECK-NOT: memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
+// CHECK-NOT: memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_contig_slice(
+ %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+ return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_contig_slice
+// CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
+// CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
+// CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+
+// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_discontig_slice(
+ %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+ return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_discontig_slice
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
+// CHECK-128B-NOT: vector.shape_cast
+
|
1d025f3
to
1fe6866
Compare
d9a470e
to
efac666
Compare
1fe6866
to
c97c6c1
Compare
c981866
to
eb31b1e
Compare
e3b6e21
to
ab4681a
Compare
eb31b1e
to
21266b0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Other than my question about the change to first dimension of the memref that gets collapsed, my comments are all quite minor.
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
21266b0
to
f60e73d
Compare
ab4681a
to
1cee446
Compare
f60e73d
to
1210d59
Compare
1210d59
to
3b17c94
Compare
d3458a6
to
20b495d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % some minor suggestions, thanks!
Could you add a short note in the summary about the source of changes in the tests? In particular the fact that this PR reduces the number of the collapsed dimensions? That wasn't clear to me.
Commit message updated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; thanks!
…ousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous.
Even though it's possible in principle, the affected patterns need strides to be determined statically.
358f5da
to
1740d45
Compare
…ousSlice` (llvm#142422) Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such cases, only the trailing `n` dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write` ops. The patterns used to collapse a number of dimensions equal to the vector rank which missed some opportunities when the leading unit dimensions of the vector span non-contiguous dimensions of the memref. Now that the contiguity of the slice is determined correctly, there is a choice how many dimensions of the memref to collapse, ranging from a) the number of vector dimensions after ignoring the leading unit dimensions, up to b) the maximum number of contiguous memref dimensions This patch makes a choice to do minimal memref collapsing. The rationale behind this decision is that this way the least amount of information is discarded. (It follows that in some cases where the patterns used to trigger and collapse some memref dimensions, after this patch the patterns may collapse less dimensions).
…ousSlice` (llvm#142422) Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such cases, only the trailing `n` dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write` ops. The patterns used to collapse a number of dimensions equal to the vector rank which missed some opportunities when the leading unit dimensions of the vector span non-contiguous dimensions of the memref. Now that the contiguity of the slice is determined correctly, there is a choice how many dimensions of the memref to collapse, ranging from a) the number of vector dimensions after ignoring the leading unit dimensions, up to b) the maximum number of contiguous memref dimensions This patch makes a choice to do minimal memref collapsing. The rationale behind this decision is that this way the least amount of information is discarded. (It follows that in some cases where the patterns used to trigger and collapse some memref dimensions, after this patch the patterns may collapse less dimensions).
…ousSlice` (llvm#142422) Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such cases, only the trailing `n` dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write` ops. The patterns used to collapse a number of dimensions equal to the vector rank which missed some opportunities when the leading unit dimensions of the vector span non-contiguous dimensions of the memref. Now that the contiguity of the slice is determined correctly, there is a choice how many dimensions of the memref to collapse, ranging from a) the number of vector dimensions after ignoring the leading unit dimensions, up to b) the maximum number of contiguous memref dimensions This patch makes a choice to do minimal memref collapsing. The rationale behind this decision is that this way the least amount of information is discarded. (It follows that in some cases where the patterns used to trigger and collapse some memref dimensions, after this patch the patterns may collapse less dimensions).
Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g.,
vector<1x1x...x1xd0xd1x...xdn-1xT>
. In such cases, only the trailingn
dimensions of the memref need to be contiguous, not the entire vector rank.This affects how
FlattenContiguousRowMajorTransfer{Read,Write}Pattern
flattenstransfer_read
andtransfer_write
ops.The patterns used to collapse a number of dimensions equal to the vector rank which missed some opportunities when the leading unit dimensions of the vector span non-contiguous dimensions of the memref.
Now that the contiguity of the slice is determined correctly, there is a choice how many dimensions of the
memref to collapse, ranging from
a) the number of vector dimensions after ignoring the leading unit dimensions, up to
b) the maximum number of contiguous memref dimensions
This patch makes a choice to do minimal memref collapsing. The rationale behind this decision is that
this way the least amount of information is discarded.
(It follows that in some cases where the patterns used to trigger and collapse some memref dimensions, after this patch the patterns may collapse less dimensions).