Skip to content

Commit 08aa956

Browse files
authored
[mlir][bufferization]-Replace only one use in TensorEmptyElimination (#118958)
In many cases the emptyTensorElimination can not transform or eliminate the empty tensor which is being inserted into the `SubsetInsertionOpInterface`. Two major reasons for that: 1- Failing when trying to find a legal/suitable insertion point for the `subsetExtract` which is about to replace the empty tensor. However, we may try to handle this issue by moving the needed values which responsible on building the `subsetExtract` nearby the empty tensor (which is about to be eliminated). Thus increasing the probability to find a legal insertion point. 2-The EmptyTensorElimination transform replaces the tensor.empty's uses all at once in one apply, rather than replacing only the specific use which was visited in the use-def chain (when traversing from the tensor.insert_slice). This scenario of replacing all the uses of the tensor.empty may lead into additional read effects after bufferization of the specific subset extract/subview which should not be the case. Both cases may result in many copies in the coming bufferization which can not be canonicalized. The first case can be noticed when having a `tensor.empty` followed by `SubsetInsertionOpInterface` (or in simple words `tensor.insert_slice`), which have been lowered from `tensor/tosa.concat`. The second case can be noticed when having a `tensor.empty`, with many uses and leading to applying the transformation only once, since the whole uses have been replaced at once. The first commit in the PR only adds the lit tests for the cases shown above (NFC), to emphasize how the transform works, in the coming MRs will upload a slight changes to handle these case. The second commit in this PR, we want to replace only the specific use which was visited in the `use-def` chain (when traversing from the `tensor.insert_slice`'s source).
1 parent 6457aee commit 08aa956

File tree

5 files changed

+143
-28
lines changed

5 files changed

+143
-28
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ class AnalysisState {
459459
/// Starting from `value`, follow the use-def chain in reverse, always
460460
/// selecting the aliasing OpOperands. Find and return Values for which
461461
/// `condition` evaluates to true. OpOperands of such matching Values are not
462-
/// traversed any further.
462+
/// traversed any further, the visited aliasing opOperands will be preserved
463+
/// through `visitedOpOperands`.
463464
///
464465
/// When reaching the end of a chain, also return the last Value of that
465466
/// chain if `config.alwaysIncludeLeaves` is set.
@@ -484,7 +485,8 @@ class AnalysisState {
484485
/// `config`.
485486
SetVector<Value> findValueInReverseUseDefChain(
486487
Value value, llvm::function_ref<bool(Value)> condition,
487-
TraversalConfig config = TraversalConfig()) const;
488+
TraversalConfig config = TraversalConfig(),
489+
llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
488490

489491
/// Find the values that may define the contents of the given value at
490492
/// runtime. A block argument is always a definition. An OpResult is a

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,12 @@ bool AnalysisState::isValueRead(Value value) const {
483483
// Starting from `value`, follow the use-def chain in reverse, always selecting
484484
// the aliasing OpOperands. Find and return Values for which `condition`
485485
// evaluates to true. OpOperands of such matching Values are not traversed any
486-
// further.
486+
// further, the visited aliasing opOperands will be preserved through
487+
// `visitedOpOperands`.
487488
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
488489
Value value, llvm::function_ref<bool(Value)> condition,
489-
TraversalConfig config) const {
490+
TraversalConfig config,
491+
llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
490492
llvm::DenseSet<Value> visited;
491493
llvm::SetVector<Value> result, workingSet;
492494
workingSet.insert(value);
@@ -553,6 +555,8 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
553555
}
554556

555557
workingSet.insert(a.opOperand->get());
558+
if (visitedOpOperands)
559+
visitedOpOperands->insert(a.opOperand);
556560
}
557561
}
558562

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,20 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
4848
return true;
4949
}
5050

51-
/// Return true if the given `insertionPoint` dominates all uses of
52-
/// `emptyTensorOp`.
53-
static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
54-
Operation *insertionPoint,
55-
Operation *emptyTensorOp) {
56-
return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
57-
return domInfo.dominates(insertionPoint, user);
58-
});
59-
}
60-
61-
/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
62-
/// that the replacement may use any value from `neededValues`.
51+
/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
52+
/// use of `user` operation, assuming that the replacement may use any
53+
/// value from `neededValues`.
6354
static Operation *
64-
findValidInsertionPoint(Operation *emptyTensorOp,
55+
findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
6556
const SmallVector<Value> &neededValues) {
6657
DominanceInfo domInfo;
58+
Operation *candidateInsertionPoint = emptyTensorOp;
6759

68-
// Gather all possible insertion points: the location of `emptyTensorOp` and
69-
// right after the definition of each value in `neededValues`.
60+
// Gather all possible insertion points: the location of
61+
// `candidateInsertionPoint` and right after the definition of each value in
62+
// `neededValues`.
7063
SmallVector<Operation *> insertionPointCandidates;
71-
insertionPointCandidates.push_back(emptyTensorOp);
64+
insertionPointCandidates.push_back(candidateInsertionPoint);
7265
for (Value val : neededValues) {
7366
// Note: The anchor op is using all of `neededValues`, so:
7467
// * in case of a block argument: There must be at least one op in the block
@@ -90,8 +83,8 @@ findValidInsertionPoint(Operation *emptyTensorOp,
9083
if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
9184
neededValues))
9285
continue;
93-
// Check if the insertion point is before all uses.
94-
if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
86+
// Check if the insertion point is before the use to be replaced.
87+
if (!domInfo.dominates(insertionPoint, user))
9588
continue;
9689
return insertionPoint;
9790
}
@@ -103,8 +96,9 @@ findValidInsertionPoint(Operation *emptyTensorOp,
10396
LogicalResult mlir::bufferization::eliminateEmptyTensors(
10497
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
10598
OpBuilder::InsertionGuard g(rewriter);
106-
99+
llvm::DenseSet<OpOperand *> visitedOpOperands;
107100
op->walk([&](SubsetInsertionOpInterface op) {
101+
visitedOpOperands.clear();
108102
OpOperand &source = op.getSourceOperand();
109103
// Skip operands that do not bufferize inplace. "tensor.empty" could still
110104
// be replaced, but the transformation may not be beneficial.
@@ -131,16 +125,28 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
131125
config.followSameTypeOrCastsOnly = true;
132126
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
133127
source.get(), /*condition=*/
134-
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
135-
config);
128+
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
129+
&visitedOpOperands);
136130

137131
for (Value v : emptyTensors) {
138132
Operation *emptyTensorOp = v.getDefiningOp();
139133

134+
// Find the use to be replaced from the use-def chain.
135+
auto iter = llvm::find_if(
136+
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
137+
return llvm::count(emptyTensorOp->getUses(), *opOperand);
138+
});
139+
// This could be achieved when a use of `emptyTensorOp` is being
140+
// consumed by `SubsetInsertionOpInterface`'s source directly.
141+
if (iter == visitedOpOperands.end())
142+
continue;
143+
OpOperand *useToBeReplaced = *iter;
144+
Operation *user = useToBeReplaced->getOwner();
145+
140146
// Find a suitable insertion point. If no suitable insertion point for
141147
// the replacement can be found, skip this replacement.
142148
Operation *insertionPoint =
143-
findValidInsertionPoint(emptyTensorOp, neededValues);
149+
findValidInsertionPoint(emptyTensorOp, user, neededValues);
144150
if (!insertionPoint)
145151
continue;
146152

@@ -159,8 +165,10 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
159165
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
160166
replacement);
161167
}
162-
// Replace the tensor::EmptyOp.
163-
rewriter.replaceOp(emptyTensorOp, replacement);
168+
// Replace the specific use of the tensor::EmptyOp.
169+
rewriter.modifyOpInPlace(user, [&]() {
170+
user->setOperand(useToBeReplaced->getOperandNumber(), replacement);
171+
});
164172
state.resetCache();
165173
}
166174

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<
5555
// CHECK: tensor.extract_slice
5656
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
5757
%cst = arith.constant 0.000000e+00 : f32
58+
// CHECK: bufferization.alloc_tensor(%arg1)
5859
%0 = tensor.empty(%arg1) : tensor<?xf32>
5960

6061
// CHECK: bufferization.alloc_tensor(%arg1)

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,103 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
365365
bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
366366
return
367367
}
368+
369+
// -----
370+
371+
// `EmptyTensorElimination` fails to find a valid insertion
372+
// point for the new injected `SubsetExtraction`.
373+
// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
374+
func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
375+
%cst_1 = arith.constant 1.0 : f32
376+
%cst_2 = arith.constant 2.0 : f32
377+
// CHECK: memref.alloc
378+
// CHECK: memref.alloc
379+
// CHECK: memref.alloc
380+
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
381+
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
382+
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
383+
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
384+
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
385+
// CHECK: memref.copy
386+
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
387+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
388+
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
389+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
390+
return %inserted_slice_2 : tensor<5x6x128xf32>
391+
}
392+
393+
// -----
394+
395+
// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor
396+
func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
397+
%cst_1 = arith.constant 1.0 : f32
398+
%cst_2 = arith.constant 2.0 : f32
399+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
400+
// CHECK: memref.alloc
401+
// CHECK-NOT: memref.alloc
402+
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
403+
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
404+
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
405+
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
406+
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
407+
// CHECK: memref.copy
408+
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
409+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
410+
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
411+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
412+
return %inserted_slice_2 : tensor<5x6x128xf32>
413+
}
414+
415+
// -----
416+
417+
// `EmptyTensorElimination` will replace the specific use of the tensor
418+
// empty with the new injected `SubsetExtraction`, i.e. the specific use
419+
// which has been tracked.
420+
421+
// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty
422+
// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty
423+
func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
424+
%cst_1 = arith.constant 1.0 : f32
425+
%cst_2 = arith.constant 2.0 : f32
426+
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
427+
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
428+
// CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
429+
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
430+
// CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
431+
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
432+
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
433+
// CHECK: memref.copy
434+
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
435+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
436+
// CHECK-NOT: memref.copy
437+
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
438+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
439+
return %inserted_slice_2 : tensor<5x6x128xf32>
440+
}
441+
442+
// -----
443+
444+
// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read
445+
// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read
446+
func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: tensor<5x6x128xf32> , %arg2: tensor<5x6x64xf32>)
447+
-> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) {
448+
%cst_1 = arith.constant 1.0 : f32
449+
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
450+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x64xf32>
451+
// CHECK-NOT: memref.alloc
452+
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
453+
%res_2 = linalg.generic{
454+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
455+
iterator_types = ["parallel", "parallel", "parallel"]
456+
}
457+
ins(%empty_1 : tensor<5x6x64xf32>)
458+
outs(%arg2 :tensor<5x6x64xf32>) {
459+
^bb0(%in: f32, %out: f32):
460+
%res = arith.addf %in, %in : f32
461+
linalg.yield %res : f32
462+
} -> tensor<5x6x64xf32>
463+
// CHECK-NOT: memref.copy
464+
%inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1]
465+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
466+
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
467+
}

0 commit comments

Comments
 (0)