Skip to content

Revert "[mlir][vector] Drop inner unit dims for transfer ops on dynamic shapes." #80712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 5, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Feb 5, 2024

Reverts #79752 because it is causing regressions in downstream projects.

@llvmbot
Copy link
Member

llvmbot commented Feb 5, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

Reverts llvm/llvm-project#79752 because it is causing regressions in downstream projects.


Full diff: https://github.com/llvm/llvm-project/pull/80712.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+12-17)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (-40)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8363e73857e5c..12aa11e9e33f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1236,7 +1236,7 @@ class DropInnerMostUnitDimsTransferRead
       return failure();
 
     auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
-    if (!srcType)
+    if (!srcType || !srcType.hasStaticShape())
       return failure();
 
     if (!readOp.getPermutationMap().isMinorIdentity())
@@ -1260,21 +1260,19 @@ class DropInnerMostUnitDimsTransferRead
                         targetType.getElementType());
 
     auto loc = readOp.getLoc();
-    SmallVector<OpFoldResult> sizes =
-        memref::getMixedSizes(rewriter, loc, readOp.getSource());
-    SmallVector<OpFoldResult> offsets(srcType.getRank(),
-                                      rewriter.getIndexAttr(0));
-    SmallVector<OpFoldResult> strides(srcType.getRank(),
-                                      rewriter.getIndexAttr(1));
     MemRefType resultMemrefType =
         getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
+    SmallVector<int64_t> offsets(srcType.getRank(), 0);
+    SmallVector<int64_t> strides(srcType.getRank(), 1);
+
     ArrayAttr inBoundsAttr =
         readOp.getInBounds()
             ? rewriter.getArrayAttr(
                   readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
             : ArrayAttr();
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
-        loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
+        loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
+        strides);
     auto permMap = getTransferMinorIdentityMap(
         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
     Value result = rewriter.create<vector::TransferReadOp>(
@@ -1320,7 +1318,7 @@ class DropInnerMostUnitDimsTransferWrite
       return failure();
 
     auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
-    if (!srcType)
+    if (!srcType || !srcType.hasStaticShape())
       return failure();
 
     if (!writeOp.getPermutationMap().isMinorIdentity())
@@ -1343,23 +1341,20 @@ class DropInnerMostUnitDimsTransferWrite
         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
                         targetType.getElementType());
 
-    Location loc = writeOp.getLoc();
-    SmallVector<OpFoldResult> sizes =
-        memref::getMixedSizes(rewriter, loc, writeOp.getSource());
-    SmallVector<OpFoldResult> offsets(srcType.getRank(),
-                                      rewriter.getIndexAttr(0));
-    SmallVector<OpFoldResult> strides(srcType.getRank(),
-                                      rewriter.getIndexAttr(1));
     MemRefType resultMemrefType =
         getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
+    SmallVector<int64_t> offsets(srcType.getRank(), 0);
+    SmallVector<int64_t> strides(srcType.getRank(), 1);
     ArrayAttr inBoundsAttr =
         writeOp.getInBounds()
             ? rewriter.getArrayAttr(
                   writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
             : ArrayAttr();
 
+    Location loc = writeOp.getLoc();
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
-        loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
+        loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
+        strides);
     auto permMap = getTransferMinorIdentityMap(
         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 3984f17f9e8cd..d6d69c8af8850 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -16,25 +16,6 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,
 
 // -----
 
-func.func @contiguous_outer_dyn_inner_most_view(%in: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.0 : f32
-  %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
-  return %0 : vector<1x8x1xf32>
-}
-//      CHECK: func @contiguous_outer_dyn_inner_most_view(
-// CHECK-SAME:   %[[SRC:[a-zA-Z0-9]+]]
-//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
-//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1]
-// CHECK-SAME:    memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>
-//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
-// CHECK-SAME:    memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
-//      CHECK:   %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
-//      CHECK:   return %[[RESULT]]
-
-// -----
-
 func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
   %c0 = arith.constant 0 : index
   %f0 = arith.constant 0.0 : f32
@@ -138,27 +119,6 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
 
 // -----
 
-func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
-    {in_bounds = [true, true, true, true]}
-    : vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
-  return
-}
-// CHECK:      func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
-//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
-// CHECK-SAME:     memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
-// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
-// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
-// CHECK-SAME:     [%[[IDX]], %[[C0]], %[[C0]]]
-
-// -----
-
 func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]

@hanhanW hanhanW merged commit b4c7152 into main Feb 5, 2024
@hanhanW hanhanW deleted the revert-79752-transfer-write-inner-unit-dims-dynamic branch February 5, 2024 17:32
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
…ic shapes." (llvm#80712)

Reverts llvm#79752 because it is causing regressions in
downstream projects.
@dcaballe
Copy link
Contributor

dcaballe commented Feb 7, 2024

I just realized I need the xfer read part of this PR so I may help with the investigation :)

dcaballe added a commit to iree-org/llvm-project that referenced this pull request Feb 13, 2024
…mic shapes." (llvm#80712)

This reverts commit b4c7152.

Regression has been fixed.
dcaballe added a commit to dcaballe/llvm-project that referenced this pull request Feb 14, 2024
…mic shapes." (llvm#80712)

This reverts commit b4c7152.

Downstream regression due to another issue that this PR exposes.
We have identified the work-items to fix the new issue here: iree-org/iree#16406

Co-authored-by: Han-Chung Wang <[email protected]>
dcaballe added a commit that referenced this pull request Feb 14, 2024
…mic shapes." (#80712) (#81778)

This reverts commit b4c7152.

Downstream regression due to another issue that this PR exposes. We have identified the work-items to fix the new issue here: iree-org/iree#16406

Co-authored-by: Han-Chung Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants