Skip to content

Commit fdf73e9

Browse files
authored
[mlir][memref] Remove incorrect memref.transpose fold (#79809)
This folded casts into `memref.transpose` without updating the result type of the transpose op, which resulted in IR that failed to verify for statically sized memrefs. i.e. ```mlir %cast = memref.cast %0 : memref<?x4xf32> to memref<?x?xf32> %transpose = memref.transpose %cast : memref<?x?xf32> to memref<?x?xf32> ``` would fold to: ```mlir // Fails verification: %transpose = memref.transpose %cast : memref<?x4xf32> to memref<?x?xf32> ```
1 parent fa10121 commit fdf73e9

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3227,8 +3227,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
32273227
// result types are identical already.
32283228
if (getPermutation().isIdentity() && getType() == getIn().getType())
32293229
return getIn();
3230-
if (succeeded(foldMemRefCast(*this)))
3231-
return getResult();
32323230
// Fold two consecutive memref.transpose Ops into one by composing their
32333231
// permutation maps.
32343232
if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,3 +1023,18 @@ func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3
10231023
// CHECK: return %[[arg0]]
10241024
return %1 : memref<1x2x3x4x5xf32>
10251025
}
1026+
1027+
// -----
1028+
1029+
#transpose_map = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)>
1030+
1031+
// CHECK-LABEL: func @cannot_fold_transpose_cast(
1032+
// CHECK-SAME: %[[arg0:.*]]: memref<?x4xf32>
1033+
func.func @cannot_fold_transpose_cast(%arg0: memref<?x4xf32>) -> memref<?x?xf32, #transpose_map> {
1034+
// CHECK: %[[CAST:.*]] = memref.cast %[[arg0]] : memref<?x4xf32> to memref<?x?xf32>
1035+
%cast = memref.cast %arg0 : memref<?x4xf32> to memref<?x?xf32>
1036+
// CHECK: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #{{.*}}>
1037+
%transpose = memref.transpose %cast (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32, #transpose_map>
1038+
// CHECK: return %[[TRANSPOSE]]
1039+
return %transpose : memref<?x?xf32, #transpose_map>
1040+
}

0 commit comments

Comments
 (0)