Skip to content

Commit e761c49

Browse files
[mlir][linalg][bufferize][NFC] Utilize isWritable for FuncOps
This is a cleanup of ModuleBufferization. Instead of storing information about writable function arguments in BufferizationAliasInfo, we can use isWritable and make the decision there, based on dialect-specifc bufferization state. Differential Revision: https://reviews.llvm.org/D114930
1 parent 9873ef4 commit e761c49

File tree

9 files changed

+89
-88
lines changed

9 files changed

+89
-88
lines changed

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,6 @@ class BufferizationAliasInfo {
172172
/// Apply `fun` to all aliases of `v`.
173173
void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
174174

175-
// TODO: Move these out of BufferizationAliasInfo.
176-
/// Return true if the value is known to bufferize to writable memory.
177-
bool bufferizesToWritableMemory(Value v) const;
178-
179-
/// Specify that the value is known to bufferize to writable memory.
180-
void setBufferizesToWritableMemory(Value v);
181-
182175
/// Mark a value as in-place bufferized.
183176
void markInPlace(OpResult v) { inplaceBufferized.insert(v); }
184177

@@ -200,9 +193,6 @@ class BufferizationAliasInfo {
200193
/// Check that aliasInfo for `v` exists and return a reference to it.
201194
EquivalenceClassRangeType getAliases(Value v) const;
202195

203-
/// Set of tensors that are known to bufferize to writable memory.
204-
llvm::DenseSet<Value> bufferizeToWritableMemory;
205-
206196
/// Set of all OpResults that were decided to bufferize in-place.
207197
llvm::DenseSet<OpResult> inplaceBufferized;
208198

@@ -429,7 +419,9 @@ struct AllocationHoistingBarrierOnly
429419
return BufferRelation::None;
430420
}
431421

432-
bool isWritable(Operation *op, Value value) const { return false; }
422+
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
423+
return false;
424+
}
433425

434426
LogicalResult bufferize(Operation *op, OpBuilder &b,
435427
BufferizationState &state) const {

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
226226
}],
227227
/*retType=*/"bool",
228228
/*methodName=*/"isWritable",
229-
/*args=*/(ins "Value":$value),
229+
/*args=*/(ins "Value":$value,
230+
"BufferizationState &":$state),
230231
/*methodBody=*/"",
231232
/*defaultImplementation=*/[{
232233
return value.isa<OpResult>();

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct ConstantOpInterface
4242
return success();
4343
}
4444

45-
bool isWritable(Operation *op, Value value) const {
45+
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
4646
// Memory locations returned by memref::GetGlobalOp may not be written to.
4747
assert(value.isa<OpResult>());
4848
return false;

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,6 @@ void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
130130
equivalentInfo.unionSets(newValue, alias);
131131
}
132132

133-
bool BufferizationAliasInfo::bufferizesToWritableMemory(Value v) const {
134-
return bufferizeToWritableMemory.count(v) > 0;
135-
}
136-
137-
/// Specify that the value is known to bufferize to writable memory.
138-
void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
139-
bufferizeToWritableMemory.insert(v);
140-
}
141-
142133
/// Return `true` if a value was marked as in-place bufferized.
143134
bool BufferizationAliasInfo::isInPlace(OpResult opResult) const {
144135
return inplaceBufferized.contains(opResult);

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ struct ToTensorOpInterface
7171
return success();
7272
}
7373

74-
bool isWritable(Operation *op, Value value) const {
74+
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
7575
// It is unknown whether the MemRef operand is writable or not.
7676
return false;
7777
}

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,6 @@ static void setInPlaceOpResult(OpResult opResult, bool inPlace) {
171171
OpBuilder(op).getStrArrayAttr(inPlaceVector));
172172
}
173173

174-
/// Set the attribute that triggers inplace bufferization on a FuncOp argument
175-
/// `bbArg`.
176-
static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) {
177-
auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
178-
funcOp.setArgAttr(bbArg.getArgNumber(),
179-
BufferizableOpInterface::kInplaceableAttrName,
180-
BoolAttr::get(bbArg.getContext(), inPlace));
181-
}
182-
183174
//===----------------------------------------------------------------------===//
184175
// Printing helpers.
185176
//===----------------------------------------------------------------------===//
@@ -258,25 +249,22 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
258249
/// Return true if, under current bufferization decisions, the buffer of `value`
259250
/// is not writable.
260251
static bool aliasesNonWritableBuffer(Value value,
261-
const BufferizationAliasInfo &aliasInfo) {
252+
const BufferizationAliasInfo &aliasInfo,
253+
BufferizationState &state) {
262254
LDBG("WRITABILITY ANALYSIS FOR " << printValueInfo(value) << "\n");
263255
bool foundNonWritableBuffer = false;
264256
aliasInfo.applyOnAliases(value, [&](Value v) {
265-
// Some values are known to be writable.
266-
if (aliasInfo.bufferizesToWritableMemory(v))
267-
return;
268-
269257
// Query BufferizableOpInterface to see if the OpResult is writable.
270258
// TODO: Out-of-place bufferized OpResult could be considered writable.
271259
if (auto bufferizableOp = v.getDefiningOp<BufferizableOpInterface>())
272-
if (bufferizableOp && bufferizableOp.isWritable(v))
260+
if (bufferizableOp && bufferizableOp.isWritable(v, state))
273261
return;
274262

275263
// Query BufferizableOpInterface to see if the BlockArgument is writable.
276264
if (auto bbArg = v.dyn_cast<BlockArgument>())
277265
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(
278266
bbArg.getOwner()->getParentOp()))
279-
if (bufferizableOp.isWritable(bbArg))
267+
if (bufferizableOp.isWritable(bbArg, state))
280268
return;
281269

282270
foundNonWritableBuffer = true;
@@ -515,7 +503,8 @@ bool wouldCreateReadAfterWriteInterference(
515503
/// a write to a non-writable buffer.
516504
static bool
517505
wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
518-
const BufferizationAliasInfo &aliasInfo) {
506+
const BufferizationAliasInfo &aliasInfo,
507+
BufferizationState &state) {
519508
#ifndef NDEBUG
520509
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
521510
assert(llvm::find(opOperands, &opOperand) != opOperands.end() &&
@@ -525,9 +514,10 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
525514
// Certain buffers are not writeable:
526515
// 1. A function bbArg that is not inplaceable or
527516
// 2. A constant op.
528-
assert(!aliasesNonWritableBuffer(opResult, aliasInfo) &&
517+
assert(!aliasesNonWritableBuffer(opResult, aliasInfo, state) &&
529518
"expected that opResult does not alias non-writable buffer");
530-
bool nonWritable = aliasesNonWritableBuffer(opOperand.get(), aliasInfo);
519+
bool nonWritable =
520+
aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
531521
if (!nonWritable)
532522
return false;
533523

@@ -547,10 +537,9 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
547537
//===----------------------------------------------------------------------===//
548538

549539
/// Determine if `operand` can be bufferized in-place with `result`.
550-
static LogicalResult
551-
bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
552-
BufferizationAliasInfo &aliasInfo,
553-
const DominanceInfo &domInfo) {
540+
static LogicalResult bufferizableInPlaceAnalysisImpl(
541+
OpOperand &operand, OpResult result, BufferizationAliasInfo &aliasInfo,
542+
BufferizationState &state, const DominanceInfo &domInfo) {
554543
#ifndef NDEBUG
555544
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
556545
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
@@ -565,7 +554,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
565554
<< printValueInfo(result) << '\n');
566555

567556
bool foundInterference =
568-
wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo) ||
557+
wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) ||
569558
wouldCreateReadAfterWriteInterference(operand, result, domInfo,
570559
aliasInfo);
571560

@@ -599,6 +588,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
599588
/// RaW dependence violations.
600589
static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
601590
BufferizationAliasInfo &aliasInfo,
591+
BufferizationState &state,
602592
const DominanceInfo &domInfo,
603593
unsigned analysisFuzzerSeed = 0) {
604594
if (analysisFuzzerSeed) {
@@ -615,8 +605,8 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
615605
if (opOperand.get().getType().isa<TensorType>())
616606
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
617607
if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
618-
if (failed(bufferizableInPlaceAnalysisImpl(opOperand, opResult,
619-
aliasInfo, domInfo)))
608+
if (failed(bufferizableInPlaceAnalysisImpl(
609+
opOperand, opResult, aliasInfo, state, domInfo)))
620610
return failure();
621611

622612
return success();
@@ -625,6 +615,7 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
625615
/// Analyze all ops that are contained in `op`.
626616
static LogicalResult inPlaceAnalysis(Operation *op,
627617
BufferizationAliasInfo &aliasInfo,
618+
BufferizationState &state,
628619
const DominanceInfo &domInfo,
629620
unsigned analysisFuzzerSeed = 0) {
630621
// Collect ops so we can build our own reverse traversal.
@@ -637,7 +628,7 @@ static LogicalResult inPlaceAnalysis(Operation *op,
637628
ops.push_back(op);
638629
});
639630

640-
return inPlaceAnalysis(ops, aliasInfo, domInfo, analysisFuzzerSeed);
631+
return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
641632
}
642633

643634
/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
@@ -712,15 +703,9 @@ static void
712703
annotateOpsWithBufferizationMarkers(Operation *op,
713704
const BufferizationAliasInfo &aliasInfo) {
714705
op->walk([&](Operation *op) {
715-
for (OpResult opResult : op->getResults()) {
706+
for (OpResult opResult : op->getResults())
716707
if (opResult.getType().isa<TensorType>())
717708
setInPlaceOpResult(opResult, aliasInfo.isInPlace(opResult));
718-
if (auto funcOp = dyn_cast<FuncOp>(op))
719-
for (BlockArgument bbArg : funcOp.getArguments())
720-
if (bbArg.getType().isa<TensorType>())
721-
setInPlaceFuncArgument(bbArg,
722-
aliasInfo.bufferizesToWritableMemory(bbArg));
723-
}
724709
});
725710
}
726711

@@ -739,8 +724,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
739724

740725
// If the analysis fails, just return.
741726
Operation *op = funcOp.getOperation();
742-
if (failed(
743-
inPlaceAnalysis(op, aliasInfo, domInfo, options.analysisFuzzerSeed)))
727+
if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
728+
options.analysisFuzzerSeed)))
744729
return failure();
745730
equivalenceAnalysis(op, aliasInfo);
746731

@@ -750,7 +735,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
750735
if (failed(step->run(funcOp, state, newOps)))
751736
return failure();
752737
// Analyze ops that were created by the PostAnalysisStep.
753-
if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
738+
if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
754739
return failure();
755740
equivalenceAnalysis(newOps, aliasInfo);
756741
}

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ struct TiledLoopOpInterface
193193
return BufferRelation::Equivalent;
194194
}
195195

196-
bool isWritable(Operation *op, Value value) const {
196+
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
197197
// Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed
198198
// inplace from the perspective of ops nested under:
199199
// 1. Either the matching iter operand is not bufferized inplace and an

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ struct ModuleBufferizationState : public DialectBufferizationState {
3636
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
3737
/// indices.
3838
DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
39+
40+
SmallVector<FuncOp> orderedFuncOps;
41+
42+
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
3943
};
4044
} // namespace
4145

@@ -689,6 +693,32 @@ struct FuncOpInterface
689693
return comprehensive_bufferize::bufferize(&funcOp.body(), state);
690694
}
691695

696+
/// Return `true` if the given function argument is writable.
697+
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
698+
auto funcOp = cast<FuncOp>(op);
699+
BlockArgument bbArg = value.dyn_cast<BlockArgument>();
700+
assert(bbArg && "expected BlockArgument");
701+
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
702+
703+
// In a first approximation:
704+
// =========================
705+
// If the function is called, we can allocate on the caller side which lets
706+
// us force inplace arguments at function boundaries.
707+
// TODO: do not rely on this behavior.
708+
if (moduleState.callerMap.find(funcOp) != moduleState.callerMap.end())
709+
return true;
710+
711+
// Set the function arguments marked with inplaceable to be known as
712+
// bufferizing to a writeable memory.
713+
BoolAttr inplaceAttr = funcOp.getArgAttrOfType<BoolAttr>(
714+
bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName);
715+
if (inplaceAttr && inplaceAttr.getValue())
716+
return true;
717+
718+
// All other function arguments are not writable.
719+
return false;
720+
}
721+
692722
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
693723
};
694724

@@ -704,46 +734,44 @@ void mlir::linalg::comprehensive_bufferize::std_ext::
704734
registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();
705735
}
706736

737+
/// Set the attribute that triggers inplace bufferization on a FuncOp argument
738+
/// `bbArg`.
739+
static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) {
740+
auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
741+
funcOp.setArgAttr(bbArg.getArgNumber(),
742+
BufferizableOpInterface::kInplaceableAttrName,
743+
BoolAttr::get(bbArg.getContext(), inPlace));
744+
}
745+
746+
/// Annotate the IR with the result of the analysis. For testing/debugging only.
747+
static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
748+
BufferizationState &state) {
749+
auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation());
750+
for (BlockArgument bbArg : funcOp.getArguments())
751+
if (bbArg.getType().isa<TensorType>())
752+
setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state));
753+
}
754+
707755
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
708756
ModuleOp moduleOp, const BufferizationOptions &options) {
709-
SmallVector<FuncOp> orderedFuncOps;
710-
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
711-
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
712-
return failure();
713-
714757
BufferizationState state(moduleOp, options);
715758
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
716759
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
717760

761+
if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
762+
moduleState.callerMap)))
763+
return failure();
764+
718765
// Interestingly, all function args that are not visible outside of a module
719766
// can be fully bufferized inplace by guaranteeing the CallOp is bufferized
720767
// inplace. Therefore, we just bufferize funcOp as if none of its results were
721768
// inplaceable, detect which operands are cloned internally and decide what to
722769
// do at call sites.
723-
for (FuncOp funcOp : orderedFuncOps) {
770+
for (FuncOp funcOp : moduleState.orderedFuncOps) {
724771
// No body => no analysis.
725772
if (funcOp.body().empty())
726773
continue;
727774

728-
// In a first approximation:
729-
// =========================
730-
// If the function is called, we can allocate on the caller side which lets
731-
// us force inplace arguments at function boundaries.
732-
// TODO: do not rely on this behavior.
733-
if (callerMap.find(funcOp) != callerMap.end())
734-
for (BlockArgument bbArg : funcOp.getArguments())
735-
if (bbArg.getType().isa<TensorType>())
736-
aliasInfo.setBufferizesToWritableMemory(bbArg);
737-
738-
// Set the function arguments marked with inplaceable to be known as
739-
// bufferizing to a writeable memory.
740-
for (BlockArgument bbArg : funcOp.getArguments()) {
741-
BoolAttr inplaceAttr = funcOp.getArgAttrOfType<BoolAttr>(
742-
bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName);
743-
if (inplaceAttr && inplaceAttr.getValue())
744-
aliasInfo.setBufferizesToWritableMemory(bbArg);
745-
}
746-
747775
// Register extra post analysis steps. These cannot be stored in `options`
748776
// because `options` is immutable.
749777
PostAnalysisStepList extraSteps;
@@ -755,12 +783,16 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
755783
// Analyze and bufferize funcOp.
756784
if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps)))
757785
return failure();
786+
787+
// Add annotations to function arguments.
788+
if (options.testAnalysisOnly)
789+
annotateOpsWithBufferizationMarkers(funcOp, state);
758790
}
759791

760792
if (options.testAnalysisOnly)
761793
return success();
762794

763-
for (FuncOp funcOp : orderedFuncOps) {
795+
for (FuncOp funcOp : moduleState.orderedFuncOps) {
764796
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
765797
// would be invalidated.
766798
if (failed(bufferizeFuncOpBoundary(funcOp, state)))

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ struct ForOpInterface
204204
return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
205205
}
206206

207-
bool isWritable(Operation *op, Value value) const {
207+
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
208208
// Interestingly, scf::ForOp's bbArg can **always** be viewed
209209
// inplace from the perspective of ops nested under:
210210
// 1. Either the matching iter operand is not bufferized inplace and an

0 commit comments

Comments
 (0)