Skip to content

Commit 6d88ac1

Browse files
[mlir][bufferization] Transfer restrict during empty tensor elimination (#68729)
Empty tensor elimination is looking for `bufferization.materialize_in_destination` ops with a `tensor.empty` source. It replaces the `tensor.empty` with a `bufferization.to_tensor restrict` of the memref destination. As part of this rewrite, the `restrict` keyword should be removed, so that no second `to_tensor restrict` op will be inserted. Such IR would be invalid. `bufferization.materialize_in_destination` with memref destination and without the `restrict` attribute are ignored by empty tensor elimination. Also relax the verifier of `materialize_in_destination`. The `restrict` keyword is not generally needed because the op does not expose the buffer as a tensor.
1 parent a12e747 commit 6d88ac1

File tree

4 files changed

+58
-30
lines changed

4 files changed

+58
-30
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -246,28 +246,32 @@ def Bufferization_MaterializeInDestinationOp
246246
rewrite IR such that a computation is performed directly in `dest` and no
247247
memcpy is needed.
248248

249-
If `dest` is a buffer, the `restrict` and `writable` attributes must be
250-
specified. These attributes have the same meaning as the respective
251-
attributes of `bufferization.to_tensor`. `writable` indicates that the
252-
`dest` buffer is considered writable. It does not make sense to materialize
253-
a computation in a read-only buffer, so `writable` is required. `restrict`
254-
indicates that this op is the only way for the tensor IR to access `dest`
255-
(or an alias thereof). E.g., there must be no other `to_tensor` ops with
256-
`dest` or with an alias of `dest`. Such IR is not supported by
257-
One-Shot Bufferize. Ops that have incorrect usage of `restrict` may
258-
bufferize incorrectly.
259-
260-
Note: `restrict` and `writable` could be removed from this op because they
261-
must always be set for memref destinations. This op has these attributes to
262-
make clear the requirements on the `dest` operand in the op assembly format.
263-
Moreover, these requirements may be relaxed at some point in the future.
249+
If `dest` is a buffer, the `writable` attribute must be specified and the
250+
`restrict` keyword can be specified. These attributes have the same meaning
251+
as the respective attributes of `bufferization.to_tensor`.
252+
253+
`writable` indicates that the `dest` buffer is considered writable. It does
254+
not make sense to materialize a computation in a read-only buffer, so
255+
`writable` is required.
256+
257+
`restrict` indicates that there is no `bufferization.to_tensor` op and no
258+
other `bufferization.materialize_in_destination` op with `dest` (or an alias
259+
thereof) and "restrict". Only ops with this attribute are considered for
260+
"empty tensor elimination". As part of empty tensor elimination, a new
261+
`to_tensor` op with `dest` may be inserted and the `restrict` attribute is
262+
transferred from this op to the new `to_tensor` op. Having "restrict" on
263+
this op guarantees that performing empty tensor elimination would not create
264+
invalid IR (i.e., having multiple `to_tensor restrict` with aliasing
265+
buffers).
266+
267+
Note: `writable` could be removed from this op because it must always be set
268+
for memref destinations. This op has that attribute to make clear the
269+
requirements on the `dest` operand in the op assembly format.
264270

265271
Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
266272
same purpose, but since tensor dialect ops only indicate *what* should be
267273
computed but not *where*, it could fold away, causing the computation to
268-
materialize in a different buffer. It is also possible that the
269-
`tensor.insert_slice` destination bufferizes out-of-place, which would also
270-
cause the computation to materialize in a buffer different buffer.
274+
materialize in a different buffer.
271275
}];
272276

273277
let arguments = (ins AnyTensor:$source, AnyShaped:$dest,

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -613,11 +613,19 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
613613
return getDest();
614614
}
615615

616+
// The "restrict" attribute is transferred from this op to the newly created
617+
// to_tensor op. If this op does not the "restrict" attribute, the subset
618+
// extraction cannot be built because there is no guarantee that there is no
619+
// pre-existing "restrict" to_tensor op with the same/an aliasing destination.
620+
if (!getRestrict())
621+
return {};
622+
616623
// Build a bufferization.to_tensor op.
617624
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
618625
assert(getRestrict() &&
619626
"expected that ops with memrefs dest have 'restrict'");
620-
return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
627+
setRestrict(false);
628+
return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
621629
getWritable());
622630
}
623631

@@ -647,9 +655,8 @@ LogicalResult MaterializeInDestinationOp::verify() {
647655
if (isa<BaseMemRefType>(getDest().getType()) &&
648656
getOperation()->getNumResults() != 0)
649657
return emitOpError("memref 'dest' implies zero results");
650-
if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
651-
return emitOpError("'restrict' must be specified if and only if the "
652-
"destination is of memref type");
658+
if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
659+
return emitOpError("'restrict' is valid only for memref destinations");
653660
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
654661
return emitOpError("'writable' must be specified if and only if the "
655662
"destination is of memref type");

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries" -cse -canonicalize -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -eliminate-empty-tensors | FileCheck %s --check-prefix=CHECK-ELIM
23

34
// CHECK: func @buffer_forwarding_conflict(
45
// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
@@ -341,3 +342,26 @@ func.func @linalg_copy_empty() -> tensor<26xi32> {
341342
%1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32>
342343
return %1 : tensor<26xi32>
343344
}
345+
346+
// -----
347+
348+
// CHECK-ELIM-LABEL: func @multiple_materialize_in_destination_buffer(
349+
// CHECK-ELIM-SAME: %[[m:.*]]: memref<5xf32>
350+
// CHECK-ELIM: tensor.empty
351+
// CHECK-ELIM: bufferization.to_tensor %[[m]] restrict writable
352+
// CHECK-ELIM: bufferization.materialize_in_destination {{.*}} in writable %[[m]]
353+
func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32, %f2: f32, %c: i1) {
354+
%0 = tensor.empty() : tensor<5xf32>
355+
%filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
356+
357+
%1 = tensor.empty() : tensor<5xf32>
358+
%filled2 = linalg.fill ins(%f2 : f32) outs(%1 : tensor<5xf32>) -> tensor<5xf32>
359+
360+
%selected = scf.if %c -> tensor<5xf32> {
361+
scf.yield %filled : tensor<5xf32>
362+
} else {
363+
scf.yield %filled2 : tensor<5xf32>
364+
}
365+
bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
366+
return
367+
}

mlir/test/Dialect/Bufferization/invalid.mlir

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,6 @@ func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %a
8080

8181
// -----
8282

83-
func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
84-
// expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
85-
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
86-
}
87-
88-
// -----
89-
9083
func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
9184
// expected-error @below{{memref 'dest' implies zero results}}
9285
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
@@ -102,7 +95,7 @@ func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32
10295
// -----
10396

10497
func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
105-
// expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
98+
// expected-error @below{{'restrict' is valid only for memref destinations}}
10699
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
107100
}
108101

0 commit comments

Comments
 (0)