-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][memref] Remove incorrect memref.transpose
fold
#79809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This folded casts into `memref.transpose` without updating the result type of the transpose op, which resulted in IR that failed to verify for statically sized memrefs. i.e. ```mlir %cast = memref.cast %0 : memref<?x4xf32> to memref<?x?xf32> %transpose = memref.transpose %cast : memref<?x?xf32> to memref<?x?xf32> ``` would fold to: ```mlir // Fails verification: %transpose = memref.transpose %cast : memref<?x4xf32> to memref<?x?xf32> ```
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Benjamin Maxwell (MacDue) ChangesThis folded casts into i.e. %cast = memref.cast %0 : memref<?x4xf32> to memref<?x?xf32>
%transpose = memref.transpose %cast : memref<?x?xf32> to memref<?x?xf32> would fold to: // Fails verification:
%transpose = memref.transpose %cast : memref<?x4xf32> to memref<?x?xf32> Full diff: https://github.com/llvm/llvm-project/pull/79809.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b79ab8f3d671e01..8b5765b7f8dba2a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3227,8 +3227,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
// result types are identical already.
if (getPermutation().isIdentity() && getType() == getIn().getType())
return getIn();
- if (succeeded(foldMemRefCast(*this)))
- return getResult();
// Fold two consecutive memref.transpose Ops into one by composing their
// permutation maps.
if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index eccfc485b2034e4..61790bbc8a96ed6 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1023,3 +1023,16 @@ func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3
// CHECK: return %[[arg0]]
return %1 : memref<1x2x3x4x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @cannot_fold_transpose_cast(
+// CHECK-SAME: %[[arg0:.*]]: memref<?x4xf32, strided<[?, ?], offset: ?>>
+func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32, strided<[?, ?], offset: ?>>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ // CHECK: %[[CAST:.*]] = memref.cast %[[arg0]] : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %cast = memref.cast %arg0 : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ %transpose = memref.transpose %cast (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: return %[[TRANSPOSE]]
+ return %transpose : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for fix Ben, makes sense to me, just left one minor nit/question, otherwise LGTM, cheers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it makes sense to remove this fold. I just wonder what the original reason was for having it... Maybe it used to make sense before we had a verifier.
I'm not sure... but given there's no tests covering it, I think it's safe to remove. |
This folded casts into
memref.transpose
without updating the result type of the transpose op, which resulted in IR that failed to verify for statically sized memrefs.i.e.
would fold to: