Skip to content

[DSE] Fix non-determinism due to address reuse #84943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,9 @@ struct DSEState {

/// Delete dead memory defs and recursively add their operands to ToRemove if
/// they became dead.
void deleteDeadInstruction(Instruction *SI) {
void
deleteDeadInstruction(Instruction *SI,
SmallPtrSetImpl<MemoryAccess *> *Deleted = nullptr) {
MemorySSAUpdater Updater(&MSSA);
SmallVector<Instruction *, 32> NowDeadInsts;
NowDeadInsts.push_back(SI);
Expand All @@ -1720,6 +1722,8 @@ struct DSEState {
if (IsMemDef) {
auto *MD = cast<MemoryDef>(MA);
SkipStores.insert(MD);
if (Deleted)
Deleted->insert(MD);
if (auto *SI = dyn_cast<StoreInst>(MD->getMemoryInst())) {
if (SI->getValueOperand()->getType()->isPointerTy()) {
const Value *UO = getUnderlyingObject(SI->getValueOperand());
Expand Down Expand Up @@ -2168,14 +2172,19 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
unsigned PartialLimit = MemorySSAPartialStoreLimit;
// Worklist of MemoryAccesses that may be killed by KillingDef.
SmallSetVector<MemoryAccess *, 8> ToCheck;
// Track MemoryAccesses that have been deleted in the loop below, so we can
// skip them. Don't use SkipStores for this, which may contain reused
// MemoryAccess addresses.
SmallPtrSet<MemoryAccess *, 8> Deleted;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here explaining why this extra set is needed here?

[[maybe_unused]] unsigned OrigNumSkipStores = State.SkipStores.size();
ToCheck.insert(KillingDef->getDefiningAccess());

bool Shortend = false;
bool IsMemTerm = State.isMemTerminatorInst(KillingI);
// Check if MemoryAccesses in the worklist are killed by KillingDef.
for (unsigned I = 0; I < ToCheck.size(); I++) {
MemoryAccess *Current = ToCheck[I];
if (State.SkipStores.count(Current))
if (Deleted.contains(Current))
continue;

std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef(
Expand Down Expand Up @@ -2222,7 +2231,7 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
continue;
LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI
<< "\n KILLER: " << *KillingI << '\n');
State.deleteDeadInstruction(DeadI);
State.deleteDeadInstruction(DeadI, &Deleted);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to check if all uses in the loop pass Deleted, e.g. add an assert that checks if we add as many elements to SkipStore as we do to Deleted after the loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an assertion -- does that match what you had in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

++NumFastStores;
MadeChange = true;
} else {
Expand Down Expand Up @@ -2259,7 +2268,7 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
Shortend = true;
// Remove killing store and remove any outstanding overlap
// intervals for the updated store.
State.deleteDeadInstruction(KillingSI);
State.deleteDeadInstruction(KillingSI, &Deleted);
auto I = State.IOLs.find(DeadSI->getParent());
if (I != State.IOLs.end())
I->second.erase(DeadSI);
Expand All @@ -2271,13 +2280,16 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
if (OR == OW_Complete) {
LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI
<< "\n KILLER: " << *KillingI << '\n');
State.deleteDeadInstruction(DeadI);
State.deleteDeadInstruction(DeadI, &Deleted);
++NumFastStores;
MadeChange = true;
}
}
}

assert(State.SkipStores.size() - OrigNumSkipStores == Deleted.size() &&
"SkipStores and Deleted out of sync?");

// Check if the store is a no-op.
if (!Shortend && State.storeIsNoop(KillingDef, KillingUndObj)) {
LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " << *KillingI
Expand Down