Skip to content

[mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) #73523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Nov 27, 2023

Updates patterns for flattening vector.transfer_read by relaxing the
requirement that the "collapsed" indices are all zero. This enables
collapsing cases like this one:

  %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>

Previously only the following case would be consider for collapsing:

  %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>

Also adds some new comments and renames the firstContiguousInnerDim parameter
as firstDimToCollapse (the latter better matches the actual meaning).

Similar updates for vector.transfer_write will be implemented in a
follow-up patch.

@banach-space banach-space changed the title andrzej/extend collapse pattern [mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) Nov 27, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Updates patterns for flattening vector.transfer_read by relaxing the
requirement that the "collapsed" indices are all zero. This enables
collapsing cases like this one:

  %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... :
    memref&lt;1x43x4x6xi32&gt;, vector&lt;1x2x6xi32&gt;

Previously only the following case would be consider for collapsing:

  %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... :
    memref&lt;1x43x4x6xi32&gt;, vector&lt;1x2x6xi32&gt;

The pattern itself, FlattenContiguousRowMajorTransferReadPattern, was
a bit refactored too:

  • added comments,
  • renamed firstContiguousInnerDim as firstDimToCollapse (the
    latter better matches the meaning and is already consistently used
    in various helper methods that use it),

Similar update for vector.transfer_write will be implemented in a
follow-up patch.


Full diff: https://github.com/llvm/llvm-project/pull/73523.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+125-27)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+110-14)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..951a378b84cf0e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -487,26 +487,76 @@ class TransferWriteDropUnitDimsPattern
 
 } // namespace
 
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
-                                              ArrayRef<int64_t> targetShape) {
-  auto shape = memrefType.getShape();
-  SmallVector<int64_t> strides;
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`.
+///
+/// There are two cases:
+///
+/// 1. The trailing dimensions of `memrefType` match the dimensions of
+/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
+/// not matter in this case):
+///
+///   vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+///   vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
+/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
+/// first dim of `vectorType` that does not match can be arbitrary, but the
+/// remaining leading dims have to be 1:
+///
+///   vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+///   vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// In both cases `memrefType` has to be contiguous (this is checked by looking
+/// at strides).
+///
+/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
+/// TODO: Update
+static bool isContiguousSlice(MemRefType memrefType,
+                              VectorType vectorType) {
+
+  ArrayRef<int64_t> targetShape = vectorType.getShape();
+  auto targetShapeTrailingDims = targetShape.drop_front(1);
+
+  // Not used
   int64_t offset;
+  SmallVector<int64_t> strides;
   if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
     return false;
+
+  // Non-unit stride in the trailing dimension means that this is memref is
+  // not contiguous.
   if (strides.back() != 1)
     return false;
-  strides.pop_back();
+
+  // Do all but the leading dim of `vectorType` and the trailing dims of
+  // `memrefType` match?
+  bool allTrailingDimsMatch = true;
+
+  // The trailing dimension of `memrefType` after collapsing/flattening the
+  // current dim. This will be a product of the leading dims, hence initialising
+  // to 1.
   int64_t flatDim = 1;
-  for (auto [targetDim, memrefDim, memrefStride] :
-       llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+  strides.pop_back();
+  for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
+           targetShapeTrailingDims, memrefType.getShape(), strides))) {
     flatDim *= memrefDim;
-    if (flatDim != memrefStride || targetDim != memrefDim)
+    // If the memref stride does not match the flattened dim, then this is
+    // memref is not contiguous.
+    if (flatDim != memrefStride)
+      return false;
+
+    // If a non-matching dim was found, then the remaining dims of `VectorType`
+    // should be 1.
+    if (!allTrailingDimsMatch && (targetDim != 1))
       return false;
+
+    allTrailingDimsMatch = (targetDim == memrefDim);
   }
-  return true;
+
+  return allTrailingDimsMatch ? true : (targetShape[0] == 1);
 }
 
 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
@@ -529,6 +579,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
 /// Checks that the indices corresponding to dimensions starting at
 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
+/// TODO: Extract the logic that writes to outIndices so that this method
+/// simply checks one pre-condition.
 static LogicalResult
 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
                                  SmallVector<Value> &outIndices) {
@@ -562,18 +614,16 @@ class FlattenContiguousRowMajorTransferReadPattern
     VectorType vectorType = cast<VectorType>(vector.getType());
     Value source = transferReadOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+    // 0. Check pre-conditions
     // Contiguity check is valid on tensors only.
     if (!sourceType)
       return failure();
+    // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
-      // Already 0D/1D, nothing to do.
       return failure();
-    if (!hasMatchingInnerContigousShape(
-            sourceType,
-            vectorType.getShape().take_back(vectorType.getRank() - 1)))
+    if (!isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstContiguousInnerDim =
-        sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferReadOp.hasOutOfBoundsDim())
       return failure();
@@ -581,26 +631,76 @@ class FlattenContiguousRowMajorTransferReadPattern
       return failure();
     if (transferReadOp.getMask())
       return failure();
+
     SmallVector<Value> collapsedIndices;
-    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
-                                                firstContiguousInnerDim,
-                                                collapsedIndices)))
-      return failure();
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+
+    // 1. Collapse the source memref
     Value collapsedSource =
-        collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+        collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
         dyn_cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
-    assert(collapsedRank == firstContiguousInnerDim + 1);
+    assert(collapsedRank == firstDimToCollapse + 1);
+
+    // 2. Generate input args for a new vector.transfer_read that will read
+    // from the collapsed memref.
+    // 2.1. New dim exprs + affine map
     SmallVector<AffineExpr, 1> dimExprs{
-        getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+        getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+
+    // 2.2 New indices
+    // If all the collapsed indices are zero then no extra logic is needed.
+    // Otherwise, a new offset/index has to be computed.
+    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
+                                                firstDimToCollapse,
+                                                collapsedIndices))) {
+      // Copy all the leading indices
+      collapsedIndices = transferReadOp.getIndices();
+      collapsedIndices.resize(firstDimToCollapse);
+
+      // Compute the remaining trailing index/offset required for reading from
+      // the collapsed memref:
+      //
+      //    offset = 0
+      //    for (i = firstDimToCollapse; i < outputRank; ++i)
+      //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
+      //
+      // For this example:
+      //   %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
+      //   memref<1x43x2xi32>, vector<1x2xi32>
+      // which would be collapsed to:
+      //   %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
+      //   memref<1x86xi32>, vector<2xi32>
+      // one would get the following offset:
+      //    %offset = %arg0 * 43
+      int64_t outputRank = transferReadOp.getIndices().size();
+      Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+      for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
+        Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        auto sourceDimSize =
+            rewriter.create<memref::DimOp>(loc, source, dimIdx);
+
+        offset = rewriter.create<arith::AddIOp>(
+            loc,
+            rewriter.create<arith::MulIOp>(loc, transferReadOp.getIndices()[i],
+                                           sourceDimSize),
+            offset);
+      }
+      collapsedIndices.push_back(offset);
+    }
+
+    // 3. Create new vector.transfer_read that reads from the collapsed memref
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+    // 4. Replace the old transfer_read with the new one reading from the
+    // collapsed shape
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
     return success();
@@ -628,9 +728,7 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    if (!hasMatchingInnerContigousShape(
-            sourceType,
-            vectorType.getShape().take_back(vectorType.getRank() - 1)))
+    if (!isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
         sourceType.getRank() - vectorType.getRank();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae62a5ba43d055a..8369069e31ab7c6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
 
-func.func @transfer_read_flattenable_with_offset(
+func.func @transfer_read_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
     %c0 = arith.constant 0 : index
     %cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
     return %v : vector<5x4x3x2xi8>
 }
 
-// CHECK-LABEL: func @transfer_read_flattenable_with_offset
+// CHECK-LABEL: func @transfer_read_dims_match_contiguous
 // CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,7 +18,76 @@ func.func @transfer_read_flattenable_with_offset(
 
 // -----
 
-func.func @transfer_write_flattenable_with_offset(
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable".
+
+func.func @transfer_read_dims_mismatch_contiguous(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+    return %v : vector<1x1x2x2xi8>
+}
+
+// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
+// CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
+// CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_non_zero_indices(
+                     %idx_1: index,
+                     %idx_2: index,
+                     %m_in: memref<1x43x4x6xi32>,
+                     %m_out: memref<1x2x6xi32>) {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+    memref<1x43x4x6xi32>, vector<1x2x6xi32>
+  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x6xi32>, memref<1x2x6xi32>
+  return
+}
+
+// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
+// CHECK-SAME:      %[[VAL_2:.*]]: memref<1x43x4x6xi32>,
+// CHECK-SAME:      %[[VAL_3:.*]]: memref<1x2x6xi32>) {
+// CHECK:           %[[VAL_4:.*]] = arith.constant 43 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK:           %[[VAL_9:.*]] = arith.muli %[[VAL_0]], %[[VAL_4]] : index
+// CHECK:           %[[VAL_10:.*]] = arith.muli %[[VAL_1]], %[[VAL_5]] : index
+// CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_12:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_7]], %[[VAL_11]]], %[[VAL_6]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK:           %[[VAL_13:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
+// CHECK:           vector.transfer_write %[[VAL_12]], %[[VAL_13]]{{\[}}%[[VAL_7]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_contiguous(
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+    return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
+func.func @transfer_write_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -26,7 +95,7 @@ func.func @transfer_write_flattenable_with_offset(
     return
 }
 
-// CHECK-LABEL: func @transfer_write_flattenable_with_offset
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous
 // CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
 // CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +104,46 @@ func.func @transfer_write_flattenable_with_offset(
 
 // -----
 
+func.func @transfer_write_dims_mismatch_contiguous(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+      vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+    return
+}
+
+// CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-SAME:                                            %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME:                                            %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
+// CHECK:           vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @transfer_write_dims_mismatch_non_contiguous(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+      vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+    return
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
 func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
       vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
       return
 }
 
-// CHECK-LABEL: func @transfer_write_0d
-// CHECK-SAME:       %[[ARG:.+]]: memref<i8>
-// CHECK-SAME:       %[[VEC:.+]]: vector<i8>
-// CHECK:          vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
-// CHECK:          return
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
 
 // -----
 
@@ -54,11 +153,8 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
       return %0 : vector<i8>
 }
 
-// CHECK-LABEL: func @transfer_read_0d
-// CHECK-SAME:       %[[ARG:.+]]: memref<i8>
-// CHECK:            %[[CST:.+]] = arith.constant 0 : i8
-// CHECK:            %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
-// CHECK:            return %[[READ]]
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
 
 // -----
 

@banach-space
Copy link
Contributor Author

banach-space commented Nov 27, 2023

Depends on #73522 - please only review the top commit 🙏🏻 .

@banach-space banach-space force-pushed the andrzej/extend_collapse_pattern branch from cc810a7 to fecd909 Compare November 27, 2023 14:51
Copy link

github-actions bot commented Nov 27, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@bjacob
Copy link
Contributor

bjacob commented Nov 27, 2023

I'm not a legitimate reviewer here but I'll vouch for the usefulness of this change. Flattening patterns are often key to consistently good codegen of transfer ops and this PR seems to remove an unnecessary limitation.

Updates patterns for flattening vector.transfer_read by relaxing the
requirement that the "collapsed" indices are all zero. This enables
collapsing cases like this one:

```mlir
  %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>
```

Previously only the following case would be consider for collapsing:

```mlir
  %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>
```

The pattern itself, `FlattenContiguousRowMajorTransferReadPattern`, was
a bit refactored too:
  * added comments,
  * renamed `firstContiguousInnerDim` as `firstDimToCollapse` (the
    latter better matches the meaning and is already consistently used
    in various helper methods that use it),

Similar update for `vector.transfer_write` will be implemented in a
follow-up patch.
…(2/N)

Refactor to use makeComposedFoldedAffineApply
@banach-space banach-space force-pushed the andrzej/extend_collapse_pattern branch from fecd909 to b27c49d Compare December 4, 2023 10:36
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing on this and being patient with my review comments!

Copy link
Contributor Author

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am about to send a small update - it addresses comments from @hanhanW and also restricts the "rewrite" added here. If there are no new comments, I will land it tomorrow.

Thank you for taking a look :)

@banach-space
Copy link
Contributor Author

Thanks for (...) being patient with my review comments!

It works both ways - thank your for bearing with me :) And for excellent comments - really helped to improve this patch (same comment for my previous PR) 🙏🏻

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG, thank you!

@banach-space banach-space merged commit 2eb9e33 into llvm:main Dec 5, 2023
@banach-space banach-space deleted the andrzej/extend_collapse_pattern branch December 5, 2023 08:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants