Skip to content

[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n) #95745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 22, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Jun 17, 2024

The main goal of this and subsequent PRs is to unify and categorize
tests in:

  • vector-transfer-flatten.mlir

This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

  1. For consistency with other tests,
    @transfer_read_flattenable_with_dynamic_dims_and_indices is renamed
    as @transfer_read_leading_dynamic_dims. It is also moved near other
    tests for xfer_read, variable names are updated to match other
    xfer_read tests

  2. @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim
    is renamed as @negative_transfer_read_dynamic_dim_to_flatten to
    better highlight that it's a negative test and to contrast it with
    @transfer_read_leading_dynamic_dims (and to emphasise the
    difference between the two).

  3. Similar changes for tests for xfer_write.

  4. Make sure that we consistently use %idx_N (as opposed to %idxN).

Follow-up for #95743 and #95744

@llvmbot
Copy link
Member

llvmbot commented Jun 17, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n)
  • [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n)
  • [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n)

Patch is 35.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95745.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+23-5)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+304-183)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c131fde517f80..4c93d3841bf87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -568,6 +568,7 @@ namespace {
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
 /// the trailing dimension of the vector read is smaller than the provided
 /// bitwidth.
@@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
-        dyn_cast<MemRefType>(collapsedSource.getType());
+        cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
@@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_write has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
 public:
@@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     VectorType vectorType = cast<VectorType>(vector.getType());
     Value source = transferWriteOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+    // 0. Check pre-conditions
     // Contiguity check is valid on tensors only.
     if (!sourceType)
       return failure();
+    // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
@@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
       return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
       return failure();
@@ -697,10 +704,9 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
-    SmallVector<Value> collapsedIndices =
-        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
-                            transferWriteOp.getIndices(), firstDimToCollapse);
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
 
+    // 1. Collapse the source memref
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
@@ -708,11 +714,20 @@ class FlattenContiguousRowMajorTransferWritePattern
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
+    // 2. Generate input args for a new vector.transfer_read that will read
+    // from the collapsed memref.
+    // 2.1. New dim exprs + affine map
     SmallVector<AffineExpr, 1> dimExprs{
         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
 
+    // 2.2 New indices
+    SmallVector<Value> collapsedIndices =
+        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+                            transferWriteOp.getIndices(), firstDimToCollapse);
+
+    // 3. Create new vector.transfer_write that writes to the collapsed memref
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
     Value flatVector =
@@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
         rewriter.create<vector::TransferWriteOp>(
             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+    // 4. Replace the old transfer_write with the new one writing the
+    // collapsed shape
     rewriter.eraseOp(transferWriteOp);
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index d7365d25d21b4..d38cb4a56d722 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,17 +1,23 @@
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+/// [Pattern: FlattenContiguousRowMajorTransferReadPattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_read_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous
-// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
 // CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
@@ -24,11 +30,12 @@ func.func @transfer_read_dims_match_contiguous(
 
 func.func @transfer_read_dims_match_contiguous_empty_stride(
     %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
@@ -47,16 +54,17 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
 // contiguous subset of the memref, so "flattenable".
 
 func.func @transfer_read_dims_mismatch_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
-    return %v : vector<1x1x2x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+  return %v : vector<1x1x2x2xi8>
 }
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
-// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
@@ -70,135 +78,195 @@ func.func @transfer_read_dims_mismatch_contiguous(
 // -----
 
 func.func @transfer_read_dims_mismatch_non_zero_indices(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x43x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+    %idx_1: index,
+    %idx_2: index,
+    %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
 // CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>
 // CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
 // CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
 // CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
-// CHECK:           %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
+// Overall, the source memref is non-contiguous. However, the slice from which
+// the output vector is to be read _is_ contiguous. Hence the flattening works fine.
+
 func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-    %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
-    %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+    %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+    %idx0 : index,
+    %idx1 : index) -> vector<2x2xf32> {
+
   %c0 = arith.constant 0 : index
   %cst_1 = arith.constant 0.000000e+00 : f32
-  %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+  %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+    memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
   return %8 : vector<2x2xf32>
 }
 
-//       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
 // CHECK-LABEL:  func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-//       CHECK:    %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-//       CHECK:    %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK:         %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK:         %[[APPLY:.*]] = affine.apply #[[$MAP]]()
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
 
+/// The leading dynamic shapes don't affect whether this example is flattenable
+/// or not as those dynamic shapes are not candidates for flattening anyway.
+
+func.func @transfer_read_leading_dynamic_dims(
+    %arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
+    %idx_1 : index,
+    %idx_2 : index) -> vector<8x4xi8> {
+
+  %c0_i8 = arith.constant 0 : i8
+  %c0 = arith.constant 0 : index
+  %result = vector.transfer_read %arg[%idx_1, %idx_2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
+  return %result : vector<8x4xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
+// CHECK-SAME:    %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
+// CHECK:       %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:       %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK-SAME:    [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:    {in_bounds = [true]}
+// CHECK-SAME:    : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK:       %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
+// CHECK:       return %[[VEC2D]] : vector<8x4xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
-func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x?x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
+    %idx_1: index,
+    %idx_2: index,
+    %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x?x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
-// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:           vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
-
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous(
-    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+// The vector to be read represents a _non-contiguous_ slice of the input
+// memref.
+
+func.func @transfer_read_dims_mismatch_non_contiguous_slice(
+    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+  return %v : vector<2x1x2x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
-    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+func.func @transfer_read_0d(
+    %arg : memref<i8>) -> vector<i8> {
+
+  %cst = arith.constant 0 : i8
+  %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
+  return %0 : vector<i8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_0d(
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+// Strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_read_non_contiguous_src(
+    %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
 //   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_write_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
-      vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
-    return
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %vec : vector<5x4x...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 17, 2024

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n)
  • [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n)
  • [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n)

Patch is 35.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95745.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+23-5)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+304-183)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c131fde517f80..4c93d3841bf87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -568,6 +568,7 @@ namespace {
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
 /// the trailing dimension of the vector read is smaller than the provided
 /// bitwidth.
@@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
-        dyn_cast<MemRefType>(collapsedSource.getType());
+        cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
@@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_write has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
 public:
@@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     VectorType vectorType = cast<VectorType>(vector.getType());
     Value source = transferWriteOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+    // 0. Check pre-conditions
     // Contiguity check is valid on tensors only.
     if (!sourceType)
       return failure();
+    // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
@@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
       return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
       return failure();
@@ -697,10 +704,9 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
-    SmallVector<Value> collapsedIndices =
-        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
-                            transferWriteOp.getIndices(), firstDimToCollapse);
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
 
+    // 1. Collapse the source memref
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
@@ -708,11 +714,20 @@ class FlattenContiguousRowMajorTransferWritePattern
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
+    // 2. Generate input args for a new vector.transfer_read that will read
+    // from the collapsed memref.
+    // 2.1. New dim exprs + affine map
     SmallVector<AffineExpr, 1> dimExprs{
         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
 
+    // 2.2 New indices
+    SmallVector<Value> collapsedIndices =
+        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+                            transferWriteOp.getIndices(), firstDimToCollapse);
+
+    // 3. Create new vector.transfer_write that writes to the collapsed memref
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
     Value flatVector =
@@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
         rewriter.create<vector::TransferWriteOp>(
             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+    // 4. Replace the old transfer_write with the new one writing the
+    // collapsed shape
     rewriter.eraseOp(transferWriteOp);
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index d7365d25d21b4..d38cb4a56d722 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,17 +1,23 @@
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+/// [Pattern: FlattenContiguousRowMajorTransferReadPattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_read_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous
-// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
 // CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
@@ -24,11 +30,12 @@ func.func @transfer_read_dims_match_contiguous(
 
 func.func @transfer_read_dims_match_contiguous_empty_stride(
     %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
@@ -47,16 +54,17 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
 // contiguous subset of the memref, so "flattenable".
 
 func.func @transfer_read_dims_mismatch_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
-    return %v : vector<1x1x2x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+  return %v : vector<1x1x2x2xi8>
 }
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
-// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
@@ -70,135 +78,195 @@ func.func @transfer_read_dims_mismatch_contiguous(
 // -----
 
 func.func @transfer_read_dims_mismatch_non_zero_indices(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x43x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+    %idx_1: index,
+    %idx_2: index,
+    %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
 // CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>
 // CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
 // CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
 // CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
-// CHECK:           %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
+// Overall, the source memref is non-contiguous. However, the slice from which
+// the output vector is to be read _is_ contiguous. Hence the flattening works fine.
+
 func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-    %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
-    %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+    %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+    %idx0 : index,
+    %idx1 : index) -> vector<2x2xf32> {
+
   %c0 = arith.constant 0 : index
   %cst_1 = arith.constant 0.000000e+00 : f32
-  %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+  %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+    memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
   return %8 : vector<2x2xf32>
 }
 
-//       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
 // CHECK-LABEL:  func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-//       CHECK:    %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-//       CHECK:    %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK:         %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK:         %[[APPLY:.*]] = affine.apply #[[$MAP]]()
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
 
+/// The leading dynamic shapes don't affect whether this example is flattenable
+/// or not as those dynamic shapes are not candidates for flattening anyway.
+
+func.func @transfer_read_leading_dynamic_dims(
+    %arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
+    %idx_1 : index,
+    %idx_2 : index) -> vector<8x4xi8> {
+
+  %c0_i8 = arith.constant 0 : i8
+  %c0 = arith.constant 0 : index
+  %result = vector.transfer_read %arg[%idx_1, %idx_2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
+  return %result : vector<8x4xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
+// CHECK-SAME:    %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
+// CHECK:       %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:       %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK-SAME:    [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:    {in_bounds = [true]}
+// CHECK-SAME:    : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK:       %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
+// CHECK:       return %[[VEC2D]] : vector<8x4xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
-func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x?x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
+    %idx_1: index,
+    %idx_2: index,
+    %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x?x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
-// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:           vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
-
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous(
-    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+// The vector to be read represents a _non-contiguous_ slice of the input
+// memref.
+
+func.func @transfer_read_dims_mismatch_non_contiguous_slice(
+    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+  return %v : vector<2x1x2x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
-    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+func.func @transfer_read_0d(
+    %arg : memref<i8>) -> vector<i8> {
+
+  %cst = arith.constant 0 : i8
+  %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
+  return %0 : vector<i8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_0d(
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+// Strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_read_non_contiguous_src(
+    %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
 //   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_write_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
-      vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
-    return
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %vec : vector<5x4x...
[truncated]

@banach-space banach-space changed the title andrzej/refactor xfer flatten 3 [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n) Jun 17, 2024
@banach-space banach-space force-pushed the andrzej/refactor_xfer_flatten_3 branch from b9ee24d to f39bfc7 Compare June 17, 2024 07:32
@banach-space banach-space requested a review from nujaa June 17, 2024 12:01
Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, Just a few NITs since we are refactoring. 👌


func.func @transfer_read_leading_dynamic_dims(
%arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
%idx_1 : index,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: As part of unifying variable names, transfer_read_dims_mismatch_non_contiguous_non_zero_indices's indices args start from 0 and do not have an _. idx0 : index

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out! In fact, ATM there are

  • 4 instances of idx1, and
  • 12 instances of idx_1

in this file. Let me make sure that we are indeed consistent, but I'll use idx_1 rather than idx1 - the former is already more common :)

Great work noticing this 😅 🙏🏻

// CHECK-128B: memref.collapse_shape

// -----

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diff-ed these tests with the ones added above. No diff but variable names. 👍

@banach-space banach-space force-pushed the andrzej/refactor_xfer_flatten_3 branch from f39bfc7 to 9062edd Compare June 21, 2024 11:50
@banach-space banach-space force-pushed the andrzej/refactor_xfer_flatten_3 branch 2 times, most recently from 1155fe6 to 7413388 Compare July 21, 2024 17:41
@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Jul 21, 2024
@banach-space banach-space force-pushed the andrzej/refactor_xfer_flatten_3 branch from 7413388 to 9219d85 Compare July 22, 2024 07:07
// The input memref has a dynamic trailing shape and hence is not flattened.
// TODO: This case could be supported via memref.dim
/// The leading dynamic shapes don't affect whether this example is flattenable
/// or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, not sure if on purpose to highlight comments in opposition to checks with 3 slashes /// (which I think makes sense). Yet, other comments such as l.108 or l.167 do not implement this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, it wasn't intentional.

For context, some people use different format for comments in tests (e.g. /// instead of //) as apparently their editors handle that better (i.e. use different highlighting for comment and RUN/CHECK lines).

In any case, we should always prioritise consistency - I've just updated these from /// to //.

@nujaa
Copy link
Contributor

nujaa commented Jul 22, 2024

LGTM, thanks.

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`,
   i.e. move it near other tests for xfer_read, unify variable names to
   match other xfer_read tests, highlight what makes this a positive
   test to better contrast it with
   `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`

2. Similar changes for
   `@transfer_write_flattenable_with_dynamic_dims_and_indices`.

Depends on llvm#95743 and llvm#95744

**Only review the top top commit**
…fc) (3/n)

Replace /// with //, improve indentation
@banach-space banach-space force-pushed the andrzej/refactor_xfer_flatten_3 branch from bd7e8d8 to 1d942de Compare July 22, 2024 09:28
@banach-space banach-space merged commit 85e7428 into llvm:main Jul 22, 2024
7 checks passed
@banach-space banach-space deleted the andrzej/refactor_xfer_flatten_3 branch July 22, 2024 17:37
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
)

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir

This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. For consistency with other tests,
   `@transfer_read_flattenable_with_dynamic_dims_and_indices` is renamed
   as `@transfer_read_leading_dynamic_dims`. It is also moved near other
   tests for `xfer_read`, variable names are updated to match other
   `xfer_read` tests

2. `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`
   is renamed as `@negative_transfer_read_dynamic_dim_to_flatten` to
   better highlight that it's a negative test and to contrast it with
   `@transfer_read_leading_dynamic_dims` (and to emphasise the
   difference between the two).

3. Similar changes for tests for `xfer_write`.

4. Make sure that we consistently use `%idx_N` (as opposed to `%idxN`).

Follow-up for #95743 and #95744
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants