@@ -724,62 +724,42 @@ static void annotateNonWritableTensor(Value value) {
724
724
}
725
725
}
726
726
727
- // / Check the reverse SSA use-def chain (following aliasing OpOperands) for
728
- // / non-writable tensor values. Stop searching when an out-of-place bufferized
729
- // / OpOperand was found (or when the OpOperand was not bufferized yet).
730
- // / `currentOpOperand` is assumed to be in-place, even if that decision was not
731
- // / materialized in `aliasInfo` yet.
732
- static bool
733
- hasPrecedingAliasingNonWritableTensor (Value value, OpOperand *currentOpOperand,
734
- const OneShotAnalysisState &state) {
735
- SmallVector<Value> worklist;
736
- worklist.push_back (value);
737
- while (!worklist.empty ()) {
738
- Value nextVal = worklist.pop_back_val ();
739
- if (!state.isWritable (nextVal)) {
740
- if (state.getOptions ().printConflicts )
741
- annotateNonWritableTensor (nextVal);
742
- return true ;
743
- }
744
-
745
- // If `nextVal` is not a BlockArgument: End of use-def chain reached.
746
- auto opResult = nextVal.dyn_cast <OpResult>();
747
- if (!opResult)
748
- continue ;
749
-
750
- // Follow reverse SSA use-def chain.
751
- AliasingOpOperandList aliasingOpOperands =
752
- state.getAliasingOpOperands (opResult);
753
- for (OpOperand *opOperand : aliasingOpOperands)
754
- if (state.isInPlace (*opOperand) || currentOpOperand == opOperand)
755
- worklist.push_back (opOperand->get ());
756
- }
757
- return false ;
758
- }
759
-
760
727
// / Return true if bufferizing `operand` inplace would create a write to a
761
728
// / non-writable buffer.
762
729
static bool
763
730
wouldCreateWriteToNonWritableBuffer (OpOperand &operand,
764
731
OneShotAnalysisState &state,
765
732
bool checkConsistencyOnly = false ) {
766
- // Collect writes of all aliases of OpOperand and OpResult.
767
- DenseSet<OpOperand *> usesWrite;
768
- getAliasingInplaceWrites (usesWrite, operand.get (), state);
769
- for (OpResult result : state.getAliasingOpResults (operand)) {
770
- getAliasingInplaceWrites (usesWrite, result, state);
733
+ bool foundWrite =
734
+ !checkConsistencyOnly && state.bufferizesToMemoryWrite (operand);
735
+
736
+ if (!foundWrite) {
737
+ // Collect writes of all aliases of OpOperand and OpResult.
738
+ DenseSet<OpOperand *> usesWrite;
739
+ getAliasingInplaceWrites (usesWrite, operand.get (), state);
740
+ for (OpResult result : state.getAliasingOpResults (operand))
741
+ getAliasingInplaceWrites (usesWrite, result, state);
742
+ foundWrite = !usesWrite.empty ();
771
743
}
772
- if (!checkConsistencyOnly && state.bufferizesToMemoryWrite (operand))
773
- usesWrite.insert (&operand);
774
744
775
- // Assuming that `operand` bufferizes in-place: For each write (to each
776
- // alias), check if there is a non-writable tensor in the reverse SSA use-def
777
- // chain.
778
- for (OpOperand *uWrite : usesWrite) {
779
- if (hasPrecedingAliasingNonWritableTensor (uWrite->get (), &operand, state)) {
780
- LLVM_DEBUG (llvm::dbgs () << " => NOT WRITABLE\n " );
781
- return true ;
745
+ if (!foundWrite)
746
+ return false ;
747
+
748
+ // Look for a read-only tensor among all aliases.
749
+ bool foundReadOnly = false ;
750
+ auto checkReadOnly = [&](Value v) {
751
+ if (!state.isWritable (v)) {
752
+ foundReadOnly = true ;
753
+ if (state.getOptions ().printConflicts )
754
+ annotateNonWritableTensor (v);
782
755
}
756
+ };
757
+ state.applyOnAliases (operand.get (), checkReadOnly);
758
+ for (OpResult result : state.getAliasingOpResults (operand))
759
+ state.applyOnAliases (result, checkReadOnly);
760
+ if (foundReadOnly) {
761
+ LLVM_DEBUG (llvm::dbgs () << " => NOT WRITABLE\n " );
762
+ return true ;
783
763
}
784
764
785
765
return false ;
0 commit comments