Skip to content

Commit 645ed9d

Browse files
[mlir][bufferization] Fix OneShotBufferize when defaultMemorySpaceFn is used
As mentioned in the issue described in issue #91518, a previous PR #78484 introduced the `defaultMemorySpaceFn` into bufferization options, allowing one to inform OneShotBufferize that it should use a specified function to derive the memory space attribute from the encoding attribute attached to tensor types. However, introducing this feature exposed a unhandled edge cases, examples of which are introduced by this change in the new test under `test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir`. Fixing the inconsistencies introduced by `defaultMemorySpaceFn` is pretty simple. This change: - updates the `bufferization.to_memref` and `bufferization.to_tensor` operations to explicitly include operand and destination types, whereas previously they relied on type inference to deduce the tensor types. Since the type inference cannot recover the correct tensor encoding/memory space, the operand and result types must be explicitly included. - makes minor updates to other bufferization functions to handle the changes in building the above ops - updates bufferization of `tensor.from_elements` to handle memory space
1 parent 2ed8c5d commit 645ed9d

File tree

68 files changed

+514
-346
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+514
-346
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,16 @@
1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
1414
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1516
#include "mlir/Interfaces/CopyOpInterface.h"
1617
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1718
#include "mlir/Interfaces/InferTypeOpInterface.h"
1819
#include "mlir/Interfaces/SubsetOpInterface.h"
20+
#include "llvm/Support/Debug.h"
21+
22+
namespace mlir::bufferization::detail {
23+
bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
24+
} // namespace mlir::bufferization::detail
1925

2026
//===----------------------------------------------------------------------===//
2127
// Bufferization Dialect

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
387387
BufferizableOpInterface,
388388
SameOperandsAndResultShape,
389389
SameOperandsAndResultElementType,
390-
TypesMatchWith<"result type matches tensor equivalent of 'memref'",
391-
"memref", "result",
392-
"memref::getTensorTypeFromMemRefType($_self)">
390+
AllElementTypesMatch<["memref", "result"]>
393391
]> {
394392
let summary = "create a tensor from a `memref`";
395393
let description = [{
@@ -476,9 +474,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
476474

477475
let assemblyFormat = [{
478476
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
479-
`:` type($memref)
477+
`:` type($memref) `->` type($result)
480478
}];
481479

480+
let builders = [
481+
OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
482+
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
483+
build($_builder, $_state, rtt, memref, restrict, writeable);
484+
}]>
485+
];
486+
482487
let hasCanonicalizer = 1;
483488
let hasFolder = 1;
484489
}
@@ -495,7 +500,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
495500
Pure,
496501
TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
497502
"memref", "tensor",
498-
"memref::getTensorTypeFromMemRefType($_self)">
503+
"memref::getTensorTypeFromMemRefType($_self)",
504+
"bufferization::detail::tensorTypesMatchUpToEncoding">
499505
]> {
500506
let summary = "cast a tensor to memref";
501507
let description = [{
@@ -550,7 +556,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
550556
}];
551557

552558
let assemblyFormat = [{
553-
$tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
559+
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
554560
}];
555561

556562
let hasFolder = 1;

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
510510
/*default=*/"false",
511511
"The memory space of an memref types must always be inferred. If "
512512
"unset, a default memory space of 0 is used otherwise.">,
513+
Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
514+
/*default=*/"false",
515+
"Use the Tensor encoding attribute for the memory space. Exclusive to"
516+
" the 'must-infer-memory-space option'">,
513517
Option<"testAnalysisOnly", "test-analysis-only", "bool",
514518
/*default=*/"false",
515519
"Test only: Only run inplaceability analysis and annotate IR">,

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
718718
// loose all of its users and eventually DCE away.
719719
rewriter.setInsertionPointAfter(op);
720720
replacement = rewriter.create<bufferization::ToTensorOp>(
721-
replacement.getLoc(), replacement);
721+
replacement.getLoc(), opResult.getType(), replacement);
722722
}
723723
replacements.push_back(replacement);
724724
}

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
2323
// Helper functions
2424
//===----------------------------------------------------------------------===//
2525

26+
bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
27+
auto lhsType = cast<ShapedType>(lhs);
28+
auto rhsType = cast<ShapedType>(rhs);
29+
if (lhsType.getElementType() != rhsType.getElementType())
30+
return false;
31+
if (lhsType.hasRank() && rhsType.hasRank())
32+
return lhsType.getShape() == rhsType.getShape();
33+
return true;
34+
}
35+
2636
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
2737
OpBuilder &b, Value value, MemRefType destType,
2838
const BufferizationOptions &options) {

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
6666
ValueRange inputs, Location loc) -> Value {
6767
assert(inputs.size() == 1 && "expected exactly one input");
6868

69+
// Unranked to ranked casts must be explicit.
70+
if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
71+
return nullptr;
72+
6973
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
7074
// MemRef to MemRef cast.
7175
assert(inputType != type && "expected different types");
72-
// Unranked to ranked and ranked to unranked casts must be explicit.
76+
// Ranked to unranked casts must be explicit.
7377
auto rankedDestType = dyn_cast<MemRefType>(type);
7478
if (!rankedDestType)
7579
return nullptr;
@@ -152,6 +156,13 @@ struct OneShotBufferizePass
152156
[](TensorType t) -> std::optional<Attribute> {
153157
return std::nullopt;
154158
};
159+
} else if (useEncodingForMemorySpace) {
160+
opt.defaultMemorySpaceFn =
161+
[](TensorType t) -> std::optional<Attribute> {
162+
if (auto rtt = dyn_cast<RankedTensorType>(t))
163+
return rtt.getEncoding();
164+
return std::nullopt;
165+
};
155166
}
156167
opt.printConflicts = printConflicts;
157168
opt.bufferAlignment = bufferAlignment;

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,6 @@ struct FromElementsOpInterface
480480
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
481481
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
482482

483-
// TODO: Implement memory space for this op.
484-
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
485-
return op->emitError("memory space not implemented yet");
486-
487483
// Allocate a buffer for the result.
488484
Location loc = op->getLoc();
489485
auto shape = tensorType.getShape();
@@ -493,10 +489,12 @@ struct FromElementsOpInterface
493489
/*copy=*/false);
494490
if (failed(tensorAlloc))
495491
return failure();
496-
auto memrefType =
497-
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
492+
FailureOr<BaseMemRefType> memrefType =
493+
bufferization::getBufferType(*tensorAlloc, options);
494+
if (failed(memrefType))
495+
return failure();
498496
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
499-
op->getLoc(), memrefType, *tensorAlloc);
497+
op->getLoc(), *memrefType, *tensorAlloc);
500498

501499
// Case: tensor<0xelem_type>.
502500
if (fromElementsOp.getElements().empty()) {

mlir/test/Dialect/Arith/bufferize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
77
%index_scalar = arith.index_cast %scalar : i32 to index
88
return %index_tensor, %index_scalar : tensor<index>, index
99
}
10-
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
10+
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
1111
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
1212
// CHECK-SAME: memref<i32> to memref<index>
1313
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -83,8 +83,8 @@ func.func @non_tensor() {
8383
// CHECK-SAME: %[[PRED:.*]]: i1,
8484
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
8585
// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
86-
// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
87-
// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
86+
// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
87+
// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
8888
// CHECK: %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
8989
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
9090
// CHECK: return %[[RET]] : tensor<f32>

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
// CHECK-NEXT: %[[clone:.*]] = bufferization.clone %[[m]]
1010
// CHECK-NEXT: return %[[clone]]
1111
func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
12-
%0 = bufferization.to_memref %t : memref<?x?x?xf16>
12+
%0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
1313
return %0 : memref<?x?x?xf16>
1414
}
1515

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
9696
// Some op may write into the result of to_memref later.
9797
// CHECK: bufferization.to_memref
9898
// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
99-
%m = bufferization.to_memref %t : memref<5xf32>
99+
%m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
100100
%2 = tensor.extract %t[%idx] : tensor<5xf32>
101101
return %2 : f32
102102
}
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
112112
// Some op may write into the result of to_memref later.
113113
// CHECK: bufferization.to_memref
114114
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
115-
%m = bufferization.to_memref %t {read_only} : memref<5xf32>
115+
%m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
116116
%2 = tensor.extract %t[%idx] : tensor<5xf32>
117117
return %2 : f32
118118
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
2+
3+
// TODO: move to tensor dialect tests
4+
func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
5+
%t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
6+
%i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
7+
return %i : tensor<3xf32, 1>
8+
}
9+
10+
// CHECK-LABEL: @from_elements
11+
// CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
12+
// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
13+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
14+
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
15+
// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
16+
// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
17+
// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
18+
// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
19+
// CHECK: memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : memref<3xf32, 1>
20+
// CHECK: %[[v0:.+]] = bufferization.to_tensor %[[alloc]] : memref<3xf32, 1> -> tensor<3xf32, 1 : i64>
21+
// CHECK: return %[[v0]] : tensor<3xf32, 1 : i64>
22+
23+
// -----
24+
25+
func.func @alloc_tesor_with_space_no_encoding() -> tensor<128xf32> {
26+
%0 = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<128xf32>
27+
return %0 : tensor<128xf32>
28+
}
29+
30+
// CHECK-LABEL: @alloc_tesor_with_space_no_encoding
31+
// CHECK-SAME: () -> tensor<128xf32> {
32+
// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 1>
33+
// CHECK: %[[v0:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 1> -> tensor<128xf32>
34+
// CHECK: return %[[v0]] : tensor<128xf32>
35+
36+
// -----
37+
38+
func.func @alloc_tesor_with_space_and_cast() -> tensor<128xf32, 1> {
39+
%0 = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<128xf32>
40+
%1 = tensor.cast %0 : tensor<128xf32> to tensor<128xf32, 1>
41+
return %1 : tensor<128xf32, 1>
42+
}
43+
44+
// CHECK-LABEL: @alloc_tesor_with_space_and_cast
45+
// CHECK-SAME: () -> tensor<128xf32, 1 : i64> {
46+
// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 1>
47+
// CHECK: %[[v0:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 1> -> tensor<128xf32, 1 : i64>
48+
// CHECK: return %[[v0]] : tensor<128xf32, 1 : i64>
49+
50+
// -----
51+
52+
func.func @alloc_tesor_with_space_with_encoding() -> tensor<128xf32, 1 : i64> {
53+
%0 = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<128xf32, 1 : i64>
54+
return %0 : tensor<128xf32, 1 : i64>
55+
}
56+
57+
// CHECK-LABEL: @alloc_tesor_with_space_with_encoding
58+
// CHECK-SAME: () -> tensor<128xf32, 1 : i64> {
59+
// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 1>
60+
// CHECK: %[[v0:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 1> -> tensor<128xf32, 1 : i64>
61+
// CHECK: return %[[v0]] : tensor<128xf32, 1 : i64>
62+
63+
// -----
64+
65+
func.func @alloc_tesor_copy_from_default_space(%arg0: tensor<128xf32>) -> tensor<128xf32> {
66+
%0 = bufferization.alloc_tensor() copy(%arg0) {memory_space = 1 : i64} : tensor<128xf32>
67+
return %0 : tensor<128xf32>
68+
}
69+
70+
// CHECK-LABEL: @alloc_tesor_copy_from_default_space
71+
// CHECK-SAME: (%[[arg0:.+]]: tensor<128xf32>) -> tensor<128xf32> {
72+
// CHECK: %[[v0:.+]] = bufferization.to_memref %[[arg0]] : tensor<128xf32> -> memref<128xf32, strided<[?], offset: ?>>
73+
// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 1>
74+
// CHECK: memref.copy %[[v0]], %[[alloc]] : memref<128xf32, strided<[?], offset: ?>> to memref<128xf32, 1>
75+
// CHECK: %[[v1:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 1> -> tensor<128xf32>
76+
// CHECK: return %[[v1]] : tensor<128xf32>
77+
78+
// -----
79+
80+
func.func @alloc_tesor_copy_from_non_default_space(%arg0: tensor<128xf32, 1>) -> tensor<128xf32, 2> {
81+
%0 = bufferization.alloc_tensor() copy(%arg0) {memory_space = 2 : i64} : tensor<128xf32, 1>
82+
%1 = tensor.cast %0 : tensor<128xf32, 1> to tensor<128xf32, 2>
83+
return %1 : tensor<128xf32, 2>
84+
}
85+
86+
// CHECK-LABEL: @alloc_tesor_copy_from_non_default_space
87+
// CHECK-SAME: (%[[arg0:.+]]: tensor<128xf32, 1 : i64>) -> tensor<128xf32, 2 : i64> {
88+
// CHECK: %[[v0:.+]] = bufferization.to_memref %[[arg0]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1>
89+
// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 2>
90+
// CHECK: memref.copy %[[v0]], %[[alloc]] : memref<128xf32, strided<[?], offset: ?>, 1> to memref<128xf32, 2>
91+
// CHECK: %[[v1:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 2> -> tensor<128xf32, 2 : i64>
92+
// CHECK: return %[[v1]] : tensor<128xf32, 2 : i64>
93+
94+
// -----
95+
96+
// TODO: this should be illegal since ultimately we can not eliminate the `bufferization.to_tensor` when we
97+
// bufferize function boundaries.
98+
func.func @alloc_tesor_copy_from_non_default_space_no_cast(%arg0: tensor<128xf32, 1>,
99+
%arg1: tensor<4xf32, 1>) -> tensor<128xf32, 1> {
100+
%0 = bufferization.alloc_tensor() copy(%arg0) {memory_space = 2 : i64} : tensor<128xf32, 1>
101+
%1 = tensor.insert_slice %arg1 into %arg0 [0][4][1] : tensor<4xf32, 1> into tensor<128xf32, 1>
102+
return %0 : tensor<128xf32, 1>
103+
}
104+
105+
// CHECK-LABEL: @alloc_tesor_copy_from_non_default_space_no_cast
106+
// CHECK-SAME: (%[[arg0:.+]]: tensor<128xf32, 1 : i64>, %[[arg1:.+]]: tensor<4xf32, 1 : i64>) -> tensor<128xf32, 1 : i64> {
107+
// CHECK: %[[v0:.+]] = bufferization.to_memref %[[arg1]] : tensor<4xf32, 1 : i64> -> memref<4xf32, strided<[?], offset: ?>, 1>
108+
// CHECK: %[[v1:.+]] = bufferization.to_memref %[[arg0]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1>
109+
// CHECK: %[[v2:.+]] = bufferization.to_memref %[[arg0]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1>
110+
// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 2>
111+
// CHECK: memref.copy %[[v2]], %[[alloc]] : memref<128xf32, strided<[?], offset: ?>, 1> to memref<128xf32, 2>
112+
// CHECK: %[[v3:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 2> -> tensor<128xf32, 1 : i64>
113+
// CHECK: %[[alloc_0:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 1>
114+
// CHECK: memref.copy %[[v1]], %[[alloc_0]] : memref<128xf32, strided<[?], offset: ?>, 1> to memref<128xf32, 1>
115+
// CHECK: %[[subview:.+]] = memref.subview %[[alloc_0]][0] [4] [1] : memref<128xf32, 1> to memref<4xf32, strided<[1]>, 1>
116+
// CHECK: memref.copy %[[v0]], %[[subview]] : memref<4xf32, strided<[?], offset: ?>, 1> to memref<4xf32, strided<[1]>, 1>
117+
// CHECK: return %[[v3]] : tensor<128xf32, 1 : i64>
118+
119+
// -----
120+
121+
func.func @materialize_in_destination(%arg0: tensor<128xf32, 1>) -> tensor<128xf32, 2> {
122+
%0 = bufferization.alloc_tensor () {memory_space = 2 : i64} : tensor<128xf32, 2>
123+
%1 = bufferization.materialize_in_destination %arg0 in %0 : (tensor<128xf32, 1>, tensor<128xf32, 2>) -> tensor<128xf32, 2>
124+
return %1 : tensor<128xf32, 2>
125+
}
126+
127+
// CHECK-LABEL: @materialize_in_destination
128+
// CHECK-SAME: (%[[arg0:.+]]: tensor<128xf32, 1 : i64>) -> tensor<128xf32, 2 : i64> {
129+
// CHECK: %[[v0:.+]] = bufferization.to_memref %[[arg0]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1>
130+
// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 2>
131+
// CHECK: memref.copy %[[v0]], %[[alloc]] : memref<128xf32, strided<[?], offset: ?>, 1> to memref<128xf32, 2>
132+
// CHECK: %[[v1:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 2> -> tensor<128xf32, 2 : i64>
133+
// CHECK: return %[[v1]] : tensor<128xf32, 2 : i64>

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ func.func @use_of_unknown_op_1(%t1: tensor<?xf32>)
2525

2626
%idx = arith.constant 0 : index
2727
%cst = arith.constant 0.0 : f32
28-
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref<?xf32, strided<[?], offset: ?>>
28+
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : tensor<?xf32> -> memref<?xf32, strided<[?], offset: ?>>
2929
// CHECK: vector.transfer_read %[[dummy_memref]][%{{.*}}], %{{.*}} : memref<?xf32, strided<[?], offset: ?>>
30-
// CHECK-NO-LAYOUT-MAP: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref<?xf32>
30+
// CHECK-NO-LAYOUT-MAP: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : tensor<?xf32> -> memref<?xf32>
3131
// CHECK-NO-LAYOUT-MAP: vector.transfer_read %[[dummy_memref]][%{{.*}}], %{{.*}} : memref<?xf32>
3232
%1 = vector.transfer_read %0[%idx], %cst : tensor<?xf32>, vector<5xf32>
3333
return %1 : vector<5xf32>
@@ -61,7 +61,7 @@ func.func @use_of_unknown_op_3(%t1: tensor<?xf32>)
6161

6262
// CHECK: %[[dummy:.*]] = "test.dummy_op"(%[[t1]])
6363
%0 = "test.dummy_op"(%t1) : (tensor<?xf32>) -> tensor<?xf32>
64-
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref<?xf32, strided<[?], offset: ?>>
64+
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : tensor<?xf32> -> memref<?xf32, strided<[?], offset: ?>>
6565
// CHECK: %[[v2:.*]] = vector.transfer_read %[[dummy_memref]]
6666
%2 = vector.transfer_read %0[%idx], %cst : tensor<?xf32>, vector<5xf32>
6767

0 commit comments

Comments
 (0)