Skip to content

[mlir][memref] Remove incorrect memref.transpose fold #79809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2024
Merged

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jan 29, 2024

This folded casts into memref.transpose without updating the result type of the transpose op, which resulted in IR that failed to verify for statically sized memrefs.

i.e.

%cast = memref.cast %0 : memref<?x4xf32> to memref<?x?xf32>
%transpose = memref.transpose %cast : memref<?x?xf32> to memref<?x?xf32>

would fold to:

// Fails verification:
%transpose = memref.transpose %cast : memref<?x4xf32> to memref<?x?xf32>

This folded casts into `memref.transpose` without updating the result
type of the transpose op, which resulted in IR that failed to verify
for statically sized memrefs.

i.e.

```mlir
%cast = memref.cast %0 : memref<?x4xf32> to memref<?x?xf32>
%transpose = memref.transpose %cast : memref<?x?xf32> to memref<?x?xf32>
```

would fold to:

```mlir
// Fails verification:
%transpose = memref.transpose %cast : memref<?x4xf32> to memref<?x?xf32>
```
@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Benjamin Maxwell (MacDue)

Changes

This folded casts into memref.transpose without updating the result type of the transpose op, which resulted in IR that failed to verify for statically sized memrefs.

i.e.

%cast = memref.cast %0 : memref&lt;?x4xf32&gt; to memref&lt;?x?xf32&gt;
%transpose = memref.transpose %cast : memref&lt;?x?xf32&gt; to memref&lt;?x?xf32&gt;

would fold to:

// Fails verification:
%transpose = memref.transpose %cast : memref&lt;?x4xf32&gt; to memref&lt;?x?xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/79809.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (-2)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+13)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b79ab8f3d671e01..8b5765b7f8dba2a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3227,8 +3227,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   // result types are identical already.
   if (getPermutation().isIdentity() && getType() == getIn().getType())
     return getIn();
-  if (succeeded(foldMemRefCast(*this)))
-    return getResult();
   // Fold two consecutive memref.transpose Ops into one by composing their
   // permutation maps.
   if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index eccfc485b2034e4..61790bbc8a96ed6 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1023,3 +1023,16 @@ func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3
   // CHECK: return %[[arg0]]
   return %1 : memref<1x2x3x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @cannot_fold_transpose_cast(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?x4xf32, strided<[?, ?], offset: ?>>
+func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32, strided<[?, ?], offset: ?>>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+    // CHECK: %[[CAST:.*]] = memref.cast %[[arg0]] : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    %cast = memref.cast %arg0 : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    %transpose = memref.transpose %cast (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK: return %[[TRANSPOSE]]
+    return %transpose : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fix Ben, makes sense to me, just left one minor nit/question, otherwise LGTM, cheers

Copy link
Member

@ubfx ubfx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it makes sense to remove this fold. I just wonder what the original reason was for having it... Maybe it used to make sense before we had a verifier.

@MacDue
Copy link
Member Author

MacDue commented Jan 30, 2024

I just wonder what the original reason was for having it... Maybe it used to make sense before we had a verifier.

I'm not sure... but given there's no tests covering it, I think it's safe to remove.

@MacDue MacDue merged commit fdf73e9 into llvm:main Jan 30, 2024
@MacDue MacDue deleted the rm_fold branch January 30, 2024 15:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants