Skip to content

Commit e9fb4dc

Browse files
[mlir][linalg][bufferize] Remove buffer equivalence from bufferize
Remove all function calls related to buffer equivalence from bufferize implementations. Add a new PostAnalysisStep for scf.for that ensures that yielded values are equivalent to the corresponding BBArgs. (This was previously checked in `bufferize`.) This will be relaxed in a subsequent commit. Note: This commit changes two test cases. These were broken by design and should not have passed. With the new scf.for PostAnalysisStep, this bug was fixed. Differential Revision: https://reviews.llvm.org/D114927
1 parent a96d828 commit e9fb4dc

File tree

9 files changed

+49
-60
lines changed

9 files changed

+49
-60
lines changed

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ namespace linalg {
1919
namespace comprehensive_bufferize {
2020
namespace scf_ext {
2121

22+
/// Equivalence analysis for scf.for. Raise an error if iter_args are not
23+
/// equivalent to their corresponding loop yield values.
24+
struct AssertDestinationPassingStyle : public PostAnalysisStep {
25+
LogicalResult run(FuncOp funcOp, BufferizationState &state,
26+
SmallVector<Operation *> &newOps) override;
27+
};
28+
2229
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
2330

2431
} // namespace scf_ext

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ struct ConstantOpInterface
3737
auto globalMemref = globalCreator.getGlobalFor(constantOp);
3838
Value memref = b.create<memref::GetGlobalOp>(
3939
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
40-
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
4140
state.mapBuffer(constantOp, memref);
4241

4342
return success();

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,7 @@ void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
141141

142142
/// Return `true` if a value was marked as in-place bufferized.
143143
bool BufferizationAliasInfo::isInPlace(OpResult opResult) const {
144-
bool inplace = inplaceBufferized.contains(opResult);
145-
#ifndef NDEBUG
146-
if (inplace) {
147-
auto bufferizableOp =
148-
dyn_cast<BufferizableOpInterface>(opResult.getDefiningOp());
149-
assert(bufferizableOp &&
150-
"expected that in-place bufferized op is bufferizable");
151-
SmallVector<OpOperand *> operands =
152-
bufferizableOp.getAliasingOpOperand(opResult);
153-
for (OpOperand *operand : operands)
154-
assert(areAliasingBufferizedValues(operand->get(), opResult) &&
155-
"expected that in-place bufferized OpResult aliases with "
156-
"aliasing OpOperand");
157-
}
158-
#endif // NDEBUG
159-
return inplace;
144+
return inplaceBufferized.contains(opResult);
160145
}
161146

162147
/// Set the inPlace bufferization spec to true.
@@ -593,7 +578,6 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
593578
Value casted = allocated.getValue();
594579
if (memRefType && memRefType != allocMemRefType) {
595580
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
596-
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
597581
}
598582

599583
// 2. Create memory deallocation.

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,6 @@ struct TiledLoopOpInterface
253253
return failure();
254254

255255
// Insert mapping and aliasing info.
256-
state.aliasInfo.createAliasInfoEntry(resultBuffer);
257-
state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
258256
state.mapBuffer(opResult, resultBuffer);
259257

260258
// Insert new operand and bbArg.
@@ -263,9 +261,6 @@ struct TiledLoopOpInterface
263261
body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
264262
BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
265263
// Insert mapping and aliasing info.
266-
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
267-
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
268-
newBufferBBArg);
269264
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
270265

271266
// Set operand of `linalg.yield` to the bbArg so it just canonicalizes
@@ -303,9 +298,6 @@ struct TiledLoopOpInterface
303298
BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
304299

305300
// Insert mapping and aliasing info.
306-
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
307-
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
308-
newBufferBBArg);
309301
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
310302

311303
// Increment indices.

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
223223
BufferizationState &state) {
224224
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
225225
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
226-
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
227226

228227
// If nothing to do then we are done.
229228
if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
@@ -321,15 +320,12 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
321320
auto castOp = b.create<memref::CastOp>(
322321
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
323322
toMemrefOp.memref().replaceAllUsesWith(castOp);
324-
aliasInfo.insertNewBufferEquivalence(castOp.dest(),
325-
toMemrefOp.memref());
326323
}
327324
}
328325
// Replace all remaining uses by a to_tensor.
329326
if (!bbArg.use_empty()) {
330327
auto toTensorOp =
331328
b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
332-
aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg);
333329
bbArg.replaceAllUsesWith(toTensorOp);
334330
}
335331
frontBlock.eraseArgument(0);
@@ -562,7 +558,6 @@ struct CallOpInterface
562558
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
563559
// Add CallOp operand/result equivalence: this is interprocedural
564560
// info.
565-
state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
566561
state.mapBuffer(oldRes, buffer);
567562
// Add a ToTensorOp to kill all uses of the CallOp return.
568563
// Replace all uses of the CallOp results so we can erase the CallOp.
@@ -572,7 +567,6 @@ struct CallOpInterface
572567
b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
573568
oldRes.replaceAllUsesWith(toTensorOp);
574569
// Add new op equivalence info.
575-
state.aliasInfo.insertNewBufferEquivalence(toTensorOp, buffer);
576570
state.mapBuffer(toTensorOp, buffer);
577571
continue;
578572
}
@@ -615,7 +609,6 @@ struct CallOpInterface
615609
Value castBuffer =
616610
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
617611
// Add new op equivalence info.
618-
state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
619612
state.mapBuffer(tensorOperand, castBuffer);
620613
buffer = castBuffer;
621614
}
@@ -663,7 +656,6 @@ struct ReturnOpInterface
663656
Value returnTensor = b.create<bufferization::ToTensorOp>(
664657
returnOp.getLoc(), v);
665658
operand.set(returnTensor);
666-
state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
667659
state.mapBuffer(returnTensor, v);
668660
}
669661
return success();
@@ -690,7 +682,6 @@ struct FuncOpInterface
690682
: getContiguousOrUnrankedMemRefType(tensorType);
691683
Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
692684
memRefType, bbArg);
693-
state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
694685
state.mapBuffer(bbArg, bufferCast);
695686
}
696687

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ struct IfOpInterface
147147
if (!resultBuffer)
148148
return failure();
149149

150-
state.aliasInfo.createAliasInfoEntry(resultBuffer);
151150
state.mapBuffer(opResult, resultBuffer);
152151
}
153152

@@ -237,8 +236,6 @@ struct ForOpInterface
237236

238237
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
239238
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
240-
state.aliasInfo.createAliasInfoEntry(resultBuffer);
241-
state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
242239
state.mapBuffer(bbArg, resultBuffer);
243240
state.mapBuffer(opResult, resultBuffer);
244241
}
@@ -257,15 +254,6 @@ struct ForOpInterface
257254
OpOperand &forOperand = forOp.getOpOperandForResult(
258255
forOp->getResult(operand.getOperandNumber()));
259256
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
260-
if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
261-
bbArg)) {
262-
// TODO: this could get resolved with copies but it can also turn into
263-
// swaps so we need to be careful about order of copies.
264-
return yieldOp->emitError()
265-
<< "Yield operand #" << operand.getOperandNumber()
266-
<< " does not bufferize to an equivalent buffer to the matching"
267-
<< " enclosing scf::for operand";
268-
}
269257

270258
// Buffers are equivalent so the work is already done and we just yield
271259
// the bbArg so that it later canonicalizes away.
@@ -275,6 +263,41 @@ struct ForOpInterface
275263
}
276264
};
277265

266+
LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
267+
AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state,
268+
SmallVector<Operation *> &newOps) {
269+
LogicalResult status = success();
270+
funcOp->walk([&](scf::YieldOp yieldOp) {
271+
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
272+
if (!forOp)
273+
return WalkResult::advance();
274+
275+
for (OpOperand &operand : yieldOp->getOpOperands()) {
276+
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
277+
if (!tensorType)
278+
continue;
279+
280+
OpOperand &forOperand = forOp.getOpOperandForResult(
281+
forOp->getResult(operand.getOperandNumber()));
282+
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
283+
if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
284+
bbArg)) {
285+
// TODO: this could get resolved with copies but it can also turn into
286+
// swaps so we need to be careful about order of copies.
287+
status =
288+
yieldOp->emitError()
289+
<< "Yield operand #" << operand.getOperandNumber()
290+
<< " does not bufferize to an equivalent buffer to the matching"
291+
<< " enclosing scf::for operand";
292+
return WalkResult::interrupt();
293+
}
294+
}
295+
296+
return WalkResult::advance();
297+
});
298+
return status;
299+
}
300+
278301
struct YieldOpInterface
279302
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
280303
scf::YieldOp> {

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ struct CastOpInterface
8080
castOp.getResult().getType(), layout, memorySpace);
8181
Value res =
8282
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
83-
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
8483
state.mapBuffer(castOp.getResult(), res);
8584
return success();
8685
}
@@ -233,7 +232,6 @@ struct InsertOpInterface
233232
b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
234233
insertOp.indices());
235234
state.mapBuffer(insertOp, destMemref);
236-
state.aliasInfo.insertNewBufferAlias(insertOp, destMemref);
237235
return success();
238236
}
239237

@@ -421,8 +419,6 @@ struct InsertSliceOpInterface
421419
Value subView = b.create<memref::SubViewOp>(
422420
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
423421
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
424-
// Insert new alias.
425-
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
426422
// Copy tensor.
427423
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
428424
state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
9696
// TODO: Find a way to enable this step automatically when bufferizing tensor
9797
// dialect ops.
9898
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
99+
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
99100

100101
ModuleOp moduleOp = getOperation();
101102
applyEnablingTransformations(moduleOp);

mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ func @reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
11131113

11141114
// Read from %t1 via alias %e.
11151115
%v2 = vector.transfer_read %e[%s], %cst : tensor<?xf32>, vector<5xf32>
1116-
scf.yield %e, %v2 : tensor<?xf32>, vector<5xf32>
1116+
scf.yield %t2, %v2 : tensor<?xf32>, vector<5xf32>
11171117
}
11181118
// CHECK: __inplace_results_attr__ = ["true", "false"]
11191119

@@ -1154,14 +1154,10 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
11541154
// This loop does not read from %t1. It only writes to it.
11551155
// CHECK: scf.for
11561156
%r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
1157-
// CHECK: tensor.extract_slice
1158-
// CHECK-SAME: __inplace_results_attr__ = ["true"]
1159-
%e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
1160-
1161-
// Write to %t1 via alias. (Overwrite %t3.)
1157+
// Write to %t1 via %t2. (Overwrite %t3.)
11621158
// CHECK: linalg.generic
11631159
// CHECK-SAME: __inplace_results_attr__ = ["true"]
1164-
%o2 = linalg.generic #trait outs (%e : tensor<?xf32>) {
1160+
%o2 = linalg.generic #trait outs (%t2 : tensor<?xf32>) {
11651161
^bb(%0: f32) :
11661162
linalg.yield %cst : f32
11671163
} -> (tensor<?xf32>)
@@ -1172,8 +1168,8 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
11721168
}
11731169

11741170
// Use %t3 in some way without reading it, so that it does not get DCE'd.
1175-
// CHECK: linalg.generic
1176-
// CHECK-SAME: __inplace_results_attr__ = ["true"]
1171+
// CHECK: linalg.generic
1172+
// CHECK-SAME: __inplace_results_attr__ = ["true"]
11771173
%o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
11781174
^bb(%0: f32) :
11791175
linalg.yield %cst : f32

0 commit comments

Comments
 (0)