Skip to content

Commit 1e25109

Browse files
committed
Canonicalize static alloc followed by memref_cast and std.view
Summary: Rewrite alloc, memref_cast, std.view into allo, std.view by droping memref_cast. Reviewers: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72379
1 parent 31992a6 commit 1e25109

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

mlir/lib/Dialect/StandardOps/Ops.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2527,11 +2527,31 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
25272527
}
25282528
};
25292529

2530+
struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2531+
using OpRewritePattern<ViewOp>::OpRewritePattern;
2532+
2533+
PatternMatchResult matchAndRewrite(ViewOp viewOp,
2534+
PatternRewriter &rewriter) const override {
2535+
Value memrefOperand = viewOp.getOperand(0);
2536+
MemRefCastOp memrefCastOp =
2537+
dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
2538+
if (!memrefCastOp)
2539+
return matchFailure();
2540+
Value allocOperand = memrefCastOp.getOperand();
2541+
AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
2542+
if (!allocOp)
2543+
return matchFailure();
2544+
rewriter.replaceOpWithNewOp<ViewOp>(memrefOperand, viewOp, viewOp.getType(),
2545+
allocOperand, viewOp.operands());
2546+
return matchSuccess();
2547+
}
2548+
};
2549+
25302550
} // end anonymous namespace
25312551

25322552
void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
25332553
MLIRContext *context) {
2534-
results.insert<ViewOpShapeFolder>(context);
2554+
results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
25352555
}
25362556

25372557
//===----------------------------------------------------------------------===//

mlir/test/Transforms/canonicalize.mlir

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,7 @@ func @cast_values(%arg0: tensor<*xi32>, %arg1: memref<?xi32>) -> (tensor<2xi32>,
695695

696696
// CHECK-LABEL: func @view
697697
func @view(%arg0 : index) {
698+
// CHECK: %[[ALLOC_MEM:.*]] = alloc() : memref<2048xi8>
698699
%0 = alloc() : memref<2048xi8>
699700
%c0 = constant 0 : index
700701
%c7 = constant 7 : index
@@ -730,11 +731,15 @@ func @view(%arg0 : index) {
730731

731732
// Test: preserve an existing static dim size while folding a dynamic
732733
// dimension and offset.
733-
// CHECK: std.view %0[][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]>
734-
%5 = view %0[%c15][%c7]
735-
: memref<2048xi8> to memref<?x4xf32, #TEST_VIEW_MAP2>
734+
// CHECK: std.view %[[ALLOC_MEM]][][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]>
735+
%5 = view %0[%c15][%c7] : memref<2048xi8> to memref<?x4xf32, #TEST_VIEW_MAP2>
736736
load %5[%c0, %c0] : memref<?x4xf32, #TEST_VIEW_MAP2>
737737

738+
// Test: folding static alloc and memref_cast into a view.
739+
// CHECK: std.view %0[][%c15, %c7] : memref<2048xi8> to memref<?x?xf32>
740+
%6 = memref_cast %0 : memref<2048xi8> to memref<?xi8>
741+
%7 = view %6[%c15][%c7] : memref<?xi8> to memref<?x?xf32>
742+
load %7[%c0, %c0] : memref<?x?xf32>
738743
return
739744
}
740745

0 commit comments

Comments
 (0)