-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][bufferization] Fix OneShotBufferize when defaultMemorySpaceFn
is used
#91524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Fix OneShotBufferize when defaultMemorySpaceFn
is used
#91524
Conversation
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir-linalg Author: Christopher Bate (christopherbate) ChangesAs mentioned in the issue described in issue llvm/llvm-project#91518, a previous However, introducing this feature exposed a unhandled edge cases, examples of which Fixing the inconsistencies introduced by
Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff 68 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
//===----------------------------------------------------------------------===//
// Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
- TypesMatchWith<"result type matches tensor equivalent of 'memref'",
- "memref", "result",
- "memref::getTensorTypeFromMemRefType($_self)">
+ AllElementTypesMatch<["memref", "result"]>
]> {
let summary = "create a tensor from a `memref`";
let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
let assemblyFormat = [{
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
- `:` type($memref)
+ `:` type($memref) `->` type($result)
}];
+ let builders = [
+ OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+ auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+ build($_builder, $_state, rtt, memref, restrict, writeable);
+ }]>
+ ];
+
let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
Pure,
TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
"memref", "tensor",
- "memref::getTensorTypeFromMemRefType($_self)">
+ "memref::getTensorTypeFromMemRefType($_self)",
+ "bufferization::detail::tensorTypesMatchUpToEncoding">
]> {
let summary = "cast a tensor to memref";
let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
}];
let assemblyFormat = [{
- $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
/*default=*/"false",
"The memory space of an memref types must always be inferred. If "
"unset, a default memory space of 0 is used otherwise.">,
+ Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+ /*default=*/"false",
+ "Use the Tensor encoding attribute for the memory space. Exclusive to"
+ " the 'must-infer-memory-space option'">,
Option<"testAnalysisOnly", "test-analysis-only", "bool",
/*default=*/"false",
"Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
// loose all of its users and eventually DCE away.
rewriter.setInsertionPointAfter(op);
replacement = rewriter.create<bufferization::ToTensorOp>(
- replacement.getLoc(), replacement);
+ replacement.getLoc(), opResult.getType(), replacement);
}
replacements.push_back(replacement);
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
// Helper functions
//===----------------------------------------------------------------------===//
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+ auto lhsType = cast<ShapedType>(lhs);
+ auto rhsType = cast<ShapedType>(rhs);
+ if (lhsType.getElementType() != rhsType.getElementType())
+ return false;
+ if (lhsType.hasRank() && rhsType.hasRank())
+ return lhsType.getShape() == rhsType.getShape();
+ return true;
+}
+
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
OpBuilder &b, Value value, MemRefType destType,
const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
+ // Unranked to ranked casts must be explicit.
+ if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+ return nullptr;
+
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
// MemRef to MemRef cast.
assert(inputType != type && "expected different types");
- // Unranked to ranked and ranked to unranked casts must be explicit.
+ // Ranked to unranked casts must be explicit.
auto rankedDestType = dyn_cast<MemRefType>(type);
if (!rankedDestType)
return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
[](TensorType t) -> std::optional<Attribute> {
return std::nullopt;
};
+ } else if (useEncodingForMemorySpace) {
+ opt.defaultMemorySpaceFn =
+ [](TensorType t) -> std::optional<Attribute> {
+ if (auto rtt = dyn_cast<RankedTensorType>(t))
+ return rtt.getEncoding();
+ return std::nullopt;
+ };
}
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
- mixedSizes, mixedStrides);
+ loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+ mixedOffsets, mixedSizes, mixedStrides);
replaceOpWithBufferizedValues(rewriter, op, subView);
return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
- mixedOffsets, mixedSizes, mixedStrides));
+ extractSliceOp.getType().getShape(),
+ llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+ mixedStrides));
}
};
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
- // TODO: Implement memory space for this op.
- if (options.defaultMemorySpaceFn(tensorType) != Attribute())
- return op->emitError("memory space not implemented yet");
+ std::optional<Attribute> memorySpace =
+ options.defaultMemorySpaceFn(tensorType);
// Allocate a buffer for the result.
Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
- auto memrefType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ FailureOr<BaseMemRefType> memrefType =
+ bufferization::getBufferType(*tensorAlloc, options);
+ if (failed(memrefType))
+ return failure();
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
- op->getLoc(), memrefType, *tensorAlloc);
+ op->getLoc(), *memrefType, *tensorAlloc);
// Case: tensor<0xelem_type>.
if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
%index_scalar = arith.index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
-// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32> to memref<index>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
// CHECK: %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
// CHECK: return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
// CHECK-NEXT: %[[clone:.*]] = bufferization.clone %[[m]]
// CHECK-NEXT: return %[[clone]]
func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
- %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+ %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
return %0 : memref<?x?x?xf16>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
// CHECK: return %[[ARG]] : memref<f32>
func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
// expected-error @+1 {{failed to legalize operation 'test.source'}}
%0 = "test.source"() : () -> tensor<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
// -----
func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
"test.sink"(%0) : (tensor<f32>) -> ()
return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
// TODO: to_memref with layout maps not supported yet. This should fold to a
// memref.cast.
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
- %0 = bufferization.to_tensor %m : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
// expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
- %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
// expected-note @+1 {{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
// -----
func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
- // expected-note @+1 {{prior use here}}
- %0 = bufferization.to_tensor %m : memref<*xf32>
- // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+ // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
- %m = bufferization.to_memref %t : memref<5xf32>
+ %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
- %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+ %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+ %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+ %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+ return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+// CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]
|
@llvm/pr-subscribers-mlir-gpu Author: Christopher Bate (christopherbate) ChangesAs mentioned in the issue described in issue llvm/llvm-project#91518, a previous However, introducing this feature exposed a unhandled edge cases, examples of which Fixing the inconsistencies introduced by
Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff 68 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
//===----------------------------------------------------------------------===//
// Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
- TypesMatchWith<"result type matches tensor equivalent of 'memref'",
- "memref", "result",
- "memref::getTensorTypeFromMemRefType($_self)">
+ AllElementTypesMatch<["memref", "result"]>
]> {
let summary = "create a tensor from a `memref`";
let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
let assemblyFormat = [{
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
- `:` type($memref)
+ `:` type($memref) `->` type($result)
}];
+ let builders = [
+ OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+ auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+ build($_builder, $_state, rtt, memref, restrict, writeable);
+ }]>
+ ];
+
let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
Pure,
TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
"memref", "tensor",
- "memref::getTensorTypeFromMemRefType($_self)">
+ "memref::getTensorTypeFromMemRefType($_self)",
+ "bufferization::detail::tensorTypesMatchUpToEncoding">
]> {
let summary = "cast a tensor to memref";
let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
}];
let assemblyFormat = [{
- $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
/*default=*/"false",
"The memory space of an memref types must always be inferred. If "
"unset, a default memory space of 0 is used otherwise.">,
+ Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+ /*default=*/"false",
+ "Use the Tensor encoding attribute for the memory space. Exclusive to"
+ " the 'must-infer-memory-space option'">,
Option<"testAnalysisOnly", "test-analysis-only", "bool",
/*default=*/"false",
"Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
// loose all of its users and eventually DCE away.
rewriter.setInsertionPointAfter(op);
replacement = rewriter.create<bufferization::ToTensorOp>(
- replacement.getLoc(), replacement);
+ replacement.getLoc(), opResult.getType(), replacement);
}
replacements.push_back(replacement);
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
// Helper functions
//===----------------------------------------------------------------------===//
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+ auto lhsType = cast<ShapedType>(lhs);
+ auto rhsType = cast<ShapedType>(rhs);
+ if (lhsType.getElementType() != rhsType.getElementType())
+ return false;
+ if (lhsType.hasRank() && rhsType.hasRank())
+ return lhsType.getShape() == rhsType.getShape();
+ return true;
+}
+
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
OpBuilder &b, Value value, MemRefType destType,
const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
+ // Unranked to ranked casts must be explicit.
+ if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+ return nullptr;
+
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
// MemRef to MemRef cast.
assert(inputType != type && "expected different types");
- // Unranked to ranked and ranked to unranked casts must be explicit.
+ // Ranked to unranked casts must be explicit.
auto rankedDestType = dyn_cast<MemRefType>(type);
if (!rankedDestType)
return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
[](TensorType t) -> std::optional<Attribute> {
return std::nullopt;
};
+ } else if (useEncodingForMemorySpace) {
+ opt.defaultMemorySpaceFn =
+ [](TensorType t) -> std::optional<Attribute> {
+ if (auto rtt = dyn_cast<RankedTensorType>(t))
+ return rtt.getEncoding();
+ return std::nullopt;
+ };
}
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
- mixedSizes, mixedStrides);
+ loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+ mixedOffsets, mixedSizes, mixedStrides);
replaceOpWithBufferizedValues(rewriter, op, subView);
return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
- mixedOffsets, mixedSizes, mixedStrides));
+ extractSliceOp.getType().getShape(),
+ llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+ mixedStrides));
}
};
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
- // TODO: Implement memory space for this op.
- if (options.defaultMemorySpaceFn(tensorType) != Attribute())
- return op->emitError("memory space not implemented yet");
+ std::optional<Attribute> memorySpace =
+ options.defaultMemorySpaceFn(tensorType);
// Allocate a buffer for the result.
Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
- auto memrefType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ FailureOr<BaseMemRefType> memrefType =
+ bufferization::getBufferType(*tensorAlloc, options);
+ if (failed(memrefType))
+ return failure();
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
- op->getLoc(), memrefType, *tensorAlloc);
+ op->getLoc(), *memrefType, *tensorAlloc);
// Case: tensor<0xelem_type>.
if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
%index_scalar = arith.index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
-// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32> to memref<index>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
// CHECK: %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
// CHECK: return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
// CHECK-NEXT: %[[clone:.*]] = bufferization.clone %[[m]]
// CHECK-NEXT: return %[[clone]]
func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
- %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+ %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
return %0 : memref<?x?x?xf16>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
// CHECK: return %[[ARG]] : memref<f32>
func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
// expected-error @+1 {{failed to legalize operation 'test.source'}}
%0 = "test.source"() : () -> tensor<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
// -----
func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
"test.sink"(%0) : (tensor<f32>) -> ()
return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
// TODO: to_memref with layout maps not supported yet. This should fold to a
// memref.cast.
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
- %0 = bufferization.to_tensor %m : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
// expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
- %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
// expected-note @+1 {{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
// -----
func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
- // expected-note @+1 {{prior use here}}
- %0 = bufferization.to_tensor %m : memref<*xf32>
- // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+ // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
- %m = bufferization.to_memref %t : memref<5xf32>
+ %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
- %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+ %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+ %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+ %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+ return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+// CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]
|
@llvm/pr-subscribers-mlir-bufferization Author: Christopher Bate (christopherbate) ChangesAs mentioned in the issue described in issue llvm/llvm-project#91518, a previous However, introducing this feature exposed a unhandled edge cases, examples of which Fixing the inconsistencies introduced by
Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff 68 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
//===----------------------------------------------------------------------===//
// Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
- TypesMatchWith<"result type matches tensor equivalent of 'memref'",
- "memref", "result",
- "memref::getTensorTypeFromMemRefType($_self)">
+ AllElementTypesMatch<["memref", "result"]>
]> {
let summary = "create a tensor from a `memref`";
let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
let assemblyFormat = [{
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
- `:` type($memref)
+ `:` type($memref) `->` type($result)
}];
+ let builders = [
+ OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+ auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+ build($_builder, $_state, rtt, memref, restrict, writeable);
+ }]>
+ ];
+
let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
Pure,
TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
"memref", "tensor",
- "memref::getTensorTypeFromMemRefType($_self)">
+ "memref::getTensorTypeFromMemRefType($_self)",
+ "bufferization::detail::tensorTypesMatchUpToEncoding">
]> {
let summary = "cast a tensor to memref";
let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
}];
let assemblyFormat = [{
- $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
/*default=*/"false",
"The memory space of an memref types must always be inferred. If "
"unset, a default memory space of 0 is used otherwise.">,
+ Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+ /*default=*/"false",
+ "Use the Tensor encoding attribute for the memory space. Exclusive to"
+ " the 'must-infer-memory-space option'">,
Option<"testAnalysisOnly", "test-analysis-only", "bool",
/*default=*/"false",
"Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
// loose all of its users and eventually DCE away.
rewriter.setInsertionPointAfter(op);
replacement = rewriter.create<bufferization::ToTensorOp>(
- replacement.getLoc(), replacement);
+ replacement.getLoc(), opResult.getType(), replacement);
}
replacements.push_back(replacement);
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
// Helper functions
//===----------------------------------------------------------------------===//
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+ auto lhsType = cast<ShapedType>(lhs);
+ auto rhsType = cast<ShapedType>(rhs);
+ if (lhsType.getElementType() != rhsType.getElementType())
+ return false;
+ if (lhsType.hasRank() && rhsType.hasRank())
+ return lhsType.getShape() == rhsType.getShape();
+ return true;
+}
+
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
OpBuilder &b, Value value, MemRefType destType,
const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
+ // Unranked to ranked casts must be explicit.
+ if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+ return nullptr;
+
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
// MemRef to MemRef cast.
assert(inputType != type && "expected different types");
- // Unranked to ranked and ranked to unranked casts must be explicit.
+ // Ranked to unranked casts must be explicit.
auto rankedDestType = dyn_cast<MemRefType>(type);
if (!rankedDestType)
return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
[](TensorType t) -> std::optional<Attribute> {
return std::nullopt;
};
+ } else if (useEncodingForMemorySpace) {
+ opt.defaultMemorySpaceFn =
+ [](TensorType t) -> std::optional<Attribute> {
+ if (auto rtt = dyn_cast<RankedTensorType>(t))
+ return rtt.getEncoding();
+ return std::nullopt;
+ };
}
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
- mixedSizes, mixedStrides);
+ loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+ mixedOffsets, mixedSizes, mixedStrides);
replaceOpWithBufferizedValues(rewriter, op, subView);
return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
- mixedOffsets, mixedSizes, mixedStrides));
+ extractSliceOp.getType().getShape(),
+ llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+ mixedStrides));
}
};
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
- // TODO: Implement memory space for this op.
- if (options.defaultMemorySpaceFn(tensorType) != Attribute())
- return op->emitError("memory space not implemented yet");
+ std::optional<Attribute> memorySpace =
+ options.defaultMemorySpaceFn(tensorType);
// Allocate a buffer for the result.
Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
- auto memrefType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ FailureOr<BaseMemRefType> memrefType =
+ bufferization::getBufferType(*tensorAlloc, options);
+ if (failed(memrefType))
+ return failure();
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
- op->getLoc(), memrefType, *tensorAlloc);
+ op->getLoc(), *memrefType, *tensorAlloc);
// Case: tensor<0xelem_type>.
if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
%index_scalar = arith.index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
-// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32> to memref<index>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
// CHECK: %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
// CHECK: return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
// CHECK-NEXT: %[[clone:.*]] = bufferization.clone %[[m]]
// CHECK-NEXT: return %[[clone]]
func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
- %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+ %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
return %0 : memref<?x?x?xf16>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
// CHECK: return %[[ARG]] : memref<f32>
func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
// expected-error @+1 {{failed to legalize operation 'test.source'}}
%0 = "test.source"() : () -> tensor<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
// -----
func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
"test.sink"(%0) : (tensor<f32>) -> ()
return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
// TODO: to_memref with layout maps not supported yet. This should fold to a
// memref.cast.
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
- %0 = bufferization.to_tensor %m : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
// expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
- %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
// expected-note @+1 {{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
// -----
func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
- // expected-note @+1 {{prior use here}}
- %0 = bufferization.to_tensor %m : memref<*xf32>
- // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+ // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
- %m = bufferization.to_memref %t : memref<5xf32>
+ %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
- %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+ %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+ %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+ %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+ return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+// CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]
|
@llvm/pr-subscribers-mlir-func Author: Christopher Bate (christopherbate) ChangesAs mentioned in the issue described in issue llvm/llvm-project#91518, a previous However, introducing this feature exposed a unhandled edge cases, examples of which Fixing the inconsistencies introduced by
Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff 68 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
//===----------------------------------------------------------------------===//
// Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
- TypesMatchWith<"result type matches tensor equivalent of 'memref'",
- "memref", "result",
- "memref::getTensorTypeFromMemRefType($_self)">
+ AllElementTypesMatch<["memref", "result"]>
]> {
let summary = "create a tensor from a `memref`";
let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
let assemblyFormat = [{
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
- `:` type($memref)
+ `:` type($memref) `->` type($result)
}];
+ let builders = [
+ OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+ auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+ build($_builder, $_state, rtt, memref, restrict, writeable);
+ }]>
+ ];
+
let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
Pure,
TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
"memref", "tensor",
- "memref::getTensorTypeFromMemRefType($_self)">
+ "memref::getTensorTypeFromMemRefType($_self)",
+ "bufferization::detail::tensorTypesMatchUpToEncoding">
]> {
let summary = "cast a tensor to memref";
let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
}];
let assemblyFormat = [{
- $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
/*default=*/"false",
"The memory space of an memref types must always be inferred. If "
"unset, a default memory space of 0 is used otherwise.">,
+ Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+ /*default=*/"false",
+ "Use the Tensor encoding attribute for the memory space. Exclusive to"
+ " the 'must-infer-memory-space option'">,
Option<"testAnalysisOnly", "test-analysis-only", "bool",
/*default=*/"false",
"Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
// loose all of its users and eventually DCE away.
rewriter.setInsertionPointAfter(op);
replacement = rewriter.create<bufferization::ToTensorOp>(
- replacement.getLoc(), replacement);
+ replacement.getLoc(), opResult.getType(), replacement);
}
replacements.push_back(replacement);
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
// Helper functions
//===----------------------------------------------------------------------===//
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+ auto lhsType = cast<ShapedType>(lhs);
+ auto rhsType = cast<ShapedType>(rhs);
+ if (lhsType.getElementType() != rhsType.getElementType())
+ return false;
+ if (lhsType.hasRank() && rhsType.hasRank())
+ return lhsType.getShape() == rhsType.getShape();
+ return true;
+}
+
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
OpBuilder &b, Value value, MemRefType destType,
const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
+ // Unranked to ranked casts must be explicit.
+ if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+ return nullptr;
+
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
// MemRef to MemRef cast.
assert(inputType != type && "expected different types");
- // Unranked to ranked and ranked to unranked casts must be explicit.
+ // Ranked to unranked casts must be explicit.
auto rankedDestType = dyn_cast<MemRefType>(type);
if (!rankedDestType)
return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
[](TensorType t) -> std::optional<Attribute> {
return std::nullopt;
};
+ } else if (useEncodingForMemorySpace) {
+ opt.defaultMemorySpaceFn =
+ [](TensorType t) -> std::optional<Attribute> {
+ if (auto rtt = dyn_cast<RankedTensorType>(t))
+ return rtt.getEncoding();
+ return std::nullopt;
+ };
}
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
- mixedSizes, mixedStrides);
+ loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+ mixedOffsets, mixedSizes, mixedStrides);
replaceOpWithBufferizedValues(rewriter, op, subView);
return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
- mixedOffsets, mixedSizes, mixedStrides));
+ extractSliceOp.getType().getShape(),
+ llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+ mixedStrides));
}
};
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
- // TODO: Implement memory space for this op.
- if (options.defaultMemorySpaceFn(tensorType) != Attribute())
- return op->emitError("memory space not implemented yet");
+ std::optional<Attribute> memorySpace =
+ options.defaultMemorySpaceFn(tensorType);
// Allocate a buffer for the result.
Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
- auto memrefType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ FailureOr<BaseMemRefType> memrefType =
+ bufferization::getBufferType(*tensorAlloc, options);
+ if (failed(memrefType))
+ return failure();
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
- op->getLoc(), memrefType, *tensorAlloc);
+ op->getLoc(), *memrefType, *tensorAlloc);
// Case: tensor<0xelem_type>.
if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
%index_scalar = arith.index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
-// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32> to memref<index>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
// CHECK: %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
// CHECK: return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
// CHECK-NEXT: %[[clone:.*]] = bufferization.clone %[[m]]
// CHECK-NEXT: return %[[clone]]
func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
- %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+ %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
return %0 : memref<?x?x?xf16>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
// CHECK: return %[[ARG]] : memref<f32>
func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
// expected-error @+1 {{failed to legalize operation 'test.source'}}
%0 = "test.source"() : () -> tensor<f32>
- %1 = bufferization.to_memref %0 : memref<f32>
+ %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
return %1 : memref<f32>
}
// -----
func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
- %0 = bufferization.to_tensor %arg0 : memref<f32>
+ %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
"test.sink"(%0) : (tensor<f32>) -> ()
return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
- %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
// TODO: to_memref with layout maps not supported yet. This should fold to a
// memref.cast.
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
- %0 = bufferization.to_tensor %m : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
// expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
- %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
// expected-note @+1 {{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
// -----
func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
- // expected-note @+1 {{prior use here}}
- %0 = bufferization.to_tensor %m : memref<*xf32>
- // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
- %1 = bufferization.to_memref %0 : memref<?xf32>
+ %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+ // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
return %1 : memref<?xf32>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
- %m = bufferization.to_memref %t : memref<5xf32>
+ %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
// Some op may write into the result of to_memref later.
// CHECK: bufferization.to_memref
// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
- %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+ %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
%2 = tensor.extract %t[%idx] : tensor<5xf32>
return %2 : f32
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+ %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+ %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+ return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+// CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+// CHECK: memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]
|
@matthias-springer Can you review? The vast majority of changes here are just to update tests due to change in assembly formats for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add extra unit tests for scf.for
(with tensor iter_arg that has an encoding), scf.if
and scf.execute_region
(with tensor result). I think these bufferization pattern must be updated because they create new to_tensor
ops.
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
Outdated
Show resolved
Hide resolved
// ----- | ||
|
||
func.func @alloc_tesor_copy_from_non_default_space(%arg0: tensor<128xf32, 1>) -> tensor<128xf32, 2> { | ||
%0 = bufferization.alloc_tensor() copy(%arg0) {memory_space = 2 : i64} : tensor<128xf32, 1> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the meaning of this op? The two memory spaces here are inconsistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since having a copy
generates a memref.copy
, the intent is to copy data between memory spaces. However, bufferization.alloc_tensor
only prints one type for both copy
operand and the result. Therefore we will need to update ASM format for bufferization.alloc_tensor
as well. That will enable dropping the tensor.cast
from this test.
Ideally, to make this PR not even larger, if an update to these patterns is indeed needed, you could update those patterns in a first, separate PR. I think not much has to change, you just have to take the tensor types from the input IR instead of inferring them in the op builder. |
lgtm for all sparse test signature changes |
f90c2f1
to
543e5a0
Compare
543e5a0
to
7780f90
Compare
is this PR dead? wondering why it's been sitting here approved for months |
It's not dead; I have just been pressed for time. Let me rebase and address Matthias' last comments this week. |
Hi, this PR is useful for me, do you have any time to rebase it. Thanks. |
ce6c22d
to
c627d06
Compare
38cc33f
to
4d81769
Compare
…` is used As described in issue llvm#91518, a previous PR llvm#78484 introduced the `defaultMemorySpaceFn` into bufferization options, allowing one to inform OneShotBufferize that it should use a specified function to derive the memory space attribute from the encoding attribute attached to tensor types. However, introducing this feature exposed unhandled edge cases, examples of which are introduced by this change in the new test under `test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir`. Fixing the inconsistencies introduced by `defaultMemorySpaceFn` is pretty simple. This change: - Updates the `bufferization.to_memref` and `bufferization.to_tensor` operations to explicitly include operand and destination types, whereas previously they relied on type inference to deduce the tensor types. Since the type inference cannot recover the correct tensor encoding/memory space, the operand and result types must be explicitly included. This is a small assembly format change, but it touches a large number of test files. - Makes minor updates to other bufferization functions to handle the changes in building the above ops. - Updates bufferization of `tensor.from_elements` to handle memory space.
92d486c
to
bbe57de
Compare
As described in issue #91518, a previous PR
#78484 introduced the
defaultMemorySpaceFn
intobufferization options, allowing one to inform OneShotBufferize that it
should use a specified function to derive the memory space attribute
from the encoding attribute attached to tensor types.
However, introducing this feature exposed unhandled edge cases,
examples of which are introduced by this change in the new test under
test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
.Fixing the inconsistencies introduced by
defaultMemorySpaceFn
is prettysimple. This change:
Updates the
bufferization.to_memref
andbufferization.to_tensor
operations to explicitly include operand and destination types,
whereas previously they relied on type inference to deduce the
tensor types. Since the type inference cannot recover the correct
tensor encoding/memory space, the operand and result types must be
explicitly included. This is a small assembly format change, but it
touches a large number of test files.
Makes minor updates to other bufferization functions to handle the
changes in building the above ops.
Updates bufferization of
tensor.from_elements
to handle memoryspace.
Integration/upgrade guide:
In downstream projects, if you have tests or MLIR files that explicitly use
bufferization.to_tensor
orbufferization.to_memref
, then updatethem to the new assembly format as follows:
becomes