Skip to content

Commit b20e150

Browse files
committed
[mlir] Use static shape knowledge when lowering memref.reshape
This is actually necessary for correctness, as memref.reinterpret_cast doesn't verify if the output shape doesn't match the static sizes. Differential Revision: https://reviews.llvm.org/D102232
1 parent ec28e43 commit b20e150

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,19 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
9191
Location loc = op.getLoc();
9292
Value stride = rewriter.create<ConstantIndexOp>(loc, 1);
9393
for (int i = rank - 1; i >= 0; --i) {
94-
Value index = rewriter.create<ConstantIndexOp>(loc, i);
95-
Value size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
96-
if (!size.getType().isa<IndexType>())
97-
size = rewriter.create<IndexCastOp>(loc, size, rewriter.getIndexType());
98-
sizes[i] = size;
94+
Value size;
95+
// Load dynamic sizes from the shape input, use constants for static dims.
96+
if (op.getType().isDynamicDim(i)) {
97+
Value index = rewriter.create<ConstantIndexOp>(loc, i);
98+
size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
99+
if (!size.getType().isa<IndexType>())
100+
size =
101+
rewriter.create<IndexCastOp>(loc, size, rewriter.getIndexType());
102+
sizes[i] = size;
103+
} else {
104+
sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
105+
size = rewriter.create<ConstantOp>(loc, sizes[i].get<Attribute>());
106+
}
99107
strides[i] = stride;
100108
if (i > 0)
101109
stride = rewriter.create<MulIOp>(loc, stride, size);

mlir/test/Dialect/Standard/expand-ops.mlir

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,17 @@ func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
8484

8585
// CHECK-LABEL: func @memref_reshape(
8686
func @memref_reshape(%input: memref<*xf32>,
87-
%shape: memref<3xi32>) -> memref<?x?x?xf32> {
87+
%shape: memref<3xi32>) -> memref<?x?x8xf32> {
8888
%result = memref.reshape %input(%shape)
89-
: (memref<*xf32>, memref<3xi32>) -> memref<?x?x?xf32>
90-
return %result : memref<?x?x?xf32>
89+
: (memref<*xf32>, memref<3xi32>) -> memref<?x?x8xf32>
90+
return %result : memref<?x?x8xf32>
9191
}
9292
// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
93-
// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x?xf32> {
93+
// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> {
9494

9595
// CHECK: [[C1:%.*]] = constant 1 : index
96-
// CHECK: [[C2:%.*]] = constant 2 : index
97-
// CHECK: [[DIM_2:%.*]] = memref.load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32>
98-
// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index
99-
// CHECK: [[STRIDE_1:%.*]] = muli [[C1]], [[SIZE_2]] : index
96+
// CHECK: [[C8:%.*]] = constant 8 : index
97+
// CHECK: [[STRIDE_1:%.*]] = muli [[C1]], [[C8]] : index
10098

10199
// CHECK: [[C1_:%.*]] = constant 1 : index
102100
// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32>
@@ -108,6 +106,6 @@ func @memref_reshape(%input: memref<*xf32>,
108106
// CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index
109107

110108
// CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]]
111-
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]],
109+
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
112110
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
113-
// CHECK-SAME: : memref<*xf32> to memref<?x?x?xf32>
111+
// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>

0 commit comments

Comments
 (0)