Skip to content

Commit 28d6aa9

Browse files
authored
[mlir][bufferization] Unranked memref support for clone (#94757)
bufferization.clone does not currently support lowering to memref for unranked memrefs. This interferes with bufferizing unranked tensors at boundaries where a clone operation is needed. ``` func.func @foo(%input: memref<*xf32>, %shape: memref<?xindex>) -> memref<*xf32> { %reshape = memref.reshape %input(%shape) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32> %copy = bufferization.clone %reshape : memref<*xf32> to memref<*xf32> return %copy : memref<*xf32> } ``` Patterns such as that are possibly when bufferizing functions with input and output unranked tensors. The clone operation currently fails to legalize during the bufferization-to-memref conversion with unranked memrefs. This change modifies the conversion of bufferization.clone to memref to generate the runtime calculations and allocation to allow for cloning an unranked memref.
1 parent 0aeaa2d commit 28d6aa9

File tree

2 files changed

+82
-32
lines changed

2 files changed

+82
-32
lines changed

mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -42,39 +42,76 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
4242
LogicalResult
4343
matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
4444
ConversionPatternRewriter &rewriter) const override {
45-
// Check for unranked memref types which are currently not supported.
46-
Type type = op.getType();
47-
if (isa<UnrankedMemRefType>(type)) {
48-
return rewriter.notifyMatchFailure(
49-
op, "UnrankedMemRefType is not supported.");
50-
}
51-
MemRefType memrefType = cast<MemRefType>(type);
52-
MemRefLayoutAttrInterface layout;
53-
auto allocType =
54-
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
55-
layout, memrefType.getMemorySpace());
56-
// Since this implementation always allocates, certain result types of the
57-
// clone op cannot be lowered.
58-
if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
59-
return failure();
60-
61-
// Transform a clone operation into alloc + copy operation and pay
62-
// attention to the shape dimensions.
6345
Location loc = op->getLoc();
64-
SmallVector<Value, 4> dynamicOperands;
65-
for (int i = 0; i < memrefType.getRank(); ++i) {
66-
if (!memrefType.isDynamicDim(i))
67-
continue;
68-
Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
69-
dynamicOperands.push_back(dim);
46+
47+
Type type = op.getType();
48+
Value alloc;
49+
50+
if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
51+
// Constants
52+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
53+
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
54+
55+
// Dynamically evaluate the size and shape of the unranked memref
56+
Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
57+
MemRefType allocType =
58+
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
59+
Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
60+
61+
// Create a loop to query dimension sizes, store them as a shape, and
62+
// compute the total size of the memref
63+
auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
64+
ValueRange args) {
65+
auto acc = args.front();
66+
auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
67+
68+
rewriter.create<memref::StoreOp>(loc, dim, shape, i);
69+
acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
70+
71+
rewriter.create<scf::YieldOp>(loc, acc);
72+
};
73+
auto size = rewriter
74+
.create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
75+
loopBody)
76+
.getResult(0);
77+
78+
MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
79+
unrankedType.getElementType());
80+
81+
// Allocate new memref with 1D dynamic shape, then reshape into the
82+
// shape of the original unranked memref
83+
alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
84+
alloc =
85+
rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
86+
} else {
87+
MemRefType memrefType = cast<MemRefType>(type);
88+
MemRefLayoutAttrInterface layout;
89+
auto allocType =
90+
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
91+
layout, memrefType.getMemorySpace());
92+
// Since this implementation always allocates, certain result types of
93+
// the clone op cannot be lowered.
94+
if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
95+
return failure();
96+
97+
// Transform a clone operation into alloc + copy operation and pay
98+
// attention to the shape dimensions.
99+
SmallVector<Value, 4> dynamicOperands;
100+
for (int i = 0; i < memrefType.getRank(); ++i) {
101+
if (!memrefType.isDynamicDim(i))
102+
continue;
103+
Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
104+
dynamicOperands.push_back(dim);
105+
}
106+
107+
// Allocate a memref with identity layout.
108+
alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
109+
// Cast the allocation to the specified type if needed.
110+
if (memrefType != allocType)
111+
alloc =
112+
rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
70113
}
71114

72-
// Allocate a memref with identity layout.
73-
Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
74-
dynamicOperands);
75-
// Cast the allocation to the specified type if needed.
76-
if (memrefType != allocType)
77-
alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
78115
rewriter.replaceOp(op, alloc);
79116
rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
80117
return success();

mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,34 @@ func.func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
2222
}
2323

2424
// CHECK: %[[CONST:.*]] = arith.constant
25-
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
25+
// CHECK: %[[DIM:.*]] = memref.dim %[[ARG:.*]], %[[CONST]]
2626
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[DIM]])
2727
// CHECK-NEXT: memref.copy %[[ARG]], %[[ALLOC]]
2828
// CHECK-NEXT: memref.dealloc %[[ARG]]
2929
// CHECK-NEXT: return %[[ALLOC]]
3030

3131
// -----
3232

33+
// CHECK-LABEL: @conversion_unknown
3334
func.func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
34-
// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
3535
%1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
3636
memref.dealloc %arg0 : memref<*xf32>
3737
return %1 : memref<*xf32>
3838
}
3939

40+
// CHECK: %[[RANK:.*]] = memref.rank %[[ARG:.*]]
41+
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca(%[[RANK]])
42+
// CHECK-NEXT: %[[FOR:.*]] = scf.for
43+
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG:.*]] %[[ARG:.*]]
44+
// CHECK-NEXT: memref.store %[[DIM:.*]], %[[ALLOCA:.*]][%[[ARG:.*]]]
45+
// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG:.*]], %[[DIM:.*]]
46+
// CHECK-NEXT: scf.yield %[[MUL:.*]]
47+
// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[FOR:.*]])
48+
// CHECK-NEXT: %[[RESHAPE:.*]] = memref.reshape %[[ALLOC:.*]]
49+
// CHECK-NEXT: memref.copy
50+
// CHECK-NEXT: memref.dealloc
51+
// CHECK-NEXT: return %[[RESHAPE:.*]]
52+
4053
// -----
4154

4255
// CHECK-LABEL: func @conversion_with_layout_map(

0 commit comments

Comments
 (0)