-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][bufferization] Transfer restrict
during empty tensor elimination
#68729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Transfer restrict
during empty tensor elimination
#68729
Conversation
…tion Empty tensor elimination is looking for `bufferization.materialize_in_destination` ops with a `tensor.empty` source. It replaces the `tensor.empty` with a `bufferization.to_tensor restrict` of the memref destination. As part of this rewrite, the `restrict` keyword should be removed, so that no second `to_tensor restrict` op will be inserted. Such IR would be invalid. Also relex the verifier of `materialize_in_destination`. The `restrict` keyword is not generally needed because the op does not expose the buffer as a tensor.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-bufferization Author: Matthias Springer (matthias-springer) ChangesEmpty tensor elimination is looking for Also relax the verifier of Full diff: https://github.com/llvm/llvm-project/pull/68729.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 34a6f5d74b13956..c779d1f843d76a0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -246,28 +246,32 @@ def Bufferization_MaterializeInDestinationOp
rewrite IR such that a computation is performed directly in `dest` and no
memcpy is needed.
- If `dest` is a buffer, the `restrict` and `writable` attributes must be
- specified. These attributes have the same meaning as the respective
- attributes of `bufferization.to_tensor`. `writable` indicates that the
- `dest` buffer is considered writable. It does not make sense to materialize
- a computation in a read-only buffer, so `writable` is required. `restrict`
- indicates that this op is the only way for the tensor IR to access `dest`
- (or an alias thereof). E.g., there must be no other `to_tensor` ops with
- `dest` or with an alias of `dest`. Such IR is not supported by
- One-Shot Bufferize. Ops that have incorrect usage of `restrict` may
- bufferize incorrectly.
-
- Note: `restrict` and `writable` could be removed from this op because they
- must always be set for memref destinations. This op has these attributes to
- make clear the requirements on the `dest` operand in the op assembly format.
- Moreover, these requirements may be relaxed at some point in the future.
+ If `dest` is a buffer, the `writable` attribute must be specified and the
+ `restrict` keyword can be specified. These attributes have the same meaning
+ as the respective attributes of `bufferization.to_tensor`.
+
+ `writable` indicates that the `dest` buffer is considered writable. It does
+ not make sense to materialize a computation in a read-only buffer, so
+ `writable` is required.
+
+ `restrict` indicates that there is no `bufferization.to_tensor` op and no
+ other `bufferization.materialize_in_destination` op with `dest` (or an alias
+ thereof) and "restrict". Only ops with this attribute are considered for
+ "empty tensor elimination". As part of empty tensor elimination, a new
+ `to_tensor` op with `dest` may be inserted and the `restrict` attribute is
+ transferred from this op to the new `to_tensor` op. Having "restrict" on
+ this op guarantees that performing empty tensor elimination would not create
+ invalid IR (i.e., having multiple `to_tensor restrict` with aliasing
+ buffers).
+
+ Note: `writable` could be removed from this op because it must always be set
+ for memref destinations. This op has that attribute to make clear the
+ requirements on the `dest` operand in the op assembly format.
Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
same purpose, but since tensor dialect ops only indicate *what* should be
computed but not *where*, it could fold away, causing the computation to
- materialize in a different buffer. It is also possible that the
- `tensor.insert_slice` destination bufferizes out-of-place, which would also
- cause the computation to materialize in a buffer different buffer.
+ materialize in a different buffer.
}];
let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 738c8374d7add03..5716dcc9d905016 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -613,11 +613,19 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
return getDest();
}
+ // The "restrict" attribute is transferred from this op to the newly created
+ // to_tensor op. If this op does not the "restrict" attribute, the subset
+ // extraction cannot be built because there is no guarantee that there is no
+ // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
+ if (!getRestrict())
+ return {};
+
// Build a bufferization.to_tensor op.
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
assert(getRestrict() &&
"expected that ops with memrefs dest have 'restrict'");
- return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
+ setRestrict(false);
+ return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
getWritable());
}
@@ -647,9 +655,8 @@ LogicalResult MaterializeInDestinationOp::verify() {
if (isa<BaseMemRefType>(getDest().getType()) &&
getOperation()->getNumResults() != 0)
return emitOpError("memref 'dest' implies zero results");
- if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
- return emitOpError("'restrict' must be specified if and only if the "
- "destination is of memref type");
+ if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
+ return emitOpError("'restrict' is valid only for memref destinations");
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'writable' must be specified if and only if the "
"destination is of memref type");
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 99b974b9ef3c67e..9a3e14b6d391782 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries" -cse -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -eliminate-empty-tensors | FileCheck %s --check-prefix=CHECK-ELIM
// CHECK: func @buffer_forwarding_conflict(
// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
@@ -341,3 +342,26 @@ func.func @linalg_copy_empty() -> tensor<26xi32> {
%1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32>
return %1 : tensor<26xi32>
}
+
+// -----
+
+// CHECK-ELIM-LABEL: func @multiple_materialize_in_destination_buffer(
+// CHECK-ELIM-SAME: %[[m:.*]]: memref<5xf32>
+// CHECK-ELIM: tensor.empty
+// CHECK-ELIM: bufferization.to_tensor %[[m]] restrict writable
+// CHECK-ELIM: bufferization.materialize_in_destination {{.*}} in writable %[[m]]
+func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32, %f2: f32, %c: i1) {
+ %0 = tensor.empty() : tensor<5xf32>
+ %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+
+ %1 = tensor.empty() : tensor<5xf32>
+ %filled2 = linalg.fill ins(%f2 : f32) outs(%1 : tensor<5xf32>) -> tensor<5xf32>
+
+ %selected = scf.if %c -> tensor<5xf32> {
+ scf.yield %filled : tensor<5xf32>
+ } else {
+ scf.yield %filled2 : tensor<5xf32>
+ }
+ bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
+ return
+}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index ce56f89c1f1bbe6..996d8430b84d48b 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -80,13 +80,6 @@ func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %a
// -----
-func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
- // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
- bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
-}
-
-// -----
-
func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
// expected-error @below{{memref 'dest' implies zero results}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
@@ -102,7 +95,7 @@ func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32
// -----
func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
- // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+ // expected-error @below{{'restrict' is valid only for memref destinations}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
}
|
Empty tensor elimination is looking for
bufferization.materialize_in_destination
ops with atensor.empty
source. It replaces thetensor.empty
with abufferization.to_tensor restrict
of the memref destination. As part of this rewrite, therestrict
keyword should be removed, so that no secondto_tensor restrict
op will be inserted. Such IR would be invalid.bufferization.materialize_in_destination
with memref destination and without therestrict
attribute are ignored by empty tensor elimination.Also relax the verifier of
materialize_in_destination
. Therestrict
keyword is not generally needed because the op does not expose the buffer as a tensor.