Skip to content

Commit c95fcd3

Browse files
[mlir][bufferization] Remove resolveUsesInRepetitiveRegions (llvm#67927)
The bufferization analysis has been improved over the last months and this workaround is no longer needed.
1 parent 6a01da4 commit c95fcd3

File tree

3 files changed

+16
-91
lines changed

3 files changed

+16
-91
lines changed

mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -26,82 +26,9 @@ namespace bufferization {
2626
using namespace mlir;
2727
using namespace mlir::bufferization;
2828

29-
/// Resolve all operands that are also used inside of repetitive regions of the
30-
/// same op. Such cases are not fully supported by One-Shot Bufferize.
31-
///
32-
/// E.g.:
33-
/// %r = scf.for ... iter_args(%t = %tensor) -> tensor<?xf32> {
34-
/// "some_use"(%tensor)
35-
/// ...
36-
/// }
37-
///
38-
/// Is converted to:
39-
/// %tensor_copy = bufferization.alloc_tensor copy(%tensor)
40-
/// %r = scf.for ... iter_args(%t = %tensor) -> tensor<?xf32> {
41-
/// "some_use"(%tensor_copy)
42-
/// ...
43-
/// }
44-
static void
45-
resolveUsesInRepetitiveRegions(Operation *op,
46-
const BufferizationOptions &options) {
47-
IRRewriter rewriter(op->getContext());
48-
AnalysisState state(options);
49-
50-
// Look for repetitive ops (loops).
51-
op->walk([&](BufferizableOpInterface bufferizableOp) {
52-
// Skip filtered ops.
53-
if (!options.isOpAllowed(bufferizableOp.getOperation()))
54-
return WalkResult::advance();
55-
56-
// Find all operands that are also used inside of a repetitive region of
57-
// this op.
58-
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
59-
Value operand = opOperand.get();
60-
// Skip non-tensor operands.
61-
if (!isa<TensorType>(operand.getType()))
62-
continue;
63-
// Skip operands that do not bufferize to memory writes.
64-
if (!bufferizableOp.bufferizesToMemoryWrite(opOperand, state))
65-
continue;
66-
67-
// Gather all uses inside repetitive regions.
68-
SmallVector<OpOperand *> usesInsideRegion;
69-
for (OpOperand &use : operand.getUses()) {
70-
Operation *owner = use.getOwner();
71-
if (!bufferizableOp->isProperAncestor(owner))
72-
continue;
73-
for (Region &r : bufferizableOp->getRegions()) {
74-
if (r.findAncestorOpInRegion(*owner) &&
75-
bufferizableOp.isRepetitiveRegion(r.getRegionNumber())) {
76-
usesInsideRegion.push_back(&use);
77-
break;
78-
}
79-
}
80-
}
81-
// Nothing to do if the operand is not used inside a repetitive region.
82-
if (usesInsideRegion.empty())
83-
continue;
84-
85-
// Insert a tensor copy and replace all uses inside of repetitive regions.
86-
rewriter.setInsertionPoint(bufferizableOp);
87-
auto tensorCopy = rewriter.create<AllocTensorOp>(
88-
bufferizableOp->getLoc(), cast<TensorType>(operand.getType()),
89-
/*dynamicSizes=*/ValueRange(),
90-
/*copy=*/operand, /*memory_space=*/IntegerAttr());
91-
for (OpOperand *use : usesInsideRegion)
92-
use->set(tensorCopy);
93-
}
94-
95-
return WalkResult::advance();
96-
});
97-
}
98-
9929
LogicalResult mlir::bufferization::insertTensorCopies(
10030
Operation *op, const OneShotBufferizationOptions &options,
10131
BufferizationStatistics *statistics) {
102-
// Preprocessing: Resolve currently unsupported bufferization cases.
103-
resolveUsesInRepetitiveRegions(op, options);
104-
10532
OneShotAnalysisState state(op, options);
10633
// Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
10734
// analysis depending on whether function boundary bufferization is enabled or

mlir/test/Dialect/Linalg/one-shot-bufferize.mlir

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ func.func @matmul(
160160
%c16 = arith.constant 16 : index
161161

162162
// Hoisted alloc.
163-
// CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<128x192xf32>
164-
// CHECK: memref.copy %[[C]], %[[ALLOC]]
163+
// CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x16xf32>
165164

166165
// CHECK: scf.for %[[I:.*]] =
167166
%0 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %C) -> (tensor<128x192xf32>) {
@@ -173,14 +172,12 @@ func.func @matmul(
173172
%3 = tensor.extract_slice %B[0, %arg5] [256, 16] [1, 1] :
174173
tensor<256x192xf32> to tensor<256x16xf32>
175174

176-
// C was already replaced with a copy by preprocessing, so no copy is
177-
// needed here.
178-
// CHECK: %[[C_SLICE:.*]] = memref.subview %[[ALLOC]]
175+
// Insert an artificial out-of-place buffer by extracting from %C instead
176+
// of %arg6.
179177
%4 = tensor.extract_slice %C[%arg3, %arg5] [8, 16] [1, 1] :
180178
tensor<128x192xf32> to tensor<8x16xf32>
181179

182-
// linalg.fill is inplace.
183-
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[C_SLICE]]
180+
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[ALLOC]]
184181
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<8x16xf32>) -> tensor<8x16xf32>
185182

186183
// CHECK: scf.for %[[K:.*]] =
@@ -191,7 +188,7 @@ func.func @matmul(
191188
tensor<256x16xf32> to tensor<32x16xf32>
192189

193190
// linalg.matmul is inplace as well as the enclosing scf.for.
194-
// CHECK: linalg.matmul ins({{.*}} outs(%[[C_SLICE]]
191+
// CHECK: linalg.matmul ins({{.*}} outs(%[[ALLOC]]
195192
%10 = linalg.matmul ins(%8, %9 : tensor<8x32xf32>, tensor<32x16xf32>)
196193
outs(%arg8 : tensor<8x16xf32>)
197194
-> tensor<8x16xf32>
@@ -202,7 +199,7 @@ func.func @matmul(
202199
// that is not in place. So we must insert a copy of the small buffer into
203200
// the bigger buffer.
204201
// CHECK: %[[T:.*]] = memref.subview %[[C]][%[[I]], %[[J]]] [8, 16] [1, 1]
205-
// CHECK: memref.copy %[[C_SLICE]], %[[T]]
202+
// CHECK: memref.copy %[[ALLOC]], %[[T]]
206203
%7 = tensor.insert_slice %6 into %arg6[%arg3, %arg5] [8, 16] [1, 1] :
207204
tensor<8x16xf32> into tensor<128x192xf32>
208205

mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,10 @@ func.func @scf_execute_region_yield_non_equivalent(%i: index, %j: index) -> f32
253253
// CHECK-SAME: %[[t:.*]]: memref<?xf32
254254
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}})
255255
// CHECK: memref.copy %[[t]], %[[alloc]]
256-
// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[t]])
256+
// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[alloc]])
257257
// CHECK-DAG: %[[alloc2:.*]] = memref.alloc(%{{.*}})
258-
// CHECK: memref.copy %[[alloc]], %[[alloc2]]
259-
// CHECK: %[[alloc2_casted:.*]] = memref.cast %[[alloc2]]
260-
// CHECK: scf.yield %[[alloc2_casted]]
258+
// CHECK: memref.copy %[[t]], %[[alloc2]]
259+
// CHECK: scf.yield %[[alloc2]]
261260
// CHECK: return %[[for]]
262261
func.func @scf_for_yield_non_equivalent(
263262
%t: tensor<?xf32>, %lb : index, %ub : index, %step : index) -> tensor<?xf32> {
@@ -606,9 +605,9 @@ func.func @scf_foreach_private_var(%t: tensor<10xf32>) -> f32 {
606605

607606
// CHECK: scf.forall (%{{.*}}) in (2) {
608607

609-
// Load from the copy and store into the shared output.
610-
// CHECK: %[[subview:.*]] = memref.subview %[[t]]
611-
// CHECK: memref.load %[[t_copy]]
608+
// Load from the original and store into the copy.
609+
// CHECK: %[[subview:.*]] = memref.subview %[[t_copy]]
610+
// CHECK: memref.load %[[t]]
612611
// CHECK: memref.store %{{.*}}, %[[subview]]
613612
%0 = scf.forall (%tid) in (%c2) shared_outs(%o = %t) -> tensor<10xf32> {
614613
%offset = arith.muli %c5, %tid : index
@@ -752,14 +751,16 @@ func.func @scf_for_yield_alias_of_non_equivalent(%sz: index) -> tensor<?xf32> {
752751
// CHECK: scf.for
753752
%r = scf.for %iv = %c0 to %sz step %c1 iter_args(%t = %0) -> tensor<?xf32> {
754753
%iv_sub = arith.subi %iv, %c1 : index
755-
// CHECK: memref.subview %[[generate_copy]]
754+
// CHECK: memref.subview %[[generate]]
756755
%ll = tensor.extract_slice %0[%iv_sub][%sz][1] : tensor<?xf32> to tensor<?xf32>
757756
%l = tensor.extract %ll[%c0] : tensor<?xf32>
758757
%double = arith.mulf %cst, %l : f32
759-
// CHECK: memref.store %{{.*}}, %[[generate]]
758+
// CHECK: memref.store %{{.*}}, %[[generate_copy]]
760759
%s = tensor.insert %double into %t[%iv] : tensor<?xf32>
761760
scf.yield %s : tensor<?xf32>
762761
}
762+
763+
// CHECK: return %[[generate_copy]]
763764
return %r : tensor<?xf32>
764765
}
765766

0 commit comments

Comments
 (0)