Skip to content

Commit dc7ad19

Browse files
[mlir][bufferize][NFC] Optimize read-only tensor detection
Check alias sets instead of traversing the IR. Differential Revision: https://reviews.llvm.org/D143500
1 parent cedfd27 commit dc7ad19

File tree

1 file changed

+27
-47
lines changed

1 file changed

+27
-47
lines changed

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -724,62 +724,42 @@ static void annotateNonWritableTensor(Value value) {
724724
}
725725
}
726726

727-
/// Check the reverse SSA use-def chain (following aliasing OpOperands) for
728-
/// non-writable tensor values. Stop searching when an out-of-place bufferized
729-
/// OpOperand was found (or when the OpOperand was not bufferized yet).
730-
/// `currentOpOperand` is assumed to be in-place, even if that decision was not
731-
/// materialized in `aliasInfo` yet.
732-
static bool
733-
hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
734-
const OneShotAnalysisState &state) {
735-
SmallVector<Value> worklist;
736-
worklist.push_back(value);
737-
while (!worklist.empty()) {
738-
Value nextVal = worklist.pop_back_val();
739-
if (!state.isWritable(nextVal)) {
740-
if (state.getOptions().printConflicts)
741-
annotateNonWritableTensor(nextVal);
742-
return true;
743-
}
744-
745-
// If `nextVal` is not a BlockArgument: End of use-def chain reached.
746-
auto opResult = nextVal.dyn_cast<OpResult>();
747-
if (!opResult)
748-
continue;
749-
750-
// Follow reverse SSA use-def chain.
751-
AliasingOpOperandList aliasingOpOperands =
752-
state.getAliasingOpOperands(opResult);
753-
for (OpOperand *opOperand : aliasingOpOperands)
754-
if (state.isInPlace(*opOperand) || currentOpOperand == opOperand)
755-
worklist.push_back(opOperand->get());
756-
}
757-
return false;
758-
}
759-
760727
/// Return true if bufferizing `operand` inplace would create a write to a
761728
/// non-writable buffer.
762729
static bool
763730
wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
764731
OneShotAnalysisState &state,
765732
bool checkConsistencyOnly = false) {
766-
// Collect writes of all aliases of OpOperand and OpResult.
767-
DenseSet<OpOperand *> usesWrite;
768-
getAliasingInplaceWrites(usesWrite, operand.get(), state);
769-
for (OpResult result : state.getAliasingOpResults(operand)) {
770-
getAliasingInplaceWrites(usesWrite, result, state);
733+
bool foundWrite =
734+
!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
735+
736+
if (!foundWrite) {
737+
// Collect writes of all aliases of OpOperand and OpResult.
738+
DenseSet<OpOperand *> usesWrite;
739+
getAliasingInplaceWrites(usesWrite, operand.get(), state);
740+
for (OpResult result : state.getAliasingOpResults(operand))
741+
getAliasingInplaceWrites(usesWrite, result, state);
742+
foundWrite = !usesWrite.empty();
771743
}
772-
if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
773-
usesWrite.insert(&operand);
774744

775-
// Assuming that `operand` bufferizes in-place: For each write (to each
776-
// alias), check if there is a non-writable tensor in the reverse SSA use-def
777-
// chain.
778-
for (OpOperand *uWrite : usesWrite) {
779-
if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, state)) {
780-
LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
781-
return true;
745+
if (!foundWrite)
746+
return false;
747+
748+
// Look for a read-only tensor among all aliases.
749+
bool foundReadOnly = false;
750+
auto checkReadOnly = [&](Value v) {
751+
if (!state.isWritable(v)) {
752+
foundReadOnly = true;
753+
if (state.getOptions().printConflicts)
754+
annotateNonWritableTensor(v);
782755
}
756+
};
757+
state.applyOnAliases(operand.get(), checkReadOnly);
758+
for (OpResult result : state.getAliasingOpResults(operand))
759+
state.applyOnAliases(result, checkReadOnly);
760+
if (foundReadOnly) {
761+
LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
762+
return true;
783763
}
784764

785765
return false;

0 commit comments

Comments
 (0)