Skip to content

Commit ced2fc7

Browse files
[mlir][bufferization] Fix OneShotBufferize when defaultMemorySpaceFn is used (#91524)
As described in issue #91518, a previous PR #78484 introduced the `defaultMemorySpaceFn` into bufferization options, allowing one to inform OneShotBufferize that it should use a specified function to derive the memory space attribute from the encoding attribute attached to tensor types. However, introducing this feature exposed unhandled edge cases, examples of which are introduced by this change in the new test under `test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir`. Fixing the inconsistencies introduced by `defaultMemorySpaceFn` is pretty simple. This change: - Updates the `bufferization.to_memref` and `bufferization.to_tensor` operations to explicitly include operand and destination types, whereas previously they relied on type inference to deduce the tensor types. Since the type inference cannot recover the correct tensor encoding/memory space, the operand and result types must be explicitly included. This is a small assembly format change, but it touches a large number of test files. - Makes minor updates to other bufferization functions to handle the changes in building the above ops. - Updates bufferization of `tensor.from_elements` to handle memory space. Integration/upgrade guide: In downstream projects, if you have tests or MLIR files that explicitly use `bufferization.to_tensor` or `bufferization.to_memref`, then update them to the new assembly format as follows: ``` %1 = bufferization.to_memref %0 : memref<10xf32> %2 = bufferization.to_tensor %1 : memref<10xf32> ``` becomes ``` %1 = bufferization.to_memref %0 : tensor<10xf32> to memref<10xf32> %2 = bufferization.to_tensor %0 : memref<10xf32> to tensor<10xf32> ```
1 parent 88cff86 commit ced2fc7

File tree

73 files changed

+608
-373
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+608
-373
lines changed

mlir/docs/Bufferization.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ func.func @test_matmul(%A: memref<1x17x19xf32>,
223223
%B: memref<1x19x29xf32>,
224224
%C: memref<1x17x29xf32>) {
225225
226-
%A_tensor = bufferization.to_tensor %A restrict : memref<1x17x19xf32>
227-
%B_tensor = bufferization.to_tensor %B restrict : memref<1x19x29xf32>
226+
%A_tensor = bufferization.to_tensor %A restrict : memref<1x17x19xf32> to tensor<1x17x19xf32>
227+
%B_tensor = bufferization.to_tensor %B restrict : memref<1x19x29xf32> to tensor<1x19x29xf32>
228228
229229
%0 = tosa.matmul %A_tensor, %B_tensor
230230
: (tensor<1x17x19xf32>, tensor<1x19x29xf32>) ->

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
387387
BufferizableOpInterface,
388388
SameOperandsAndResultShape,
389389
SameOperandsAndResultElementType,
390-
TypesMatchWith<"result type matches tensor equivalent of 'memref'",
391-
"memref", "result",
392-
"memref::getTensorTypeFromMemRefType($_self)">
390+
AllElementTypesMatch<["memref", "result"]>
393391
]> {
394392
let summary = "create a tensor from a `memref`";
395393
let description = [{
@@ -404,7 +402,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
404402

405403
```mlir
406404
// Produces a value of tensor<4x?xf32> type.
407-
%t = bufferization.to_tensor %m : memref<4x?xf32, #layout, 0>
405+
%t = bufferization.to_tensor %m : memref<4x?xf32, #layout, 0> to tensor<4x?xf32>
408406
```
409407

410408
If the `writable` unit attribute is set, the produced tensor is considered
@@ -427,7 +425,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
427425
Example:
428426

429427
```
430-
%t = bufferization.to_tensor %m restrict writable : memref<4xf32>
428+
%t = bufferization.to_tensor %m restrict writable : memref<4xf32> to tensor<4xf32>
431429

432430
// %t is writable, so the tensor.insert may bufferize in-place in the
433431
// absence of other conflicts.
@@ -476,9 +474,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
476474

477475
let assemblyFormat = [{
478476
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
479-
`:` type($memref)
477+
`:` type($memref) `to` type($result)
480478
}];
481479

480+
let builders = [
481+
OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
482+
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
483+
build($_builder, $_state, rtt, memref, restrict, writeable);
484+
}]>
485+
];
486+
482487
let hasCanonicalizer = 1;
483488
let hasFolder = 1;
484489
}
@@ -493,17 +498,16 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
493498
SameOperandsAndResultShape,
494499
SameOperandsAndResultElementType,
495500
Pure,
496-
TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
497-
"memref", "tensor",
498-
"memref::getTensorTypeFromMemRefType($_self)">
501+
AllShapesMatch<["memref", "tensor"]>,
502+
AllElementTypesMatch<["memref", "tensor"]>
499503
]> {
500504
let summary = "cast a tensor to memref";
501505
let description = [{
502506
An operation that returns the future buffer of a `tensor`.
503507

504508
```mlir
505509
// Result type is memref<4x?xf32, #layout, 0>
506-
%m = bufferization.to_memref %t : memref<4x?xf32, #layout, 0>
510+
%m = bufferization.to_memref %t : tensor<4x?xf32> to memref<4x?xf32, #layout, 0>
507511
```
508512

509513
This operation is a specialized variant of the built-in
@@ -550,7 +554,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
550554
}];
551555

552556
let assemblyFormat = [{
553-
$tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
557+
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
554558
}];
555559

556560
let hasFolder = 1;

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def OptimizeAllocationLiveness
250250
let summary = "This pass optimizes the liveness of temp allocations in the "
251251
"input function";
252252
let description =
253-
[{This pass will find all operations that have a memory allocation effect.
254-
It will search for the corresponding deallocation and move it right after
253+
[{This pass will find all operations that have a memory allocation effect.
254+
It will search for the corresponding deallocation and move it right after
255255
the last user of the allocation.
256256
This will optimize the liveness of the allocations.
257257

@@ -510,6 +510,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
510510
/*default=*/"false",
511511
"The memory space of an memref types must always be inferred. If "
512512
"unset, a default memory space of 0 is used otherwise.">,
513+
Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
514+
/*default=*/"false",
515+
"Use the Tensor encoding attribute for the memory space. Exclusive to"
516+
" the 'must-infer-memory-space' option">,
513517
Option<"testAnalysisOnly", "test-analysis-only", "bool",
514518
/*default=*/"false",
515519
"Test only: Only run inplaceability analysis and annotate IR">,

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
718718
// loose all of its users and eventually DCE away.
719719
rewriter.setInsertionPointAfter(op);
720720
replacement = rewriter.create<bufferization::ToTensorOp>(
721-
replacement.getLoc(), replacement);
721+
replacement.getLoc(), opResult.getType(), replacement);
722722
}
723723
replacements.push_back(replacement);
724724
}

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
6969
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
7070
// MemRef to MemRef cast.
7171
assert(inputType != type && "expected different types");
72-
// Unranked to ranked and ranked to unranked casts must be explicit.
72+
// Ranked to unranked casts must be explicit.
7373
auto rankedDestType = dyn_cast<MemRefType>(type);
7474
if (!rankedDestType)
7575
return nullptr;
@@ -147,12 +147,31 @@ struct OneShotBufferizePass
147147
opt.dumpAliasSets = dumpAliasSets;
148148
opt.setFunctionBoundaryTypeConversion(
149149
parseLayoutMapOption(functionBoundaryTypeConversion));
150+
151+
if (mustInferMemorySpace && useEncodingForMemorySpace) {
152+
emitError(getOperation()->getLoc())
153+
<< "only one of 'must-infer-memory-space' and "
154+
"'use-encoding-for-memory-space' are allowed in "
155+
<< getArgument();
156+
return signalPassFailure();
157+
}
158+
150159
if (mustInferMemorySpace) {
151160
opt.defaultMemorySpaceFn =
152161
[](TensorType t) -> std::optional<Attribute> {
153162
return std::nullopt;
154163
};
155164
}
165+
166+
if (useEncodingForMemorySpace) {
167+
opt.defaultMemorySpaceFn =
168+
[](TensorType t) -> std::optional<Attribute> {
169+
if (auto rtt = dyn_cast<RankedTensorType>(t))
170+
return rtt.getEncoding();
171+
return std::nullopt;
172+
};
173+
}
174+
156175
opt.printConflicts = printConflicts;
157176
opt.bufferAlignment = bufferAlignment;
158177
opt.testAnalysisOnly = testAnalysisOnly;

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ struct ExecuteRegionOpInterface
203203
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
204204
if (isa<TensorType>(it.value())) {
205205
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
206-
executeRegionOp.getLoc(), newOp->getResult(it.index())));
206+
executeRegionOp.getLoc(), it.value(),
207+
newOp->getResult(it.index())));
207208
} else {
208209
newResults.push_back(newOp->getResult(it.index()));
209210
}
@@ -485,15 +486,17 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
485486
/// ToTensorOps, so that the block body can be moved over to the new op.
486487
static SmallVector<Value>
487488
getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
489+
Block::BlockArgListType oldBbArgs,
488490
const DenseSet<int64_t> &tensorIndices) {
489491
SmallVector<Value> result;
490492
for (const auto &it : llvm::enumerate(bbArgs)) {
491493
size_t idx = it.index();
492494
Value val = it.value();
493495
if (tensorIndices.contains(idx)) {
494-
result.push_back(
495-
rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
496-
.getResult());
496+
result.push_back(rewriter
497+
.create<bufferization::ToTensorOp>(
498+
val.getLoc(), oldBbArgs[idx].getType(), val)
499+
.getResult());
497500
} else {
498501
result.push_back(val);
499502
}
@@ -763,7 +766,8 @@ struct ForOpInterface
763766
// iter_args of the new loop in ToTensorOps.
764767
rewriter.setInsertionPointToStart(loopBody);
765768
SmallVector<Value> iterArgs =
766-
getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
769+
getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
770+
forOp.getRegionIterArgs(), indices);
767771
iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
768772

769773
// Move loop body to new loop.
@@ -1000,16 +1004,18 @@ struct WhileOpInterface
10001004
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
10011005
// in ToTensorOps.
10021006
rewriter.setInsertionPointToStart(newBeforeBody);
1003-
SmallVector<Value> newBeforeArgs = getBbArgReplacements(
1004-
rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
1007+
SmallVector<Value> newBeforeArgs =
1008+
getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1009+
whileOp.getBeforeArguments(), indicesBefore);
10051010
rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
10061011

10071012
// Set up new iter_args and move the loop body block to the new op.
10081013
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
10091014
// in ToTensorOps.
10101015
rewriter.setInsertionPointToStart(newAfterBody);
1011-
SmallVector<Value> newAfterArgs = getBbArgReplacements(
1012-
rewriter, newWhileOp.getAfterArguments(), indicesAfter);
1016+
SmallVector<Value> newAfterArgs =
1017+
getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1018+
whileOp.getAfterArguments(), indicesAfter);
10131019
rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
10141020

10151021
// Replace loop results.
@@ -1255,8 +1261,8 @@ struct ForallOpInterface
12551261
forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
12561262
BlockArgument bbArg = std::get<0>(it);
12571263
Value buffer = std::get<1>(it);
1258-
Value bufferAsTensor =
1259-
rewriter.create<ToTensorOp>(forallOp.getLoc(), buffer);
1264+
Value bufferAsTensor = rewriter.create<ToTensorOp>(
1265+
forallOp.getLoc(), bbArg.getType(), buffer);
12601266
bbArg.replaceAllUsesWith(bufferAsTensor);
12611267
}
12621268

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,6 @@ struct FromElementsOpInterface
480480
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
481481
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
482482

483-
// TODO: Implement memory space for this op.
484-
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
485-
return op->emitError("memory space not implemented yet");
486-
487483
// Allocate a buffer for the result.
488484
Location loc = op->getLoc();
489485
auto shape = tensorType.getShape();
@@ -493,10 +489,12 @@ struct FromElementsOpInterface
493489
/*copy=*/false);
494490
if (failed(tensorAlloc))
495491
return failure();
496-
auto memrefType =
497-
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
492+
FailureOr<BaseMemRefType> memrefType =
493+
bufferization::getBufferType(*tensorAlloc, options);
494+
if (failed(memrefType))
495+
return failure();
498496
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
499-
op->getLoc(), memrefType, *tensorAlloc);
497+
op->getLoc(), *memrefType, *tensorAlloc);
500498

501499
// Case: tensor<0xelem_type>.
502500
if (fromElementsOp.getElements().empty()) {

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ module {
242242
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
243243
tensor.yield %cst_f32 : f32
244244
} : tensor<1x32x32x8xf32> to tensor<1x40x8229x8xf32>
245-
%1 = bufferization.to_memref %padded : memref<1x40x8229x8xf32>
245+
%1 = bufferization.to_memref %padded : tensor<1x40x8229x8xf32> to memref<1x40x8229x8xf32>
246246
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x8xf32>
247247
affine.for %arg1 = 0 to 1 {
248248
affine.for %arg2 = 0 to 32 {
@@ -280,7 +280,7 @@ module {
280280
// SPIRV-NOT: affine.for %{{.*}}
281281

282282
// SPIRV: ReturnValue
283-
%2 = bufferization.to_tensor %alloc_1 : memref<1x32x32x8xf32>
283+
%2 = bufferization.to_tensor %alloc_1 : memref<1x32x32x8xf32> to tensor<1x32x32x8xf32>
284284
%3 = builtin.unrealized_conversion_cast %2 : tensor<1x32x32x8xf32> to !spirv.array<8192 x f32>
285285
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
286286
}

mlir/test/Dialect/Arith/bufferize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
77
%index_scalar = arith.index_cast %scalar : i32 to index
88
return %index_tensor, %index_scalar : tensor<index>, index
99
}
10-
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
10+
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
1111
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
1212
// CHECK-SAME: memref<i32> to memref<index>
1313
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -83,8 +83,8 @@ func.func @non_tensor() {
8383
// CHECK-SAME: %[[PRED:.*]]: i1,
8484
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
8585
// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
86-
// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
87-
// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
86+
// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
87+
// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
8888
// CHECK: %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
8989
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
9090
// CHECK: return %[[RET]] : tensor<f32>

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
// CHECK-NEXT: %[[clone:.*]] = bufferization.clone %[[m]]
1010
// CHECK-NEXT: return %[[clone]]
1111
func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
12-
%0 = bufferization.to_memref %t : memref<?x?x?xf16>
12+
%0 = bufferization.to_memref %t : tensor<?x?x?xf16> to memref<?x?x?xf16>
1313
return %0 : memref<?x?x?xf16>
1414
}
1515

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
9696
// Some op may write into the result of to_memref later.
9797
// CHECK: bufferization.to_memref
9898
// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
99-
%m = bufferization.to_memref %t : memref<5xf32>
99+
%m = bufferization.to_memref %t : tensor<5xf32> to memref<5xf32>
100100
%2 = tensor.extract %t[%idx] : tensor<5xf32>
101101
return %2 : f32
102102
}
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
112112
// Some op may write into the result of to_memref later.
113113
// CHECK: bufferization.to_memref
114114
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
115-
%m = bufferization.to_memref %t {read_only} : memref<5xf32>
115+
%m = bufferization.to_memref %t {read_only} : tensor<5xf32> to memref<5xf32>
116116
%2 = tensor.extract %t[%idx] : tensor<5xf32>
117117
return %2 : f32
118118
}

0 commit comments

Comments
 (0)