Skip to content

Commit 8dca38d

Browse files
[mlir][bufferize] Support layout maps in bufferization.clone lowering
Differential Revision: https://reviews.llvm.org/D121278
1 parent 1eeb2bf commit 8dca38d

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,18 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
3939
return rewriter.notifyMatchFailure(
4040
op, "UnrankedMemRefType is not supported.");
4141
}
42+
MemRefType memrefType = type.cast<MemRefType>();
43+
MemRefLayoutAttrInterface layout;
44+
auto allocType =
45+
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
46+
layout, memrefType.getMemorySpace());
47+
// Since this implementation always allocates, certain result types of the
48+
// clone op cannot be lowered.
49+
if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
50+
return failure();
4251

4352
// Transform a clone operation into alloc + copy operation and pay
4453
// attention to the shape dimensions.
45-
MemRefType memrefType = type.cast<MemRefType>();
4654
Location loc = op->getLoc();
4755
SmallVector<Value, 4> dynamicOperands;
4856
for (int i = 0; i < memrefType.getRank(); ++i) {
@@ -52,8 +60,14 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
5260
Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.input(), size);
5361
dynamicOperands.push_back(dim);
5462
}
55-
Value alloc = rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
56-
dynamicOperands);
63+
64+
// Allocate a memref with identity layout.
65+
Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
66+
dynamicOperands);
67+
// Cast the allocation to the specified type if needed.
68+
if (memrefType != allocType)
69+
alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
70+
rewriter.replaceOp(op, alloc);
5771
rewriter.create<memref::CopyOp>(loc, op.input(), alloc);
5872
return success();
5973
}

mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
// CHECK-LABEL: @conversion_static
44
func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
5-
%0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
6-
memref.dealloc %arg0 : memref<2xf32>
7-
return %0 : memref<2xf32>
5+
%0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32>
6+
memref.dealloc %arg0 : memref<2xf32>
7+
return %0 : memref<2xf32>
88
}
99

1010
// CHECK: %[[ALLOC:.*]] = memref.alloc
@@ -16,9 +16,9 @@ func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> {
1616

1717
// CHECK-LABEL: @conversion_dynamic
1818
func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
19-
%1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
20-
memref.dealloc %arg0 : memref<?xf32>
21-
return %1 : memref<?xf32>
19+
%1 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32>
20+
memref.dealloc %arg0 : memref<?xf32>
21+
return %1 : memref<?xf32>
2222
}
2323

2424
// CHECK: %[[CONST:.*]] = arith.constant
@@ -32,7 +32,40 @@ func @conversion_dynamic(%arg0 : memref<?xf32>) -> memref<?xf32> {
3232

3333
func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> {
3434
// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
35-
%1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
36-
memref.dealloc %arg0 : memref<*xf32>
37-
return %1 : memref<*xf32>
35+
%1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32>
36+
memref.dealloc %arg0 : memref<*xf32>
37+
return %1 : memref<*xf32>
38+
}
39+
40+
// -----
41+
42+
// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
43+
#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
44+
// CHECK-LABEL: func @conversion_with_layout_map(
45+
// CHECK-SAME: %[[ARG:.*]]: memref<?xf32, #[[$MAP]]>
46+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
47+
// CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]]
48+
// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32>
49+
// CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC]] : memref<?xf32> to memref<?xf32, #[[$MAP]]>
50+
// CHECK: memref.copy
51+
// CHECK: memref.dealloc
52+
// CHECK: return %[[CASTED]]
53+
func @conversion_with_layout_map(%arg0 : memref<?xf32, #map>) -> memref<?xf32, #map> {
54+
%1 = bufferization.clone %arg0 : memref<?xf32, #map> to memref<?xf32, #map>
55+
memref.dealloc %arg0 : memref<?xf32, #map>
56+
return %1 : memref<?xf32, #map>
57+
}
58+
59+
// -----
60+
61+
// This bufferization.clone cannot be lowered because a buffer with this layout
62+
// map cannot be allocated (or casted to).
63+
64+
#map2 = affine_map<(d0)[s0] -> (d0 * 10 + s0)>
65+
func @conversion_with_invalid_layout_map(%arg0 : memref<?xf32, #map2>)
66+
-> memref<?xf32, #map2> {
67+
// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}}
68+
%1 = bufferization.clone %arg0 : memref<?xf32, #map2> to memref<?xf32, #map2>
69+
memref.dealloc %arg0 : memref<?xf32, #map2>
70+
return %1 : memref<?xf32, #map2>
3871
}

0 commit comments

Comments
 (0)