Skip to content

Reapply "[mlir][vector] Drop inner unit dims for transfer ops on dynamic shapes." (#80712) #81778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2024

Conversation

dcaballe
Copy link
Contributor

This reverts commit b4c7152.

Downstream regression due to another issue that this PR exposes. We have identified the work-items to fix the new issue here: iree-org/iree#16406

…mic shapes." (llvm#80712)

This reverts commit b4c7152.

Downstream regression due to another issue that this PR exposes.
We have identified the work-items to fix the new issue here: iree-org/iree#16406

Co-authored-by: Han-Chung Wang <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This reverts commit b4c7152.

Downstream regression due to another issue that this PR exposes. We have identified the work-items to fix the new issue here: iree-org/iree#16406


Full diff: https://github.com/llvm/llvm-project/pull/81778.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+8-6)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+19)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 53ae138d1e43a0..74dd1dfaca0da8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1237,7 +1237,7 @@ class DropInnerMostUnitDimsTransferRead
       return failure();
 
     auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
-    if (!srcType || !srcType.hasStaticShape())
+    if (!srcType)
       return failure();
 
     if (!readOp.getPermutationMap().isMinorIdentity())
@@ -1261,19 +1261,21 @@ class DropInnerMostUnitDimsTransferRead
                         targetType.getElementType());
 
     auto loc = readOp.getLoc();
+    SmallVector<OpFoldResult> sizes =
+        memref::getMixedSizes(rewriter, loc, readOp.getSource());
+    SmallVector<OpFoldResult> offsets(srcType.getRank(),
+                                      rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(srcType.getRank(),
+                                      rewriter.getIndexAttr(1));
     MemRefType resultMemrefType =
         getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
-    SmallVector<int64_t> offsets(srcType.getRank(), 0);
-    SmallVector<int64_t> strides(srcType.getRank(), 1);
-
     ArrayAttr inBoundsAttr =
         readOp.getInBounds()
             ? rewriter.getArrayAttr(
                   readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
             : ArrayAttr();
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
-        loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
-        strides);
+        loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
     auto permMap = getTransferMinorIdentityMap(
         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
     Value result = rewriter.create<vector::TransferReadOp>(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 750879df129b14..3984f17f9e8cdb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -16,6 +16,25 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,
 
 // -----
 
+func.func @contiguous_outer_dyn_inner_most_view(%in: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
+  return %0 : vector<1x8x1xf32>
+}
+//      CHECK: func @contiguous_outer_dyn_inner_most_view(
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
+//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1]
+// CHECK-SAME:    memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
+// CHECK-SAME:    memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
+//      CHECK:   %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
   %c0 = arith.constant 0 : index
   %f0 = arith.constant 0.0 : f32

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tracking this down!!

@MacDue
Copy link
Member

MacDue commented Mar 5, 2024

Hey, this is breaking building dynamically sized matmuls for neon with IREE (see iree-org/iree#16475).

Reduced to this case that produces incorrect IR: https://godbolt.org/z/36Er81G7G

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants