Skip to content

[mlir][bufferization] Transfer restrict during empty tensor elimination #68729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -246,28 +246,32 @@ def Bufferization_MaterializeInDestinationOp
rewrite IR such that a computation is performed directly in `dest` and no
memcpy is needed.

If `dest` is a buffer, the `restrict` and `writable` attributes must be
specified. These attributes have the same meaning as the respective
attributes of `bufferization.to_tensor`. `writable` indicates that the
`dest` buffer is considered writable. It does not make sense to materialize
a computation in a read-only buffer, so `writable` is required. `restrict`
indicates that this op is the only way for the tensor IR to access `dest`
(or an alias thereof). E.g., there must be no other `to_tensor` ops with
`dest` or with an alias of `dest`. Such IR is not supported by
One-Shot Bufferize. Ops that have incorrect usage of `restrict` may
bufferize incorrectly.

Note: `restrict` and `writable` could be removed from this op because they
must always be set for memref destinations. This op has these attributes to
make clear the requirements on the `dest` operand in the op assembly format.
Moreover, these requirements may be relaxed at some point in the future.
If `dest` is a buffer, the `writable` attribute must be specified and the
`restrict` keyword can be specified. These attributes have the same meaning
as the respective attributes of `bufferization.to_tensor`.

`writable` indicates that the `dest` buffer is considered writable. It does
not make sense to materialize a computation in a read-only buffer, so
`writable` is required.

`restrict` indicates that there is no `bufferization.to_tensor` op and no
other `bufferization.materialize_in_destination` op with `dest` (or an alias
thereof) and "restrict". Only ops with this attribute are considered for
"empty tensor elimination". As part of empty tensor elimination, a new
`to_tensor` op with `dest` may be inserted and the `restrict` attribute is
transferred from this op to the new `to_tensor` op. Having "restrict" on
this op guarantees that performing empty tensor elimination would not create
invalid IR (i.e., having multiple `to_tensor restrict` with aliasing
buffers).

Note: `writable` could be removed from this op because it must always be set
for memref destinations. This op has that attribute to make clear the
requirements on the `dest` operand in the op assembly format.

Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
same purpose, but since tensor dialect ops only indicate *what* should be
computed but not *where*, it could fold away, causing the computation to
materialize in a different buffer. It is also possible that the
`tensor.insert_slice` destination bufferizes out-of-place, which would also
cause the computation to materialize in a buffer different buffer.
materialize in a different buffer.
}];

let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
Expand Down
15 changes: 11 additions & 4 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,19 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
return getDest();
}

// The "restrict" attribute is transferred from this op to the newly created
// to_tensor op. If this op does not the "restrict" attribute, the subset
// extraction cannot be built because there is no guarantee that there is no
// pre-existing "restrict" to_tensor op with the same/an aliasing destination.
if (!getRestrict())
return {};

// Build a bufferization.to_tensor op.
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
assert(getRestrict() &&
"expected that ops with memrefs dest have 'restrict'");
return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
setRestrict(false);
return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
getWritable());
}

Expand Down Expand Up @@ -647,9 +655,8 @@ LogicalResult MaterializeInDestinationOp::verify() {
if (isa<BaseMemRefType>(getDest().getType()) &&
getOperation()->getNumResults() != 0)
return emitOpError("memref 'dest' implies zero results");
if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'restrict' must be specified if and only if the "
"destination is of memref type");
if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'restrict' is valid only for memref destinations");
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'writable' must be specified if and only if the "
"destination is of memref type");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries" -cse -canonicalize -split-input-file | FileCheck %s
// RUN: mlir-opt %s -eliminate-empty-tensors | FileCheck %s --check-prefix=CHECK-ELIM

// CHECK: func @buffer_forwarding_conflict(
// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
Expand Down Expand Up @@ -341,3 +342,26 @@ func.func @linalg_copy_empty() -> tensor<26xi32> {
%1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32>
return %1 : tensor<26xi32>
}

// -----

// CHECK-ELIM-LABEL: func @multiple_materialize_in_destination_buffer(
// CHECK-ELIM-SAME: %[[m:.*]]: memref<5xf32>
// CHECK-ELIM: tensor.empty
// CHECK-ELIM: bufferization.to_tensor %[[m]] restrict writable
// CHECK-ELIM: bufferization.materialize_in_destination {{.*}} in writable %[[m]]
func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32, %f2: f32, %c: i1) {
%0 = tensor.empty() : tensor<5xf32>
%filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>

%1 = tensor.empty() : tensor<5xf32>
%filled2 = linalg.fill ins(%f2 : f32) outs(%1 : tensor<5xf32>) -> tensor<5xf32>

%selected = scf.if %c -> tensor<5xf32> {
scf.yield %filled : tensor<5xf32>
} else {
scf.yield %filled2 : tensor<5xf32>
}
bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
return
}
9 changes: 1 addition & 8 deletions mlir/test/Dialect/Bufferization/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,6 @@ func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %a

// -----

func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
// expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
}

// -----

func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
// expected-error @below{{memref 'dest' implies zero results}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
Expand All @@ -102,7 +95,7 @@ func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32
// -----

func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
// expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
// expected-error @below{{'restrict' is valid only for memref destinations}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
}

Expand Down