Skip to content

Commit 5158b17

Browse files
committed
[mlir][Bufferization] castOrReallocMemRefValue: Use BufferizationOptions
1 parent f82d018 commit 5158b17

File tree

7 files changed

+31
-24
lines changed

7 files changed

+31
-24
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue,
5353
/// This function returns `failure()` in case of unsupported casts. E.g., casts
5454
/// with differing element types or memory spaces.
5555
FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
56-
MemRefType type);
56+
MemRefType type,
57+
const BufferizationOptions &options);
5758

5859
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
5960
/// to_memref op are different, a memref.cast is needed.
6061
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
61-
ToMemrefOp toMemref);
62+
ToMemrefOp toMemref,
63+
const BufferizationOptions &options);
6264

6365
/// Add the canonicalization patterns for bufferization.dealloc to the given
6466
/// pattern set to make them available to other passes (such as

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ using namespace mlir::bufferization;
2323
// Helper functions
2424
//===----------------------------------------------------------------------===//
2525

26-
FailureOr<Value>
27-
mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
28-
MemRefType destType) {
26+
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
27+
OpBuilder &b, Value value, MemRefType destType,
28+
const BufferizationOptions &options) {
2929
auto srcType = llvm::cast<MemRefType>(value.getType());
3030

3131
// Element type, rank and memory space must match.
@@ -73,18 +73,23 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
7373
Value size = b.create<memref::DimOp>(loc, value, i);
7474
dynamicOperands.push_back(size);
7575
}
76-
// TODO: Use alloc/memcpy callback from BufferizationOptions if called via
77-
// BufferizableOpInterface impl of ToMemrefOp.
78-
Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
79-
b.create<memref::CopyOp>(loc, value, copy);
76+
77+
FailureOr<Value> copy =
78+
options.createAlloc(b, loc, destType, dynamicOperands);
79+
if (failed(copy)) {
80+
return failure();
81+
}
82+
if (failed(options.createMemCpy(b, loc, value, *copy))) {
83+
return failure();
84+
}
8085
return copy;
8186
}
8287

8388
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
8489
/// to_memref op are different, a memref.cast is needed.
85-
LogicalResult
86-
mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
87-
ToMemrefOp toMemref) {
90+
LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
91+
RewriterBase &rewriter, ToMemrefOp toMemref,
92+
const BufferizationOptions &options) {
8893
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
8994
if (!memrefToTensor)
9095
return failure();
@@ -105,7 +110,7 @@ mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
105110
// Ranked memref -> Ranked memref cast.
106111
if (rankedSrcType && rankedDestType) {
107112
FailureOr<Value> replacement = castOrReallocMemRefValue(
108-
rewriter, memrefToTensor.getMemref(), rankedDestType);
113+
rewriter, memrefToTensor.getMemref(), rankedDestType, options);
109114
if (failed(replacement))
110115
return failure();
111116

@@ -792,7 +797,7 @@ struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
792797

793798
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
794799
PatternRewriter &rewriter) const final {
795-
return foldToMemrefToTensorPair(rewriter, toMemref);
800+
return foldToMemrefToTensorPair(rewriter, toMemref, {});
796801
}
797802
};
798803

@@ -840,7 +845,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
840845
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
841846
const BufferizationOptions &options) {
842847
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
843-
(void)foldToMemrefToTensorPair(rewriter, *this);
848+
(void)foldToMemrefToTensorPair(rewriter, *this, options);
844849
// Note: The return value of `bufferize` indicates whether there was an error
845850
// or not. (And not whether the pattern matched or not.)
846851
return success();

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
7575
if (!rankedDestType)
7676
return nullptr;
7777
FailureOr<Value> replacement =
78-
castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
78+
castOrReallocMemRefValue(builder, inputs[0], rankedDestType, {});
7979
if (failed(replacement))
8080
return nullptr;
8181
return *replacement;
@@ -512,8 +512,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
512512
// Fold all to_memref(to_tensor(x)) pairs.
513513
for (Operation *op : toMemrefOps) {
514514
rewriter.setInsertionPoint(op);
515-
(void)bufferization::foldToMemrefToTensorPair(rewriter,
516-
cast<ToMemrefOp>(op));
515+
(void)bufferization::foldToMemrefToTensorPair(
516+
rewriter, cast<ToMemrefOp>(op), options);
517517
}
518518

519519
// Remove all dead to_tensor ops.

mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
3333
// CHECK-SAME: %[[arg:.*]]: memref<?xf32, strided<[1], offset: ?>>)
3434
// CHECK: %[[c0:.*]] = arith.constant 0 : index
3535
// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
36-
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
36+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
3737
// CHECK: memref.copy %[[arg]], %[[alloc]]
3838
// CHECK: return %[[alloc]]
3939
func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
@@ -48,7 +48,7 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
4848
// CHECK-SAME: %[[arg:.*]]: memref<?xf32, strided<[100], offset: ?>>)
4949
// CHECK: %[[c0:.*]] = arith.constant 0 : index
5050
// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
51-
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
51+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
5252
// CHECK: memref.copy %[[arg]], %[[alloc]]
5353
// CHECK: return %[[alloc]]
5454
func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
@@ -63,7 +63,7 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
6363
// CHECK-SAME: %[[arg:.*]]: memref<?xf32, strided<[1], offset: 25>>)
6464
// CHECK: %[[c0:.*]] = arith.constant 0 : index
6565
// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
66-
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
66+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
6767
// CHECK: memref.copy %[[arg]], %[[alloc]]
6868
// CHECK: return %[[alloc]]
6969
func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func.func @main(%t: tensor<5xf32>) -> (f32, f32) {
8484
// Note: This alloc is not needed, but it is inserted before the returned buffer
8585
// is promoted to an out param to reconcile mismatching layout maps on return
8686
// value and function signature.
87-
// CHECK-NO-LAYOUT: %[[alloc2:.*]] = memref.alloc() : memref<2x5xf32>
87+
// CHECK-NO-LAYOUT: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<2x5xf32>
8888
// CHECK-NO-LAYOUT: memref.copy %[[subview]], %[[alloc2]]
8989
// CHECK-NO-LAYOUT: memref.copy %[[alloc2]], %[[r]]
9090

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32
5252
// CHECK-NO-LAYOUT-MAP-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32>
5353
// CHECK-NO-LAYOUT-MAP: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<20x10xf32>
5454
// CHECK-NO-LAYOUT-MAP: %[[subview:.*]] = memref.subview {{.*}} : memref<20x10xf32> to memref<2x?xf32, strided<[10, 1], offset: ?>>
55-
// CHECK-NO-LAYOUT-MAP: %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) : memref<2x?xf32>
55+
// CHECK-NO-LAYOUT-MAP: %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<2x?xf32>
5656
// CHECK-NO-LAYOUT-MAP: memref.copy %[[subview]], %[[alloc_no_layout]]
5757
// TODO: %alloc should be deallocated here, but we currently do not dealloc
5858
// buffers that are inserted due to to_tensor/to_memref canonicalization (when

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func.func @canonicalize_buffer_cast_of_tensor_load_to_copy(
8484
// CHECK-NOT: bufferization.to_memref
8585
// CHECK: %[[C0:.*]] = arith.constant 0 : index
8686
// CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref<?xf32, strided<[1], offset: ?>>
87-
// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, strided<[1], offset: 3>>
87+
// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) {{.*}} : memref<?xf32, strided<[1], offset: 3>>
8888
// CHECK: memref.copy %[[M]], %[[ALLOC]]
8989
// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: 3>>
9090
// CHECK: return %[[ALLOC]]

0 commit comments

Comments
 (0)