Skip to content

[MLIR] Fix incorrect slice contiguity inference in vector::isContiguousSlice #142422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 23, 2025

Conversation

momchil-velikov
Copy link
Collaborator

@momchil-velikov momchil-velikov commented Jun 2, 2025

Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., vector<1x1x...x1xd0xd1x...xdn-1xT>. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank.

This affects how FlattenContiguousRowMajorTransfer{Read,Write}Pattern flattens transfer_read and transfer_write ops.

The patterns used to collapse a number of dimensions equal to the vector rank which missed some opportunities when the leading unit dimensions of the vector span non-contiguous dimensions of the memref.
Now that the contiguity of the slice is determined correctly, there is a choice how many dimensions of the
memref to collapse, ranging from
a) the number of vector dimensions after ignoring the leading unit dimensions, up to
b) the maximum number of contiguous memref dimensions

This patch makes a choice to do minimal memref collapsing. The rationale behind this decision is that
this way the least amount of information is discarded.

(It follows that in some cases where the patterns used to trigger and collapse some memref dimensions, after this patch the patterns may collapse less dimensions).

@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Momchil Velikov (momchil-velikov)

Changes

Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., vector&lt;1x1x...x1xd0xd1x...xdn-1xT&gt;. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank.

This affects how FlattenContiguousRowMajorTransfer{Read,Write}Pattern flattens transfer_read and transfer_write ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized.

This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous.


Full diff: https://github.com/llvm/llvm-project/pull/142422.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/IndexingUtils.h (+2-1)
  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+28-26)
  • (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+6-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+8-17)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+78-30)
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 99218f491ddef..a1c8ec2db056a 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -40,7 +40,8 @@ class ArrayAttr;
 /// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
 ///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
 ///
-/// `sizes` elements are asserted to be non-negative.
+/// `sizes` element `s0` is asserted to be kDynamic or non-negative.
+/// `sizes` elements `s1` to `sn` are asserted to be non-negative.
 ///
 /// Return an empty vector if `sizes` is empty.
 SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 6609b28d77b6c..ed06d7a029494 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -49,35 +49,37 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
 
 /// Return true if `vectorType` is a contiguous slice of `memrefType`.
 ///
-/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
-/// checked (the other dims are not relevant). Note that for `vectorType` to be
-/// a contiguous slice of `memrefType`, the trailing dims of the latter have
-/// to be contiguous - this is checked by looking at the corresponding strides.
+/// The leading unit dimensions of the vector type are ignored as they
+/// are not relevant to the result. Let N be the number of the vector
+/// dimensions after ignoring a leading sequence of unit ones.
 ///
-/// There might be some restriction on the leading dim of `VectorType`:
+/// For `vectorType` to be a contiguous slice of `memrefType`
+///   a) the N trailing dimensions of the latter must be contiguous, and
+///   b) the trailing N dimensions of `vectorType` and `memrefType`,
+///      except the first of them, must match.
 ///
-/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
-///         of `memrefType` then the leading dim of `vectorType` can be
-///         arbitrary.
-///
-///        Ex. 1.1 contiguous slice, perfect match
-///          vector<4x3x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
-///          vector<2x3x2xi32> from memref<5x4x3x2xi32>
-///
-/// Case 2. If an "internal" dim of `vectorType` does not match the
-///         corresponding trailing dim in `memrefType` then the remaining
-///         leading dims of `vectorType` have to be 1 (the first non-matching
-///         dim can be arbitrary).
+/// Examples:
 ///
-///        Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
-///          vector<2x2x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 2.2  contiguous slice, 2 != 3 and the leading dim == <1>
-///          vector<1x2x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
-///          vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
-///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+///   Ex.1 contiguous slice, perfect match
+///     vector<4x3x2xi32> from memref<5x4x3x2xi32>
+///   Ex.2 contiguous slice, the leading dim does not match (2 != 4)
+///     vector<2x3x2xi32> from memref<5x4x3x2xi32>
+///   Ex.3 non-contiguous slice, 2 != 3
+///     vector<2x2x2xi32> from memref<5x4x3x2xi32>
+///   Ex.4 contiguous slice, leading unit dimension of the vector ignored,
+///        2 != 3 (allowed)
+///     vector<1x2x2xi32> from memref<5x4x3x2xi32>
+///   Ex.5. contiguous slice, leasing two unit dims of the vector ignored,
+///         2 != 3 (allowed)
+///     vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+///   Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
+///     vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+///   Ex.7 contiguous slice, memref needs to be contiguous only on the last
+///        dimension
+///     vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
+///   Ex.8 non-contiguous slice, memref needs to be contiguous one the last
+///        two dimensions, and it isn't
+///     vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 
 /// Returns an iterator for all positions in the leading dimensions of `vType`
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 8de77e2c3cb08..bb719d46a215a 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -69,8 +69,10 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
 //===----------------------------------------------------------------------===//
 
 SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
-  assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
-         "sizes must be nonnegative");
+  assert(sizes.size() == 0 ||
+         ((sizes[0] == ShapedType::kDynamic || sizes[0] >= 0) &&
+          llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
+             "sizes must be nonnegative");
   int64_t unit = 1;
   return ::computeSuffixProductImpl(sizes, unit);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 7dbb7a334fe62..709716365f825 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -630,7 +630,9 @@ class FlattenContiguousRowMajorTransferReadPattern
     if (transferReadOp.getMask())
       return failure();
 
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+    // Determinine the first memref dimension to collapse
+    int64_t firstDimToCollapse =
+        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
 
     // 1. Collapse the source memref
     Value collapsedSource =
@@ -722,7 +724,9 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+    // Determinine the first memref dimension to collapse
+    int64_t firstDimToCollapse =
+        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
 
     // 1. Collapse the source memref
     Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..5f8f3b6adf9db 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   if (vectorType.isScalable())
     return false;
 
-  ArrayRef<int64_t> vectorShape = vectorType.getShape();
-  auto vecRank = vectorType.getRank();
+  // Ignore a leading contiguous sequence of unit dimensions in the vector.
+  ArrayRef<int64_t> vectorShape =
+      vectorType.getShape().drop_while([](auto v) { return v == 1; });
+  auto vecRank = vectorShape.size();
 
   if (!memrefType.areTrailingDimsContiguous(vecRank))
     return false;
 
-  // Extract the trailing dims and strides of the input memref
+  // Extract the trailing dims of the input memref
   auto memrefShape = memrefType.getShape().take_back(vecRank);
 
-  // Compare the dims of `vectorType` against `memrefType` (in reverse).
-  // In the most basic case, all dims will match.
-  auto firstNonMatchingDim =
-      std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
-                    memrefShape.rbegin(), memrefShape.rend());
-  if (firstNonMatchingDim.first == vectorShape.rend())
-    return true;
-
-  // One non-matching dim is still fine, however the remaining leading dims of
-  // `vectorType` need to be 1.
-  SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
-                                   vectorShape.rend());
-
-  return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
+  // Compare the dims of `vectorType` against `memrefType`.
+  // All of the dimensions, except the first must match.
+  return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
 }
 
 std::optional<StaticTileOffsetRange>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 5b2f2ab1f2cef..594f7ce371347 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -116,10 +116,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
 // CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>
 // CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
-// CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
-// CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0, 1, 2, 3]]
+// CHECK-SAME:         : memref<1x43x4x6xi32> into memref<1032xi32>
 // CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -170,16 +171,18 @@ func.func @transfer_read_leading_dynamic_dims(
   return %res : vector<8x4xi8>
 }
 
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
 // CHECK-LABEL: func @transfer_read_leading_dynamic_dims
 // CHECK-SAME:    %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
 // CHECK:         %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK:         %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
 // CHECK:         %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME:    [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:    [%[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I8]]
 // CHECK-SAME:    {in_bounds = [true]}
-// CHECK-SAME:      : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK-SAME:      : memref<?x?xi8, {{.+}}>, vector<32xi8>
 // CHECK:         %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
 // CHECK:         return %[[RES]] : vector<8x4xi8>
 
@@ -210,13 +213,12 @@ func.func @transfer_read_dynamic_dim_to_flatten(
 // CHECK-SAME:    %[[IDX_2:arg1]]
 // CHECK-SAME:    %[[MEM:arg2]]
 // CHECK:              %[[C0_I32:.*]] = arith.constant 0 : i32
-// CHECK:              %[[C0:.*]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
-// CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME-LITERAL:   [[0, 1, 2, 3]]
+// CHECK-SAME:           memref<1x?x4x6xi32> into memref<?xi32>
 // CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
-// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[COLLAPSED_IDX]]],
+// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<?xi32>, vector<12xi32>
 // CHECK:              %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
 // CHECK:              return %[[RESULT]] : vector<1x2x6xi32>
 
@@ -397,11 +399,10 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
 // CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>,
 // CHECK-SAME:      %[[VEC:.*]]: vector<1x2x6xi32>) {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
-// CHECK-DAG:       %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK-DAG:       %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<1x43x4x6xi32> into memref<1032xi32>
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:           vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
+// CHECK:           vector.transfer_write %[[SC]], %[[CS]][%[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1032xi32>
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -449,16 +450,18 @@ func.func @transfer_write_leading_dynamic_dims(
   return
 }
 
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
 // CHECK-LABEL: func @transfer_write_leading_dynamic_dims
 // CHECK-SAME:    %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
-// CHECK:         %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[ARG3]]]
 // CHECK:         %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME:      [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME:      [%[[ARG2]], %[[COLLAPSED_IDX]]]
 // CHECK-SAME:      {in_bounds = [true]}
-// CHECK-SAME:      : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+// CHECK-SAME:      : vector<32xi8>, memref<?x?xi8, {{.+}}>
 
 // CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
 //       CHECK-128B:   memref.collapse_shape
@@ -488,14 +491,13 @@ func.func @transfer_write_dynamic_to_flatten(
 // CHECK-SAME:    %[[VEC:arg2]]: vector<1x2x6xi32>
 // CHECK-SAME:    %[[MEM:arg3]]: memref<1x?x4x6xi32>
 
-// CHECK:              %[[C0:.*]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
-// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME-LITERAL:   [[0, 1, 2, 3]]
+// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<?xi32>
 // CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:              %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
-// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
+// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[COLLAPSED_IDX]]]
+// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<?xi32>
 
 // CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -573,8 +575,12 @@ func.func @negative_out_of_bound_transfer_read(
     memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
   return %res : vector<5x4x3x2xi8>
 }
-// CHECK:     func.func @negative_out_of_bound_transfer_read
-// CHECK-NOT:   memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
+// CHECK-NOT:     memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
@@ -585,5 +591,47 @@ func.func @negative_out_of_bound_transfer_write(
     vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
   return
 }
-// CHECK:     func.func @negative_out_of_bound_transfer_write
-// CHECK-NOT:   memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
+// CHECK-NOT:     memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_contig_slice(
+    %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+    vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+  return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_contig_slice
+// CHECK-SAME:    %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
+// CHECK-SAME:    %[[VEC:.+]]: vector<1x1x8xi32>
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
+// CHECK:       vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
+// CHECK-SAME:    : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+
+// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_discontig_slice(
+    %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+  return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_discontig_slice
+// CHECK-NOT:    vector.shape_cast
+
+// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
+//   CHECK-128B-NOT:   vector.shape_cast
+

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig branch from 1d025f3 to 1fe6866 Compare June 3, 2025 17:06
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from d9a470e to efac666 Compare June 3, 2025 17:06
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig branch from 1fe6866 to c97c6c1 Compare June 4, 2025 11:13
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch 2 times, most recently from c981866 to eb31b1e Compare June 5, 2025 11:25
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig branch 2 times, most recently from e3b6e21 to ab4681a Compare June 5, 2025 16:50
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from eb31b1e to 21266b0 Compare June 5, 2025 16:50
Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Other than my question about the change to first dimension of the memref that gets collapsed, my comments are all quite minor.

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from 21266b0 to f60e73d Compare June 6, 2025 13:02
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig branch from ab4681a to 1cee446 Compare June 6, 2025 13:02
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from f60e73d to 1210d59 Compare June 6, 2025 15:27
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from 1210d59 to 3b17c94 Compare June 9, 2025 16:42
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig branch from d3458a6 to 20b495d Compare June 9, 2025 16:42
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % some minor suggestions, thanks!

Could you add a short note in the summary about the source of changes in the tests? In particular the fact that this PR reduces the number of the collapsed dimensions? That wasn't clear to me.

@momchil-velikov
Copy link
Collaborator Author

Commit message updated.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; thanks!

Base automatically changed from users/momchil-velikov/memref-contig to main June 23, 2025 08:28
…ousSlice`

Previously, slices were sometimes marked as non-contiguous when
they were actually contiguous. This occurred when the vector type had
leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``.
In such cases, only the trailing n dimensions of the memref need to be
contiguous, not the entire vector rank.

This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern`
flattens `transfer_read` and `transfer_write`` ops. The pattern used
to collapse a number of dimensions equal the vector rank, which
may be is incorrect when leading dimensions are unit-sized.

This patch fixes the issue by collapsing only as many trailing memref
dimensions as are actually contiguous.
Even though it's possible in principle, the affected patterns need
strides to be determined statically.
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from 358f5da to 1740d45 Compare June 23, 2025 09:10
@momchil-velikov momchil-velikov merged commit c7d9b6e into main Jun 23, 2025
7 checks passed
@momchil-velikov momchil-velikov deleted the users/momchil-velikov/memref-contig-slice branch June 23, 2025 13:12
miguelcsx pushed a commit to miguelcsx/llvm-project that referenced this pull request Jun 23, 2025
…ousSlice` (llvm#142422)

Previously, slices were sometimes marked as non-contiguous when they
were actually contiguous. This occurred when the vector type had leading
unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such
cases, only the trailing `n` dimensions of the memref need to be
contiguous, not the entire vector rank.

This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern`
flattens `transfer_read` and `transfer_write` ops.

The patterns used to collapse a number of dimensions equal to the vector
rank which missed some opportunities when the leading unit dimensions of
the vector span non-contiguous dimensions of the memref.
Now that the contiguity of the slice is determined correctly, there is a
choice how many dimensions of the
memref to collapse, ranging from 
a) the number of vector dimensions after ignoring the leading unit
dimensions, up to
  b) the maximum number of contiguous memref dimensions

This patch makes a choice to do minimal memref collapsing. The rationale
behind this decision is that
this way the least amount of information is discarded.

(It follows that in some cases where the patterns used to trigger and
collapse some memref dimensions, after this patch the patterns may
collapse less dimensions).
Jaddyen pushed a commit to Jaddyen/llvm-project that referenced this pull request Jun 23, 2025
…ousSlice` (llvm#142422)

Previously, slices were sometimes marked as non-contiguous when they
were actually contiguous. This occurred when the vector type had leading
unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such
cases, only the trailing `n` dimensions of the memref need to be
contiguous, not the entire vector rank.

This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern`
flattens `transfer_read` and `transfer_write` ops.

The patterns used to collapse a number of dimensions equal to the vector
rank which missed some opportunities when the leading unit dimensions of
the vector span non-contiguous dimensions of the memref.
Now that the contiguity of the slice is determined correctly, there is a
choice how many dimensions of the
memref to collapse, ranging from 
a) the number of vector dimensions after ignoring the leading unit
dimensions, up to
  b) the maximum number of contiguous memref dimensions

This patch makes a choice to do minimal memref collapsing. The rationale
behind this decision is that
this way the least amount of information is discarded.

(It follows that in some cases where the patterns used to trigger and
collapse some memref dimensions, after this patch the patterns may
collapse less dimensions).
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
…ousSlice` (llvm#142422)

Previously, slices were sometimes marked as non-contiguous when they
were actually contiguous. This occurred when the vector type had leading
unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such
cases, only the trailing `n` dimensions of the memref need to be
contiguous, not the entire vector rank.

This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern`
flattens `transfer_read` and `transfer_write` ops.

The patterns used to collapse a number of dimensions equal to the vector
rank which missed some opportunities when the leading unit dimensions of
the vector span non-contiguous dimensions of the memref.
Now that the contiguity of the slice is determined correctly, there is a
choice how many dimensions of the
memref to collapse, ranging from 
a) the number of vector dimensions after ignoring the leading unit
dimensions, up to
  b) the maximum number of contiguous memref dimensions

This patch makes a choice to do minimal memref collapsing. The rationale
behind this decision is that
this way the least amount of information is discarded.

(It follows that in some cases where the patterns used to trigger and
collapse some memref dimensions, after this patch the patterns may
collapse less dimensions).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants