Skip to content

Commit a89e55c

Browse files
author
Stephan Herhut
committed
[mlir][std] Canonicalize a dim(memref_reshape) into a load from the shape operand
This canonicalization helps propagate shape information through the program. Differential Revision: https://reviews.llvm.org/D91854
1 parent dfd2858 commit a89e55c

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,6 +1753,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
17531753
Optional<int64_t> getConstantIndex();
17541754
}];
17551755

1756+
let hasCanonicalizer = 1;
17561757
let hasFolder = 1;
17571758
}
17581759

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,6 +1555,34 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
15551555
return {};
15561556
}
15571557

1558+
namespace {
1559+
/// Fold dim of a memref reshape operation to a load into the reshape's shape
1560+
/// operand.
1561+
struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1562+
using OpRewritePattern<DimOp>::OpRewritePattern;
1563+
1564+
LogicalResult matchAndRewrite(DimOp dim,
1565+
PatternRewriter &rewriter) const override {
1566+
auto reshape = dim.memrefOrTensor().getDefiningOp<MemRefReshapeOp>();
1567+
1568+
if (!reshape)
1569+
return failure();
1570+
1571+
// Place the load directly after the reshape to ensure that the shape memref
1572+
// was not mutated.
1573+
rewriter.setInsertionPointAfter(reshape);
1574+
rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
1575+
llvm::makeArrayRef({dim.index()}));
1576+
return success();
1577+
}
1578+
};
1579+
} // end anonymous namespace.
1580+
1581+
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1582+
MLIRContext *context) {
1583+
results.insert<DimOfMemRefReshape>(context);
1584+
}
1585+
15581586
// ---------------------------------------------------------------------------
15591587
// DmaStartOp
15601588
// ---------------------------------------------------------------------------

mlir/test/Dialect/Standard/canonicalize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,23 @@ func @cmpi_equal_operands(%arg0: i64)
9595
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
9696
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
9797
}
98+
99+
// Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx]
100+
// CHECK-LABEL: func @dim_of_memref_reshape(
101+
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
102+
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
103+
// CHECK-NEXT: %[[IDX:.*]] = constant 3
104+
// CHECK-NEXT: %[[DIM:.*]] = load %[[SHP]][%[[IDX]]]
105+
// CHECK-NEXT: store
106+
// CHECK-NOT: dim
107+
// CHECK: return %[[DIM]] : index
108+
func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
109+
-> index {
110+
%c3 = constant 3 : index
111+
%0 = memref_reshape %arg0(%arg1)
112+
: (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
113+
// Update the shape to test that he load ends up in the right place.
114+
store %c3, %arg1[%c3] : memref<?xindex>
115+
%1 = dim %0, %c3 : memref<*xf32>
116+
return %1 : index
117+
}

0 commit comments

Comments
 (0)