Skip to content

Commit 33fc84c

Browse files
[mlir][bufferization] Fix OneShotBufferize when defaultMemorySpaceFn is used
As mentioned in the issue described in issue llvm#91518, a previous PR llvm#78484 introduced the `defaultMemorySpaceFn` into bufferization options, allowing one to inform OneShotBufferize that it should use a specified function to derive the memory space attribute from the encoding attribute attached to tensor types. However, introducing this feature exposed a unhandled edge cases, examples of which are introduced by this change in the new test under `test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir`. Fixing the inconsistencies introduced by `defaultMemorySpaceFn` is pretty simple. This change: - updates the `bufferization.to_memref` and `bufferization.to_tensor` operations to explicitly include operand and destination types, whereas previously they relied on type inference to deduce the tensor types. Since the type inference cannot recover the correct tensor encoding/memory space, the operand and result types must be explicitly included. - makes minor updates to other bufferization functions to handle the changes in building the above ops - updates bufferization of `tensor.from_elements` to handle memory space
1 parent 45c8766 commit 33fc84c

File tree

70 files changed

+531
-363
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+531
-363
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,16 @@
1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
1414
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1516
#include "mlir/Interfaces/CopyOpInterface.h"
1617
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1718
#include "mlir/Interfaces/InferTypeOpInterface.h"
1819
#include "mlir/Interfaces/SubsetOpInterface.h"
20+
#include "llvm/Support/Debug.h"
21+
22+
namespace mlir::bufferization::detail {
23+
bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
24+
} // namespace mlir::bufferization::detail
1925

2026
//===----------------------------------------------------------------------===//
2127
// Bufferization Dialect

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
387387
BufferizableOpInterface,
388388
SameOperandsAndResultShape,
389389
SameOperandsAndResultElementType,
390-
TypesMatchWith<"result type matches tensor equivalent of 'memref'",
391-
"memref", "result",
392-
"memref::getTensorTypeFromMemRefType($_self)">
390+
AllElementTypesMatch<["memref", "result"]>
393391
]> {
394392
let summary = "create a tensor from a `memref`";
395393
let description = [{
@@ -476,9 +474,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
476474

477475
let assemblyFormat = [{
478476
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
479-
`:` type($memref)
477+
`:` type($memref) `->` type($result)
480478
}];
481479

480+
let builders = [
481+
OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
482+
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
483+
build($_builder, $_state, rtt, memref, restrict, writeable);
484+
}]>
485+
];
486+
482487
let hasCanonicalizer = 1;
483488
let hasFolder = 1;
484489
}
@@ -495,7 +500,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
495500
Pure,
496501
TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
497502
"memref", "tensor",
498-
"memref::getTensorTypeFromMemRefType($_self)">
503+
"memref::getTensorTypeFromMemRefType($_self)",
504+
"bufferization::detail::tensorTypesMatchUpToEncoding">
499505
]> {
500506
let summary = "cast a tensor to memref";
501507
let description = [{
@@ -550,7 +556,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
550556
}];
551557

552558
let assemblyFormat = [{
553-
$tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
559+
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
554560
}];
555561

556562
let hasFolder = 1;

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
533533
/*default=*/"false",
534534
"The memory space of an memref types must always be inferred. If "
535535
"unset, a default memory space of 0 is used otherwise.">,
536+
Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
537+
/*default=*/"false",
538+
"Use the Tensor encoding attribute for the memory space. Exclusive to"
539+
" the 'must-infer-memory-space option'">,
536540
Option<"testAnalysisOnly", "test-analysis-only", "bool",
537541
/*default=*/"false",
538542
"Test only: Only run inplaceability analysis and annotate IR">,

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
719719
// loose all of its users and eventually DCE away.
720720
rewriter.setInsertionPointAfter(op);
721721
replacement = rewriter.create<bufferization::ToTensorOp>(
722-
replacement.getLoc(), replacement);
722+
replacement.getLoc(), opResult.getType(), replacement);
723723
}
724724
replacements.push_back(replacement);
725725
}

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
2323
// Helper functions
2424
//===----------------------------------------------------------------------===//
2525

26+
bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
27+
auto lhsType = cast<ShapedType>(lhs);
28+
auto rhsType = cast<ShapedType>(rhs);
29+
if (lhsType.getElementType() != rhsType.getElementType())
30+
return false;
31+
if (lhsType.hasRank() && rhsType.hasRank())
32+
return lhsType.getShape() == rhsType.getShape();
33+
return true;
34+
}
35+
2636
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
2737
OpBuilder &b, Value value, MemRefType destType,
2838
const BufferizationOptions &options) {

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
6767
ValueRange inputs, Location loc) -> Value {
6868
assert(inputs.size() == 1 && "expected exactly one input");
6969

70+
// Unranked to ranked casts must be explicit.
71+
if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
72+
return nullptr;
73+
7074
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
7175
// MemRef to MemRef cast.
7276
assert(inputType != type && "expected different types");
73-
// Unranked to ranked and ranked to unranked casts must be explicit.
77+
// Ranked to unranked casts must be explicit.
7478
auto rankedDestType = dyn_cast<MemRefType>(type);
7579
if (!rankedDestType)
7680
return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
222226
[](TensorType t) -> std::optional<Attribute> {
223227
return std::nullopt;
224228
};
229+
} else if (useEncodingForMemorySpace) {
230+
opt.defaultMemorySpaceFn =
231+
[](TensorType t) -> std::optional<Attribute> {
232+
if (auto rtt = dyn_cast<RankedTensorType>(t))
233+
return rtt.getEncoding();
234+
return std::nullopt;
235+
};
225236
}
226237
opt.printConflicts = printConflicts;
227238
opt.testAnalysisOnly = testAnalysisOnly;

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -479,10 +479,6 @@ struct FromElementsOpInterface
479479
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
480480
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
481481

482-
// TODO: Implement memory space for this op.
483-
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
484-
return op->emitError("memory space not implemented yet");
485-
486482
// Allocate a buffer for the result.
487483
Location loc = op->getLoc();
488484
auto shape = tensorType.getShape();
@@ -492,10 +488,12 @@ struct FromElementsOpInterface
492488
/*copy=*/false);
493489
if (failed(tensorAlloc))
494490
return failure();
495-
auto memrefType =
496-
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
491+
FailureOr<BaseMemRefType> memrefType =
492+
bufferization::getBufferType(*tensorAlloc, options);
493+
if (failed(memrefType))
494+
return failure();
497495
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
498-
op->getLoc(), memrefType, *tensorAlloc);
496+
op->getLoc(), *memrefType, *tensorAlloc);
499497

500498
// Case: tensor<0xelem_type>.
501499
if (fromElementsOp.getElements().empty()) {

mlir/test/Dialect/Arith/bufferize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
77
%index_scalar = arith.index_cast %scalar : i32 to index
88
return %index_tensor, %index_scalar : tensor<index>, index
99
}
10-
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
10+
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
1111
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
1212
// CHECK-SAME: memref<i32> to memref<index>
1313
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -83,8 +83,8 @@ func.func @non_tensor() {
8383
// CHECK-SAME: %[[PRED:.*]]: i1,
8484
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
8585
// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
86-
// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
87-
// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
86+
// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
87+
// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
8888
// CHECK: %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
8989
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
9090
// CHECK: return %[[RET]] : tensor<f32>

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
// CHECK-NEXT: %[[clone:.*]] = bufferization.clone %[[m]]
1010
// CHECK-NEXT: return %[[clone]]
1111
func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
12-
%0 = bufferization.to_memref %t : memref<?x?x?xf16>
12+
%0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
1313
return %0 : memref<?x?x?xf16>
1414
}
1515

mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
55
// CHECK: return %[[ARG]] : memref<f32>
66
func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
7-
%0 = bufferization.to_tensor %arg0 : memref<f32>
8-
%1 = bufferization.to_memref %0 : memref<f32>
7+
%0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
8+
%1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
99
return %1 : memref<f32>
1010
}
1111

@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
1414
func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
1515
// expected-error @+1 {{failed to legalize operation 'test.source'}}
1616
%0 = "test.source"() : () -> tensor<f32>
17-
%1 = bufferization.to_memref %0 : memref<f32>
17+
%1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
1818
return %1 : memref<f32>
1919
}
2020

2121
// -----
2222

2323
func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
24-
%0 = bufferization.to_tensor %arg0 : memref<f32>
24+
%0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
2525
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
2626
"test.sink"(%0) : (tensor<f32>) -> ()
2727
return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
3737
// CHECK: memref.copy %[[arg]], %[[alloc]]
3838
// CHECK: return %[[alloc]]
3939
func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
40-
%0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
41-
%1 = bufferization.to_memref %0 : memref<?xf32>
40+
%0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
41+
%1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
4242
return %1 : memref<?xf32>
4343
}
4444

@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
5252
// CHECK: memref.copy %[[arg]], %[[alloc]]
5353
// CHECK: return %[[alloc]]
5454
func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
55-
%0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
56-
%1 = bufferization.to_memref %0 : memref<?xf32>
55+
%0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
56+
%1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
5757
return %1 : memref<?xf32>
5858
}
5959

@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
6767
// CHECK: memref.copy %[[arg]], %[[alloc]]
6868
// CHECK: return %[[alloc]]
6969
func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
70-
%0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
71-
%1 = bufferization.to_memref %0 : memref<?xf32>
70+
%0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
71+
%1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
7272
return %1 : memref<?xf32>
7373
}
7474

@@ -77,19 +77,19 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
7777
// TODO: to_memref with layout maps not supported yet. This should fold to a
7878
// memref.cast.
7979
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
80-
%0 = bufferization.to_tensor %m : memref<?xf32>
80+
%0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
8181
// expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
82-
%1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
82+
%1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
8383
// expected-note @below{{see existing live user here}}
8484
return %1 : memref<?xf32, strided<[1], offset: ?>>
8585
}
8686

8787
// -----
8888

8989
func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
90-
// expected-note @+1 {{prior use here}}
91-
%0 = bufferization.to_tensor %m : memref<*xf32>
92-
// expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
93-
%1 = bufferization.to_memref %0 : memref<?xf32>
90+
91+
%0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<*xf32>
92+
// expected-error @+1 {{failed to legalize unresolved materialization from ('memref<*xf32>') to 'memref<?xf32>' that remained live after conversion}}
93+
%1 = bufferization.to_memref %0 : tensor<*xf32> -> memref<?xf32>
9494
return %1 : memref<?xf32>
9595
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
9696
// Some op may write into the result of to_memref later.
9797
// CHECK: bufferization.to_memref
9898
// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
99-
%m = bufferization.to_memref %t : memref<5xf32>
99+
%m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
100100
%2 = tensor.extract %t[%idx] : tensor<5xf32>
101101
return %2 : f32
102102
}
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
112112
// Some op may write into the result of to_memref later.
113113
// CHECK: bufferization.to_memref
114114
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
115-
%m = bufferization.to_memref %t {read_only} : memref<5xf32>
115+
%m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
116116
%2 = tensor.extract %t[%idx] : tensor<5xf32>
117117
return %2 : f32
118118
}

0 commit comments

Comments
 (0)