Skip to content

Commit fc9b37d

Browse files
[mlir][bufferization] Do not canonicalize to_tensor(to_memref(x))
This is a partial revert of D128615. to_memref(to_tensor(x)) always be folded to x. But to_tensor(to_memref(x)) cannot be folded in the general case because writes to the intermediary memref may go unnoticed. Differential Revision: https://reviews.llvm.org/D129354
1 parent e1272ab commit fc9b37d

File tree

3 files changed

+5
-17
lines changed

3 files changed

+5
-17
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -539,20 +539,6 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
539539
}
540540

541541
namespace {
542-
/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
543-
struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
544-
using OpRewritePattern<ToTensorOp>::OpRewritePattern;
545-
546-
LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
547-
PatternRewriter &rewriter) const final {
548-
auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
549-
if (!toMemrefOp)
550-
return failure();
551-
rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
552-
return success();
553-
}
554-
};
555-
556542
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
557543
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
558544

@@ -571,7 +557,7 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
571557

572558
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
573559
MLIRContext *context) {
574-
results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
560+
results.add<DimOfToTensorFolder>(context);
575561
}
576562

577563
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,8 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
787787
}
788788

789789
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
790-
// CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
790+
// CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
791+
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
791792
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
792793
}
793794

mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@
109109
// CHECK: scf.yield %[[VAL_84]] : f64
110110
// CHECK: }
111111
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
112-
// CHECK: return %[[VAL_0]] : tensor<f64>
112+
// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
113+
// CHECK: return %[[VAL_87]] : tensor<f64>
113114
// CHECK: }
114115
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
115116
%arga: tensor<64x32xf64, #SparseMatrix>,

0 commit comments

Comments
 (0)